style: apply cargo fmt across entire Rust codebase

- Remove Rust formatting check from CI workflow since formatting is now applied
- Standardize import ordering and organization throughout codebase
- Fix indentation, spacing, and line breaks for consistency
- Clean up trailing whitespace and formatting inconsistencies
- Apply rustfmt to all Rust source files including checkpoint, sandbox, commands, and test modules

This establishes a consistent code style baseline for the project.
This commit is contained in:
Mufeed VH
2025-06-25 03:45:59 +05:30
parent bb48a32784
commit bcffce0a08
41 changed files with 3617 additions and 2662 deletions

View File

@@ -100,14 +100,6 @@ jobs:
- name: Build frontend - name: Build frontend
run: bun run build run: bun run build
# Check Rust formatting
- name: Check Rust formatting
if: matrix.platform.os == 'ubuntu-latest'
working-directory: ./src-tauri
run: |
rustup component add rustfmt
cargo fmt -- --check
# Run Rust linter # Run Rust linter
- name: Run Rust linter - name: Run Rust linter
if: matrix.platform.os == 'ubuntu-latest' if: matrix.platform.os == 'ubuntu-latest'

View File

@@ -1,16 +1,16 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use chrono::{DateTime, TimeZone, Utc};
use log;
use std::collections::HashMap; use std::collections::HashMap;
use std::fs; use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use chrono::{Utc, TimeZone, DateTime};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use log;
use super::{ use super::{
Checkpoint, CheckpointMetadata, FileSnapshot, FileTracker, FileState, storage::{self, CheckpointStorage},
CheckpointResult, SessionTimeline, CheckpointStrategy, CheckpointPaths, Checkpoint, CheckpointMetadata, CheckpointPaths, CheckpointResult, CheckpointStrategy,
storage::{CheckpointStorage, self}, FileSnapshot, FileState, FileTracker, SessionTimeline,
}; };
/// Manages checkpoint operations for a session /// Manages checkpoint operations for a session
@@ -111,34 +111,38 @@ impl CheckpointManager {
// Read current file state // Read current file state
let (hash, exists, _size, modified) = if full_path.exists() { let (hash, exists, _size, modified) = if full_path.exists() {
let content = fs::read_to_string(&full_path) let content = fs::read_to_string(&full_path).unwrap_or_default();
.unwrap_or_default();
let metadata = fs::metadata(&full_path)?; let metadata = fs::metadata(&full_path)?;
let modified = metadata.modified() let modified = metadata
.modified()
.ok() .ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok()) .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
.map(|d| Utc.timestamp_opt(d.as_secs() as i64, d.subsec_nanos()).unwrap()) .map(|d| {
Utc.timestamp_opt(d.as_secs() as i64, d.subsec_nanos())
.unwrap()
})
.unwrap_or_else(Utc::now); .unwrap_or_else(Utc::now);
( (
storage::CheckpointStorage::calculate_file_hash(&content), storage::CheckpointStorage::calculate_file_hash(&content),
true, true,
metadata.len(), metadata.len(),
modified modified,
) )
} else { } else {
(String::new(), false, 0, Utc::now()) (String::new(), false, 0, Utc::now())
}; };
// Check if file has actually changed // Check if file has actually changed
let is_modified = if let Some(existing_state) = tracker.tracked_files.get(&PathBuf::from(file_path)) { let is_modified =
if let Some(existing_state) = tracker.tracked_files.get(&PathBuf::from(file_path)) {
// File is modified if: // File is modified if:
// 1. Hash has changed // 1. Hash has changed
// 2. Existence state has changed // 2. Existence state has changed
// 3. It was already marked as modified // 3. It was already marked as modified
existing_state.last_hash != hash || existing_state.last_hash != hash
existing_state.exists != exists || || existing_state.exists != exists
existing_state.is_modified || existing_state.is_modified
} else { } else {
// New file is always considered modified // New file is always considered modified
true true
@@ -161,8 +165,8 @@ impl CheckpointManager {
async fn track_bash_side_effects(&self, command: &str) -> Result<()> { async fn track_bash_side_effects(&self, command: &str) -> Result<()> {
// Common file-modifying commands // Common file-modifying commands
let file_commands = [ let file_commands = [
"echo", "cat", "cp", "mv", "rm", "touch", "sed", "awk", "echo", "cat", "cp", "mv", "rm", "touch", "sed", "awk", "npm", "yarn", "pnpm", "bun",
"npm", "yarn", "pnpm", "bun", "cargo", "make", "gcc", "g++", "cargo", "make", "gcc", "g++",
]; ];
// Simple heuristic: if command contains file-modifying operations // Simple heuristic: if command contains file-modifying operations
@@ -190,11 +194,16 @@ impl CheckpointManager {
let message_index = messages.len().saturating_sub(1); let message_index = messages.len().saturating_sub(1);
// Extract metadata from the last user message // Extract metadata from the last user message
let (user_prompt, model_used, total_tokens) = self.extract_checkpoint_metadata(&messages).await?; let (user_prompt, model_used, total_tokens) =
self.extract_checkpoint_metadata(&messages).await?;
// Ensure every file in the project is tracked so new checkpoints include all files // Ensure every file in the project is tracked so new checkpoints include all files
// Recursively walk the project directory and track each file // Recursively walk the project directory and track each file
fn collect_files(dir: &std::path::Path, base: &std::path::Path, files: &mut Vec<std::path::PathBuf>) -> Result<(), std::io::Error> { fn collect_files(
dir: &std::path::Path,
base: &std::path::Path,
files: &mut Vec<std::path::PathBuf>,
) -> Result<(), std::io::Error> {
for entry in std::fs::read_dir(dir)? { for entry in std::fs::read_dir(dir)? {
let entry = entry?; let entry = entry?;
let path = entry.path(); let path = entry.path();
@@ -306,7 +315,8 @@ impl CheckpointManager {
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(msg_str) { if let Ok(msg) = serde_json::from_str::<serde_json::Value>(msg_str) {
// Check for user message // Check for user message
if msg.get("type").and_then(|t| t.as_str()) == Some("user") { if msg.get("type").and_then(|t| t.as_str()) == Some("user") {
if let Some(content) = msg.get("message") if let Some(content) = msg
.get("message")
.and_then(|m| m.get("content")) .and_then(|m| m.get("content"))
.and_then(|c| c.as_array()) .and_then(|c| c.as_array())
{ {
@@ -344,10 +354,16 @@ impl CheckpointManager {
total_tokens += output; total_tokens += output;
} }
// Also count cache tokens // Also count cache tokens
if let Some(cache_creation) = usage.get("cache_creation_input_tokens").and_then(|t| t.as_u64()) { if let Some(cache_creation) = usage
.get("cache_creation_input_tokens")
.and_then(|t| t.as_u64())
{
total_tokens += cache_creation; total_tokens += cache_creation;
} }
if let Some(cache_read) = usage.get("cache_read_input_tokens").and_then(|t| t.as_u64()) { if let Some(cache_read) = usage
.get("cache_read_input_tokens")
.and_then(|t| t.as_u64())
{
total_tokens += cache_read; total_tokens += cache_read;
} }
} }
@@ -362,10 +378,16 @@ impl CheckpointManager {
total_tokens += output; total_tokens += output;
} }
// Also count cache tokens // Also count cache tokens
if let Some(cache_creation) = usage.get("cache_creation_input_tokens").and_then(|t| t.as_u64()) { if let Some(cache_creation) = usage
.get("cache_creation_input_tokens")
.and_then(|t| t.as_u64())
{
total_tokens += cache_creation; total_tokens += cache_creation;
} }
if let Some(cache_read) = usage.get("cache_read_input_tokens").and_then(|t| t.as_u64()) { if let Some(cache_read) = usage
.get("cache_read_input_tokens")
.and_then(|t| t.as_u64())
{
total_tokens += cache_read; total_tokens += cache_read;
} }
} }
@@ -389,8 +411,7 @@ impl CheckpointManager {
let full_path = self.project_path.join(rel_path); let full_path = self.project_path.join(rel_path);
let (content, exists, permissions, size, current_hash) = if full_path.exists() { let (content, exists, permissions, size, current_hash) = if full_path.exists() {
let content = fs::read_to_string(&full_path) let content = fs::read_to_string(&full_path).unwrap_or_default();
.unwrap_or_default();
let current_hash = storage::CheckpointStorage::calculate_file_hash(&content); let current_hash = storage::CheckpointStorage::calculate_file_hash(&content);
// Don't skip based on hash - if is_modified is true, we should snapshot it // Don't skip based on hash - if is_modified is true, we should snapshot it
@@ -430,14 +451,16 @@ impl CheckpointManager {
/// Restore a checkpoint /// Restore a checkpoint
pub async fn restore_checkpoint(&self, checkpoint_id: &str) -> Result<CheckpointResult> { pub async fn restore_checkpoint(&self, checkpoint_id: &str) -> Result<CheckpointResult> {
// Load checkpoint data // Load checkpoint data
let (checkpoint, file_snapshots, messages) = self.storage.load_checkpoint( let (checkpoint, file_snapshots, messages) =
&self.project_id, self.storage
&self.session_id, .load_checkpoint(&self.project_id, &self.session_id, checkpoint_id)?;
checkpoint_id,
)?;
// First, collect all files currently in the project to handle deletions // First, collect all files currently in the project to handle deletions
fn collect_all_project_files(dir: &std::path::Path, base: &std::path::Path, files: &mut Vec<std::path::PathBuf>) -> Result<(), std::io::Error> { fn collect_all_project_files(
dir: &std::path::Path,
base: &std::path::Path,
files: &mut Vec<std::path::PathBuf>,
) -> Result<(), std::io::Error> {
for entry in std::fs::read_dir(dir)? { for entry in std::fs::read_dir(dir)? {
let entry = entry?; let entry = entry?;
let path = entry.path(); let path = entry.path();
@@ -460,7 +483,8 @@ impl CheckpointManager {
} }
let mut current_files = Vec::new(); let mut current_files = Vec::new();
let _ = collect_all_project_files(&self.project_path, &self.project_path, &mut current_files); let _ =
collect_all_project_files(&self.project_path, &self.project_path, &mut current_files);
// Create a set of files that should exist after restore // Create a set of files that should exist after restore
let mut checkpoint_files = std::collections::HashSet::new(); let mut checkpoint_files = std::collections::HashSet::new();
@@ -484,14 +508,21 @@ impl CheckpointManager {
log::info!("Deleted file not in checkpoint: {:?}", current_file); log::info!("Deleted file not in checkpoint: {:?}", current_file);
} }
Err(e) => { Err(e) => {
warnings.push(format!("Failed to delete {}: {}", current_file.display(), e)); warnings.push(format!(
"Failed to delete {}: {}",
current_file.display(),
e
));
} }
} }
} }
} }
// Clean up empty directories // Clean up empty directories
fn remove_empty_dirs(dir: &std::path::Path, base: &std::path::Path) -> Result<bool, std::io::Error> { fn remove_empty_dirs(
dir: &std::path::Path,
base: &std::path::Path,
) -> Result<bool, std::io::Error> {
if dir == base { if dir == base {
return Ok(false); // Don't remove the base directory return Ok(false); // Don't remove the base directory
} }
@@ -524,8 +555,11 @@ impl CheckpointManager {
for snapshot in &file_snapshots { for snapshot in &file_snapshots {
match self.restore_file_snapshot(snapshot).await { match self.restore_file_snapshot(snapshot).await {
Ok(_) => files_processed += 1, Ok(_) => files_processed += 1,
Err(e) => warnings.push(format!("Failed to restore {}: {}", Err(e) => warnings.push(format!(
snapshot.file_path.display(), e)), "Failed to restore {}: {}",
snapshot.file_path.display(),
e
)),
} }
} }
@@ -571,19 +605,16 @@ impl CheckpointManager {
if snapshot.is_deleted { if snapshot.is_deleted {
// Delete the file if it exists // Delete the file if it exists
if full_path.exists() { if full_path.exists() {
fs::remove_file(&full_path) fs::remove_file(&full_path).context("Failed to delete file")?;
.context("Failed to delete file")?;
} }
} else { } else {
// Create parent directories if needed // Create parent directories if needed
if let Some(parent) = full_path.parent() { if let Some(parent) = full_path.parent() {
fs::create_dir_all(parent) fs::create_dir_all(parent).context("Failed to create parent directories")?;
.context("Failed to create parent directories")?;
} }
// Write file content // Write file content
fs::write(&full_path, &snapshot.content) fs::write(&full_path, &snapshot.content).context("Failed to write file")?;
.context("Failed to write file")?;
// Restore permissions if available // Restore permissions if available
#[cfg(unix)] #[cfg(unix)]
@@ -616,7 +647,10 @@ impl CheckpointManager {
} }
/// Recursively collect checkpoints from timeline tree /// Recursively collect checkpoints from timeline tree
fn collect_checkpoints_from_node(node: &super::TimelineNode, checkpoints: &mut Vec<Checkpoint>) { fn collect_checkpoints_from_node(
node: &super::TimelineNode,
checkpoints: &mut Vec<Checkpoint>,
) {
checkpoints.push(node.checkpoint.clone()); checkpoints.push(node.checkpoint.clone());
for child in &node.children { for child in &node.children {
Self::collect_checkpoints_from_node(child, checkpoints); Self::collect_checkpoints_from_node(child, checkpoints);
@@ -630,21 +664,19 @@ impl CheckpointManager {
description: Option<String>, description: Option<String>,
) -> Result<CheckpointResult> { ) -> Result<CheckpointResult> {
// Load the checkpoint to fork from // Load the checkpoint to fork from
let (_base_checkpoint, _, _) = self.storage.load_checkpoint( let (_base_checkpoint, _, _) =
&self.project_id, self.storage
&self.session_id, .load_checkpoint(&self.project_id, &self.session_id, checkpoint_id)?;
checkpoint_id,
)?;
// Restore to that checkpoint first // Restore to that checkpoint first
self.restore_checkpoint(checkpoint_id).await?; self.restore_checkpoint(checkpoint_id).await?;
// Create a new checkpoint with the fork // Create a new checkpoint with the fork
let fork_description = description.unwrap_or_else(|| { let fork_description =
format!("Fork from checkpoint {}", &checkpoint_id[..8]) description.unwrap_or_else(|| format!("Fork from checkpoint {}", &checkpoint_id[..8]));
});
self.create_checkpoint(Some(fork_description), Some(checkpoint_id.to_string())).await self.create_checkpoint(Some(fork_description), Some(checkpoint_id.to_string()))
.await
} }
/// Check if auto-checkpoint should be triggered /// Check if auto-checkpoint should be triggered
@@ -668,7 +700,11 @@ impl CheckpointManager {
CheckpointStrategy::PerToolUse => { CheckpointStrategy::PerToolUse => {
// Check if message contains tool use // Check if message contains tool use
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(message) { if let Ok(msg) = serde_json::from_str::<serde_json::Value>(message) {
if let Some(content) = msg.get("message").and_then(|m| m.get("content")).and_then(|c| c.as_array()) { if let Some(content) = msg
.get("message")
.and_then(|m| m.get("content"))
.and_then(|c| c.as_array())
{
content.iter().any(|item| { content.iter().any(|item| {
item.get("type").and_then(|t| t.as_str()) == Some("tool_use") item.get("type").and_then(|t| t.as_str()) == Some("tool_use")
}) })
@@ -682,12 +718,19 @@ impl CheckpointManager {
CheckpointStrategy::Smart => { CheckpointStrategy::Smart => {
// Smart strategy: checkpoint after destructive operations // Smart strategy: checkpoint after destructive operations
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(message) { if let Ok(msg) = serde_json::from_str::<serde_json::Value>(message) {
if let Some(content) = msg.get("message").and_then(|m| m.get("content")).and_then(|c| c.as_array()) { if let Some(content) = msg
.get("message")
.and_then(|m| m.get("content"))
.and_then(|c| c.as_array())
{
content.iter().any(|item| { content.iter().any(|item| {
if item.get("type").and_then(|t| t.as_str()) == Some("tool_use") { if item.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
let tool_name = item.get("name").and_then(|n| n.as_str()).unwrap_or(""); let tool_name =
matches!(tool_name.to_lowercase().as_str(), item.get("name").and_then(|n| n.as_str()).unwrap_or("");
"write" | "edit" | "multiedit" | "bash" | "rm" | "delete") matches!(
tool_name.to_lowercase().as_str(),
"write" | "edit" | "multiedit" | "bash" | "rm" | "delete"
)
} else { } else {
false false
} }
@@ -715,7 +758,8 @@ impl CheckpointManager {
// Save updated timeline // Save updated timeline
let claude_dir = self.storage.claude_dir.clone(); let claude_dir = self.storage.claude_dir.clone();
let paths = CheckpointPaths::new(&claude_dir, &self.project_id, &self.session_id); let paths = CheckpointPaths::new(&claude_dir, &self.project_id, &self.session_id);
self.storage.save_timeline(&paths.timeline_file, &timeline)?; self.storage
.save_timeline(&paths.timeline_file, &timeline)?;
Ok(()) Ok(())
} }
@@ -723,7 +767,8 @@ impl CheckpointManager {
/// Get files modified since a given timestamp /// Get files modified since a given timestamp
pub async fn get_files_modified_since(&self, since: DateTime<Utc>) -> Vec<PathBuf> { pub async fn get_files_modified_since(&self, since: DateTime<Utc>) -> Vec<PathBuf> {
let tracker = self.file_tracker.read().await; let tracker = self.file_tracker.read().await;
tracker.tracked_files tracker
.tracked_files
.iter() .iter()
.filter(|(_, state)| state.last_modified > since && state.is_modified) .filter(|(_, state)| state.last_modified > since && state.is_modified)
.map(|(path, _)| path.clone()) .map(|(path, _)| path.clone())
@@ -733,7 +778,8 @@ impl CheckpointManager {
/// Get the last modification time of any tracked file /// Get the last modification time of any tracked file
pub async fn get_last_modification_time(&self) -> Option<DateTime<Utc>> { pub async fn get_last_modification_time(&self) -> Option<DateTime<Utc>> {
let tracker = self.file_tracker.read().await; let tracker = self.file_tracker.read().await;
tracker.tracked_files tracker
.tracked_files
.values() .values()
.map(|state| state.last_modified) .map(|state| state.last_modified)
.max() .max()

View File

@@ -1,11 +1,11 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::path::PathBuf; use std::path::PathBuf;
use chrono::{DateTime, Utc};
pub mod manager; pub mod manager;
pub mod storage;
pub mod state; pub mod state;
pub mod storage;
/// Represents a checkpoint in the session timeline /// Represents a checkpoint in the session timeline
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -191,7 +191,8 @@ impl SessionTimeline {
/// Find a checkpoint by ID in the timeline tree /// Find a checkpoint by ID in the timeline tree
pub fn find_checkpoint(&self, checkpoint_id: &str) -> Option<&TimelineNode> { pub fn find_checkpoint(&self, checkpoint_id: &str) -> Option<&TimelineNode> {
self.root_node.as_ref() self.root_node
.as_ref()
.and_then(|root| Self::find_in_tree(root, checkpoint_id)) .and_then(|root| Self::find_in_tree(root, checkpoint_id))
} }
@@ -253,6 +254,9 @@ impl CheckpointPaths {
#[allow(dead_code)] #[allow(dead_code)]
pub fn file_reference_path(&self, checkpoint_id: &str, safe_filename: &str) -> PathBuf { pub fn file_reference_path(&self, checkpoint_id: &str, safe_filename: &str) -> PathBuf {
// References are stored per checkpoint // References are stored per checkpoint
self.files_dir.join("refs").join(checkpoint_id).join(format!("{}.json", safe_filename)) self.files_dir
.join("refs")
.join(checkpoint_id)
.join(format!("{}.json", safe_filename))
} }
} }

View File

@@ -1,8 +1,8 @@
use anyhow::Result;
use std::collections::HashMap; use std::collections::HashMap;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use anyhow::Result;
use super::manager::CheckpointManager; use super::manager::CheckpointManager;
@@ -71,12 +71,9 @@ impl CheckpointState {
}; };
// Create new manager // Create new manager
let manager = CheckpointManager::new( let manager =
project_id, CheckpointManager::new(project_id, session_id.clone(), project_path, claude_dir)
session_id.clone(), .await?;
project_path,
claude_dir,
).await?;
let manager_arc = Arc::new(manager); let manager_arc = Arc::new(manager);
managers.insert(session_id, Arc::clone(&manager_arc)); managers.insert(session_id, Arc::clone(&manager_arc));
@@ -157,18 +154,16 @@ mod tests {
let project_path = temp_dir.path().join("project"); let project_path = temp_dir.path().join("project");
std::fs::create_dir_all(&project_path).unwrap(); std::fs::create_dir_all(&project_path).unwrap();
let manager1 = state.get_or_create_manager( let manager1 = state
session_id.clone(), .get_or_create_manager(session_id.clone(), project_id.clone(), project_path.clone())
project_id.clone(), .await
project_path.clone(), .unwrap();
).await.unwrap();
// Getting the same session should return the same manager // Getting the same session should return the same manager
let manager2 = state.get_or_create_manager( let manager2 = state
session_id.clone(), .get_or_create_manager(session_id.clone(), project_id.clone(), project_path.clone())
project_id.clone(), .await
project_path.clone(), .unwrap();
).await.unwrap();
assert!(Arc::ptr_eq(&manager1, &manager2)); assert!(Arc::ptr_eq(&manager1, &manager2));
assert_eq!(state.active_count().await, 1); assert_eq!(state.active_count().await, 1);
@@ -179,11 +174,10 @@ mod tests {
assert_eq!(state.active_count().await, 0); assert_eq!(state.active_count().await, 0);
// Getting after removal should create a new one // Getting after removal should create a new one
let manager3 = state.get_or_create_manager( let manager3 = state
session_id.clone(), .get_or_create_manager(session_id.clone(), project_id, project_path)
project_id, .await
project_path, .unwrap();
).await.unwrap();
assert!(!Arc::ptr_eq(&manager1, &manager3)); assert!(!Arc::ptr_eq(&manager1, &manager3));
} }

View File

@@ -1,13 +1,12 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use sha2::{Digest, Sha256};
use std::fs; use std::fs;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use sha2::{Sha256, Digest};
use zstd::stream::{encode_all, decode_all};
use uuid::Uuid; use uuid::Uuid;
use zstd::stream::{decode_all, encode_all};
use super::{ use super::{
Checkpoint, FileSnapshot, SessionTimeline, Checkpoint, CheckpointPaths, CheckpointResult, FileSnapshot, SessionTimeline, TimelineNode,
TimelineNode, CheckpointPaths, CheckpointResult
}; };
/// Manages checkpoint storage operations /// Manages checkpoint storage operations
@@ -32,8 +31,7 @@ impl CheckpointStorage {
// Create directory structure // Create directory structure
fs::create_dir_all(&paths.checkpoints_dir) fs::create_dir_all(&paths.checkpoints_dir)
.context("Failed to create checkpoints directory")?; .context("Failed to create checkpoints directory")?;
fs::create_dir_all(&paths.files_dir) fs::create_dir_all(&paths.files_dir).context("Failed to create files directory")?;
.context("Failed to create files directory")?;
// Initialize empty timeline if it doesn't exist // Initialize empty timeline if it doesn't exist
if !paths.timeline_file.exists() { if !paths.timeline_file.exists() {
@@ -57,15 +55,13 @@ impl CheckpointStorage {
let checkpoint_dir = paths.checkpoint_dir(&checkpoint.id); let checkpoint_dir = paths.checkpoint_dir(&checkpoint.id);
// Create checkpoint directory // Create checkpoint directory
fs::create_dir_all(&checkpoint_dir) fs::create_dir_all(&checkpoint_dir).context("Failed to create checkpoint directory")?;
.context("Failed to create checkpoint directory")?;
// Save checkpoint metadata // Save checkpoint metadata
let metadata_path = paths.checkpoint_metadata_file(&checkpoint.id); let metadata_path = paths.checkpoint_metadata_file(&checkpoint.id);
let metadata_json = serde_json::to_string_pretty(checkpoint) let metadata_json = serde_json::to_string_pretty(checkpoint)
.context("Failed to serialize checkpoint metadata")?; .context("Failed to serialize checkpoint metadata")?;
fs::write(&metadata_path, metadata_json) fs::write(&metadata_path, metadata_json).context("Failed to write checkpoint metadata")?;
.context("Failed to write checkpoint metadata")?;
// Save messages (compressed) // Save messages (compressed)
let messages_path = paths.checkpoint_messages_file(&checkpoint.id); let messages_path = paths.checkpoint_messages_file(&checkpoint.id);
@@ -81,17 +77,16 @@ impl CheckpointStorage {
for snapshot in &file_snapshots { for snapshot in &file_snapshots {
match self.save_file_snapshot(&paths, snapshot) { match self.save_file_snapshot(&paths, snapshot) {
Ok(_) => files_processed += 1, Ok(_) => files_processed += 1,
Err(e) => warnings.push(format!("Failed to save {}: {}", Err(e) => warnings.push(format!(
snapshot.file_path.display(), e)), "Failed to save {}: {}",
snapshot.file_path.display(),
e
)),
} }
} }
// Update timeline // Update timeline
self.update_timeline_with_checkpoint( self.update_timeline_with_checkpoint(&paths.timeline_file, checkpoint, &file_snapshots)?;
&paths.timeline_file,
checkpoint,
&file_snapshots
)?;
Ok(CheckpointResult { Ok(CheckpointResult {
checkpoint: checkpoint.clone(), checkpoint: checkpoint.clone(),
@@ -105,8 +100,7 @@ impl CheckpointStorage {
// Use content-addressable storage: store files by their hash // Use content-addressable storage: store files by their hash
// This prevents duplication of identical file content across checkpoints // This prevents duplication of identical file content across checkpoints
let content_pool_dir = paths.files_dir.join("content_pool"); let content_pool_dir = paths.files_dir.join("content_pool");
fs::create_dir_all(&content_pool_dir) fs::create_dir_all(&content_pool_dir).context("Failed to create content pool directory")?;
.context("Failed to create content pool directory")?;
// Store the actual content in the content pool // Store the actual content in the content pool
let content_file = content_pool_dir.join(&snapshot.hash); let content_file = content_pool_dir.join(&snapshot.hash);
@@ -114,7 +108,8 @@ impl CheckpointStorage {
// Only write the content if it doesn't already exist // Only write the content if it doesn't already exist
if !content_file.exists() { if !content_file.exists() {
// Compress and save file content // Compress and save file content
let compressed_content = encode_all(snapshot.content.as_bytes(), self.compression_level) let compressed_content =
encode_all(snapshot.content.as_bytes(), self.compression_level)
.context("Failed to compress file content")?; .context("Failed to compress file content")?;
fs::write(&content_file, compressed_content) fs::write(&content_file, compressed_content)
.context("Failed to write file content to pool")?; .context("Failed to write file content to pool")?;
@@ -135,7 +130,8 @@ impl CheckpointStorage {
}); });
// Use a sanitized filename for the reference // Use a sanitized filename for the reference
let safe_filename = snapshot.file_path let safe_filename = snapshot
.file_path
.to_string_lossy() .to_string_lossy()
.replace('/', "_") .replace('/', "_")
.replace('\\', "_"); .replace('\\', "_");
@@ -158,17 +154,18 @@ impl CheckpointStorage {
// Load checkpoint metadata // Load checkpoint metadata
let metadata_path = paths.checkpoint_metadata_file(checkpoint_id); let metadata_path = paths.checkpoint_metadata_file(checkpoint_id);
let metadata_json = fs::read_to_string(&metadata_path) let metadata_json =
.context("Failed to read checkpoint metadata")?; fs::read_to_string(&metadata_path).context("Failed to read checkpoint metadata")?;
let checkpoint: Checkpoint = serde_json::from_str(&metadata_json) let checkpoint: Checkpoint =
.context("Failed to parse checkpoint metadata")?; serde_json::from_str(&metadata_json).context("Failed to parse checkpoint metadata")?;
// Load messages // Load messages
let messages_path = paths.checkpoint_messages_file(checkpoint_id); let messages_path = paths.checkpoint_messages_file(checkpoint_id);
let compressed_messages = fs::read(&messages_path) let compressed_messages =
.context("Failed to read compressed messages")?; fs::read(&messages_path).context("Failed to read compressed messages")?;
let messages = String::from_utf8(decode_all(&compressed_messages[..]) let messages = String::from_utf8(
.context("Failed to decompress messages")?) decode_all(&compressed_messages[..]).context("Failed to decompress messages")?,
)
.context("Invalid UTF-8 in messages")?; .context("Invalid UTF-8 in messages")?;
// Load file snapshots // Load file snapshots
@@ -181,7 +178,7 @@ impl CheckpointStorage {
fn load_file_snapshots( fn load_file_snapshots(
&self, &self,
paths: &CheckpointPaths, paths: &CheckpointPaths,
checkpoint_id: &str checkpoint_id: &str,
) -> Result<Vec<FileSnapshot>> { ) -> Result<Vec<FileSnapshot>> {
let refs_dir = paths.files_dir.join("refs").join(checkpoint_id); let refs_dir = paths.files_dir.join("refs").join(checkpoint_id);
if !refs_dir.exists() { if !refs_dir.exists() {
@@ -202,21 +199,23 @@ impl CheckpointStorage {
} }
// Load reference metadata // Load reference metadata
let ref_json = fs::read_to_string(&path) let ref_json = fs::read_to_string(&path).context("Failed to read file reference")?;
.context("Failed to read file reference")?; let ref_metadata: serde_json::Value =
let ref_metadata: serde_json::Value = serde_json::from_str(&ref_json) serde_json::from_str(&ref_json).context("Failed to parse file reference")?;
.context("Failed to parse file reference")?;
let hash = ref_metadata["hash"].as_str() let hash = ref_metadata["hash"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing hash in reference"))?; .ok_or_else(|| anyhow::anyhow!("Missing hash in reference"))?;
// Load content from pool // Load content from pool
let content_file = content_pool_dir.join(hash); let content_file = content_pool_dir.join(hash);
let content = if content_file.exists() { let content = if content_file.exists() {
let compressed_content = fs::read(&content_file) let compressed_content =
.context("Failed to read file content from pool")?; fs::read(&content_file).context("Failed to read file content from pool")?;
String::from_utf8(decode_all(&compressed_content[..]) String::from_utf8(
.context("Failed to decompress file content")?) decode_all(&compressed_content[..])
.context("Failed to decompress file content")?,
)
.context("Invalid UTF-8 in file content")? .context("Invalid UTF-8 in file content")?
} else { } else {
// Handle missing content gracefully // Handle missing content gracefully
@@ -240,19 +239,17 @@ impl CheckpointStorage {
/// Save timeline to disk /// Save timeline to disk
pub fn save_timeline(&self, timeline_path: &Path, timeline: &SessionTimeline) -> Result<()> { pub fn save_timeline(&self, timeline_path: &Path, timeline: &SessionTimeline) -> Result<()> {
let timeline_json = serde_json::to_string_pretty(timeline) let timeline_json =
.context("Failed to serialize timeline")?; serde_json::to_string_pretty(timeline).context("Failed to serialize timeline")?;
fs::write(timeline_path, timeline_json) fs::write(timeline_path, timeline_json).context("Failed to write timeline")?;
.context("Failed to write timeline")?;
Ok(()) Ok(())
} }
/// Load timeline from disk /// Load timeline from disk
pub fn load_timeline(&self, timeline_path: &Path) -> Result<SessionTimeline> { pub fn load_timeline(&self, timeline_path: &Path) -> Result<SessionTimeline> {
let timeline_json = fs::read_to_string(timeline_path) let timeline_json = fs::read_to_string(timeline_path).context("Failed to read timeline")?;
.context("Failed to read timeline")?; let timeline: SessionTimeline =
let timeline: SessionTimeline = serde_json::from_str(&timeline_json) serde_json::from_str(&timeline_json).context("Failed to parse timeline")?;
.context("Failed to parse timeline")?;
Ok(timeline) Ok(timeline)
} }
@@ -268,9 +265,7 @@ impl CheckpointStorage {
let new_node = TimelineNode { let new_node = TimelineNode {
checkpoint: checkpoint.clone(), checkpoint: checkpoint.clone(),
children: Vec::new(), children: Vec::new(),
file_snapshot_ids: file_snapshots.iter() file_snapshot_ids: file_snapshots.iter().map(|s| s.hash.clone()).collect(),
.map(|s| s.hash.clone())
.collect(),
}; };
// If this is the first checkpoint // If this is the first checkpoint
@@ -301,7 +296,7 @@ impl CheckpointStorage {
fn add_child_to_node( fn add_child_to_node(
node: &mut TimelineNode, node: &mut TimelineNode,
parent_id: &str, parent_id: &str,
child: TimelineNode child: TimelineNode,
) -> Result<()> { ) -> Result<()> {
if node.checkpoint.id == parent_id { if node.checkpoint.id == parent_id {
node.children.push(child); node.children.push(child);
@@ -330,14 +325,9 @@ impl CheckpointStorage {
} }
/// Estimate storage size for a checkpoint /// Estimate storage size for a checkpoint
pub fn estimate_checkpoint_size( pub fn estimate_checkpoint_size(messages: &str, file_snapshots: &[FileSnapshot]) -> u64 {
messages: &str,
file_snapshots: &[FileSnapshot],
) -> u64 {
let messages_size = messages.len() as u64; let messages_size = messages.len() as u64;
let files_size: u64 = file_snapshots.iter() let files_size: u64 = file_snapshots.iter().map(|s| s.content.len() as u64).sum();
.map(|s| s.content.len() as u64)
.sum();
// Estimate compressed size (typically 20-30% of original for text) // Estimate compressed size (typically 20-30% of original for text)
(messages_size + files_size) / 4 (messages_size + files_size) / 4
@@ -400,15 +390,13 @@ impl CheckpointStorage {
// Remove checkpoint metadata directory // Remove checkpoint metadata directory
let checkpoint_dir = paths.checkpoint_dir(checkpoint_id); let checkpoint_dir = paths.checkpoint_dir(checkpoint_id);
if checkpoint_dir.exists() { if checkpoint_dir.exists() {
fs::remove_dir_all(&checkpoint_dir) fs::remove_dir_all(&checkpoint_dir).context("Failed to remove checkpoint directory")?;
.context("Failed to remove checkpoint directory")?;
} }
// Remove file references for this checkpoint // Remove file references for this checkpoint
let refs_dir = paths.files_dir.join("refs").join(checkpoint_id); let refs_dir = paths.files_dir.join("refs").join(checkpoint_id);
if refs_dir.exists() { if refs_dir.exists() {
fs::remove_dir_all(&refs_dir) fs::remove_dir_all(&refs_dir).context("Failed to remove file references")?;
.context("Failed to remove file references")?;
} }
// Note: We don't remove content from the pool here as it might be // Note: We don't remove content from the pool here as it might be
@@ -418,11 +406,7 @@ impl CheckpointStorage {
} }
/// Garbage collect unreferenced content from the content pool /// Garbage collect unreferenced content from the content pool
pub fn garbage_collect_content( pub fn garbage_collect_content(&self, project_id: &str, session_id: &str) -> Result<usize> {
&self,
project_id: &str,
session_id: &str,
) -> Result<usize> {
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id); let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
let content_pool_dir = paths.files_dir.join("content_pool"); let content_pool_dir = paths.files_dir.join("content_pool");
let refs_dir = paths.files_dir.join("refs"); let refs_dir = paths.files_dir.join("refs");
@@ -442,7 +426,9 @@ impl CheckpointStorage {
let ref_path = ref_entry?.path(); let ref_path = ref_entry?.path();
if ref_path.extension().and_then(|e| e.to_str()) == Some("json") { if ref_path.extension().and_then(|e| e.to_str()) == Some("json") {
if let Ok(ref_json) = fs::read_to_string(&ref_path) { if let Ok(ref_json) = fs::read_to_string(&ref_path) {
if let Ok(ref_metadata) = serde_json::from_str::<serde_json::Value>(&ref_json) { if let Ok(ref_metadata) =
serde_json::from_str::<serde_json::Value>(&ref_json)
{
if let Some(hash) = ref_metadata["hash"].as_str() { if let Some(hash) = ref_metadata["hash"].as_str() {
referenced_hashes.insert(hash.to_string()); referenced_hashes.insert(hash.to_string());
} }

View File

@@ -1,12 +1,12 @@
use anyhow::Result;
use log::{debug, error, info, warn};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
/// Shared module for detecting Claude Code binary installations /// Shared module for detecting Claude Code binary installations
/// Supports NVM installations, aliased paths, and version-based selection /// Supports NVM installations, aliased paths, and version-based selection
use std::path::PathBuf; use std::path::PathBuf;
use std::process::Command; use std::process::Command;
use log::{info, warn, debug, error};
use anyhow::Result;
use std::cmp::Ordering;
use tauri::Manager; use tauri::Manager;
use serde::{Serialize, Deserialize};
/// Represents a Claude installation with metadata /// Represents a Claude installation with metadata
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -61,8 +61,10 @@ pub fn find_claude_binary(app_handle: &tauri::AppHandle) -> Result<String, Strin
// Select the best installation (highest version) // Select the best installation (highest version)
if let Some(best) = select_best_installation(installations) { if let Some(best) = select_best_installation(installations) {
info!("Selected Claude installation: path={}, version={:?}, source={}", info!(
best.path, best.version, best.source); "Selected Claude installation: path={}, version={:?}, source={}",
best.path, best.version, best.source
);
Ok(best.path) Ok(best.path)
} else { } else {
Err("No valid Claude installation found".to_string()) Err("No valid Claude installation found".to_string())
@@ -87,12 +89,12 @@ pub fn discover_claude_installations() -> Vec<ClaudeInstallation> {
// If versions are equal, prefer by source // If versions are equal, prefer by source
source_preference(a).cmp(&source_preference(b)) source_preference(a).cmp(&source_preference(b))
} }
other => other other => other,
} }
} }
(Some(_), None) => Ordering::Less, // Version comes before no version (Some(_), None) => Ordering::Less, // Version comes before no version
(None, Some(_)) => Ordering::Greater, (None, Some(_)) => Ordering::Greater,
(None, None) => source_preference(a).cmp(&source_preference(b)) (None, None) => source_preference(a).cmp(&source_preference(b)),
} }
}); });
@@ -154,7 +156,8 @@ fn try_which_command() -> Option<ClaudeInstallation> {
// Parse aliased output: "claude: aliased to /path/to/claude" // Parse aliased output: "claude: aliased to /path/to/claude"
let path = if output_str.starts_with("claude:") && output_str.contains("aliased to") { let path = if output_str.starts_with("claude:") && output_str.contains("aliased to") {
output_str.split("aliased to") output_str
.split("aliased to")
.nth(1) .nth(1)
.map(|s| s.trim().to_string()) .map(|s| s.trim().to_string())
} else { } else {
@@ -187,7 +190,10 @@ fn find_nvm_installations() -> Vec<ClaudeInstallation> {
let mut installations = Vec::new(); let mut installations = Vec::new();
if let Ok(home) = std::env::var("HOME") { if let Ok(home) = std::env::var("HOME") {
let nvm_dir = PathBuf::from(&home).join(".nvm").join("versions").join("node"); let nvm_dir = PathBuf::from(&home)
.join(".nvm")
.join("versions")
.join("node");
debug!("Checking NVM directory: {:?}", nvm_dir); debug!("Checking NVM directory: {:?}", nvm_dir);
@@ -226,7 +232,10 @@ fn find_standard_installations() -> Vec<ClaudeInstallation> {
// Common installation paths for claude // Common installation paths for claude
let mut paths_to_check: Vec<(String, String)> = vec![ let mut paths_to_check: Vec<(String, String)> = vec![
("/usr/local/bin/claude".to_string(), "system".to_string()), ("/usr/local/bin/claude".to_string(), "system".to_string()),
("/opt/homebrew/bin/claude".to_string(), "homebrew".to_string()), (
"/opt/homebrew/bin/claude".to_string(),
"homebrew".to_string(),
),
("/usr/bin/claude".to_string(), "system".to_string()), ("/usr/bin/claude".to_string(), "system".to_string()),
("/bin/claude".to_string(), "system".to_string()), ("/bin/claude".to_string(), "system".to_string()),
]; ];
@@ -234,15 +243,30 @@ fn find_standard_installations() -> Vec<ClaudeInstallation> {
// Also check user-specific paths // Also check user-specific paths
if let Ok(home) = std::env::var("HOME") { if let Ok(home) = std::env::var("HOME") {
paths_to_check.extend(vec![ paths_to_check.extend(vec![
(format!("{}/.claude/local/claude", home), "claude-local".to_string()), (
(format!("{}/.local/bin/claude", home), "local-bin".to_string()), format!("{}/.claude/local/claude", home),
(format!("{}/.npm-global/bin/claude", home), "npm-global".to_string()), "claude-local".to_string(),
),
(
format!("{}/.local/bin/claude", home),
"local-bin".to_string(),
),
(
format!("{}/.npm-global/bin/claude", home),
"npm-global".to_string(),
),
(format!("{}/.yarn/bin/claude", home), "yarn".to_string()), (format!("{}/.yarn/bin/claude", home), "yarn".to_string()),
(format!("{}/.bun/bin/claude", home), "bun".to_string()), (format!("{}/.bun/bin/claude", home), "bun".to_string()),
(format!("{}/bin/claude", home), "home-bin".to_string()), (format!("{}/bin/claude", home), "home-bin".to_string()),
// Check common node_modules locations // Check common node_modules locations
(format!("{}/node_modules/.bin/claude", home), "node-modules".to_string()), (
(format!("{}/.config/yarn/global/node_modules/.bin/claude", home), "yarn-global".to_string()), format!("{}/node_modules/.bin/claude", home),
"node-modules".to_string(),
),
(
format!("{}/.config/yarn/global/node_modules/.bin/claude", home),
"yarn-global".to_string(),
),
]); ]);
} }
@@ -302,11 +326,11 @@ fn extract_version_from_output(stdout: &[u8]) -> Option<String> {
let output_str = String::from_utf8_lossy(stdout); let output_str = String::from_utf8_lossy(stdout);
// Extract version: first token before whitespace that looks like a version // Extract version: first token before whitespace that looks like a version
output_str.split_whitespace() output_str
.split_whitespace()
.find(|token| { .find(|token| {
// Version usually contains dots and numbers // Version usually contains dots and numbers
token.chars().any(|c| c == '.') && token.chars().any(|c| c == '.') && token.chars().any(|c| c.is_numeric())
token.chars().any(|c| c.is_numeric())
}) })
.map(|s| s.to_string()) .map(|s| s.to_string())
} }
@@ -320,8 +344,7 @@ fn select_best_installation(installations: Vec<ClaudeInstallation>) -> Option<Cl
// prefer binaries with version information when it is available so that // prefer binaries with version information when it is available so that
// in development builds we keep the previous behaviour of picking the // in development builds we keep the previous behaviour of picking the
// most recent version. // most recent version.
installations.into_iter() installations.into_iter().max_by(|a, b| {
.max_by(|a, b| {
match (&a.version, &b.version) { match (&a.version, &b.version) {
// If both have versions, compare them semantically. // If both have versions, compare them semantically.
(Some(v1), Some(v2)) => compare_versions(v1, v2), (Some(v1), Some(v2)) => compare_versions(v1, v2),
@@ -347,7 +370,8 @@ fn select_best_installation(installations: Vec<ClaudeInstallation>) -> Option<Cl
/// Compare two version strings /// Compare two version strings
fn compare_versions(a: &str, b: &str) -> Ordering { fn compare_versions(a: &str, b: &str) -> Ordering {
// Simple semantic version comparison // Simple semantic version comparison
let a_parts: Vec<u32> = a.split('.') let a_parts: Vec<u32> = a
.split('.')
.filter_map(|s| { .filter_map(|s| {
// Handle versions like "1.0.17-beta" by taking only numeric part // Handle versions like "1.0.17-beta" by taking only numeric part
s.chars() s.chars()
@@ -358,7 +382,8 @@ fn compare_versions(a: &str, b: &str) -> Ordering {
}) })
.collect(); .collect();
let b_parts: Vec<u32> = b.split('.') let b_parts: Vec<u32> = b
.split('.')
.filter_map(|s| { .filter_map(|s| {
s.chars() s.chars()
.take_while(|c| c.is_numeric()) .take_while(|c| c.is_numeric())
@@ -389,10 +414,19 @@ pub fn create_command_with_env(program: &str) -> Command {
// Inherit essential environment variables from parent process // Inherit essential environment variables from parent process
for (key, value) in std::env::vars() { for (key, value) in std::env::vars() {
// Pass through PATH and other essential environment variables // Pass through PATH and other essential environment variables
if key == "PATH" || key == "HOME" || key == "USER" if key == "PATH"
|| key == "SHELL" || key == "LANG" || key == "LC_ALL" || key.starts_with("LC_") || key == "HOME"
|| key == "NODE_PATH" || key == "NVM_DIR" || key == "NVM_BIN" || key == "USER"
|| key == "HOMEBREW_PREFIX" || key == "HOMEBREW_CELLAR" { || key == "SHELL"
|| key == "LANG"
|| key == "LC_ALL"
|| key.starts_with("LC_")
|| key == "NODE_PATH"
|| key == "NVM_DIR"
|| key == "NVM_BIN"
|| key == "HOMEBREW_PREFIX"
|| key == "HOMEBREW_CELLAR"
{
debug!("Inheriting env var: {}={}", key, value); debug!("Inheriting env var: {}={}", key, value);
cmd.env(&key, &value); cmd.env(&key, &value);
} }

View File

@@ -2,16 +2,16 @@ use crate::sandbox::profile::ProfileBuilder;
use anyhow::Result; use anyhow::Result;
use chrono; use chrono;
use log::{debug, error, info, warn}; use log::{debug, error, info, warn};
use reqwest;
use rusqlite::{params, Connection, Result as SqliteResult}; use rusqlite::{params, Connection, Result as SqliteResult};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue; use serde_json::Value as JsonValue;
use std::path::PathBuf; use std::path::PathBuf;
use std::process::Stdio; use std::process::Stdio;
use std::sync::Mutex; use std::sync::Mutex;
use tauri::{AppHandle, Manager, State, Emitter}; use tauri::{AppHandle, Emitter, Manager, State};
use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command; use tokio::process::Command;
use reqwest;
/// Finds the full path to the claude binary /// Finds the full path to the claude binary
/// This is necessary because macOS apps have a limited PATH environment /// This is necessary because macOS apps have a limited PATH environment
@@ -125,14 +125,16 @@ impl AgentRunMetrics {
} }
// Extract token usage - check both top-level and nested message.usage // Extract token usage - check both top-level and nested message.usage
let usage = json.get("usage") let usage = json
.get("usage")
.or_else(|| json.get("message").and_then(|m| m.get("usage"))); .or_else(|| json.get("message").and_then(|m| m.get("usage")));
if let Some(usage) = usage { if let Some(usage) = usage {
if let Some(input_tokens) = usage.get("input_tokens").and_then(|t| t.as_i64()) { if let Some(input_tokens) = usage.get("input_tokens").and_then(|t| t.as_i64()) {
total_tokens += input_tokens; total_tokens += input_tokens;
} }
if let Some(output_tokens) = usage.get("output_tokens").and_then(|t| t.as_i64()) { if let Some(output_tokens) = usage.get("output_tokens").and_then(|t| t.as_i64())
{
total_tokens += output_tokens; total_tokens += output_tokens;
} }
} }
@@ -151,9 +153,17 @@ impl AgentRunMetrics {
Self { Self {
duration_ms, duration_ms,
total_tokens: if total_tokens > 0 { Some(total_tokens) } else { None }, total_tokens: if total_tokens > 0 {
Some(total_tokens)
} else {
None
},
cost_usd: if cost_usd > 0.0 { Some(cost_usd) } else { None }, cost_usd: if cost_usd > 0.0 { Some(cost_usd) } else { None },
message_count: if message_count > 0 { Some(message_count) } else { None }, message_count: if message_count > 0 {
Some(message_count)
} else {
None
},
} }
} }
} }
@@ -171,7 +181,10 @@ pub async fn read_session_jsonl(session_id: &str, project_path: &str) -> Result<
let session_file = project_dir.join(format!("{}.jsonl", session_id)); let session_file = project_dir.join(format!("{}.jsonl", session_id));
if !session_file.exists() { if !session_file.exists() {
return Err(format!("Session file not found: {}", session_file.display())); return Err(format!(
"Session file not found: {}",
session_file.display()
));
} }
match tokio::fs::read_to_string(&session_file).await { match tokio::fs::read_to_string(&session_file).await {
@@ -204,7 +217,10 @@ pub async fn get_agent_run_with_metrics(run: AgentRun) -> AgentRunWithMetrics {
/// Initialize the agents database /// Initialize the agents database
pub fn init_database(app: &AppHandle) -> SqliteResult<Connection> { pub fn init_database(app: &AppHandle) -> SqliteResult<Connection> {
let app_dir = app.path().app_data_dir().expect("Failed to get app data dir"); let app_dir = app
.path()
.app_data_dir()
.expect("Failed to get app data dir");
std::fs::create_dir_all(&app_dir).expect("Failed to create app data dir"); std::fs::create_dir_all(&app_dir).expect("Failed to create app data dir");
let db_path = app_dir.join("agents.db"); let db_path = app_dir.join("agents.db");
@@ -231,12 +247,30 @@ pub fn init_database(app: &AppHandle) -> SqliteResult<Connection> {
// Add columns to existing table if they don't exist // Add columns to existing table if they don't exist
let _ = conn.execute("ALTER TABLE agents ADD COLUMN default_task TEXT", []); let _ = conn.execute("ALTER TABLE agents ADD COLUMN default_task TEXT", []);
let _ = conn.execute("ALTER TABLE agents ADD COLUMN model TEXT DEFAULT 'sonnet'", []); let _ = conn.execute(
let _ = conn.execute("ALTER TABLE agents ADD COLUMN sandbox_profile_id INTEGER REFERENCES sandbox_profiles(id)", []); "ALTER TABLE agents ADD COLUMN model TEXT DEFAULT 'sonnet'",
let _ = conn.execute("ALTER TABLE agents ADD COLUMN sandbox_enabled BOOLEAN DEFAULT 1", []); [],
let _ = conn.execute("ALTER TABLE agents ADD COLUMN enable_file_read BOOLEAN DEFAULT 1", []); );
let _ = conn.execute("ALTER TABLE agents ADD COLUMN enable_file_write BOOLEAN DEFAULT 1", []); let _ = conn.execute(
let _ = conn.execute("ALTER TABLE agents ADD COLUMN enable_network BOOLEAN DEFAULT 0", []); "ALTER TABLE agents ADD COLUMN sandbox_profile_id INTEGER REFERENCES sandbox_profiles(id)",
[],
);
let _ = conn.execute(
"ALTER TABLE agents ADD COLUMN sandbox_enabled BOOLEAN DEFAULT 1",
[],
);
let _ = conn.execute(
"ALTER TABLE agents ADD COLUMN enable_file_read BOOLEAN DEFAULT 1",
[],
);
let _ = conn.execute(
"ALTER TABLE agents ADD COLUMN enable_file_write BOOLEAN DEFAULT 1",
[],
);
let _ = conn.execute(
"ALTER TABLE agents ADD COLUMN enable_network BOOLEAN DEFAULT 0",
[],
);
// Create agent_runs table // Create agent_runs table
conn.execute( conn.execute(
@@ -261,16 +295,28 @@ pub fn init_database(app: &AppHandle) -> SqliteResult<Connection> {
// Migrate existing agent_runs table if needed // Migrate existing agent_runs table if needed
let _ = conn.execute("ALTER TABLE agent_runs ADD COLUMN session_id TEXT", []); let _ = conn.execute("ALTER TABLE agent_runs ADD COLUMN session_id TEXT", []);
let _ = conn.execute("ALTER TABLE agent_runs ADD COLUMN status TEXT DEFAULT 'pending'", []); let _ = conn.execute(
"ALTER TABLE agent_runs ADD COLUMN status TEXT DEFAULT 'pending'",
[],
);
let _ = conn.execute("ALTER TABLE agent_runs ADD COLUMN pid INTEGER", []); let _ = conn.execute("ALTER TABLE agent_runs ADD COLUMN pid INTEGER", []);
let _ = conn.execute("ALTER TABLE agent_runs ADD COLUMN process_started_at TEXT", []); let _ = conn.execute(
"ALTER TABLE agent_runs ADD COLUMN process_started_at TEXT",
[],
);
// Drop old columns that are no longer needed (data is now read from JSONL files) // Drop old columns that are no longer needed (data is now read from JSONL files)
// Note: SQLite doesn't support DROP COLUMN, so we'll ignore errors for existing columns // Note: SQLite doesn't support DROP COLUMN, so we'll ignore errors for existing columns
let _ = conn.execute("UPDATE agent_runs SET session_id = '' WHERE session_id IS NULL", []); let _ = conn.execute(
"UPDATE agent_runs SET session_id = '' WHERE session_id IS NULL",
[],
);
let _ = conn.execute("UPDATE agent_runs SET status = 'completed' WHERE status IS NULL AND completed_at IS NOT NULL", []); let _ = conn.execute("UPDATE agent_runs SET status = 'completed' WHERE status IS NULL AND completed_at IS NOT NULL", []);
let _ = conn.execute("UPDATE agent_runs SET status = 'failed' WHERE status IS NULL AND completed_at IS NOT NULL AND session_id = ''", []); let _ = conn.execute("UPDATE agent_runs SET status = 'failed' WHERE status IS NULL AND completed_at IS NOT NULL AND session_id = ''", []);
let _ = conn.execute("UPDATE agent_runs SET status = 'pending' WHERE status IS NULL", []); let _ = conn.execute(
"UPDATE agent_runs SET status = 'pending' WHERE status IS NULL",
[],
);
// Create trigger to update the updated_at timestamp // Create trigger to update the updated_at timestamp
conn.execute( conn.execute(
@@ -395,7 +441,9 @@ pub async fn list_agents(db: State<'_, AgentDb>) -> Result<Vec<Agent>, String> {
icon: row.get(2)?, icon: row.get(2)?,
system_prompt: row.get(3)?, system_prompt: row.get(3)?,
default_task: row.get(4)?, default_task: row.get(4)?,
model: row.get::<_, String>(5).unwrap_or_else(|_| "sonnet".to_string()), model: row
.get::<_, String>(5)
.unwrap_or_else(|_| "sonnet".to_string()),
sandbox_enabled: row.get::<_, bool>(6).unwrap_or(true), sandbox_enabled: row.get::<_, bool>(6).unwrap_or(true),
enable_file_read: row.get::<_, bool>(7).unwrap_or(true), enable_file_read: row.get::<_, bool>(7).unwrap_or(true),
enable_file_write: row.get::<_, bool>(8).unwrap_or(true), enable_file_write: row.get::<_, bool>(8).unwrap_or(true),
@@ -486,7 +534,9 @@ pub async fn update_agent(
let model = model.unwrap_or_else(|| "sonnet".to_string()); let model = model.unwrap_or_else(|| "sonnet".to_string());
// Build dynamic query based on provided parameters // Build dynamic query based on provided parameters
let mut query = "UPDATE agents SET name = ?1, icon = ?2, system_prompt = ?3, default_task = ?4, model = ?5".to_string(); let mut query =
"UPDATE agents SET name = ?1, icon = ?2, system_prompt = ?3, default_task = ?4, model = ?5"
.to_string();
let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = vec![ let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = vec![
Box::new(name), Box::new(name),
Box::new(icon), Box::new(icon),
@@ -521,7 +571,10 @@ pub async fn update_agent(
query.push_str(&format!(" WHERE id = ?{}", param_count)); query.push_str(&format!(" WHERE id = ?{}", param_count));
params_vec.push(Box::new(id)); params_vec.push(Box::new(id));
conn.execute(&query, rusqlite::params_from_iter(params_vec.iter().map(|p| p.as_ref()))) conn.execute(
&query,
rusqlite::params_from_iter(params_vec.iter().map(|p| p.as_ref())),
)
.map_err(|e| e.to_string())?; .map_err(|e| e.to_string())?;
// Fetch the updated agent // Fetch the updated agent
@@ -621,8 +674,14 @@ pub async fn list_agent_runs(
model: row.get(5)?, model: row.get(5)?,
project_path: row.get(6)?, project_path: row.get(6)?,
session_id: row.get(7)?, session_id: row.get(7)?,
status: row.get::<_, String>(8).unwrap_or_else(|_| "pending".to_string()), status: row
pid: row.get::<_, Option<i64>>(9).ok().flatten().map(|p| p as u32), .get::<_, String>(8)
.unwrap_or_else(|_| "pending".to_string()),
pid: row
.get::<_, Option<i64>>(9)
.ok()
.flatten()
.map(|p| p as u32),
process_started_at: row.get(10)?, process_started_at: row.get(10)?,
created_at: row.get(11)?, created_at: row.get(11)?,
completed_at: row.get(12)?, completed_at: row.get(12)?,
@@ -676,7 +735,10 @@ pub async fn get_agent_run(db: State<'_, AgentDb>, id: i64) -> Result<AgentRun,
/// Get agent run with real-time metrics from JSONL /// Get agent run with real-time metrics from JSONL
#[tauri::command] #[tauri::command]
pub async fn get_agent_run_with_real_time_metrics(db: State<'_, AgentDb>, id: i64) -> Result<AgentRunWithMetrics, String> { pub async fn get_agent_run_with_real_time_metrics(
db: State<'_, AgentDb>,
id: i64,
) -> Result<AgentRunWithMetrics, String> {
let run = get_agent_run(db, id).await?; let run = get_agent_run(db, id).await?;
Ok(get_agent_run_with_metrics(run).await) Ok(get_agent_run_with_metrics(run).await)
} }
@@ -731,8 +793,10 @@ pub async fn execute_agent(
info!("🔓 Agent '{}': Sandbox DISABLED", agent.name); info!("🔓 Agent '{}': Sandbox DISABLED", agent.name);
None None
} else { } else {
info!("🔒 Agent '{}': Sandbox enabled | File Read: {} | File Write: {} | Network: {}", info!(
agent.name, agent.enable_file_read, agent.enable_file_write, agent.enable_network); "🔒 Agent '{}': Sandbox enabled | File Read: {} | File Write: {} | Network: {}",
agent.name, agent.enable_file_read, agent.enable_file_write, agent.enable_network
);
// Create rules dynamically based on agent permissions // Create rules dynamically based on agent permissions
let mut rules = Vec::new(); let mut rules = Vec::new();
@@ -905,10 +969,16 @@ pub async fn execute_agent(
return Err(e); return Err(e);
} }
}; };
match std::process::Command::new(&claude_path).arg("--version").output() { match std::process::Command::new(&claude_path)
.arg("--version")
.output()
{
Ok(output) => { Ok(output) => {
if output.status.success() { if output.status.success() {
info!("✅ Claude command works: {}", String::from_utf8_lossy(&output.stdout).trim()); info!(
"✅ Claude command works: {}",
String::from_utf8_lossy(&output.stdout).trim()
);
} else { } else {
warn!("⚠️ Claude command failed with status: {}", output.status); warn!("⚠️ Claude command failed with status: {}", output.status);
warn!(" stdout: {}", String::from_utf8_lossy(&output.stdout)); warn!(" stdout: {}", String::from_utf8_lossy(&output.stdout));
@@ -924,7 +994,8 @@ pub async fn execute_agent(
// Test if Claude can actually start a session (this might reveal auth issues) // Test if Claude can actually start a session (this might reveal auth issues)
info!("🧪 Testing Claude with exact same arguments as agent (without sandbox env vars)..."); info!("🧪 Testing Claude with exact same arguments as agent (without sandbox env vars)...");
let mut test_cmd = std::process::Command::new(&claude_path); let mut test_cmd = std::process::Command::new(&claude_path);
test_cmd.arg("-p") test_cmd
.arg("-p")
.arg(&task) .arg(&task)
.arg("--system-prompt") .arg("--system-prompt")
.arg(&agent.system_prompt) .arg(&agent.system_prompt)
@@ -991,33 +1062,38 @@ pub async fn execute_agent(
agent.sandbox_enabled, agent.sandbox_enabled,
agent.enable_file_read, agent.enable_file_read,
agent.enable_file_write, agent.enable_file_write,
agent.enable_network agent.enable_network,
) { ) {
Ok(build_result) => { Ok(build_result) => {
// Create the enhanced sandbox executor // Create the enhanced sandbox executor
#[cfg(unix)] #[cfg(unix)]
let executor = crate::sandbox::executor::SandboxExecutor::new_with_serialization( let executor =
crate::sandbox::executor::SandboxExecutor::new_with_serialization(
build_result.profile, build_result.profile,
project_path_buf.clone(), project_path_buf.clone(),
build_result.serialized build_result.serialized,
); );
#[cfg(not(unix))] #[cfg(not(unix))]
let executor = crate::sandbox::executor::SandboxExecutor::new_with_serialization( let executor =
crate::sandbox::executor::SandboxExecutor::new_with_serialization(
(), (),
project_path_buf.clone(), project_path_buf.clone(),
build_result.serialized build_result.serialized,
); );
// Prepare the sandboxed command // Prepare the sandboxed command
let args = vec![ let args = vec![
"-p", &task, "-p",
"--system-prompt", &agent.system_prompt, &task,
"--model", &execution_model, "--system-prompt",
"--output-format", "stream-json", &agent.system_prompt,
"--model",
&execution_model,
"--output-format",
"stream-json",
"--verbose", "--verbose",
"--dangerously-skip-permissions" "--dangerously-skip-permissions",
]; ];
let claude_path = match find_claude_binary(&app) { let claude_path = match find_claude_binary(&app) {
@@ -1057,7 +1133,10 @@ pub async fn execute_agent(
} }
} }
Err(e) => { Err(e) => {
error!("Failed to create ProfileBuilder: {}, falling back to non-sandboxed", e); error!(
"Failed to create ProfileBuilder: {}, falling back to non-sandboxed",
e
);
// Fall back to non-sandboxed command // Fall back to non-sandboxed command
let claude_path = match find_claude_binary(&app) { let claude_path = match find_claude_binary(&app) {
@@ -1086,7 +1165,10 @@ pub async fn execute_agent(
} }
} else { } else {
// No sandbox or sandbox disabled, use regular command // No sandbox or sandbox disabled, use regular command
warn!("🚨 Running agent '{}' WITHOUT SANDBOX - full system access!", agent.name); warn!(
"🚨 Running agent '{}' WITHOUT SANDBOX - full system access!",
agent.name
);
let claude_path = match find_claude_binary(&app) { let claude_path = match find_claude_binary(&app) {
Ok(path) => path, Ok(path) => path,
Err(e) => { Err(e) => {
@@ -1168,7 +1250,10 @@ pub async fn execute_agent(
// Log first output // Log first output
if !first_output_clone.load(std::sync::atomic::Ordering::Relaxed) { if !first_output_clone.load(std::sync::atomic::Ordering::Relaxed) {
info!("🎉 First output received from Claude process! Line: {}", line); info!(
"🎉 First output received from Claude process! Line: {}",
line
);
first_output_clone.store(true, std::sync::atomic::Ordering::Relaxed); first_output_clone.store(true, std::sync::atomic::Ordering::Relaxed);
} }
@@ -1205,7 +1290,10 @@ pub async fn execute_agent(
let _ = app_handle.emit("agent-output", &line); let _ = app_handle.emit("agent-output", &line);
} }
info!("📖 Finished reading Claude stdout. Total lines: {}", line_count); info!(
"📖 Finished reading Claude stdout. Total lines: {}",
line_count
);
}); });
let app_handle_stderr = app.clone(); let app_handle_stderr = app.clone();
@@ -1234,14 +1322,19 @@ pub async fn execute_agent(
} }
if error_count > 0 { if error_count > 0 {
warn!("📖 Finished reading Claude stderr. Total error lines: {}", error_count); warn!(
"📖 Finished reading Claude stderr. Total error lines: {}",
error_count
);
} else { } else {
info!("📖 Finished reading Claude stderr. No errors."); info!("📖 Finished reading Claude stderr. No errors.");
} }
}); });
// Register the process in the registry for live output tracking (after stdout/stderr setup) // Register the process in the registry for live output tracking (after stdout/stderr setup)
registry.0.register_process( registry
.0
.register_process(
run_id, run_id,
agent_id, agent_id,
agent.name.clone(), agent.name.clone(),
@@ -1250,11 +1343,15 @@ pub async fn execute_agent(
task.clone(), task.clone(),
execution_model.clone(), execution_model.clone(),
child, child,
).map_err(|e| format!("Failed to register process: {}", e))?; )
.map_err(|e| format!("Failed to register process: {}", e))?;
info!("📋 Registered process in registry"); info!("📋 Registered process in registry");
// Create variables we need for the spawned task // Create variables we need for the spawned task
let app_dir = app.path().app_data_dir().expect("Failed to get app data dir"); let app_dir = app
.path()
.app_data_dir()
.expect("Failed to get app data dir");
let db_path = app_dir.join("agents.db"); let db_path = app_dir.join("agents.db");
// Monitor process status and wait for completion // Monitor process status and wait for completion
@@ -1262,9 +1359,13 @@ pub async fn execute_agent(
info!("🕐 Starting process monitoring..."); info!("🕐 Starting process monitoring...");
// Wait for first output with timeout // Wait for first output with timeout
for i in 0..300 { // 30 seconds (300 * 100ms) for i in 0..300 {
// 30 seconds (300 * 100ms)
if first_output.load(std::sync::atomic::Ordering::Relaxed) { if first_output.load(std::sync::atomic::Ordering::Relaxed) {
info!("✅ Output detected after {}ms, continuing normal execution", i * 100); info!(
"✅ Output detected after {}ms, continuing normal execution",
i * 100
);
break; break;
} }
@@ -1272,7 +1373,10 @@ pub async fn execute_agent(
// Log progress every 5 seconds // Log progress every 5 seconds
if i > 0 && i % 50 == 0 { if i > 0 && i % 50 == 0 {
info!("⏳ Still waiting for Claude output... ({}s elapsed)", i / 10); info!(
"⏳ Still waiting for Claude output... ({}s elapsed)",
i / 10
);
} }
} }
@@ -1287,7 +1391,10 @@ pub async fn execute_agent(
warn!(" 5. Authentication issues (API key not found/invalid)"); warn!(" 5. Authentication issues (API key not found/invalid)");
// Process timed out - kill it via PID // Process timed out - kill it via PID
warn!("🔍 Process likely stuck waiting for input, attempting to kill PID: {}", pid); warn!(
"🔍 Process likely stuck waiting for input, attempting to kill PID: {}",
pid
);
let kill_result = std::process::Command::new("kill") let kill_result = std::process::Command::new("kill")
.arg("-TERM") .arg("-TERM")
.arg(pid.to_string()) .arg(pid.to_string())
@@ -1359,9 +1466,7 @@ pub async fn execute_agent(
/// List all currently running agent sessions /// List all currently running agent sessions
#[tauri::command] #[tauri::command]
pub async fn list_running_sessions( pub async fn list_running_sessions(db: State<'_, AgentDb>) -> Result<Vec<AgentRun>, String> {
db: State<'_, AgentDb>,
) -> Result<Vec<AgentRun>, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?; let conn = db.0.lock().map_err(|e| e.to_string())?;
let mut stmt = conn.prepare( let mut stmt = conn.prepare(
@@ -1369,7 +1474,8 @@ pub async fn list_running_sessions(
FROM agent_runs WHERE status = 'running' ORDER BY process_started_at DESC" FROM agent_runs WHERE status = 'running' ORDER BY process_started_at DESC"
).map_err(|e| e.to_string())?; ).map_err(|e| e.to_string())?;
let runs = stmt.query_map([], |row| { let runs = stmt
.query_map([], |row| {
Ok(AgentRun { Ok(AgentRun {
id: Some(row.get(0)?), id: Some(row.get(0)?),
agent_id: row.get(1)?, agent_id: row.get(1)?,
@@ -1379,8 +1485,14 @@ pub async fn list_running_sessions(
model: row.get(5)?, model: row.get(5)?,
project_path: row.get(6)?, project_path: row.get(6)?,
session_id: row.get(7)?, session_id: row.get(7)?,
status: row.get::<_, String>(8).unwrap_or_else(|_| "pending".to_string()), status: row
pid: row.get::<_, Option<i64>>(9).ok().flatten().map(|p| p as u32), .get::<_, String>(8)
.unwrap_or_else(|_| "pending".to_string()),
pid: row
.get::<_, Option<i64>>(9)
.ok()
.flatten()
.map(|p| p as u32),
process_started_at: row.get(10)?, process_started_at: row.get(10)?,
created_at: row.get(11)?, created_at: row.get(11)?,
completed_at: row.get(12)?, completed_at: row.get(12)?,
@@ -1427,7 +1539,7 @@ pub async fn kill_agent_session(
conn.query_row( conn.query_row(
"SELECT pid FROM agent_runs WHERE id = ?1 AND status = 'running'", "SELECT pid FROM agent_runs WHERE id = ?1 AND status = 'running'",
params![run_id], params![run_id],
|row| row.get::<_, Option<i64>>(0) |row| row.get::<_, Option<i64>>(0),
) )
.map_err(|e| e.to_string())? .map_err(|e| e.to_string())?
}; };
@@ -1462,7 +1574,7 @@ pub async fn get_session_status(
match conn.query_row( match conn.query_row(
"SELECT status FROM agent_runs WHERE id = ?1", "SELECT status FROM agent_runs WHERE id = ?1",
params![run_id], params![run_id],
|row| row.get::<_, String>(0) |row| row.get::<_, String>(0),
) { ) {
Ok(status) => Ok(Some(status)), Ok(status) => Ok(Some(status)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
@@ -1472,19 +1584,16 @@ pub async fn get_session_status(
/// Cleanup finished processes and update their status /// Cleanup finished processes and update their status
#[tauri::command] #[tauri::command]
pub async fn cleanup_finished_processes( pub async fn cleanup_finished_processes(db: State<'_, AgentDb>) -> Result<Vec<i64>, String> {
db: State<'_, AgentDb>,
) -> Result<Vec<i64>, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?; let conn = db.0.lock().map_err(|e| e.to_string())?;
// Get all running processes // Get all running processes
let mut stmt = conn.prepare( let mut stmt = conn
"SELECT id, pid FROM agent_runs WHERE status = 'running' AND pid IS NOT NULL" .prepare("SELECT id, pid FROM agent_runs WHERE status = 'running' AND pid IS NOT NULL")
).map_err(|e| e.to_string())?; .map_err(|e| e.to_string())?;
let running_processes = stmt.query_map([], |row| { let running_processes = stmt
Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)) .query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)))
})
.map_err(|e| e.to_string())? .map_err(|e| e.to_string())?
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?; .map_err(|e| e.to_string())?;
@@ -1528,7 +1637,10 @@ pub async fn cleanup_finished_processes(
if updated > 0 { if updated > 0 {
cleaned_up.push(run_id); cleaned_up.push(run_id);
info!("Marked agent run {} as completed (PID {} no longer running)", run_id, pid); info!(
"Marked agent run {} as completed (PID {} no longer running)",
run_id, pid
);
} }
} }
} }
@@ -1615,7 +1727,8 @@ pub async fn stream_session_output(
if current_size > last_size { if current_size > last_size {
// File has grown, read new content // File has grown, read new content
if let Ok(content) = tokio::fs::read_to_string(&session_file).await { if let Ok(content) = tokio::fs::read_to_string(&session_file).await {
let _ = app.emit("session-output-update", &format!("{}:{}", run_id, content)); let _ = app
.emit("session-output-update", &format!("{}:{}", run_id, content));
} }
last_size = current_size; last_size = current_size;
} }
@@ -1629,12 +1742,15 @@ pub async fn stream_session_output(
// Check if the session is still running by querying the database // Check if the session is still running by querying the database
// If the session is no longer running, stop streaming // If the session is no longer running, stop streaming
if let Ok(conn) = rusqlite::Connection::open( if let Ok(conn) = rusqlite::Connection::open(
app.path().app_data_dir().expect("Failed to get app data dir").join("agents.db") app.path()
.app_data_dir()
.expect("Failed to get app data dir")
.join("agents.db"),
) { ) {
if let Ok(status) = conn.query_row( if let Ok(status) = conn.query_row(
"SELECT status FROM agent_runs WHERE id = ?1", "SELECT status FROM agent_runs WHERE id = ?1",
rusqlite::params![run_id], rusqlite::params![run_id],
|row| row.get::<_, String>(0) |row| row.get::<_, String>(0),
) { ) {
if status != "running" { if status != "running" {
debug!("Session {} is no longer running, stopping stream", run_id); debug!("Session {} is no longer running, stopping stream", run_id);
@@ -1642,7 +1758,10 @@ pub async fn stream_session_output(
} }
} else { } else {
// If we can't query the status, assume it's still running // If we can't query the status, assume it's still running
debug!("Could not query session status for {}, continuing stream", run_id); debug!(
"Could not query session status for {}, continuing stream",
run_id
);
} }
} }
@@ -1695,13 +1814,16 @@ pub async fn export_agent(db: State<'_, AgentDb>, id: i64) -> Result<String, Str
/// Export agent to file with native dialog /// Export agent to file with native dialog
#[tauri::command] #[tauri::command]
pub async fn export_agent_to_file(db: State<'_, AgentDb>, id: i64, file_path: String) -> Result<(), String> { pub async fn export_agent_to_file(
db: State<'_, AgentDb>,
id: i64,
file_path: String,
) -> Result<(), String> {
// Get the JSON data // Get the JSON data
let json_data = export_agent(db, id).await?; let json_data = export_agent(db, id).await?;
// Write to file // Write to file
std::fs::write(&file_path, json_data) std::fs::write(&file_path, json_data).map_err(|e| format!("Failed to write file: {}", e))?;
.map_err(|e| format!("Failed to write file: {}", e))?;
Ok(()) Ok(())
} }
@@ -1750,14 +1872,16 @@ pub async fn set_claude_binary_path(db: State<'_, AgentDb>, path: String) -> Res
"INSERT INTO app_settings (key, value) VALUES ('claude_binary_path', ?1) "INSERT INTO app_settings (key, value) VALUES ('claude_binary_path', ?1)
ON CONFLICT(key) DO UPDATE SET value = ?1", ON CONFLICT(key) DO UPDATE SET value = ?1",
params![path], params![path],
).map_err(|e| format!("Failed to save Claude binary path: {}", e))?; )
.map_err(|e| format!("Failed to save Claude binary path: {}", e))?;
Ok(()) Ok(())
} }
/// List all available Claude installations on the system /// List all available Claude installations on the system
#[tauri::command] #[tauri::command]
pub async fn list_claude_installations() -> Result<Vec<crate::claude_binary::ClaudeInstallation>, String> { pub async fn list_claude_installations(
) -> Result<Vec<crate::claude_binary::ClaudeInstallation>, String> {
let installations = crate::claude_binary::discover_claude_installations(); let installations = crate::claude_binary::discover_claude_installations();
if installations.is_empty() { if installations.is_empty() {
@@ -1779,10 +1903,19 @@ fn create_command_with_env(program: &str) -> Command {
// Copy over all environment variables from the std::process::Command // Copy over all environment variables from the std::process::Command
// This is a workaround since we can't directly convert between the two types // This is a workaround since we can't directly convert between the two types
for (key, value) in std::env::vars() { for (key, value) in std::env::vars() {
if key == "PATH" || key == "HOME" || key == "USER" if key == "PATH"
|| key == "SHELL" || key == "LANG" || key == "LC_ALL" || key.starts_with("LC_") || key == "HOME"
|| key == "NODE_PATH" || key == "NVM_DIR" || key == "NVM_BIN" || key == "USER"
|| key == "HOMEBREW_PREFIX" || key == "HOMEBREW_CELLAR" { || key == "SHELL"
|| key == "LANG"
|| key == "LC_ALL"
|| key.starts_with("LC_")
|| key == "NODE_PATH"
|| key == "NVM_DIR"
|| key == "NVM_BIN"
|| key == "HOMEBREW_PREFIX"
|| key == "HOMEBREW_CELLAR"
{
tokio_cmd.env(&key, &value); tokio_cmd.env(&key, &value);
} }
} }
@@ -1820,12 +1953,15 @@ fn create_command_with_env(program: &str) -> Command {
#[tauri::command] #[tauri::command]
pub async fn import_agent(db: State<'_, AgentDb>, json_data: String) -> Result<Agent, String> { pub async fn import_agent(db: State<'_, AgentDb>, json_data: String) -> Result<Agent, String> {
// Parse the JSON data // Parse the JSON data
let export_data: AgentExport = serde_json::from_str(&json_data) let export_data: AgentExport =
.map_err(|e| format!("Invalid JSON format: {}", e))?; serde_json::from_str(&json_data).map_err(|e| format!("Invalid JSON format: {}", e))?;
// Validate version // Validate version
if export_data.version != 1 { if export_data.version != 1 {
return Err(format!("Unsupported export version: {}. This version of the app only supports version 1.", export_data.version)); return Err(format!(
"Unsupported export version: {}. This version of the app only supports version 1.",
export_data.version
));
} }
let agent_data = export_data.agent; let agent_data = export_data.agent;
@@ -1895,10 +2031,13 @@ pub async fn import_agent(db: State<'_, AgentDb>, json_data: String) -> Result<A
/// Import agent from file /// Import agent from file
#[tauri::command] #[tauri::command]
pub async fn import_agent_from_file(db: State<'_, AgentDb>, file_path: String) -> Result<Agent, String> { pub async fn import_agent_from_file(
db: State<'_, AgentDb>,
file_path: String,
) -> Result<Agent, String> {
// Read the file // Read the file
let json_data = std::fs::read_to_string(&file_path) let json_data =
.map_err(|e| format!("Failed to read file: {}", e))?; std::fs::read_to_string(&file_path).map_err(|e| format!("Failed to read file: {}", e))?;
// Import the agent // Import the agent
import_agent(db, json_data).await import_agent(db, json_data).await
@@ -1989,7 +2128,10 @@ pub async fn fetch_github_agent_content(download_url: String) -> Result<AgentExp
.map_err(|e| format!("Failed to download agent: {}", e))?; .map_err(|e| format!("Failed to download agent: {}", e))?;
if !response.status().is_success() { if !response.status().is_success() {
return Err(format!("Failed to download agent: HTTP {}", response.status())); return Err(format!(
"Failed to download agent: HTTP {}",
response.status()
));
} }
let json_text = response let json_text = response
@@ -2003,7 +2145,10 @@ pub async fn fetch_github_agent_content(download_url: String) -> Result<AgentExp
// Validate version // Validate version
if export_data.version != 1 { if export_data.version != 1 {
return Err(format!("Unsupported agent version: {}", export_data.version)); return Err(format!(
"Unsupported agent version: {}",
export_data.version
));
} }
Ok(export_data) Ok(export_data)

View File

@@ -1,14 +1,14 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fs; use std::fs;
use std::path::PathBuf;
use std::time::SystemTime;
use std::io::{BufRead, BufReader}; use std::io::{BufRead, BufReader};
use std::path::PathBuf;
use std::process::Stdio; use std::process::Stdio;
use tauri::{AppHandle, Emitter, Manager};
use tokio::process::{Command, Child};
use tokio::sync::Mutex;
use std::sync::Arc; use std::sync::Arc;
use std::time::SystemTime;
use tauri::{AppHandle, Emitter, Manager};
use tokio::process::{Child, Command};
use tokio::sync::Mutex;
use uuid; use uuid;
/// Global state to track current Claude process /// Global state to track current Claude process
@@ -202,7 +202,9 @@ fn extract_first_user_message(jsonl_path: &PathBuf) -> (Option<String>, Option<S
} }
// Skip if it starts with command tags // Skip if it starts with command tags
if content.starts_with("<command-name>") || content.starts_with("<local-command-stdout>") { if content.starts_with("<command-name>")
|| content.starts_with("<local-command-stdout>")
{
continue; continue;
} }
@@ -229,10 +231,19 @@ fn create_command_with_env(program: &str) -> Command {
// Copy over all environment variables // Copy over all environment variables
for (key, value) in std::env::vars() { for (key, value) in std::env::vars() {
if key == "PATH" || key == "HOME" || key == "USER" if key == "PATH"
|| key == "SHELL" || key == "LANG" || key == "LC_ALL" || key.starts_with("LC_") || key == "HOME"
|| key == "NODE_PATH" || key == "NVM_DIR" || key == "NVM_BIN" || key == "USER"
|| key == "HOMEBREW_PREFIX" || key == "HOMEBREW_CELLAR" { || key == "SHELL"
|| key == "LANG"
|| key == "LC_ALL"
|| key.starts_with("LC_")
|| key == "NODE_PATH"
|| key == "NVM_DIR"
|| key == "NVM_BIN"
|| key == "HOMEBREW_PREFIX"
|| key == "HOMEBREW_CELLAR"
{
log::debug!("Inheriting env var: {}={}", key, value); log::debug!("Inheriting env var: {}={}", key, value);
tokio_cmd.env(&key, &value); tokio_cmd.env(&key, &value);
} }
@@ -308,8 +319,11 @@ pub async fn list_projects() -> Result<Vec<Project>, String> {
if let Ok(session_entries) = fs::read_dir(&path) { if let Ok(session_entries) = fs::read_dir(&path) {
for session_entry in session_entries.flatten() { for session_entry in session_entries.flatten() {
let session_path = session_entry.path(); let session_path = session_entry.path();
if session_path.is_file() && session_path.extension().and_then(|s| s.to_str()) == Some("jsonl") { if session_path.is_file()
if let Some(session_id) = session_path.file_stem().and_then(|s| s.to_str()) { && session_path.extension().and_then(|s| s.to_str()) == Some("jsonl")
{
if let Some(session_id) = session_path.file_stem().and_then(|s| s.to_str())
{
sessions.push(session_id.to_string()); sessions.push(session_id.to_string());
} }
} }
@@ -349,7 +363,11 @@ pub async fn get_project_sessions(project_id: String) -> Result<Vec<Session>, St
let project_path = match get_project_path_from_sessions(&project_dir) { let project_path = match get_project_path_from_sessions(&project_dir) {
Ok(path) => path, Ok(path) => path,
Err(e) => { Err(e) => {
log::warn!("Failed to get project path from sessions for {}: {}, falling back to decode", project_id, e); log::warn!(
"Failed to get project path from sessions for {}: {}, falling back to decode",
project_id,
e
);
decode_project_path(&project_id) decode_project_path(&project_id)
} }
}; };
@@ -407,7 +425,11 @@ pub async fn get_project_sessions(project_id: String) -> Result<Vec<Session>, St
// Sort sessions by creation time (newest first) // Sort sessions by creation time (newest first)
sessions.sort_by(|a, b| b.created_at.cmp(&a.created_at)); sessions.sort_by(|a, b| b.created_at.cmp(&a.created_at));
log::info!("Found {} sessions for project {}", sessions.len(), project_id); log::info!(
"Found {} sessions for project {}",
sessions.len(),
project_id
);
Ok(sessions) Ok(sessions)
} }
@@ -490,8 +512,7 @@ pub async fn get_system_prompt() -> Result<String, String> {
return Ok(String::new()); return Ok(String::new());
} }
fs::read_to_string(&claude_md_path) fs::read_to_string(&claude_md_path).map_err(|e| format!("Failed to read CLAUDE.md: {}", e))
.map_err(|e| format!("Failed to read CLAUDE.md: {}", e))
} }
/// Checks if Claude Code is installed and gets its version /// Checks if Claude Code is installed and gets its version
@@ -540,7 +561,11 @@ pub async fn check_claude_version(app: AppHandle) -> Result<ClaudeVersionStatus,
Ok(output) => { Ok(output) => {
let stdout = String::from_utf8_lossy(&output.stdout).to_string(); let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string(); let stderr = String::from_utf8_lossy(&output.stderr).to_string();
let full_output = if stderr.is_empty() { stdout.clone() } else { format!("{}\n{}", stdout, stderr) }; let full_output = if stderr.is_empty() {
stdout.clone()
} else {
format!("{}\n{}", stdout, stderr)
};
// Check if the output matches the expected format // Check if the output matches the expected format
// Expected format: "1.0.17 (Claude Code)" or similar // Expected format: "1.0.17 (Claude Code)" or similar
@@ -549,9 +574,7 @@ pub async fn check_claude_version(app: AppHandle) -> Result<ClaudeVersionStatus,
// Extract version number if valid // Extract version number if valid
let version = if is_valid { let version = if is_valid {
// Try to extract just the version number // Try to extract just the version number
stdout.split_whitespace() stdout.split_whitespace().next().map(|s| s.to_string())
.next()
.map(|s| s.to_string())
} else { } else {
None None
}; };
@@ -582,8 +605,7 @@ pub async fn save_system_prompt(content: String) -> Result<String, String> {
let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; let claude_dir = get_claude_dir().map_err(|e| e.to_string())?;
let claude_md_path = claude_dir.join("CLAUDE.md"); let claude_md_path = claude_dir.join("CLAUDE.md");
fs::write(&claude_md_path, content) fs::write(&claude_md_path, content).map_err(|e| format!("Failed to write CLAUDE.md: {}", e))?;
.map_err(|e| format!("Failed to write CLAUDE.md: {}", e))?;
Ok("System prompt saved successfully".to_string()) Ok("System prompt saved successfully".to_string())
} }
@@ -649,7 +671,10 @@ fn find_claude_md_recursive(
if path.is_dir() { if path.is_dir() {
// Skip common directories that shouldn't be scanned // Skip common directories that shouldn't be scanned
if let Some(dir_name) = path.file_name().and_then(|n| n.to_str()) { if let Some(dir_name) = path.file_name().and_then(|n| n.to_str()) {
if matches!(dir_name, "node_modules" | "target" | ".git" | "dist" | "build" | ".next" | "__pycache__") { if matches!(
dir_name,
"node_modules" | "target" | ".git" | "dist" | "build" | ".next" | "__pycache__"
) {
continue; continue;
} }
} }
@@ -663,7 +688,8 @@ fn find_claude_md_recursive(
let metadata = fs::metadata(&path) let metadata = fs::metadata(&path)
.map_err(|e| format!("Failed to read file metadata: {}", e))?; .map_err(|e| format!("Failed to read file metadata: {}", e))?;
let relative_path = path.strip_prefix(project_root) let relative_path = path
.strip_prefix(project_root)
.map_err(|e| format!("Failed to get relative path: {}", e))? .map_err(|e| format!("Failed to get relative path: {}", e))?
.to_string_lossy() .to_string_lossy()
.to_string(); .to_string();
@@ -699,8 +725,7 @@ pub async fn read_claude_md_file(file_path: String) -> Result<String, String> {
return Err(format!("File does not exist: {}", file_path)); return Err(format!("File does not exist: {}", file_path));
} }
fs::read_to_string(&path) fs::read_to_string(&path).map_err(|e| format!("Failed to read file: {}", e))
.map_err(|e| format!("Failed to read file: {}", e))
} }
/// Saves a specific CLAUDE.md file by its absolute path /// Saves a specific CLAUDE.md file by its absolute path
@@ -716,26 +741,35 @@ pub async fn save_claude_md_file(file_path: String, content: String) -> Result<S
.map_err(|e| format!("Failed to create parent directory: {}", e))?; .map_err(|e| format!("Failed to create parent directory: {}", e))?;
} }
fs::write(&path, content) fs::write(&path, content).map_err(|e| format!("Failed to write file: {}", e))?;
.map_err(|e| format!("Failed to write file: {}", e))?;
Ok("File saved successfully".to_string()) Ok("File saved successfully".to_string())
} }
/// Loads the JSONL history for a specific session /// Loads the JSONL history for a specific session
#[tauri::command] #[tauri::command]
pub async fn load_session_history(session_id: String, project_id: String) -> Result<Vec<serde_json::Value>, String> { pub async fn load_session_history(
log::info!("Loading session history for session: {} in project: {}", session_id, project_id); session_id: String,
project_id: String,
) -> Result<Vec<serde_json::Value>, String> {
log::info!(
"Loading session history for session: {} in project: {}",
session_id,
project_id
);
let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; let claude_dir = get_claude_dir().map_err(|e| e.to_string())?;
let session_path = claude_dir.join("projects").join(&project_id).join(format!("{}.jsonl", session_id)); let session_path = claude_dir
.join("projects")
.join(&project_id)
.join(format!("{}.jsonl", session_id));
if !session_path.exists() { if !session_path.exists() {
return Err(format!("Session file not found: {}", session_id)); return Err(format!("Session file not found: {}", session_id));
} }
let file = fs::File::open(&session_path) let file =
.map_err(|e| format!("Failed to open session file: {}", e))?; fs::File::open(&session_path).map_err(|e| format!("Failed to open session file: {}", e))?;
let reader = BufReader::new(file); let reader = BufReader::new(file);
let mut messages = Vec::new(); let mut messages = Vec::new();
@@ -759,7 +793,11 @@ pub async fn execute_claude_code(
prompt: String, prompt: String,
model: String, model: String,
) -> Result<(), String> { ) -> Result<(), String> {
log::info!("Starting new Claude Code session in: {} with model: {}", project_path, model); log::info!(
"Starting new Claude Code session in: {} with model: {}",
project_path,
model
);
// Check if sandboxing should be used // Check if sandboxing should be used
let use_sandbox = should_use_sandbox(&app)?; let use_sandbox = should_use_sandbox(&app)?;
@@ -794,7 +832,11 @@ pub async fn continue_claude_code(
prompt: String, prompt: String,
model: String, model: String,
) -> Result<(), String> { ) -> Result<(), String> {
log::info!("Continuing Claude Code conversation in: {} with model: {}", project_path, model); log::info!(
"Continuing Claude Code conversation in: {} with model: {}",
project_path,
model
);
// Check if sandboxing should be used // Check if sandboxing should be used
let use_sandbox = should_use_sandbox(&app)?; let use_sandbox = should_use_sandbox(&app)?;
@@ -831,7 +873,12 @@ pub async fn resume_claude_code(
prompt: String, prompt: String,
model: String, model: String,
) -> Result<(), String> { ) -> Result<(), String> {
log::info!("Resuming Claude Code session: {} in: {} with model: {}", session_id, project_path, model); log::info!(
"Resuming Claude Code session: {} in: {} with model: {}",
session_id,
project_path,
model
);
// Check if sandboxing should be used // Check if sandboxing should be used
let use_sandbox = should_use_sandbox(&app)?; let use_sandbox = should_use_sandbox(&app)?;
@@ -862,8 +909,14 @@ pub async fn resume_claude_code(
/// Cancel the currently running Claude Code execution /// Cancel the currently running Claude Code execution
#[tauri::command] #[tauri::command]
pub async fn cancel_claude_execution(app: AppHandle, session_id: Option<String>) -> Result<(), String> { pub async fn cancel_claude_execution(
log::info!("Cancelling Claude Code execution for session: {:?}", session_id); app: AppHandle,
session_id: Option<String>,
) -> Result<(), String> {
log::info!(
"Cancelling Claude Code execution for session: {:?}",
session_id
);
let claude_state = app.state::<ClaudeProcessState>(); let claude_state = app.state::<ClaudeProcessState>();
let mut current_process = claude_state.current_process.lock().await; let mut current_process = claude_state.current_process.lock().await;
@@ -914,7 +967,11 @@ fn should_use_sandbox(app: &AppHandle) -> Result<bool, String> {
let settings = get_claude_settings_sync(app)?; let settings = get_claude_settings_sync(app)?;
// Check for a sandboxing setting in the settings // Check for a sandboxing setting in the settings
if let Some(sandbox_enabled) = settings.data.get("sandboxEnabled").and_then(|v| v.as_bool()) { if let Some(sandbox_enabled) = settings
.data
.get("sandboxEnabled")
.and_then(|v| v.as_bool())
{
return Ok(sandbox_enabled); return Ok(sandbox_enabled);
} }
@@ -924,12 +981,13 @@ fn should_use_sandbox(app: &AppHandle) -> Result<bool, String> {
/// Helper function to create a sandboxed Claude command /// Helper function to create a sandboxed Claude command
fn create_sandboxed_claude_command(app: &AppHandle, project_path: &str) -> Result<Command, String> { fn create_sandboxed_claude_command(app: &AppHandle, project_path: &str) -> Result<Command, String> {
use crate::sandbox::{profile::ProfileBuilder, executor::create_sandboxed_command}; use crate::sandbox::{executor::create_sandboxed_command, profile::ProfileBuilder};
use std::path::PathBuf; use std::path::PathBuf;
// Get the database connection // Get the database connection
let conn = { let conn = {
let app_data_dir = app.path() let app_data_dir = app
.path()
.app_data_dir() .app_data_dir()
.map_err(|e| format!("Failed to get app data dir: {}", e))?; .map_err(|e| format!("Failed to get app data dir: {}", e))?;
let db_path = app_data_dir.join("agents.db"); let db_path = app_data_dir.join("agents.db");
@@ -948,21 +1006,28 @@ fn create_sandboxed_claude_command(app: &AppHandle, project_path: &str) -> Resul
match profile_id { match profile_id {
Some(profile_id) => { Some(profile_id) => {
log::info!("Using default sandbox profile: {} (id: {})", profile_id, profile_id); log::info!(
"Using default sandbox profile: {} (id: {})",
profile_id,
profile_id
);
// Get all rules for this profile // Get all rules for this profile
let mut stmt = conn.prepare( let mut stmt = conn
.prepare(
"SELECT operation_type, pattern_type, pattern_value, enabled, platform_support "SELECT operation_type, pattern_type, pattern_value, enabled, platform_support
FROM sandbox_rules WHERE profile_id = ?1 AND enabled = 1" FROM sandbox_rules WHERE profile_id = ?1 AND enabled = 1",
).map_err(|e| e.to_string())?; )
.map_err(|e| e.to_string())?;
let rules = stmt.query_map(rusqlite::params![profile_id], |row| { let rules = stmt
.query_map(rusqlite::params![profile_id], |row| {
Ok(( Ok((
row.get::<_, String>(0)?, row.get::<_, String>(0)?,
row.get::<_, String>(1)?, row.get::<_, String>(1)?,
row.get::<_, String>(2)?, row.get::<_, String>(2)?,
row.get::<_, bool>(3)?, row.get::<_, bool>(3)?,
row.get::<_, Option<String>>(4)? row.get::<_, Option<String>>(4)?,
)) ))
}) })
.map_err(|e| e.to_string())? .map_err(|e| e.to_string())?
@@ -979,10 +1044,14 @@ fn create_sandboxed_claude_command(app: &AppHandle, project_path: &str) -> Resul
// Convert database rules to SandboxRule structs // Convert database rules to SandboxRule structs
let mut sandbox_rules = Vec::new(); let mut sandbox_rules = Vec::new();
for (idx, (op_type, pattern_type, pattern_value, enabled, platform_support)) in rules.into_iter().enumerate() { for (idx, (op_type, pattern_type, pattern_value, enabled, platform_support)) in
rules.into_iter().enumerate()
{
// Check if this rule applies to the current platform // Check if this rule applies to the current platform
if let Some(platforms_json) = &platform_support { if let Some(platforms_json) = &platform_support {
if let Ok(platforms) = serde_json::from_str::<Vec<String>>(platforms_json) { if let Ok(platforms) =
serde_json::from_str::<Vec<String>>(platforms_json)
{
let current_platform = if cfg!(target_os = "linux") { let current_platform = if cfg!(target_os = "linux") {
"linux" "linux"
} else if cfg!(target_os = "macos") { } else if cfg!(target_os = "macos") {
@@ -1022,11 +1091,19 @@ fn create_sandboxed_claude_command(app: &AppHandle, project_path: &str) -> Resul
// Use the helper function to create sandboxed command // Use the helper function to create sandboxed command
let claude_path = find_claude_binary(app)?; let claude_path = find_claude_binary(app)?;
#[cfg(unix)] #[cfg(unix)]
return Ok(create_sandboxed_command(&claude_path, &[], &project_path_buf, profile, project_path_buf.clone())); return Ok(create_sandboxed_command(
&claude_path,
&[],
&project_path_buf,
profile,
project_path_buf.clone(),
));
#[cfg(not(unix))] #[cfg(not(unix))]
{ {
log::warn!("Sandboxing not supported on Windows, using regular command"); log::warn!(
"Sandboxing not supported on Windows, using regular command"
);
Ok(create_command_with_env(&claude_path)) Ok(create_command_with_env(&claude_path))
} }
} }
@@ -1038,7 +1115,10 @@ fn create_sandboxed_claude_command(app: &AppHandle, project_path: &str) -> Resul
} }
} }
Err(e) => { Err(e) => {
log::error!("Failed to create ProfileBuilder: {}, falling back to non-sandboxed", e); log::error!(
"Failed to create ProfileBuilder: {}, falling back to non-sandboxed",
e
);
let claude_path = find_claude_binary(app)?; let claude_path = find_claude_binary(app)?;
Ok(create_command_with_env(&claude_path)) Ok(create_command_with_env(&claude_path))
} }
@@ -1075,7 +1155,8 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St
use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::io::{AsyncBufReadExt, BufReader};
// Generate a unique session ID for this Claude Code session // Generate a unique session ID for this Claude Code session
let session_id = format!("claude-{}-{}", let session_id = format!(
"claude-{}-{}",
std::time::SystemTime::now() std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default() .unwrap_or_default()
@@ -1084,7 +1165,9 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St
); );
// Spawn the process // Spawn the process
let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn Claude: {}", e))?; let mut child = cmd
.spawn()
.map_err(|e| format!("Failed to spawn Claude: {}", e))?;
// Get stdout and stderr // Get stdout and stderr
let stdout = child.stdout.take().ok_or("Failed to get stdout")?; let stdout = child.stdout.take().ok_or("Failed to get stdout")?;
@@ -1092,7 +1175,11 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St
// Get the child PID for logging // Get the child PID for logging
let pid = child.id(); let pid = child.id();
log::info!("Spawned Claude process with PID: {:?} and session ID: {}", pid, session_id); log::info!(
"Spawned Claude process with PID: {:?} and session ID: {}",
pid,
session_id
);
// Create readers // Create readers
let stdout_reader = BufReader::new(stdout); let stdout_reader = BufReader::new(stdout);
@@ -1153,7 +1240,10 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St
log::info!("Claude process exited with status: {}", status); log::info!("Claude process exited with status: {}", status);
// Add a small delay to ensure all messages are processed // Add a small delay to ensure all messages are processed
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let _ = app_handle_wait.emit(&format!("claude-complete:{}", session_id_clone3), status.success()); let _ = app_handle_wait.emit(
&format!("claude-complete:{}", session_id_clone3),
status.success(),
);
// Also emit to the generic event for backward compatibility // Also emit to the generic event for backward compatibility
let _ = app_handle_wait.emit("claude-complete", status.success()); let _ = app_handle_wait.emit("claude-complete", status.success());
} }
@@ -1161,7 +1251,8 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St
log::error!("Failed to wait for Claude process: {}", e); log::error!("Failed to wait for Claude process: {}", e);
// Add a small delay to ensure all messages are processed // Add a small delay to ensure all messages are processed
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let _ = app_handle_wait.emit(&format!("claude-complete:{}", session_id_clone3), false); let _ = app_handle_wait
.emit(&format!("claude-complete:{}", session_id_clone3), false);
// Also emit to the generic event for backward compatibility // Also emit to the generic event for backward compatibility
let _ = app_handle_wait.emit("claude-complete", false); let _ = app_handle_wait.emit("claude-complete", false);
} }
@@ -1173,7 +1264,10 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St
}); });
// Return the session ID to the frontend // Return the session ID to the frontend
let _ = app.emit(&format!("claude-session-started:{}", session_id), session_id.clone()); let _ = app.emit(
&format!("claude-session-started:{}", session_id),
session_id.clone(),
);
Ok(()) Ok(())
} }
@@ -1204,13 +1298,14 @@ pub async fn list_directory_contents(directory_path: String) -> Result<Vec<FileE
let mut entries = Vec::new(); let mut entries = Vec::new();
let dir_entries = fs::read_dir(&path) let dir_entries =
.map_err(|e| format!("Failed to read directory: {}", e))?; fs::read_dir(&path).map_err(|e| format!("Failed to read directory: {}", e))?;
for entry in dir_entries { for entry in dir_entries {
let entry = entry.map_err(|e| format!("Failed to read entry: {}", e))?; let entry = entry.map_err(|e| format!("Failed to read entry: {}", e))?;
let entry_path = entry.path(); let entry_path = entry.path();
let metadata = entry.metadata() let metadata = entry
.metadata()
.map_err(|e| format!("Failed to read metadata: {}", e))?; .map_err(|e| format!("Failed to read metadata: {}", e))?;
// Skip hidden files/directories unless they are .claude directories // Skip hidden files/directories unless they are .claude directories
@@ -1227,7 +1322,8 @@ pub async fn list_directory_contents(directory_path: String) -> Result<Vec<FileE
.to_string(); .to_string();
let extension = if metadata.is_file() { let extension = if metadata.is_file() {
entry_path.extension() entry_path
.extension()
.and_then(|e| e.to_str()) .and_then(|e| e.to_str())
.map(|e| e.to_string()) .map(|e| e.to_string())
} else { } else {
@@ -1244,12 +1340,10 @@ pub async fn list_directory_contents(directory_path: String) -> Result<Vec<FileE
} }
// Sort: directories first, then files, alphabetically within each group // Sort: directories first, then files, alphabetically within each group
entries.sort_by(|a, b| { entries.sort_by(|a, b| match (a.is_directory, b.is_directory) {
match (a.is_directory, b.is_directory) {
(true, false) => std::cmp::Ordering::Less, (true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater, (false, true) => std::cmp::Ordering::Greater,
_ => a.name.to_lowercase().cmp(&b.name.to_lowercase()), _ => a.name.to_lowercase().cmp(&b.name.to_lowercase()),
}
}); });
Ok(entries) Ok(entries)
@@ -1330,11 +1424,13 @@ fn search_files_recursive(
// Check if name matches query // Check if name matches query
if name.to_lowercase().contains(query) { if name.to_lowercase().contains(query) {
let metadata = entry.metadata() let metadata = entry
.metadata()
.map_err(|e| format!("Failed to read metadata: {}", e))?; .map_err(|e| format!("Failed to read metadata: {}", e))?;
let extension = if metadata.is_file() { let extension = if metadata.is_file() {
entry_path.extension() entry_path
.extension()
.and_then(|e| e.to_str()) .and_then(|e| e.to_str())
.map(|e| e.to_string()) .map(|e| e.to_string())
} else { } else {
@@ -1355,7 +1451,10 @@ fn search_files_recursive(
if entry_path.is_dir() { if entry_path.is_dir() {
// Skip common directories that shouldn't be searched // Skip common directories that shouldn't be searched
if let Some(dir_name) = entry_path.file_name().and_then(|n| n.to_str()) { if let Some(dir_name) = entry_path.file_name().and_then(|n| n.to_str()) {
if matches!(dir_name, "node_modules" | "target" | ".git" | "dist" | "build" | ".next" | "__pycache__") { if matches!(
dir_name,
"node_modules" | "target" | ".git" | "dist" | "build" | ".next" | "__pycache__"
) {
continue; continue;
} }
} }
@@ -1377,13 +1476,20 @@ pub async fn create_checkpoint(
message_index: Option<usize>, message_index: Option<usize>,
description: Option<String>, description: Option<String>,
) -> Result<crate::checkpoint::CheckpointResult, String> { ) -> Result<crate::checkpoint::CheckpointResult, String> {
log::info!("Creating checkpoint for session: {} in project: {}", session_id, project_id); log::info!(
"Creating checkpoint for session: {} in project: {}",
session_id,
project_id
);
let manager = app.get_or_create_manager( let manager = app
.get_or_create_manager(
session_id.clone(), session_id.clone(),
project_id.clone(), project_id.clone(),
PathBuf::from(&project_path), PathBuf::from(&project_path),
).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; )
.await
.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
// Always load current session messages from the JSONL file // Always load current session messages from the JSONL file
let session_path = get_claude_dir() let session_path = get_claude_dir()
@@ -1405,14 +1511,18 @@ pub async fn create_checkpoint(
} }
} }
if let Ok(line) = line { if let Ok(line) = line {
manager.track_message(line).await manager
.track_message(line)
.await
.map_err(|e| format!("Failed to track message: {}", e))?; .map_err(|e| format!("Failed to track message: {}", e))?;
} }
line_count += 1; line_count += 1;
} }
} }
manager.create_checkpoint(description, None).await manager
.create_checkpoint(description, None)
.await
.map_err(|e| format!("Failed to create checkpoint: {}", e)) .map_err(|e| format!("Failed to create checkpoint: {}", e))
} }
@@ -1425,15 +1535,24 @@ pub async fn restore_checkpoint(
project_id: String, project_id: String,
project_path: String, project_path: String,
) -> Result<crate::checkpoint::CheckpointResult, String> { ) -> Result<crate::checkpoint::CheckpointResult, String> {
log::info!("Restoring checkpoint: {} for session: {}", checkpoint_id, session_id); log::info!(
"Restoring checkpoint: {} for session: {}",
checkpoint_id,
session_id
);
let manager = app.get_or_create_manager( let manager = app
.get_or_create_manager(
session_id.clone(), session_id.clone(),
project_id.clone(), project_id.clone(),
PathBuf::from(&project_path), PathBuf::from(&project_path),
).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; )
.await
.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
let result = manager.restore_checkpoint(&checkpoint_id).await let result = manager
.restore_checkpoint(&checkpoint_id)
.await
.map_err(|e| format!("Failed to restore checkpoint: {}", e))?; .map_err(|e| format!("Failed to restore checkpoint: {}", e))?;
// Update the session JSONL file with restored messages // Update the session JSONL file with restored messages
@@ -1445,11 +1564,10 @@ pub async fn restore_checkpoint(
// The manager has already restored the messages internally, // The manager has already restored the messages internally,
// but we need to update the actual session file // but we need to update the actual session file
let (_, _, messages) = manager.storage.load_checkpoint( let (_, _, messages) = manager
&result.checkpoint.project_id, .storage
&session_id, .load_checkpoint(&result.checkpoint.project_id, &session_id, &checkpoint_id)
&checkpoint_id, .map_err(|e| format!("Failed to load checkpoint data: {}", e))?;
).map_err(|e| format!("Failed to load checkpoint data: {}", e))?;
fs::write(&session_path, messages) fs::write(&session_path, messages)
.map_err(|e| format!("Failed to update session file: {}", e))?; .map_err(|e| format!("Failed to update session file: {}", e))?;
@@ -1465,13 +1583,16 @@ pub async fn list_checkpoints(
project_id: String, project_id: String,
project_path: String, project_path: String,
) -> Result<Vec<crate::checkpoint::Checkpoint>, String> { ) -> Result<Vec<crate::checkpoint::Checkpoint>, String> {
log::info!("Listing checkpoints for session: {} in project: {}", session_id, project_id); log::info!(
"Listing checkpoints for session: {} in project: {}",
let manager = app.get_or_create_manager(
session_id, session_id,
project_id, project_id
PathBuf::from(&project_path), );
).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
let manager = app
.get_or_create_manager(session_id, project_id, PathBuf::from(&project_path))
.await
.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
Ok(manager.list_checkpoints().await) Ok(manager.list_checkpoints().await)
} }
@@ -1487,7 +1608,11 @@ pub async fn fork_from_checkpoint(
new_session_id: String, new_session_id: String,
description: Option<String>, description: Option<String>,
) -> Result<crate::checkpoint::CheckpointResult, String> { ) -> Result<crate::checkpoint::CheckpointResult, String> {
log::info!("Forking from checkpoint: {} to new session: {}", checkpoint_id, new_session_id); log::info!(
"Forking from checkpoint: {} to new session: {}",
checkpoint_id,
new_session_id
);
let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; let claude_dir = get_claude_dir().map_err(|e| e.to_string())?;
@@ -1507,13 +1632,18 @@ pub async fn fork_from_checkpoint(
} }
// Create manager for the new session // Create manager for the new session
let manager = app.get_or_create_manager( let manager = app
.get_or_create_manager(
new_session_id.clone(), new_session_id.clone(),
project_id, project_id,
PathBuf::from(&project_path), PathBuf::from(&project_path),
).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; )
.await
.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
manager.fork_from_checkpoint(&checkpoint_id, description).await manager
.fork_from_checkpoint(&checkpoint_id, description)
.await
.map_err(|e| format!("Failed to fork checkpoint: {}", e)) .map_err(|e| format!("Failed to fork checkpoint: {}", e))
} }
@@ -1525,13 +1655,16 @@ pub async fn get_session_timeline(
project_id: String, project_id: String,
project_path: String, project_path: String,
) -> Result<crate::checkpoint::SessionTimeline, String> { ) -> Result<crate::checkpoint::SessionTimeline, String> {
log::info!("Getting timeline for session: {} in project: {}", session_id, project_id); log::info!(
"Getting timeline for session: {} in project: {}",
let manager = app.get_or_create_manager(
session_id, session_id,
project_id, project_id
PathBuf::from(&project_path), );
).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
let manager = app
.get_or_create_manager(session_id, project_id, PathBuf::from(&project_path))
.await
.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
Ok(manager.get_timeline().await) Ok(manager.get_timeline().await)
} }
@@ -1555,16 +1688,22 @@ pub async fn update_checkpoint_settings(
"per_prompt" => CheckpointStrategy::PerPrompt, "per_prompt" => CheckpointStrategy::PerPrompt,
"per_tool_use" => CheckpointStrategy::PerToolUse, "per_tool_use" => CheckpointStrategy::PerToolUse,
"smart" => CheckpointStrategy::Smart, "smart" => CheckpointStrategy::Smart,
_ => return Err(format!("Invalid checkpoint strategy: {}", checkpoint_strategy)), _ => {
return Err(format!(
"Invalid checkpoint strategy: {}",
checkpoint_strategy
))
}
}; };
let manager = app.get_or_create_manager( let manager = app
session_id, .get_or_create_manager(session_id, project_id, PathBuf::from(&project_path))
project_id, .await
PathBuf::from(&project_path), .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
manager.update_settings(auto_checkpoint_enabled, strategy).await manager
.update_settings(auto_checkpoint_enabled, strategy)
.await
.map_err(|e| format!("Failed to update settings: {}", e)) .map_err(|e| format!("Failed to update settings: {}", e))
} }
@@ -1578,24 +1717,32 @@ pub async fn get_checkpoint_diff(
) -> Result<crate::checkpoint::CheckpointDiff, String> { ) -> Result<crate::checkpoint::CheckpointDiff, String> {
use crate::checkpoint::storage::CheckpointStorage; use crate::checkpoint::storage::CheckpointStorage;
log::info!("Getting diff between checkpoints: {} -> {}", from_checkpoint_id, to_checkpoint_id); log::info!(
"Getting diff between checkpoints: {} -> {}",
from_checkpoint_id,
to_checkpoint_id
);
let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; let claude_dir = get_claude_dir().map_err(|e| e.to_string())?;
let storage = CheckpointStorage::new(claude_dir); let storage = CheckpointStorage::new(claude_dir);
// Load both checkpoints // Load both checkpoints
let (from_checkpoint, from_files, _) = storage.load_checkpoint(&project_id, &session_id, &from_checkpoint_id) let (from_checkpoint, from_files, _) = storage
.load_checkpoint(&project_id, &session_id, &from_checkpoint_id)
.map_err(|e| format!("Failed to load source checkpoint: {}", e))?; .map_err(|e| format!("Failed to load source checkpoint: {}", e))?;
let (to_checkpoint, to_files, _) = storage.load_checkpoint(&project_id, &session_id, &to_checkpoint_id) let (to_checkpoint, to_files, _) = storage
.load_checkpoint(&project_id, &session_id, &to_checkpoint_id)
.map_err(|e| format!("Failed to load target checkpoint: {}", e))?; .map_err(|e| format!("Failed to load target checkpoint: {}", e))?;
// Build file maps // Build file maps
let mut from_map: std::collections::HashMap<PathBuf, &crate::checkpoint::FileSnapshot> = std::collections::HashMap::new(); let mut from_map: std::collections::HashMap<PathBuf, &crate::checkpoint::FileSnapshot> =
std::collections::HashMap::new();
for file in &from_files { for file in &from_files {
from_map.insert(file.file_path.clone(), file); from_map.insert(file.file_path.clone(), file);
} }
let mut to_map: std::collections::HashMap<PathBuf, &crate::checkpoint::FileSnapshot> = std::collections::HashMap::new(); let mut to_map: std::collections::HashMap<PathBuf, &crate::checkpoint::FileSnapshot> =
std::collections::HashMap::new();
for file in &to_files { for file in &to_files {
to_map.insert(file.file_path.clone(), file); to_map.insert(file.file_path.clone(), file);
} }
@@ -1634,7 +1781,8 @@ pub async fn get_checkpoint_diff(
} }
// Calculate token delta // Calculate token delta
let token_delta = (to_checkpoint.metadata.total_tokens as i64) - (from_checkpoint.metadata.total_tokens as i64); let token_delta = (to_checkpoint.metadata.total_tokens as i64)
- (from_checkpoint.metadata.total_tokens as i64);
Ok(crate::checkpoint::CheckpointDiff { Ok(crate::checkpoint::CheckpointDiff {
from_checkpoint_id, from_checkpoint_id,
@@ -1657,13 +1805,14 @@ pub async fn track_checkpoint_message(
) -> Result<(), String> { ) -> Result<(), String> {
log::info!("Tracking message for session: {}", session_id); log::info!("Tracking message for session: {}", session_id);
let manager = app.get_or_create_manager( let manager = app
session_id, .get_or_create_manager(session_id, project_id, PathBuf::from(project_path))
project_id, .await
PathBuf::from(project_path), .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
manager.track_message(message).await manager
.track_message(message)
.await
.map_err(|e| format!("Failed to track message: {}", e)) .map_err(|e| format!("Failed to track message: {}", e))
} }
@@ -1678,11 +1827,10 @@ pub async fn check_auto_checkpoint(
) -> Result<bool, String> { ) -> Result<bool, String> {
log::info!("Checking auto-checkpoint for session: {}", session_id); log::info!("Checking auto-checkpoint for session: {}", session_id);
let manager = app.get_or_create_manager( let manager = app
session_id.clone(), .get_or_create_manager(session_id.clone(), project_id, PathBuf::from(project_path))
project_id, .await
PathBuf::from(project_path), .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
Ok(manager.should_auto_checkpoint(&message).await) Ok(manager.should_auto_checkpoint(&message).await)
} }
@@ -1696,15 +1844,24 @@ pub async fn cleanup_old_checkpoints(
project_path: String, project_path: String,
keep_count: usize, keep_count: usize,
) -> Result<usize, String> { ) -> Result<usize, String> {
log::info!("Cleaning up old checkpoints for session: {}, keeping {}", session_id, keep_count); log::info!(
"Cleaning up old checkpoints for session: {}, keeping {}",
session_id,
keep_count
);
let manager = app.get_or_create_manager( let manager = app
.get_or_create_manager(
session_id.clone(), session_id.clone(),
project_id.clone(), project_id.clone(),
PathBuf::from(project_path), PathBuf::from(project_path),
).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; )
.await
.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
manager.storage.cleanup_old_checkpoints(&project_id, &session_id, keep_count) manager
.storage
.cleanup_old_checkpoints(&project_id, &session_id, keep_count)
.map_err(|e| format!("Failed to cleanup checkpoints: {}", e)) .map_err(|e| format!("Failed to cleanup checkpoints: {}", e))
} }
@@ -1718,11 +1875,10 @@ pub async fn get_checkpoint_settings(
) -> Result<serde_json::Value, String> { ) -> Result<serde_json::Value, String> {
log::info!("Getting checkpoint settings for session: {}", session_id); log::info!("Getting checkpoint settings for session: {}", session_id);
let manager = app.get_or_create_manager( let manager = app
session_id, .get_or_create_manager(session_id, project_id, PathBuf::from(project_path))
project_id, .await
PathBuf::from(project_path), .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
let timeline = manager.get_timeline().await; let timeline = manager.get_timeline().await;
@@ -1771,13 +1927,16 @@ pub async fn get_recently_modified_files(
) -> Result<Vec<String>, String> { ) -> Result<Vec<String>, String> {
use chrono::{Duration, Utc}; use chrono::{Duration, Utc};
log::info!("Getting files modified in the last {} minutes for session: {}", minutes, session_id); log::info!(
"Getting files modified in the last {} minutes for session: {}",
minutes,
session_id
);
let manager = app.get_or_create_manager( let manager = app
session_id, .get_or_create_manager(session_id, project_id, PathBuf::from(project_path))
project_id, .await
PathBuf::from(project_path), .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?;
let since = Utc::now() - Duration::minutes(minutes); let since = Utc::now() - Duration::minutes(minutes);
let modified_files = manager.get_files_modified_since(since).await; let modified_files = manager.get_files_modified_since(since).await;
@@ -1787,7 +1946,8 @@ pub async fn get_recently_modified_files(
log::info!("Last file modification was at: {}", last_mod); log::info!("Last file modification was at: {}", last_mod);
} }
Ok(modified_files.into_iter() Ok(modified_files
.into_iter()
.map(|p| p.to_string_lossy().to_string()) .map(|p| p.to_string_lossy().to_string())
.collect()) .collect())
} }
@@ -1801,9 +1961,14 @@ pub async fn track_session_messages(
project_path: String, project_path: String,
messages: Vec<String>, messages: Vec<String>,
) -> Result<(), String> { ) -> Result<(), String> {
let mgr = state.get_or_create_manager( let mgr = state
session_id, project_id, std::path::PathBuf::from(project_path) .get_or_create_manager(
).await.map_err(|e| e.to_string())?; session_id,
project_id,
std::path::PathBuf::from(project_path),
)
.await
.map_err(|e| e.to_string())?;
for m in messages { for m in messages {
mgr.track_message(m).await.map_err(|e| e.to_string())?; mgr.track_message(m).await.map_err(|e| e.to_string())?;

View File

@@ -1,12 +1,12 @@
use tauri::AppHandle;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use dirs;
use log::{error, info};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::fs; use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
use std::process::Command; use std::process::Command;
use log::{info, error}; use tauri::AppHandle;
use dirs;
/// Helper function to create a std::process::Command with proper environment variables /// Helper function to create a std::process::Command with proper environment variables
/// This ensures commands like Claude can find Node.js and other dependencies /// This ensures commands like Claude can find Node.js and other dependencies
@@ -17,8 +17,7 @@ fn create_command_with_env(program: &str) -> Command {
/// Finds the full path to the claude binary /// Finds the full path to the claude binary
/// This is necessary because macOS apps have a limited PATH environment /// This is necessary because macOS apps have a limited PATH environment
fn find_claude_binary(app_handle: &AppHandle) -> Result<String> { fn find_claude_binary(app_handle: &AppHandle) -> Result<String> {
crate::claude_binary::find_claude_binary(app_handle) crate::claude_binary::find_claude_binary(app_handle).map_err(|e| anyhow::anyhow!(e))
.map_err(|e| anyhow::anyhow!(e))
} }
/// Represents an MCP server configuration /// Represents an MCP server configuration
@@ -107,8 +106,7 @@ fn execute_claude_mcp_command(app_handle: &AppHandle, args: Vec<&str>) -> Result
cmd.arg(arg); cmd.arg(arg);
} }
let output = cmd.output() let output = cmd.output().context("Failed to execute claude command")?;
.context("Failed to execute claude command")?;
if output.status.success() { if output.status.success() {
Ok(String::from_utf8_lossy(&output.stdout).to_string()) Ok(String::from_utf8_lossy(&output.stdout).to_string())
@@ -133,7 +131,8 @@ pub async fn mcp_add(
info!("Adding MCP server: {} with transport: {}", name, transport); info!("Adding MCP server: {} with transport: {}", name, transport);
// Prepare owned strings for environment variables // Prepare owned strings for environment variables
let env_args: Vec<String> = env.iter() let env_args: Vec<String> = env
.iter()
.map(|(key, value)| format!("{}={}", key, value)) .map(|(key, value)| format!("{}={}", key, value))
.collect(); .collect();
@@ -262,11 +261,16 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
// If the next line starts with a server name pattern, break // If the next line starts with a server name pattern, break
if next_line.contains(':') { if next_line.contains(':') {
let potential_next_name = next_line.split(':').next().unwrap_or("").trim(); let potential_next_name =
info!("Found colon in next line, potential name: {:?}", potential_next_name); next_line.split(':').next().unwrap_or("").trim();
if !potential_next_name.is_empty() && info!(
!potential_next_name.contains('/') && "Found colon in next line, potential name: {:?}",
!potential_next_name.contains('\\') { potential_next_name
);
if !potential_next_name.is_empty()
&& !potential_next_name.contains('/')
&& !potential_next_name.contains('\\')
{
info!("Next line is a new server, breaking"); info!("Next line is a new server, breaking");
break; break;
} }
@@ -312,7 +316,10 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
info!("Found {} MCP servers total", servers.len()); info!("Found {} MCP servers total", servers.len());
for (idx, server) in servers.iter().enumerate() { for (idx, server) in servers.iter().enumerate() {
info!("Server {}: name='{}', command={:?}", idx, server.name, server.command); info!(
"Server {}: name='{}', command={:?}",
idx, server.name, server.command
);
} }
Ok(servers) Ok(servers)
} }
@@ -347,7 +354,9 @@ pub async fn mcp_get(app: AppHandle, name: String) -> Result<MCPServer, String>
scope = "local".to_string(); scope = "local".to_string();
} else if scope_part.to_lowercase().contains("project") { } else if scope_part.to_lowercase().contains("project") {
scope = "project".to_string(); scope = "project".to_string();
} else if scope_part.to_lowercase().contains("user") || scope_part.to_lowercase().contains("global") { } else if scope_part.to_lowercase().contains("user")
|| scope_part.to_lowercase().contains("global")
{
scope = "user".to_string(); scope = "user".to_string();
} }
} else if line.starts_with("Type:") { } else if line.starts_with("Type:") {
@@ -409,8 +418,16 @@ pub async fn mcp_remove(app: AppHandle, name: String) -> Result<String, String>
/// Adds an MCP server from JSON configuration /// Adds an MCP server from JSON configuration
#[tauri::command] #[tauri::command]
pub async fn mcp_add_json(app: AppHandle, name: String, json_config: String, scope: String) -> Result<AddServerResult, String> { pub async fn mcp_add_json(
info!("Adding MCP server from JSON: {} with scope: {}", name, scope); app: AppHandle,
name: String,
json_config: String,
scope: String,
) -> Result<AddServerResult, String> {
info!(
"Adding MCP server from JSON: {} with scope: {}",
name, scope
);
// Build command args // Build command args
let mut cmd_args = vec!["add-json", &name, &json_config]; let mut cmd_args = vec!["add-json", &name, &json_config];
@@ -442,8 +459,14 @@ pub async fn mcp_add_json(app: AppHandle, name: String, json_config: String, sco
/// Imports MCP servers from Claude Desktop /// Imports MCP servers from Claude Desktop
#[tauri::command] #[tauri::command]
pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Result<ImportResult, String> { pub async fn mcp_add_from_claude_desktop(
info!("Importing MCP servers from Claude Desktop with scope: {}", scope); app: AppHandle,
scope: String,
) -> Result<ImportResult, String> {
info!(
"Importing MCP servers from Claude Desktop with scope: {}",
scope
);
// Get Claude Desktop config path based on platform // Get Claude Desktop config path based on platform
let config_path = if cfg!(target_os = "macos") { let config_path = if cfg!(target_os = "macos") {
@@ -460,12 +483,17 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul
.join("Claude") .join("Claude")
.join("claude_desktop_config.json") .join("claude_desktop_config.json")
} else { } else {
return Err("Import from Claude Desktop is only supported on macOS and Linux/WSL".to_string()); return Err(
"Import from Claude Desktop is only supported on macOS and Linux/WSL".to_string(),
);
}; };
// Check if config file exists // Check if config file exists
if !config_path.exists() { if !config_path.exists() {
return Err("Claude Desktop configuration not found. Make sure Claude Desktop is installed.".to_string()); return Err(
"Claude Desktop configuration not found. Make sure Claude Desktop is installed."
.to_string(),
);
} }
// Read and parse the config file // Read and parse the config file
@@ -476,7 +504,8 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul
.map_err(|e| format!("Failed to parse Claude Desktop config: {}", e))?; .map_err(|e| format!("Failed to parse Claude Desktop config: {}", e))?;
// Extract MCP servers // Extract MCP servers
let mcp_servers = config.get("mcpServers") let mcp_servers = config
.get("mcpServers")
.and_then(|v| v.as_object()) .and_then(|v| v.as_object())
.ok_or_else(|| "No MCP servers found in Claude Desktop config".to_string())?; .ok_or_else(|| "No MCP servers found in Claude Desktop config".to_string())?;
@@ -492,11 +521,17 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul
let mut json_config = serde_json::Map::new(); let mut json_config = serde_json::Map::new();
// All Claude Desktop servers are stdio type // All Claude Desktop servers are stdio type
json_config.insert("type".to_string(), serde_json::Value::String("stdio".to_string())); json_config.insert(
"type".to_string(),
serde_json::Value::String("stdio".to_string()),
);
// Add command // Add command
if let Some(command) = server_config.get("command").and_then(|v| v.as_str()) { if let Some(command) = server_config.get("command").and_then(|v| v.as_str()) {
json_config.insert("command".to_string(), serde_json::Value::String(command.to_string())); json_config.insert(
"command".to_string(),
serde_json::Value::String(command.to_string()),
);
} else { } else {
failed_count += 1; failed_count += 1;
server_results.push(ImportServerResult { server_results.push(ImportServerResult {
@@ -518,7 +553,10 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul
if let Some(env) = server_config.get("env").and_then(|v| v.as_object()) { if let Some(env) = server_config.get("env").and_then(|v| v.as_object()) {
json_config.insert("env".to_string(), env.clone().into()); json_config.insert("env".to_string(), env.clone().into());
} else { } else {
json_config.insert("env".to_string(), serde_json::Value::Object(serde_json::Map::new())); json_config.insert(
"env".to_string(),
serde_json::Value::Object(serde_json::Map::new()),
);
} }
// Convert to JSON string // Convert to JSON string
@@ -560,7 +598,10 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul
} }
} }
info!("Import complete: {} imported, {} failed", imported_count, failed_count); info!(
"Import complete: {} imported, {} failed",
imported_count, failed_count
);
Ok(ImportResult { Ok(ImportResult {
imported_count, imported_count,
@@ -651,15 +692,13 @@ pub async fn mcp_read_project_config(project_path: String) -> Result<MCPProjectC
} }
match fs::read_to_string(&mcp_json_path) { match fs::read_to_string(&mcp_json_path) {
Ok(content) => { Ok(content) => match serde_json::from_str::<MCPProjectConfig>(&content) {
match serde_json::from_str::<MCPProjectConfig>(&content) {
Ok(config) => Ok(config), Ok(config) => Ok(config),
Err(e) => { Err(e) => {
error!("Failed to parse .mcp.json: {}", e); error!("Failed to parse .mcp.json: {}", e);
Err(format!("Failed to parse .mcp.json: {}", e)) Err(format!("Failed to parse .mcp.json: {}", e))
} }
} },
}
Err(e) => { Err(e) => {
error!("Failed to read .mcp.json: {}", e); error!("Failed to read .mcp.json: {}", e);
Err(format!("Failed to read .mcp.json: {}", e)) Err(format!("Failed to read .mcp.json: {}", e))

View File

@@ -1,6 +1,6 @@
pub mod claude;
pub mod agents; pub mod agents;
pub mod sandbox; pub mod claude;
pub mod usage;
pub mod mcp; pub mod mcp;
pub mod sandbox;
pub mod screenshot; pub mod screenshot;
pub mod usage;

View File

@@ -190,7 +190,10 @@ pub async fn delete_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result<(
/// Get a single sandbox profile by ID /// Get a single sandbox profile by ID
#[tauri::command] #[tauri::command]
pub async fn get_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result<SandboxProfile, String> { pub async fn get_sandbox_profile(
db: State<'_, AgentDb>,
id: i64,
) -> Result<SandboxProfile, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?; let conn = db.0.lock().map_err(|e| e.to_string())?;
let profile = conn let profile = conn
@@ -382,13 +385,13 @@ pub async fn test_sandbox_profile(
} }
// Try to build the gaol profile // Try to build the gaol profile
let test_path = std::env::current_dir() let test_path = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("/tmp"));
.unwrap_or_else(|_| std::path::PathBuf::from("/tmp"));
let builder = crate::sandbox::profile::ProfileBuilder::new(test_path.clone()) let builder = crate::sandbox::profile::ProfileBuilder::new(test_path.clone())
.map_err(|e| format!("Failed to create profile builder: {}", e))?; .map_err(|e| format!("Failed to create profile builder: {}", e))?;
let build_result = builder.build_profile_with_serialization(rules.clone()) let build_result = builder
.build_profile_with_serialization(rules.clone())
.map_err(|e| format!("Failed to build sandbox profile: {}", e))?; .map_err(|e| format!("Failed to build sandbox profile: {}", e))?;
// Check platform support // Check platform support
@@ -406,15 +409,11 @@ pub async fn test_sandbox_profile(
let executor = crate::sandbox::executor::SandboxExecutor::new_with_serialization( let executor = crate::sandbox::executor::SandboxExecutor::new_with_serialization(
build_result.profile, build_result.profile,
test_path.clone(), test_path.clone(),
build_result.serialized build_result.serialized,
); );
// Use a simple echo command for testing // Use a simple echo command for testing
let test_command = if cfg!(windows) { let test_command = if cfg!(windows) { "cmd" } else { "echo" };
"cmd"
} else {
"echo"
};
let test_args = if cfg!(windows) { let test_args = if cfg!(windows) {
vec!["/C", "echo", "sandbox test successful"] vec!["/C", "echo", "sandbox test successful"]
@@ -452,8 +451,7 @@ pub async fn test_sandbox_profile(
)) ))
} }
} }
Err(e) => { Err(e) => Ok(format!(
Ok(format!(
"⚠️ Profile '{}' validated with warnings.\n\n\ "⚠️ Profile '{}' validated with warnings.\n\n\
{} rules loaded and validated\n\ {} rules loaded and validated\n\
• Sandbox activation: Partial\n\ • Sandbox activation: Partial\n\
@@ -463,8 +461,7 @@ pub async fn test_sandbox_profile(
rules.len(), rules.len(),
e, e,
platform_caps.os platform_caps.os
)) )),
}
} }
} }
Err(e) => { Err(e) => {
@@ -540,7 +537,8 @@ pub async fn list_sandbox_violations(
if let Some(lim) = limit { if let Some(lim) = limit {
// All three parameters // All three parameters
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![pid, aid, lim], |row| { let rows = stmt
.query_map(params![pid, aid, lim], |row| {
Ok(SandboxViolation { Ok(SandboxViolation {
id: Some(row.get(0)?), id: Some(row.get(0)?),
profile_id: row.get(1)?, profile_id: row.get(1)?,
@@ -552,12 +550,15 @@ pub async fn list_sandbox_violations(
pid: row.get(7)?, pid: row.get(7)?,
denied_at: row.get(8)?, denied_at: row.get(8)?,
}) })
}).map_err(|e| e.to_string())?; })
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())? .map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
} else { } else {
// profile_id and agent_id only // profile_id and agent_id only
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![pid, aid], |row| { let rows = stmt
.query_map(params![pid, aid], |row| {
Ok(SandboxViolation { Ok(SandboxViolation {
id: Some(row.get(0)?), id: Some(row.get(0)?),
profile_id: row.get(1)?, profile_id: row.get(1)?,
@@ -569,13 +570,16 @@ pub async fn list_sandbox_violations(
pid: row.get(7)?, pid: row.get(7)?,
denied_at: row.get(8)?, denied_at: row.get(8)?,
}) })
}).map_err(|e| e.to_string())?; })
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())? .map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
} }
} else if let Some(lim) = limit { } else if let Some(lim) = limit {
// profile_id and limit only // profile_id and limit only
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![pid, lim], |row| { let rows = stmt
.query_map(params![pid, lim], |row| {
Ok(SandboxViolation { Ok(SandboxViolation {
id: Some(row.get(0)?), id: Some(row.get(0)?),
profile_id: row.get(1)?, profile_id: row.get(1)?,
@@ -587,12 +591,15 @@ pub async fn list_sandbox_violations(
pid: row.get(7)?, pid: row.get(7)?,
denied_at: row.get(8)?, denied_at: row.get(8)?,
}) })
}).map_err(|e| e.to_string())?; })
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())? .map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
} else { } else {
// profile_id only // profile_id only
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![pid], |row| { let rows = stmt
.query_map(params![pid], |row| {
Ok(SandboxViolation { Ok(SandboxViolation {
id: Some(row.get(0)?), id: Some(row.get(0)?),
profile_id: row.get(1)?, profile_id: row.get(1)?,
@@ -604,14 +611,17 @@ pub async fn list_sandbox_violations(
pid: row.get(7)?, pid: row.get(7)?,
denied_at: row.get(8)?, denied_at: row.get(8)?,
}) })
}).map_err(|e| e.to_string())?; })
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())? .map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
} }
} else if let Some(aid) = agent_id { } else if let Some(aid) = agent_id {
if let Some(lim) = limit { if let Some(lim) = limit {
// agent_id and limit only // agent_id and limit only
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![aid, lim], |row| { let rows = stmt
.query_map(params![aid, lim], |row| {
Ok(SandboxViolation { Ok(SandboxViolation {
id: Some(row.get(0)?), id: Some(row.get(0)?),
profile_id: row.get(1)?, profile_id: row.get(1)?,
@@ -623,12 +633,15 @@ pub async fn list_sandbox_violations(
pid: row.get(7)?, pid: row.get(7)?,
denied_at: row.get(8)?, denied_at: row.get(8)?,
}) })
}).map_err(|e| e.to_string())?; })
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())? .map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
} else { } else {
// agent_id only // agent_id only
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![aid], |row| { let rows = stmt
.query_map(params![aid], |row| {
Ok(SandboxViolation { Ok(SandboxViolation {
id: Some(row.get(0)?), id: Some(row.get(0)?),
profile_id: row.get(1)?, profile_id: row.get(1)?,
@@ -640,13 +653,16 @@ pub async fn list_sandbox_violations(
pid: row.get(7)?, pid: row.get(7)?,
denied_at: row.get(8)?, denied_at: row.get(8)?,
}) })
}).map_err(|e| e.to_string())?; })
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())? .map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
} }
} else if let Some(lim) = limit { } else if let Some(lim) = limit {
// limit only // limit only
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![lim], |row| { let rows = stmt
.query_map(params![lim], |row| {
Ok(SandboxViolation { Ok(SandboxViolation {
id: Some(row.get(0)?), id: Some(row.get(0)?),
profile_id: row.get(1)?, profile_id: row.get(1)?,
@@ -658,12 +674,15 @@ pub async fn list_sandbox_violations(
pid: row.get(7)?, pid: row.get(7)?,
denied_at: row.get(8)?, denied_at: row.get(8)?,
}) })
}).map_err(|e| e.to_string())?; })
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())? .map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
} else { } else {
// No parameters // No parameters
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map([], |row| { let rows = stmt
.query_map([], |row| {
Ok(SandboxViolation { Ok(SandboxViolation {
id: Some(row.get(0)?), id: Some(row.get(0)?),
profile_id: row.get(1)?, profile_id: row.get(1)?,
@@ -675,8 +694,10 @@ pub async fn list_sandbox_violations(
pid: row.get(7)?, pid: row.get(7)?,
denied_at: row.get(8)?, denied_at: row.get(8)?,
}) })
}).map_err(|e| e.to_string())?; })
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())? .map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
}; };
Ok(violations) Ok(violations)
@@ -723,8 +744,7 @@ pub async fn clear_sandbox_violations(
"DELETE FROM sandbox_violations".to_string() "DELETE FROM sandbox_violations".to_string()
}; };
let deleted = conn.execute(&query, []) let deleted = conn.execute(&query, []).map_err(|e| e.to_string())?;
.map_err(|e| e.to_string())?;
Ok(deleted as i64) Ok(deleted as i64)
} }
@@ -738,7 +758,9 @@ pub async fn get_sandbox_violation_stats(
// Get total violations // Get total violations
let total: i64 = conn let total: i64 = conn
.query_row("SELECT COUNT(*) FROM sandbox_violations", [], |row| row.get(0)) .query_row("SELECT COUNT(*) FROM sandbox_violations", [], |row| {
row.get(0)
})
.map_err(|e| e.to_string())?; .map_err(|e| e.to_string())?;
// Get violations by operation type // Get violations by operation type
@@ -747,7 +769,7 @@ pub async fn get_sandbox_violation_stats(
"SELECT operation_type, COUNT(*) as count "SELECT operation_type, COUNT(*) as count
FROM sandbox_violations FROM sandbox_violations
GROUP BY operation_type GROUP BY operation_type
ORDER BY count DESC" ORDER BY count DESC",
) )
.map_err(|e| e.to_string())?; .map_err(|e| e.to_string())?;
@@ -812,10 +834,7 @@ pub async fn export_all_sandbox_profiles(
for profile in profiles { for profile in profiles {
if let Some(id) = profile.id { if let Some(id) = profile.id {
let rules = list_sandbox_rules(db.clone(), id).await?; let rules = list_sandbox_rules(db.clone(), id).await?;
profile_exports.push(SandboxProfileWithRules { profile_exports.push(SandboxProfileWithRules { profile, rules });
profile,
rules,
});
} }
} }
@@ -837,7 +856,10 @@ pub async fn import_sandbox_profiles(
// Validate version // Validate version
if export_data.version != 1 { if export_data.version != 1 {
return Err(format!("Unsupported export version: {}", export_data.version)); return Err(format!(
"Unsupported export version: {}",
export_data.version
));
} }
for profile_export in export_data.profiles { for profile_export in export_data.profiles {
@@ -857,7 +879,11 @@ pub async fn import_sandbox_profiles(
let (imported, new_name) = match existing { let (imported, new_name) = match existing {
Ok(_) => { Ok(_) => {
// Name conflict - append timestamp // Name conflict - append timestamp
let new_name = format!("{} (imported {})", profile.name, chrono::Utc::now().format("%Y-%m-%d %H:%M")); let new_name = format!(
"{} (imported {})",
profile.name,
chrono::Utc::now().format("%Y-%m-%d %H:%M")
);
profile.name = new_name.clone(); profile.name = new_name.clone();
(true, Some(new_name)) (true, Some(new_name))
} }
@@ -870,11 +896,9 @@ pub async fn import_sandbox_profiles(
profile.is_default = false; // Never import as default profile.is_default = false; // Never import as default
// Create the profile // Create the profile
let created_profile = create_sandbox_profile( let created_profile =
db.clone(), create_sandbox_profile(db.clone(), profile.name.clone(), profile.description)
profile.name.clone(), .await?;
profile.description,
).await?;
if let Some(new_id) = created_profile.id { if let Some(new_id) = created_profile.id {
// Import rules // Import rules
@@ -889,7 +913,8 @@ pub async fn import_sandbox_profiles(
rule.pattern_value, rule.pattern_value,
rule.enabled, rule.enabled,
rule.platform_support, rule.platform_support,
).await; )
.await;
} }
} }
@@ -902,14 +927,17 @@ pub async fn import_sandbox_profiles(
created_profile.description, created_profile.description,
profile.is_active, profile.is_active,
false, // Never set as default on import false, // Never set as default on import
).await; )
.await;
} }
} }
results.push(ImportResult { results.push(ImportResult {
profile_name: original_name, profile_name: original_name,
imported: true, imported: true,
reason: new_name.as_ref().map(|_| "Name conflict resolved".to_string()), reason: new_name
.as_ref()
.map(|_| "Name conflict resolved".to_string()),
new_name, new_name,
}); });
} }

View File

@@ -1,5 +1,5 @@
use headless_chrome::{Browser, LaunchOptions};
use headless_chrome::protocol::cdp::Page; use headless_chrome::protocol::cdp::Page;
use headless_chrome::{Browser, LaunchOptions};
use std::fs; use std::fs;
use std::time::Duration; use std::time::Duration;
use tauri::AppHandle; use tauri::AppHandle;
@@ -32,9 +32,8 @@ pub async fn capture_url_screenshot(
); );
// Run the browser operations in a blocking task since headless_chrome is not async // Run the browser operations in a blocking task since headless_chrome is not async
let result = tokio::task::spawn_blocking(move || { let result =
capture_screenshot_sync(url, selector, full_page) tokio::task::spawn_blocking(move || capture_screenshot_sync(url, selector, full_page))
})
.await .await
.map_err(|e| format!("Failed to spawn blocking task: {}", e))?; .map_err(|e| format!("Failed to spawn blocking task: {}", e))?;
@@ -61,8 +60,8 @@ fn capture_screenshot_sync(
}; };
// Launch the browser // Launch the browser
let browser = Browser::new(launch_options) let browser =
.map_err(|e| format!("Failed to launch browser: {}", e))?; Browser::new(launch_options).map_err(|e| format!("Failed to launch browser: {}", e))?;
// Create a new tab // Create a new tab
let tab = browser let tab = browser
@@ -86,7 +85,10 @@ fn capture_screenshot_sync(
// Wait explicitly for the <body> element to exist this often prevents // Wait explicitly for the <body> element to exist this often prevents
// "Unable to capture screenshot" CDP errors on some pages // "Unable to capture screenshot" CDP errors on some pages
if let Err(e) = tab.wait_for_element("body") { if let Err(e) = tab.wait_for_element("body") {
log::warn!("Timed out waiting for <body> element: {} continuing anyway", e); log::warn!(
"Timed out waiting for <body> element: {} continuing anyway",
e
);
} }
// Capture the screenshot // Capture the screenshot
@@ -103,7 +105,10 @@ fn capture_screenshot_sync(
.map_err(|e| format!("Failed to capture element screenshot: {}", e))? .map_err(|e| format!("Failed to capture element screenshot: {}", e))?
} else { } else {
// Capture the entire page or viewport // Capture the entire page or viewport
log::info!("Capturing {} screenshot", if full_page { "full page" } else { "viewport" }); log::info!(
"Capturing {} screenshot",
if full_page { "full page" } else { "viewport" }
);
// Get the page dimensions for full page screenshot // Get the page dimensions for full page screenshot
let clip = if full_page { let clip = if full_page {
@@ -176,12 +181,7 @@ fn capture_screenshot_sync(
err err
); );
tab.capture_screenshot( tab.capture_screenshot(Page::CaptureScreenshotFormatOption::Png, None, clip, true)
Page::CaptureScreenshotFormatOption::Png,
None,
clip,
true,
)
.map_err(|e| format!("Failed to capture screenshot after retry: {}", e))? .map_err(|e| format!("Failed to capture screenshot after retry: {}", e))?
} }
} }
@@ -222,15 +222,18 @@ pub async fn cleanup_screenshot_temp_files(
older_than_minutes: Option<u64>, older_than_minutes: Option<u64>,
) -> Result<usize, String> { ) -> Result<usize, String> {
let minutes = older_than_minutes.unwrap_or(60); let minutes = older_than_minutes.unwrap_or(60);
log::info!("Cleaning up screenshot files older than {} minutes", minutes); log::info!(
"Cleaning up screenshot files older than {} minutes",
minutes
);
let temp_dir = std::env::temp_dir(); let temp_dir = std::env::temp_dir();
let cutoff_time = chrono::Utc::now() - chrono::Duration::minutes(minutes as i64); let cutoff_time = chrono::Utc::now() - chrono::Duration::minutes(minutes as i64);
let mut deleted_count = 0; let mut deleted_count = 0;
// Read directory entries // Read directory entries
let entries = fs::read_dir(&temp_dir) let entries =
.map_err(|e| format!("Failed to read temp directory: {}", e))?; fs::read_dir(&temp_dir).map_err(|e| format!("Failed to read temp directory: {}", e))?;
for entry in entries { for entry in entries {
if let Ok(entry) = entry { if let Ok(entry) = entry {
@@ -239,7 +242,9 @@ pub async fn cleanup_screenshot_temp_files(
// Check if it's a claudia screenshot file // Check if it's a claudia screenshot file
if let Some(filename) = path.file_name() { if let Some(filename) = path.file_name() {
if let Some(filename_str) = filename.to_str() { if let Some(filename_str) = filename.to_str() {
if filename_str.starts_with("claudia_screenshot_") && filename_str.ends_with(".png") { if filename_str.starts_with("claudia_screenshot_")
&& filename_str.ends_with(".png")
{
// Check file age // Check file age
if let Ok(metadata) = fs::metadata(&path) { if let Ok(metadata) = fs::metadata(&path) {
if let Ok(modified) = metadata.modified() { if let Ok(modified) = metadata.modified() {

View File

@@ -1,9 +1,9 @@
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::PathBuf;
use chrono::{DateTime, Local, NaiveDate}; use chrono::{DateTime, Local, NaiveDate};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json; use serde_json;
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::PathBuf;
use tauri::command; use tauri::command;
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
@@ -110,9 +110,19 @@ fn calculate_cost(model: &str, usage: &UsageData) -> f64 {
// Calculate cost based on model // Calculate cost based on model
let (input_price, output_price, cache_write_price, cache_read_price) = let (input_price, output_price, cache_write_price, cache_read_price) =
if model.contains("opus-4") || model.contains("claude-opus-4") { if model.contains("opus-4") || model.contains("claude-opus-4") {
(OPUS_4_INPUT_PRICE, OPUS_4_OUTPUT_PRICE, OPUS_4_CACHE_WRITE_PRICE, OPUS_4_CACHE_READ_PRICE) (
OPUS_4_INPUT_PRICE,
OPUS_4_OUTPUT_PRICE,
OPUS_4_CACHE_WRITE_PRICE,
OPUS_4_CACHE_READ_PRICE,
)
} else if model.contains("sonnet-4") || model.contains("claude-sonnet-4") { } else if model.contains("sonnet-4") || model.contains("claude-sonnet-4") {
(SONNET_4_INPUT_PRICE, SONNET_4_OUTPUT_PRICE, SONNET_4_CACHE_WRITE_PRICE, SONNET_4_CACHE_READ_PRICE) (
SONNET_4_INPUT_PRICE,
SONNET_4_OUTPUT_PRICE,
SONNET_4_CACHE_WRITE_PRICE,
SONNET_4_CACHE_READ_PRICE,
)
} else { } else {
// Return 0 for unknown models to avoid incorrect cost estimations. // Return 0 for unknown models to avoid incorrect cost estimations.
(0.0, 0.0, 0.0, 0.0) (0.0, 0.0, 0.0, 0.0)
@@ -137,7 +147,8 @@ fn parse_jsonl_file(
if let Ok(content) = fs::read_to_string(path) { if let Ok(content) = fs::read_to_string(path) {
// Extract session ID from the file path // Extract session ID from the file path
let session_id = path.parent() let session_id = path
.parent()
.and_then(|p| p.file_name()) .and_then(|p| p.file_name())
.and_then(|n| n.to_str()) .and_then(|n| n.to_str())
.unwrap_or("unknown") .unwrap_or("unknown")
@@ -170,10 +181,11 @@ fn parse_jsonl_file(
if let Some(usage) = &message.usage { if let Some(usage) = &message.usage {
// Skip entries without meaningful token usage // Skip entries without meaningful token usage
if usage.input_tokens.unwrap_or(0) == 0 && if usage.input_tokens.unwrap_or(0) == 0
usage.output_tokens.unwrap_or(0) == 0 && && usage.output_tokens.unwrap_or(0) == 0
usage.cache_creation_input_tokens.unwrap_or(0) == 0 && && usage.cache_creation_input_tokens.unwrap_or(0) == 0
usage.cache_read_input_tokens.unwrap_or(0) == 0 { && usage.cache_read_input_tokens.unwrap_or(0) == 0
{
continue; continue;
} }
@@ -186,15 +198,21 @@ fn parse_jsonl_file(
}); });
// Use actual project path if found, otherwise use encoded name // Use actual project path if found, otherwise use encoded name
let project_path = actual_project_path.clone() let project_path = actual_project_path
.clone()
.unwrap_or_else(|| encoded_project_name.to_string()); .unwrap_or_else(|| encoded_project_name.to_string());
entries.push(UsageEntry { entries.push(UsageEntry {
timestamp: entry.timestamp, timestamp: entry.timestamp,
model: message.model.clone().unwrap_or_else(|| "unknown".to_string()), model: message
.model
.clone()
.unwrap_or_else(|| "unknown".to_string()),
input_tokens: usage.input_tokens.unwrap_or(0), input_tokens: usage.input_tokens.unwrap_or(0),
output_tokens: usage.output_tokens.unwrap_or(0), output_tokens: usage.output_tokens.unwrap_or(0),
cache_creation_tokens: usage.cache_creation_input_tokens.unwrap_or(0), cache_creation_tokens: usage
.cache_creation_input_tokens
.unwrap_or(0),
cache_read_tokens: usage.cache_read_input_tokens.unwrap_or(0), cache_read_tokens: usage.cache_read_input_tokens.unwrap_or(0),
cost, cost,
session_id: entry.session_id.unwrap_or_else(|| session_id.clone()), session_id: entry.session_id.unwrap_or_else(|| session_id.clone()),
@@ -296,7 +314,8 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
// Filter by days if specified // Filter by days if specified
let filtered_entries = if let Some(days) = days { let filtered_entries = if let Some(days) = days {
let cutoff = Local::now().naive_local().date() - chrono::Duration::days(days as i64); let cutoff = Local::now().naive_local().date() - chrono::Duration::days(days as i64);
all_entries.into_iter() all_entries
.into_iter()
.filter(|e| { .filter(|e| {
if let Ok(dt) = DateTime::parse_from_rfc3339(&e.timestamp) { if let Ok(dt) = DateTime::parse_from_rfc3339(&e.timestamp) {
dt.naive_local().date() >= cutoff dt.naive_local().date() >= cutoff
@@ -329,7 +348,9 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
total_cache_read_tokens += entry.cache_read_tokens; total_cache_read_tokens += entry.cache_read_tokens;
// Update model stats // Update model stats
let model_stat = model_stats.entry(entry.model.clone()).or_insert(ModelUsage { let model_stat = model_stats
.entry(entry.model.clone())
.or_insert(ModelUsage {
model: entry.model.clone(), model: entry.model.clone(),
total_cost: 0.0, total_cost: 0.0,
total_tokens: 0, total_tokens: 0,
@@ -348,7 +369,12 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
model_stat.session_count += 1; model_stat.session_count += 1;
// Update daily stats // Update daily stats
let date = entry.timestamp.split('T').next().unwrap_or(&entry.timestamp).to_string(); let date = entry
.timestamp
.split('T')
.next()
.unwrap_or(&entry.timestamp)
.to_string();
let daily_stat = daily_stats.entry(date.clone()).or_insert(DailyUsage { let daily_stat = daily_stats.entry(date.clone()).or_insert(DailyUsage {
date, date,
total_cost: 0.0, total_cost: 0.0,
@@ -356,15 +382,24 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
models_used: vec![], models_used: vec![],
}); });
daily_stat.total_cost += entry.cost; daily_stat.total_cost += entry.cost;
daily_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens; daily_stat.total_tokens += entry.input_tokens
+ entry.output_tokens
+ entry.cache_creation_tokens
+ entry.cache_read_tokens;
if !daily_stat.models_used.contains(&entry.model) { if !daily_stat.models_used.contains(&entry.model) {
daily_stat.models_used.push(entry.model.clone()); daily_stat.models_used.push(entry.model.clone());
} }
// Update project stats // Update project stats
let project_stat = project_stats.entry(entry.project_path.clone()).or_insert(ProjectUsage { let project_stat =
project_stats
.entry(entry.project_path.clone())
.or_insert(ProjectUsage {
project_path: entry.project_path.clone(), project_path: entry.project_path.clone(),
project_name: entry.project_path.split('/').last() project_name: entry
.project_path
.split('/')
.last()
.unwrap_or(&entry.project_path) .unwrap_or(&entry.project_path)
.to_string(), .to_string(),
total_cost: 0.0, total_cost: 0.0,
@@ -373,14 +408,20 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
last_used: entry.timestamp.clone(), last_used: entry.timestamp.clone(),
}); });
project_stat.total_cost += entry.cost; project_stat.total_cost += entry.cost;
project_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens; project_stat.total_tokens += entry.input_tokens
+ entry.output_tokens
+ entry.cache_creation_tokens
+ entry.cache_read_tokens;
project_stat.session_count += 1; project_stat.session_count += 1;
if entry.timestamp > project_stat.last_used { if entry.timestamp > project_stat.last_used {
project_stat.last_used = entry.timestamp.clone(); project_stat.last_used = entry.timestamp.clone();
} }
} }
let total_tokens = total_input_tokens + total_output_tokens + total_cache_creation_tokens + total_cache_read_tokens; let total_tokens = total_input_tokens
+ total_output_tokens
+ total_cache_creation_tokens
+ total_cache_read_tokens;
let total_sessions = filtered_entries.len() as u64; let total_sessions = filtered_entries.len() as u64;
// Convert hashmaps to sorted vectors // Convert hashmaps to sorted vectors
@@ -416,15 +457,13 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
let all_entries = get_all_usage_entries(&claude_path); let all_entries = get_all_usage_entries(&claude_path);
// Parse dates // Parse dates
let start = NaiveDate::parse_from_str(&start_date, "%Y-%m-%d") let start = NaiveDate::parse_from_str(&start_date, "%Y-%m-%d").or_else(|_| {
.or_else(|_| {
// Try parsing ISO datetime format // Try parsing ISO datetime format
DateTime::parse_from_rfc3339(&start_date) DateTime::parse_from_rfc3339(&start_date)
.map(|dt| dt.naive_local().date()) .map(|dt| dt.naive_local().date())
.map_err(|e| format!("Invalid start date: {}", e)) .map_err(|e| format!("Invalid start date: {}", e))
})?; })?;
let end = NaiveDate::parse_from_str(&end_date, "%Y-%m-%d") let end = NaiveDate::parse_from_str(&end_date, "%Y-%m-%d").or_else(|_| {
.or_else(|_| {
// Try parsing ISO datetime format // Try parsing ISO datetime format
DateTime::parse_from_rfc3339(&end_date) DateTime::parse_from_rfc3339(&end_date)
.map(|dt| dt.naive_local().date()) .map(|dt| dt.naive_local().date())
@@ -432,7 +471,8 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
})?; })?;
// Filter entries by date range // Filter entries by date range
let filtered_entries: Vec<_> = all_entries.into_iter() let filtered_entries: Vec<_> = all_entries
.into_iter()
.filter(|e| { .filter(|e| {
if let Ok(dt) = DateTime::parse_from_rfc3339(&e.timestamp) { if let Ok(dt) = DateTime::parse_from_rfc3339(&e.timestamp) {
let date = dt.naive_local().date(); let date = dt.naive_local().date();
@@ -478,7 +518,9 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
total_cache_read_tokens += entry.cache_read_tokens; total_cache_read_tokens += entry.cache_read_tokens;
// Update model stats // Update model stats
let model_stat = model_stats.entry(entry.model.clone()).or_insert(ModelUsage { let model_stat = model_stats
.entry(entry.model.clone())
.or_insert(ModelUsage {
model: entry.model.clone(), model: entry.model.clone(),
total_cost: 0.0, total_cost: 0.0,
total_tokens: 0, total_tokens: 0,
@@ -497,7 +539,12 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
model_stat.session_count += 1; model_stat.session_count += 1;
// Update daily stats // Update daily stats
let date = entry.timestamp.split('T').next().unwrap_or(&entry.timestamp).to_string(); let date = entry
.timestamp
.split('T')
.next()
.unwrap_or(&entry.timestamp)
.to_string();
let daily_stat = daily_stats.entry(date.clone()).or_insert(DailyUsage { let daily_stat = daily_stats.entry(date.clone()).or_insert(DailyUsage {
date, date,
total_cost: 0.0, total_cost: 0.0,
@@ -505,15 +552,24 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
models_used: vec![], models_used: vec![],
}); });
daily_stat.total_cost += entry.cost; daily_stat.total_cost += entry.cost;
daily_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens; daily_stat.total_tokens += entry.input_tokens
+ entry.output_tokens
+ entry.cache_creation_tokens
+ entry.cache_read_tokens;
if !daily_stat.models_used.contains(&entry.model) { if !daily_stat.models_used.contains(&entry.model) {
daily_stat.models_used.push(entry.model.clone()); daily_stat.models_used.push(entry.model.clone());
} }
// Update project stats // Update project stats
let project_stat = project_stats.entry(entry.project_path.clone()).or_insert(ProjectUsage { let project_stat =
project_stats
.entry(entry.project_path.clone())
.or_insert(ProjectUsage {
project_path: entry.project_path.clone(), project_path: entry.project_path.clone(),
project_name: entry.project_path.split('/').last() project_name: entry
.project_path
.split('/')
.last()
.unwrap_or(&entry.project_path) .unwrap_or(&entry.project_path)
.to_string(), .to_string(),
total_cost: 0.0, total_cost: 0.0,
@@ -522,14 +578,20 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
last_used: entry.timestamp.clone(), last_used: entry.timestamp.clone(),
}); });
project_stat.total_cost += entry.cost; project_stat.total_cost += entry.cost;
project_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens; project_stat.total_tokens += entry.input_tokens
+ entry.output_tokens
+ entry.cache_creation_tokens
+ entry.cache_read_tokens;
project_stat.session_count += 1; project_stat.session_count += 1;
if entry.timestamp > project_stat.last_used { if entry.timestamp > project_stat.last_used {
project_stat.last_used = entry.timestamp.clone(); project_stat.last_used = entry.timestamp.clone();
} }
} }
let total_tokens = total_input_tokens + total_output_tokens + total_cache_creation_tokens + total_cache_read_tokens; let total_tokens = total_input_tokens
+ total_output_tokens
+ total_cache_creation_tokens
+ total_cache_read_tokens;
let total_sessions = filtered_entries.len() as u64; let total_sessions = filtered_entries.len() as u64;
// Convert hashmaps to sorted vectors // Convert hashmaps to sorted vectors
@@ -557,7 +619,10 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
} }
#[command] #[command]
pub fn get_usage_details(project_path: Option<String>, date: Option<String>) -> Result<Vec<UsageEntry>, String> { pub fn get_usage_details(
project_path: Option<String>,
date: Option<String>,
) -> Result<Vec<UsageEntry>, String> {
let claude_path = dirs::home_dir() let claude_path = dirs::home_dir()
.ok_or("Failed to get home directory")? .ok_or("Failed to get home directory")?
.join(".claude"); .join(".claude");
@@ -609,7 +674,9 @@ pub fn get_session_stats(
let mut session_stats: HashMap<String, ProjectUsage> = HashMap::new(); let mut session_stats: HashMap<String, ProjectUsage> = HashMap::new();
for entry in &filtered_entries { for entry in &filtered_entries {
let session_key = format!("{}/{}", entry.project_path, entry.session_id); let session_key = format!("{}/{}", entry.project_path, entry.session_id);
let project_stat = session_stats.entry(session_key).or_insert_with(|| ProjectUsage { let project_stat = session_stats
.entry(session_key)
.or_insert_with(|| ProjectUsage {
project_path: entry.project_path.clone(), project_path: entry.project_path.clone(),
project_name: entry.session_id.clone(), // Using session_id as project_name for session view project_name: entry.session_id.clone(), // Using session_id as project_name for session view
total_cost: 0.0, total_cost: 0.0,
@@ -643,6 +710,5 @@ pub fn get_session_stats(
by_session.sort_by(|a, b| b.last_used.cmp(&a.last_used)); by_session.sort_by(|a, b| b.last_used.cmp(&a.last_used));
} }
Ok(by_session) Ok(by_session)
} }

View File

@@ -1,11 +1,11 @@
// Learn more about Tauri commands at https://tauri.app/develop/calling-rust/ // Learn more about Tauri commands at https://tauri.app/develop/calling-rust/
// Declare modules // Declare modules
pub mod commands;
pub mod sandbox;
pub mod checkpoint; pub mod checkpoint;
pub mod process;
pub mod claude_binary; pub mod claude_binary;
pub mod commands;
pub mod process;
pub mod sandbox;
#[cfg_attr(mobile, tauri::mobile_entry_point)] #[cfg_attr(mobile, tauri::mobile_entry_point)]
pub fn run() { pub fn run() {

View File

@@ -1,57 +1,52 @@
// Prevents additional console window on Windows in release, DO NOT REMOVE!! // Prevents additional console window on Windows in release, DO NOT REMOVE!!
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
mod commands;
mod sandbox;
mod checkpoint; mod checkpoint;
mod process;
mod claude_binary; mod claude_binary;
mod commands;
mod process;
mod sandbox;
use tauri::Manager; use checkpoint::state::CheckpointState;
use commands::claude::{
get_claude_settings, get_project_sessions, get_system_prompt, list_projects, open_new_session,
check_claude_version, save_system_prompt, save_claude_settings,
find_claude_md_files, read_claude_md_file, save_claude_md_file,
load_session_history, execute_claude_code, continue_claude_code, resume_claude_code,
list_directory_contents, search_files,
create_checkpoint, restore_checkpoint, list_checkpoints, fork_from_checkpoint,
get_session_timeline, update_checkpoint_settings, get_checkpoint_diff,
track_checkpoint_message, track_session_messages, check_auto_checkpoint, cleanup_old_checkpoints,
get_checkpoint_settings, clear_checkpoint_manager, get_checkpoint_state_stats,
get_recently_modified_files, cancel_claude_execution, ClaudeProcessState,
};
use commands::agents::{ use commands::agents::{
init_database, list_agents, create_agent, update_agent, delete_agent, cleanup_finished_processes, create_agent, delete_agent, execute_agent, export_agent,
get_agent, execute_agent, list_agent_runs, get_agent_run, export_agent_to_file, fetch_github_agent_content, fetch_github_agents, get_agent,
get_agent_run_with_real_time_metrics, list_agent_runs_with_metrics, get_agent_run, get_agent_run_with_real_time_metrics, get_claude_binary_path,
list_running_sessions, kill_agent_session, get_live_session_output, get_session_output, get_session_status, import_agent,
get_session_status, cleanup_finished_processes, get_session_output, import_agent_from_file, import_agent_from_github, init_database, kill_agent_session,
get_live_session_output, stream_session_output, get_claude_binary_path, list_agent_runs, list_agent_runs_with_metrics, list_agents, list_claude_installations,
set_claude_binary_path, export_agent, export_agent_to_file, import_agent, list_running_sessions, set_claude_binary_path, stream_session_output, update_agent, AgentDb,
import_agent_from_file, fetch_github_agents, fetch_github_agent_content,
import_agent_from_github, list_claude_installations, AgentDb
}; };
use commands::sandbox::{ use commands::claude::{
list_sandbox_profiles, create_sandbox_profile, update_sandbox_profile, delete_sandbox_profile, cancel_claude_execution, check_auto_checkpoint, check_claude_version, cleanup_old_checkpoints,
get_sandbox_profile, list_sandbox_rules, create_sandbox_rule, update_sandbox_rule, clear_checkpoint_manager, continue_claude_code, create_checkpoint, execute_claude_code,
delete_sandbox_rule, get_platform_capabilities, test_sandbox_profile, find_claude_md_files, fork_from_checkpoint, get_checkpoint_diff, get_checkpoint_settings,
list_sandbox_violations, log_sandbox_violation, clear_sandbox_violations, get_sandbox_violation_stats, get_checkpoint_state_stats, get_claude_settings, get_project_sessions,
export_sandbox_profile, export_all_sandbox_profiles, import_sandbox_profiles, get_recently_modified_files, get_session_timeline, get_system_prompt, list_checkpoints,
}; list_directory_contents, list_projects, load_session_history, open_new_session,
use commands::screenshot::{ read_claude_md_file, restore_checkpoint, resume_claude_code, save_claude_md_file,
capture_url_screenshot, cleanup_screenshot_temp_files, save_claude_settings, save_system_prompt, search_files, track_checkpoint_message,
}; track_session_messages, update_checkpoint_settings, ClaudeProcessState,
use commands::usage::{
get_usage_stats, get_usage_by_date_range, get_usage_details, get_session_stats,
}; };
use commands::mcp::{ use commands::mcp::{
mcp_add, mcp_list, mcp_get, mcp_remove, mcp_add_json, mcp_add_from_claude_desktop, mcp_add, mcp_add_from_claude_desktop, mcp_add_json, mcp_get, mcp_get_server_status, mcp_list,
mcp_serve, mcp_test_connection, mcp_reset_project_choices, mcp_get_server_status, mcp_read_project_config, mcp_remove, mcp_reset_project_choices, mcp_save_project_config,
mcp_read_project_config, mcp_save_project_config, mcp_serve, mcp_test_connection,
};
use commands::sandbox::{
clear_sandbox_violations, create_sandbox_profile, create_sandbox_rule, delete_sandbox_profile,
delete_sandbox_rule, export_all_sandbox_profiles, export_sandbox_profile,
get_platform_capabilities, get_sandbox_profile, get_sandbox_violation_stats,
import_sandbox_profiles, list_sandbox_profiles, list_sandbox_rules, list_sandbox_violations,
log_sandbox_violation, test_sandbox_profile, update_sandbox_profile, update_sandbox_rule,
};
use commands::screenshot::{capture_url_screenshot, cleanup_screenshot_temp_files};
use commands::usage::{
get_session_stats, get_usage_by_date_range, get_usage_details, get_usage_stats,
}; };
use std::sync::Mutex;
use checkpoint::state::CheckpointState;
use process::ProcessRegistryState; use process::ProcessRegistryState;
use std::sync::Mutex;
use tauri::Manager;
fn main() { fn main() {
// Initialize logger // Initialize logger
@@ -81,9 +76,11 @@ fn main() {
.ok_or_else(|| "Could not find home directory") .ok_or_else(|| "Could not find home directory")
.and_then(|home| { .and_then(|home| {
let claude_path = home.join(".claude"); let claude_path = home.join(".claude");
claude_path.canonicalize() claude_path
.canonicalize()
.map_err(|_| "Could not find ~/.claude directory") .map_err(|_| "Could not find ~/.claude directory")
}) { })
{
let state_clone = checkpoint_state.clone(); let state_clone = checkpoint_state.clone();
tauri::async_runtime::spawn(async move { tauri::async_runtime::spawn(async move {
state_clone.set_claude_dir(claude_dir).await; state_clone.set_claude_dir(claude_dir).await;

View File

@@ -1,8 +1,8 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use serde::{Deserialize, Serialize};
use tokio::process::Child; use tokio::process::Child;
use chrono::{DateTime, Utc};
/// Information about a running agent process /// Information about a running agent process
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -84,7 +84,10 @@ impl ProcessRegistry {
#[allow(dead_code)] #[allow(dead_code)]
pub fn get_running_processes(&self) -> Result<Vec<ProcessInfo>, String> { pub fn get_running_processes(&self) -> Result<Vec<ProcessInfo>, String> {
let processes = self.processes.lock().map_err(|e| e.to_string())?; let processes = self.processes.lock().map_err(|e| e.to_string())?;
Ok(processes.values().map(|handle| handle.info.clone()).collect()) Ok(processes
.values()
.map(|handle| handle.info.clone())
.collect())
} }
/// Get a specific running process /// Get a specific running process
@@ -96,7 +99,7 @@ impl ProcessRegistry {
/// Kill a running process with proper cleanup /// Kill a running process with proper cleanup
pub async fn kill_process(&self, run_id: i64) -> Result<bool, String> { pub async fn kill_process(&self, run_id: i64) -> Result<bool, String> {
use log::{info, warn, error}; use log::{error, info, warn};
// First check if the process exists and get its PID // First check if the process exists and get its PID
let (pid, child_arc) = { let (pid, child_arc) = {
@@ -108,7 +111,10 @@ impl ProcessRegistry {
} }
}; };
info!("Attempting graceful shutdown of process {} (PID: {})", run_id, pid); info!(
"Attempting graceful shutdown of process {} (PID: {})",
run_id, pid
);
// Send kill signal to the process // Send kill signal to the process
let kill_sent = { let kill_sent = {
@@ -134,9 +140,7 @@ impl ProcessRegistry {
} }
// Wait for the process to exit (with timeout) // Wait for the process to exit (with timeout)
let wait_result = tokio::time::timeout( let wait_result = tokio::time::timeout(tokio::time::Duration::from_secs(5), async {
tokio::time::Duration::from_secs(5),
async {
loop { loop {
// Check if process has exited // Check if process has exited
let status = { let status = {
@@ -171,8 +175,8 @@ impl ProcessRegistry {
} }
} }
} }
} })
).await; .await;
match wait_result { match wait_result {
Ok(Ok(_)) => { Ok(Ok(_)) => {
@@ -198,7 +202,7 @@ impl ProcessRegistry {
/// Kill a process by PID using system commands (fallback method) /// Kill a process by PID using system commands (fallback method)
pub fn kill_process_by_pid(&self, run_id: i64, pid: u32) -> Result<bool, String> { pub fn kill_process_by_pid(&self, run_id: i64, pid: u32) -> Result<bool, String> {
use log::{info, warn, error}; use log::{error, info, warn};
info!("Attempting to kill process {} by PID {}", run_id, pid); info!("Attempting to kill process {} by PID {}", run_id, pid);
@@ -226,7 +230,10 @@ impl ProcessRegistry {
if let Ok(output) = check_result { if let Ok(output) = check_result {
if output.status.success() { if output.status.success() {
// Still running, send SIGKILL // Still running, send SIGKILL
warn!("Process {} still running after SIGTERM, sending SIGKILL", pid); warn!(
"Process {} still running after SIGTERM, sending SIGKILL",
pid
);
std::process::Command::new("kill") std::process::Command::new("kill")
.args(["-KILL", &pid.to_string()]) .args(["-KILL", &pid.to_string()])
.output() .output()

View File

@@ -4,11 +4,9 @@ use rusqlite::{params, Connection, Result};
/// Create default sandbox profiles for initial setup /// Create default sandbox profiles for initial setup
pub fn create_default_profiles(conn: &Connection) -> Result<()> { pub fn create_default_profiles(conn: &Connection) -> Result<()> {
// Check if we already have profiles // Check if we already have profiles
let count: i64 = conn.query_row( let count: i64 = conn.query_row("SELECT COUNT(*) FROM sandbox_profiles", [], |row| {
"SELECT COUNT(*) FROM sandbox_profiles", row.get(0)
[], })?;
|row| row.get(0),
)?;
if count > 0 { if count > 0 {
// Already have profiles, don't create defaults // Already have profiles, don't create defaults
@@ -44,14 +42,49 @@ fn create_standard_profile(conn: &Connection) -> Result<()> {
// Add rules // Add rules
let rules = vec![ let rules = vec![
// File access // File access
("file_read_all", "subpath", "{{PROJECT_PATH}}", true, Some(r#"["linux", "macos"]"#)), (
("file_read_all", "subpath", "/usr/lib", true, Some(r#"["linux", "macos"]"#)), "file_read_all",
("file_read_all", "subpath", "/usr/local/lib", true, Some(r#"["linux", "macos"]"#)), "subpath",
("file_read_all", "subpath", "/System/Library", true, Some(r#"["macos"]"#)), "{{PROJECT_PATH}}",
("file_read_metadata", "subpath", "/", true, Some(r#"["macos"]"#)), true,
Some(r#"["linux", "macos"]"#),
),
(
"file_read_all",
"subpath",
"/usr/lib",
true,
Some(r#"["linux", "macos"]"#),
),
(
"file_read_all",
"subpath",
"/usr/local/lib",
true,
Some(r#"["linux", "macos"]"#),
),
(
"file_read_all",
"subpath",
"/System/Library",
true,
Some(r#"["macos"]"#),
),
(
"file_read_metadata",
"subpath",
"/",
true,
Some(r#"["macos"]"#),
),
// Network access // Network access
("network_outbound", "all", "", true, Some(r#"["linux", "macos"]"#)), (
"network_outbound",
"all",
"",
true,
Some(r#"["linux", "macos"]"#),
),
]; ];
for (op_type, pattern_type, pattern_value, enabled, platforms) in rules { for (op_type, pattern_type, pattern_value, enabled, platforms) in rules {
@@ -113,16 +146,56 @@ fn create_development_profile(conn: &Connection) -> Result<()> {
// Add development rules // Add development rules
let rules = vec![ let rules = vec![
// Broad file access // Broad file access
("file_read_all", "subpath", "{{PROJECT_PATH}}", true, Some(r#"["linux", "macos"]"#)), (
("file_read_all", "subpath", "{{HOME}}", true, Some(r#"["linux", "macos"]"#)), "file_read_all",
("file_read_all", "subpath", "/usr", true, Some(r#"["linux", "macos"]"#)), "subpath",
("file_read_all", "subpath", "/opt", true, Some(r#"["linux", "macos"]"#)), "{{PROJECT_PATH}}",
("file_read_all", "subpath", "/Applications", true, Some(r#"["macos"]"#)), true,
("file_read_metadata", "subpath", "/", true, Some(r#"["macos"]"#)), Some(r#"["linux", "macos"]"#),
),
(
"file_read_all",
"subpath",
"{{HOME}}",
true,
Some(r#"["linux", "macos"]"#),
),
(
"file_read_all",
"subpath",
"/usr",
true,
Some(r#"["linux", "macos"]"#),
),
(
"file_read_all",
"subpath",
"/opt",
true,
Some(r#"["linux", "macos"]"#),
),
(
"file_read_all",
"subpath",
"/Applications",
true,
Some(r#"["macos"]"#),
),
(
"file_read_metadata",
"subpath",
"/",
true,
Some(r#"["macos"]"#),
),
// Network access // Network access
("network_outbound", "all", "", true, Some(r#"["linux", "macos"]"#)), (
"network_outbound",
"all",
"",
true,
Some(r#"["linux", "macos"]"#),
),
// System info (macOS only) // System info (macOS only)
("system_info_read", "all", "", true, Some(r#"["macos"]"#)), ("system_info_read", "all", "", true, Some(r#"["macos"]"#)),
]; ];

View File

@@ -1,7 +1,9 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
#[cfg(unix)] #[cfg(unix)]
use gaol::sandbox::{ChildSandbox, ChildSandboxMethods, Command as GaolCommand, Sandbox, SandboxMethods}; use gaol::sandbox::{
use log::{info, warn, error, debug}; ChildSandbox, ChildSandboxMethods, Command as GaolCommand, Sandbox, SandboxMethods,
};
use log::{debug, error, info, warn};
use std::env; use std::env;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::process::Stdio; use std::process::Stdio;
@@ -30,7 +32,7 @@ impl SandboxExecutor {
pub fn new_with_serialization( pub fn new_with_serialization(
profile: gaol::profile::Profile, profile: gaol::profile::Profile,
project_path: PathBuf, project_path: PathBuf,
serialized_profile: SerializedProfile serialized_profile: SerializedProfile,
) -> Self { ) -> Self {
Self { Self {
profile, profile,
@@ -41,7 +43,12 @@ impl SandboxExecutor {
/// Execute a command in the sandbox (for the parent process) /// Execute a command in the sandbox (for the parent process)
/// This is used when we need to spawn a child process with sandbox /// This is used when we need to spawn a child process with sandbox
pub fn execute_sandboxed_spawn(&self, command: &str, args: &[&str], cwd: &Path) -> Result<std::process::Child> { pub fn execute_sandboxed_spawn(
&self,
command: &str,
args: &[&str],
cwd: &Path,
) -> Result<std::process::Child> {
info!("Executing sandboxed command: {} {:?}", command, args); info!("Executing sandboxed command: {} {:?}", command, args);
// On macOS, we need to check if the command is allowed by the system // On macOS, we need to check if the command is allowed by the system
@@ -49,7 +56,10 @@ impl SandboxExecutor {
{ {
// For testing purposes, we'll skip actual sandboxing for simple commands like echo // For testing purposes, we'll skip actual sandboxing for simple commands like echo
if command == "echo" || command == "/bin/echo" { if command == "echo" || command == "/bin/echo" {
debug!("Using direct execution for simple test command: {}", command); debug!(
"Using direct execution for simple test command: {}",
command
);
return std::process::Command::new(command) return std::process::Command::new(command)
.args(args) .args(args)
.current_dir(cwd) .current_dir(cwd)
@@ -73,13 +83,22 @@ impl SandboxExecutor {
// Set environment variables // Set environment variables
gaol_command.env("GAOL_CHILD_PROCESS", "1"); gaol_command.env("GAOL_CHILD_PROCESS", "1");
gaol_command.env("GAOL_SANDBOX_ACTIVE", "1"); gaol_command.env("GAOL_SANDBOX_ACTIVE", "1");
gaol_command.env("GAOL_PROJECT_PATH", self.project_path.to_string_lossy().as_ref()); gaol_command.env(
"GAOL_PROJECT_PATH",
self.project_path.to_string_lossy().as_ref(),
);
// Inherit specific parent environment variables that are safe // Inherit specific parent environment variables that are safe
for (key, value) in env::vars() { for (key, value) in env::vars() {
// Only pass through safe environment variables // Only pass through safe environment variables
if key.starts_with("PATH") || key.starts_with("HOME") || key.starts_with("USER") if key.starts_with("PATH")
|| key == "SHELL" || key == "LANG" || key == "LC_ALL" || key.starts_with("LC_") { || key.starts_with("HOME")
|| key.starts_with("USER")
|| key == "SHELL"
|| key == "LANG"
|| key == "LC_ALL"
|| key.starts_with("LC_")
{
gaol_command.env(&key, &value); gaol_command.env(&key, &value);
} }
} }
@@ -93,7 +112,9 @@ impl SandboxExecutor {
// This is a limitation of the gaol library - we can't get the Child back // This is a limitation of the gaol library - we can't get the Child back
// For now, we'll have to use the fallback approach // For now, we'll have to use the fallback approach
warn!("Gaol started the process but we can't get the Child handle - using fallback"); warn!(
"Gaol started the process but we can't get the Child handle - using fallback"
);
// Drop the process to avoid zombie // Drop the process to avoid zombie
drop(process); drop(process);
@@ -118,16 +139,21 @@ impl SandboxExecutor {
}; };
let mut std_command = std::process::Command::new(command); let mut std_command = std::process::Command::new(command);
std_command.args(args) std_command
.args(args)
.current_dir(cwd) .current_dir(cwd)
.env("GAOL_SANDBOX_ACTIVE", "1") .env("GAOL_SANDBOX_ACTIVE", "1")
.env("GAOL_PROJECT_PATH", self.project_path.to_string_lossy().as_ref()) .env(
"GAOL_PROJECT_PATH",
self.project_path.to_string_lossy().as_ref(),
)
.env("GAOL_SANDBOX_RULES", rules_json) .env("GAOL_SANDBOX_RULES", rules_json)
.stdin(Stdio::piped()) .stdin(Stdio::piped())
.stdout(Stdio::piped()) .stdout(Stdio::piped())
.stderr(Stdio::piped()); .stderr(Stdio::piped());
std_command.spawn() std_command
.spawn()
.context("Failed to spawn process with sandbox environment") .context("Failed to spawn process with sandbox environment")
} }
@@ -137,16 +163,23 @@ impl SandboxExecutor {
info!("Preparing sandboxed command: {} {:?}", command, args); info!("Preparing sandboxed command: {} {:?}", command, args);
let mut cmd = Command::new(command); let mut cmd = Command::new(command);
cmd.args(args) cmd.args(args).current_dir(cwd);
.current_dir(cwd);
// Inherit essential environment variables from parent process // Inherit essential environment variables from parent process
// This is crucial for commands like Claude that need to find Node.js // This is crucial for commands like Claude that need to find Node.js
for (key, value) in env::vars() { for (key, value) in env::vars() {
// Pass through PATH and other essential environment variables // Pass through PATH and other essential environment variables
if key == "PATH" || key == "HOME" || key == "USER" if key == "PATH"
|| key == "SHELL" || key == "LANG" || key == "LC_ALL" || key.starts_with("LC_") || key == "HOME"
|| key == "NODE_PATH" || key == "NVM_DIR" || key == "NVM_BIN" { || key == "USER"
|| key == "SHELL"
|| key == "LANG"
|| key == "LC_ALL"
|| key.starts_with("LC_")
|| key == "NODE_PATH"
|| key == "NVM_DIR"
|| key == "NVM_BIN"
{
debug!("Inheriting env var: {}={}", key, value); debug!("Inheriting env var: {}={}", key, value);
cmd.env(&key, &value); cmd.env(&key, &value);
} }
@@ -155,11 +188,19 @@ impl SandboxExecutor {
// Serialize the sandbox rules for the child process // Serialize the sandbox rules for the child process
let rules_json = if let Some(ref serialized) = self.serialized_profile { let rules_json = if let Some(ref serialized) = self.serialized_profile {
let json = serde_json::to_string(serialized).ok(); let json = serde_json::to_string(serialized).ok();
info!("🔧 Using serialized sandbox profile with {} operations", serialized.operations.len()); info!(
"🔧 Using serialized sandbox profile with {} operations",
serialized.operations.len()
);
for (i, op) in serialized.operations.iter().enumerate() { for (i, op) in serialized.operations.iter().enumerate() {
match op { match op {
SerializedOperation::FileReadAll { path, is_subpath } => { SerializedOperation::FileReadAll { path, is_subpath } => {
info!(" Rule {}: FileReadAll {} (subpath: {})", i, path.display(), is_subpath); info!(
" Rule {}: FileReadAll {} (subpath: {})",
i,
path.display(),
is_subpath
);
} }
SerializedOperation::NetworkOutbound { pattern } => { SerializedOperation::NetworkOutbound { pattern } => {
info!(" Rule {}: NetworkOutbound {}", i, pattern); info!(" Rule {}: NetworkOutbound {}", i, pattern);
@@ -188,7 +229,10 @@ impl SandboxExecutor {
warn!("🚨 TEMPORARILY DISABLED sandbox environment variables for debugging"); warn!("🚨 TEMPORARILY DISABLED sandbox environment variables for debugging");
info!("🔧 Would have set sandbox environment variables for child process"); info!("🔧 Would have set sandbox environment variables for child process");
info!(" GAOL_SANDBOX_ACTIVE=1 (disabled)"); info!(" GAOL_SANDBOX_ACTIVE=1 (disabled)");
info!(" GAOL_PROJECT_PATH={} (disabled)", self.project_path.display()); info!(
" GAOL_PROJECT_PATH={} (disabled)",
self.project_path.display()
);
info!(" GAOL_SANDBOX_RULES={} chars (disabled)", json.len()); info!(" GAOL_SANDBOX_RULES={} chars (disabled)", json.len());
} else { } else {
warn!("🚨 Failed to serialize sandbox rules - running without sandbox!"); warn!("🚨 Failed to serialize sandbox rules - running without sandbox!");
@@ -210,10 +254,10 @@ impl SandboxExecutor {
let operations = vec![ let operations = vec![
SerializedOperation::FileReadAll { SerializedOperation::FileReadAll {
path: self.project_path.clone(), path: self.project_path.clone(),
is_subpath: true is_subpath: true,
}, },
SerializedOperation::NetworkOutbound { SerializedOperation::NetworkOutbound {
pattern: "all".to_string() pattern: "all".to_string(),
}, },
]; ];
@@ -231,17 +275,19 @@ impl SandboxExecutor {
info!("Activating sandbox in child process"); info!("Activating sandbox in child process");
// Get project path // Get project path
let project_path = env::var("GAOL_PROJECT_PATH") let project_path = env::var("GAOL_PROJECT_PATH").context("GAOL_PROJECT_PATH not set")?;
.context("GAOL_PROJECT_PATH not set")?;
let project_path = PathBuf::from(project_path); let project_path = PathBuf::from(project_path);
// Try to deserialize the sandbox rules from environment // Try to deserialize the sandbox rules from environment
let profile = if let Ok(rules_json) = env::var("GAOL_SANDBOX_RULES") { let profile = if let Ok(rules_json) = env::var("GAOL_SANDBOX_RULES") {
match serde_json::from_str::<SerializedProfile>(&rules_json) { match serde_json::from_str::<SerializedProfile>(&rules_json) {
Ok(serialized) => { Ok(serialized) => {
debug!("Deserializing {} sandbox rules", serialized.operations.len()); debug!(
"Deserializing {} sandbox rules",
serialized.operations.len()
);
deserialize_profile(serialized, &project_path)? deserialize_profile(serialized, &project_path)?
}, }
Err(e) => { Err(e) => {
warn!("Failed to deserialize sandbox rules: {}", e); warn!("Failed to deserialize sandbox rules: {}", e);
// Fallback to minimal profile // Fallback to minimal profile
@@ -285,7 +331,7 @@ impl SandboxExecutor {
pub fn new_with_serialization( pub fn new_with_serialization(
_profile: (), _profile: (),
project_path: PathBuf, project_path: PathBuf,
serialized_profile: SerializedProfile serialized_profile: SerializedProfile,
) -> Self { ) -> Self {
Self { Self {
project_path, project_path,
@@ -294,8 +340,16 @@ impl SandboxExecutor {
} }
/// Execute a command in the sandbox (Windows - no sandboxing) /// Execute a command in the sandbox (Windows - no sandboxing)
pub fn execute_sandboxed_spawn(&self, command: &str, args: &[&str], cwd: &Path) -> Result<std::process::Child> { pub fn execute_sandboxed_spawn(
info!("Executing command without sandbox on Windows: {} {:?}", command, args); &self,
command: &str,
args: &[&str],
cwd: &Path,
) -> Result<std::process::Child> {
info!(
"Executing command without sandbox on Windows: {} {:?}",
command, args
);
std::process::Command::new(command) std::process::Command::new(command)
.args(args) .args(args)
@@ -309,7 +363,10 @@ impl SandboxExecutor {
/// Prepare a sandboxed tokio Command (Windows - no sandboxing) /// Prepare a sandboxed tokio Command (Windows - no sandboxing)
pub fn prepare_sandboxed_command(&self, command: &str, args: &[&str], cwd: &Path) -> Command { pub fn prepare_sandboxed_command(&self, command: &str, args: &[&str], cwd: &Path) -> Command {
info!("Preparing command without sandbox on Windows: {} {:?}", command, args); info!(
"Preparing command without sandbox on Windows: {} {:?}",
command, args
);
let mut cmd = Command::new(command); let mut cmd = Command::new(command);
cmd.args(args) cmd.args(args)
@@ -345,7 +402,7 @@ pub fn create_sandboxed_command(
args: &[&str], args: &[&str],
cwd: &Path, cwd: &Path,
profile: gaol::profile::Profile, profile: gaol::profile::Profile,
project_path: PathBuf project_path: PathBuf,
) -> Command { ) -> Command {
let executor = SandboxExecutor::new(profile, project_path); let executor = SandboxExecutor::new(profile, project_path);
executor.prepare_sandboxed_command(command, args, cwd) executor.prepare_sandboxed_command(command, args, cwd)
@@ -368,7 +425,10 @@ pub enum SerializedOperation {
} }
#[cfg(unix)] #[cfg(unix)]
fn deserialize_profile(serialized: SerializedProfile, project_path: &Path) -> Result<gaol::profile::Profile> { fn deserialize_profile(
serialized: SerializedProfile,
project_path: &Path,
) -> Result<gaol::profile::Profile> {
let mut operations = Vec::new(); let mut operations = Vec::new();
for op in serialized.operations { for op in serialized.operations {
@@ -401,12 +461,12 @@ fn deserialize_profile(serialized: SerializedProfile, project_path: &Path) -> Re
} }
SerializedOperation::NetworkTcp { port } => { SerializedOperation::NetworkTcp { port } => {
operations.push(gaol::profile::Operation::NetworkOutbound( operations.push(gaol::profile::Operation::NetworkOutbound(
gaol::profile::AddressPattern::Tcp(port) gaol::profile::AddressPattern::Tcp(port),
)); ));
} }
SerializedOperation::NetworkLocalSocket { path } => { SerializedOperation::NetworkLocalSocket { path } => {
operations.push(gaol::profile::Operation::NetworkOutbound( operations.push(gaol::profile::Operation::NetworkOutbound(
gaol::profile::AddressPattern::LocalSocket(path) gaol::profile::AddressPattern::LocalSocket(path),
)); ));
} }
SerializedOperation::SystemInfoRead => { SerializedOperation::SystemInfoRead => {
@@ -422,31 +482,29 @@ fn deserialize_profile(serialized: SerializedProfile, project_path: &Path) -> Re
if !has_project_access { if !has_project_access {
operations.push(gaol::profile::Operation::FileReadAll( operations.push(gaol::profile::Operation::FileReadAll(
gaol::profile::PathPattern::Subpath(project_path.to_path_buf()) gaol::profile::PathPattern::Subpath(project_path.to_path_buf()),
)); ));
} }
let op_count = operations.len(); let op_count = operations.len();
gaol::profile::Profile::new(operations) gaol::profile::Profile::new(operations).map_err(|e| {
.map_err(|e| {
error!("Failed to create profile: {:?}", e); error!("Failed to create profile: {:?}", e);
anyhow::anyhow!("Failed to create profile from {} operations: {:?}", op_count, e) anyhow::anyhow!(
"Failed to create profile from {} operations: {:?}",
op_count,
e
)
}) })
} }
#[cfg(unix)] #[cfg(unix)]
fn create_minimal_profile(project_path: PathBuf) -> Result<gaol::profile::Profile> { fn create_minimal_profile(project_path: PathBuf) -> Result<gaol::profile::Profile> {
let operations = vec![ let operations = vec![
gaol::profile::Operation::FileReadAll( gaol::profile::Operation::FileReadAll(gaol::profile::PathPattern::Subpath(project_path)),
gaol::profile::PathPattern::Subpath(project_path) gaol::profile::Operation::NetworkOutbound(gaol::profile::AddressPattern::All),
),
gaol::profile::Operation::NetworkOutbound(
gaol::profile::AddressPattern::All
),
]; ];
gaol::profile::Profile::new(operations) gaol::profile::Profile::new(operations).map_err(|e| {
.map_err(|e| {
error!("Failed to create minimal profile: {:?}", e); error!("Failed to create minimal profile: {:?}", e);
anyhow::anyhow!("Failed to create minimal sandbox profile: {:?}", e) anyhow::anyhow!("Failed to create minimal sandbox profile: {:?}", e)
}) })

View File

@@ -1,21 +1,21 @@
#[allow(unused)] #[allow(unused)]
pub mod profile; pub mod defaults;
#[allow(unused)] #[allow(unused)]
pub mod executor; pub mod executor;
#[allow(unused)] #[allow(unused)]
pub mod platform; pub mod platform;
#[allow(unused)] #[allow(unused)]
pub mod defaults; pub mod profile;
// These are used in agents.rs and claude.rs via direct module paths // These are used in agents.rs and claude.rs via direct module paths
#[allow(unused)] #[allow(unused)]
pub use profile::{SandboxProfile, SandboxRule, ProfileBuilder}; pub use profile::{ProfileBuilder, SandboxProfile, SandboxRule};
// These are used in main.rs and sandbox.rs // These are used in main.rs and sandbox.rs
#[allow(unused)] #[allow(unused)]
pub use executor::{SandboxExecutor, should_activate_sandbox}; pub use executor::{should_activate_sandbox, SandboxExecutor};
// These are used in sandbox.rs // These are used in sandbox.rs
#[allow(unused)] #[allow(unused)]
pub use platform::{PlatformCapabilities, get_platform_capabilities}; pub use platform::{get_platform_capabilities, PlatformCapabilities};
// Used for initial setup // Used for initial setup
#[allow(unused)] #[allow(unused)]
pub use defaults::create_default_profiles; pub use defaults::create_default_profiles;

View File

@@ -1,3 +1,4 @@
use crate::sandbox::executor::{SerializedOperation, SerializedProfile};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
#[cfg(unix)] #[cfg(unix)]
use gaol::profile::{AddressPattern, Operation, OperationSupport, PathPattern, Profile}; use gaol::profile::{AddressPattern, Operation, OperationSupport, PathPattern, Profile};
@@ -5,7 +6,6 @@ use log::{debug, info, warn};
use rusqlite::{params, Connection}; use rusqlite::{params, Connection};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::path::PathBuf; use std::path::PathBuf;
use crate::sandbox::executor::{SerializedOperation, SerializedProfile};
/// Represents a sandbox profile from the database /// Represents a sandbox profile from the database
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -50,8 +50,7 @@ pub struct ProfileBuilder {
impl ProfileBuilder { impl ProfileBuilder {
/// Create a new profile builder /// Create a new profile builder
pub fn new(project_path: PathBuf) -> Result<Self> { pub fn new(project_path: PathBuf) -> Result<Self> {
let home_dir = dirs::home_dir() let home_dir = dirs::home_dir().context("Could not determine home directory")?;
.context("Could not determine home directory")?;
Ok(Self { Ok(Self {
project_path, project_path,
@@ -60,12 +59,20 @@ impl ProfileBuilder {
} }
/// Build a gaol Profile from database rules filtered by agent permissions /// Build a gaol Profile from database rules filtered by agent permissions
pub fn build_agent_profile(&self, rules: Vec<SandboxRule>, sandbox_enabled: bool, enable_file_read: bool, enable_file_write: bool, enable_network: bool) -> Result<ProfileBuildResult> { pub fn build_agent_profile(
&self,
rules: Vec<SandboxRule>,
sandbox_enabled: bool,
enable_file_read: bool,
enable_file_write: bool,
enable_network: bool,
) -> Result<ProfileBuildResult> {
// If sandbox is completely disabled, return an empty profile // If sandbox is completely disabled, return an empty profile
if !sandbox_enabled { if !sandbox_enabled {
return Ok(ProfileBuildResult { return Ok(ProfileBuildResult {
#[cfg(unix)] #[cfg(unix)]
profile: Profile::new(vec![]).map_err(|_| anyhow::anyhow!("Failed to create empty profile"))?, profile: Profile::new(vec![])
.map_err(|_| anyhow::anyhow!("Failed to create empty profile"))?,
#[cfg(not(unix))] #[cfg(not(unix))]
profile: (), profile: (),
serialized: SerializedProfile { operations: vec![] }, serialized: SerializedProfile { operations: vec![] },
@@ -84,7 +91,7 @@ impl ProfileBuilder {
"file_read_all" | "file_read_metadata" => enable_file_read, "file_read_all" | "file_read_metadata" => enable_file_read,
"network_outbound" => enable_network, "network_outbound" => enable_network,
"system_info_read" => true, // Always allow system info reading "system_info_read" => true, // Always allow system info reading
_ => true // Include unknown rule types by default _ => true, // Include unknown rule types by default
}; };
if include_rule { if include_rule {
@@ -95,9 +102,9 @@ impl ProfileBuilder {
// Always ensure project path access if file reading is enabled // Always ensure project path access if file reading is enabled
if enable_file_read { if enable_file_read {
let has_project_access = filtered_rules.iter().any(|rule| { let has_project_access = filtered_rules.iter().any(|rule| {
rule.operation_type == "file_read_all" && rule.operation_type == "file_read_all"
rule.pattern_type == "subpath" && && rule.pattern_type == "subpath"
rule.pattern_value.contains("{{PROJECT_PATH}}") && rule.pattern_value.contains("{{PROJECT_PATH}}")
}); });
if !has_project_access { if !has_project_access {
@@ -133,7 +140,10 @@ impl ProfileBuilder {
} }
/// Build a gaol Profile from database rules and return serialized operations /// Build a gaol Profile from database rules and return serialized operations
pub fn build_profile_with_serialization(&self, rules: Vec<SandboxRule>) -> Result<ProfileBuildResult> { pub fn build_profile_with_serialization(
&self,
rules: Vec<SandboxRule>,
) -> Result<ProfileBuildResult> {
#[cfg(unix)] #[cfg(unix)]
{ {
let mut operations = Vec::new(); let mut operations = Vec::new();
@@ -146,25 +156,41 @@ impl ProfileBuilder {
// Check platform support // Check platform support
if !self.is_rule_supported_on_platform(&rule) { if !self.is_rule_supported_on_platform(&rule) {
debug!("Skipping rule {} - not supported on current platform", rule.operation_type); debug!(
"Skipping rule {} - not supported on current platform",
rule.operation_type
);
continue; continue;
} }
match self.build_operation_with_serialization(&rule) { match self.build_operation_with_serialization(&rule) {
Ok(Some((op, serialized))) => { Ok(Some((op, serialized))) => {
// Check if operation is supported on current platform // Check if operation is supported on current platform
if matches!(op.support(), gaol::profile::OperationSupportLevel::CanBeAllowed) { if matches!(
op.support(),
gaol::profile::OperationSupportLevel::CanBeAllowed
) {
operations.push(op); operations.push(op);
serialized_operations.push(serialized); serialized_operations.push(serialized);
} else { } else {
warn!("Operation {:?} not supported at desired level on current platform", rule.operation_type); warn!(
"Operation {:?} not supported at desired level on current platform",
rule.operation_type
);
}
} }
},
Ok(None) => { Ok(None) => {
debug!("Skipping unsupported operation type: {}", rule.operation_type); debug!(
}, "Skipping unsupported operation type: {}",
rule.operation_type
);
}
Err(e) => { Err(e) => {
warn!("Failed to build operation for rule {}: {}", rule.id.unwrap_or(0), e); warn!(
"Failed to build operation for rule {}: {}",
rule.id.unwrap_or(0),
e
);
} }
} }
} }
@@ -175,7 +201,9 @@ impl ProfileBuilder {
}); });
if !has_project_access { if !has_project_access {
operations.push(Operation::FileReadAll(PathPattern::Subpath(self.project_path.clone()))); operations.push(Operation::FileReadAll(PathPattern::Subpath(
self.project_path.clone(),
)));
serialized_operations.push(SerializedOperation::FileReadAll { serialized_operations.push(SerializedOperation::FileReadAll {
path: self.project_path.clone(), path: self.project_path.clone(),
is_subpath: true, is_subpath: true,
@@ -230,33 +258,39 @@ impl ProfileBuilder {
/// Build a gaol Operation and its serialized form from a database rule /// Build a gaol Operation and its serialized form from a database rule
#[cfg(unix)] #[cfg(unix)]
fn build_operation_with_serialization(&self, rule: &SandboxRule) -> Result<Option<(Operation, SerializedOperation)>> { fn build_operation_with_serialization(
&self,
rule: &SandboxRule,
) -> Result<Option<(Operation, SerializedOperation)>> {
match rule.operation_type.as_str() { match rule.operation_type.as_str() {
"file_read_all" => { "file_read_all" => {
let (pattern, path, is_subpath) = self.build_path_pattern_with_info(&rule.pattern_type, &rule.pattern_value)?; let (pattern, path, is_subpath) =
self.build_path_pattern_with_info(&rule.pattern_type, &rule.pattern_value)?;
Ok(Some(( Ok(Some((
Operation::FileReadAll(pattern), Operation::FileReadAll(pattern),
SerializedOperation::FileReadAll { path, is_subpath } SerializedOperation::FileReadAll { path, is_subpath },
))) )))
}, }
"file_read_metadata" => { "file_read_metadata" => {
let (pattern, path, is_subpath) = self.build_path_pattern_with_info(&rule.pattern_type, &rule.pattern_value)?; let (pattern, path, is_subpath) =
self.build_path_pattern_with_info(&rule.pattern_type, &rule.pattern_value)?;
Ok(Some(( Ok(Some((
Operation::FileReadMetadata(pattern), Operation::FileReadMetadata(pattern),
SerializedOperation::FileReadMetadata { path, is_subpath } SerializedOperation::FileReadMetadata { path, is_subpath },
))) )))
}, }
"network_outbound" => { "network_outbound" => {
let (pattern, serialized) = self.build_address_pattern_with_serialization(&rule.pattern_type, &rule.pattern_value)?; let (pattern, serialized) = self.build_address_pattern_with_serialization(
&rule.pattern_type,
&rule.pattern_value,
)?;
Ok(Some((Operation::NetworkOutbound(pattern), serialized))) Ok(Some((Operation::NetworkOutbound(pattern), serialized)))
}, }
"system_info_read" => { "system_info_read" => Ok(Some((
Ok(Some((
Operation::SystemInfoRead, Operation::SystemInfoRead,
SerializedOperation::SystemInfoRead SerializedOperation::SystemInfoRead,
))) ))),
}, _ => Ok(None),
_ => Ok(None)
} }
} }
@@ -269,7 +303,11 @@ impl ProfileBuilder {
/// Build a PathPattern and return additional info for serialization /// Build a PathPattern and return additional info for serialization
#[cfg(unix)] #[cfg(unix)]
fn build_path_pattern_with_info(&self, pattern_type: &str, pattern_value: &str) -> Result<(PathPattern, PathBuf, bool)> { fn build_path_pattern_with_info(
&self,
pattern_type: &str,
pattern_value: &str,
) -> Result<(PathPattern, PathBuf, bool)> {
// Replace template variables // Replace template variables
let expanded_value = pattern_value let expanded_value = pattern_value
.replace("{{PROJECT_PATH}}", &self.project_path.to_string_lossy()) .replace("{{PROJECT_PATH}}", &self.project_path.to_string_lossy())
@@ -280,41 +318,59 @@ impl ProfileBuilder {
match pattern_type { match pattern_type {
"literal" => Ok((PathPattern::Literal(path.clone()), path, false)), "literal" => Ok((PathPattern::Literal(path.clone()), path, false)),
"subpath" => Ok((PathPattern::Subpath(path.clone()), path, true)), "subpath" => Ok((PathPattern::Subpath(path.clone()), path, true)),
_ => Err(anyhow::anyhow!("Unknown path pattern type: {}", pattern_type)) _ => Err(anyhow::anyhow!(
"Unknown path pattern type: {}",
pattern_type
)),
} }
} }
/// Build an AddressPattern from pattern type and value /// Build an AddressPattern from pattern type and value
#[cfg(unix)] #[cfg(unix)]
fn build_address_pattern(&self, pattern_type: &str, pattern_value: &str) -> Result<AddressPattern> { fn build_address_pattern(
let (pattern, _) = self.build_address_pattern_with_serialization(pattern_type, pattern_value)?; &self,
pattern_type: &str,
pattern_value: &str,
) -> Result<AddressPattern> {
let (pattern, _) =
self.build_address_pattern_with_serialization(pattern_type, pattern_value)?;
Ok(pattern) Ok(pattern)
} }
/// Build an AddressPattern and its serialized form /// Build an AddressPattern and its serialized form
#[cfg(unix)] #[cfg(unix)]
fn build_address_pattern_with_serialization(&self, pattern_type: &str, pattern_value: &str) -> Result<(AddressPattern, SerializedOperation)> { fn build_address_pattern_with_serialization(
&self,
pattern_type: &str,
pattern_value: &str,
) -> Result<(AddressPattern, SerializedOperation)> {
match pattern_type { match pattern_type {
"all" => Ok(( "all" => Ok((
AddressPattern::All, AddressPattern::All,
SerializedOperation::NetworkOutbound { pattern: "all".to_string() } SerializedOperation::NetworkOutbound {
pattern: "all".to_string(),
},
)), )),
"tcp" => { "tcp" => {
let port = pattern_value.parse::<u16>() let port = pattern_value
.parse::<u16>()
.context("Invalid TCP port number")?; .context("Invalid TCP port number")?;
Ok(( Ok((
AddressPattern::Tcp(port), AddressPattern::Tcp(port),
SerializedOperation::NetworkTcp { port } SerializedOperation::NetworkTcp { port },
)) ))
}, }
"local_socket" => { "local_socket" => {
let path = PathBuf::from(pattern_value); let path = PathBuf::from(pattern_value);
Ok(( Ok((
AddressPattern::LocalSocket(path.clone()), AddressPattern::LocalSocket(path.clone()),
SerializedOperation::NetworkLocalSocket { path } SerializedOperation::NetworkLocalSocket { path },
)) ))
}, }
_ => Err(anyhow::anyhow!("Unknown address pattern type: {}", pattern_type)) _ => Err(anyhow::anyhow!(
"Unknown address pattern type: {}",
pattern_type
)),
} }
} }
@@ -332,33 +388,38 @@ impl ProfileBuilder {
/// Build only the serialized operation (for Windows) /// Build only the serialized operation (for Windows)
#[cfg(not(unix))] #[cfg(not(unix))]
fn build_serialized_operation(&self, rule: &SandboxRule) -> Result<Option<SerializedOperation>> { fn build_serialized_operation(
&self,
rule: &SandboxRule,
) -> Result<Option<SerializedOperation>> {
let pattern_value = self.expand_pattern_value(&rule.pattern_value); let pattern_value = self.expand_pattern_value(&rule.pattern_value);
match rule.operation_type.as_str() { match rule.operation_type.as_str() {
"file_read_all" => { "file_read_all" => {
let (path, is_subpath) = self.parse_path_pattern(&rule.pattern_type, &pattern_value)?; let (path, is_subpath) =
self.parse_path_pattern(&rule.pattern_type, &pattern_value)?;
Ok(Some(SerializedOperation::FileReadAll { path, is_subpath })) Ok(Some(SerializedOperation::FileReadAll { path, is_subpath }))
} }
"file_read_metadata" => { "file_read_metadata" => {
let (path, is_subpath) = self.parse_path_pattern(&rule.pattern_type, &pattern_value)?; let (path, is_subpath) =
Ok(Some(SerializedOperation::FileReadMetadata { path, is_subpath })) self.parse_path_pattern(&rule.pattern_type, &pattern_value)?;
} Ok(Some(SerializedOperation::FileReadMetadata {
"network_outbound" => { path,
Ok(Some(SerializedOperation::NetworkOutbound { pattern: pattern_value })) is_subpath,
}))
} }
"network_outbound" => Ok(Some(SerializedOperation::NetworkOutbound {
pattern: pattern_value,
})),
"network_tcp" => { "network_tcp" => {
let port = pattern_value.parse::<u16>() let port = pattern_value.parse::<u16>().context("Invalid TCP port")?;
.context("Invalid TCP port")?;
Ok(Some(SerializedOperation::NetworkTcp { port })) Ok(Some(SerializedOperation::NetworkTcp { port }))
} }
"network_local_socket" => { "network_local_socket" => {
let path = PathBuf::from(pattern_value); let path = PathBuf::from(pattern_value);
Ok(Some(SerializedOperation::NetworkLocalSocket { path })) Ok(Some(SerializedOperation::NetworkLocalSocket { path }))
} }
"system_info_read" => { "system_info_read" => Ok(Some(SerializedOperation::SystemInfoRead)),
Ok(Some(SerializedOperation::SystemInfoRead))
}
_ => Ok(None), _ => Ok(None),
} }
} }
@@ -373,13 +434,20 @@ impl ProfileBuilder {
/// Helper method to parse path patterns (Windows version) /// Helper method to parse path patterns (Windows version)
#[cfg(not(unix))] #[cfg(not(unix))]
fn parse_path_pattern(&self, pattern_type: &str, pattern_value: &str) -> Result<(PathBuf, bool)> { fn parse_path_pattern(
&self,
pattern_type: &str,
pattern_value: &str,
) -> Result<(PathBuf, bool)> {
let path = PathBuf::from(pattern_value); let path = PathBuf::from(pattern_value);
match pattern_type { match pattern_type {
"literal" => Ok((path, false)), "literal" => Ok((path, false)),
"subpath" => Ok((path, true)), "subpath" => Ok((path, true)),
_ => Err(anyhow::anyhow!("Unknown path pattern type: {}", pattern_type)) _ => Err(anyhow::anyhow!(
"Unknown path pattern type: {}",
pattern_type
)),
} }
} }
} }
@@ -400,7 +468,7 @@ pub fn load_profile(conn: &Connection, profile_id: i64) -> Result<SandboxProfile
created_at: row.get(5)?, created_at: row.get(5)?,
updated_at: row.get(6)?, updated_at: row.get(6)?,
}) })
} },
) )
.context("Failed to load sandbox profile") .context("Failed to load sandbox profile")
} }
@@ -421,7 +489,7 @@ pub fn load_default_profile(conn: &Connection) -> Result<SandboxProfile> {
created_at: row.get(5)?, created_at: row.get(5)?,
updated_at: row.get(6)?, updated_at: row.get(6)?,
}) })
} },
) )
.context("Failed to load default sandbox profile") .context("Failed to load default sandbox profile")
} }
@@ -433,7 +501,8 @@ pub fn load_profile_rules(conn: &Connection, profile_id: i64) -> Result<Vec<Sand
FROM sandbox_rules WHERE profile_id = ?1 AND enabled = 1" FROM sandbox_rules WHERE profile_id = ?1 AND enabled = 1"
)?; )?;
let rules = stmt.query_map(params![profile_id], |row| { let rules = stmt
.query_map(params![profile_id], |row| {
Ok(SandboxRule { Ok(SandboxRule {
id: Some(row.get(0)?), id: Some(row.get(0)?),
profile_id: row.get(1)?, profile_id: row.get(1)?,
@@ -452,7 +521,11 @@ pub fn load_profile_rules(conn: &Connection, profile_id: i64) -> Result<Vec<Sand
/// Get or create the gaol Profile for execution /// Get or create the gaol Profile for execution
#[cfg(unix)] #[cfg(unix)]
pub fn get_gaol_profile(conn: &Connection, profile_id: Option<i64>, project_path: PathBuf) -> Result<Profile> { pub fn get_gaol_profile(
conn: &Connection,
profile_id: Option<i64>,
project_path: PathBuf,
) -> Result<Profile> {
// Load the profile // Load the profile
let profile = if let Some(id) = profile_id { let profile = if let Some(id) = profile_id {
load_profile(conn, id)? load_profile(conn, id)?
@@ -473,7 +546,11 @@ pub fn get_gaol_profile(conn: &Connection, profile_id: Option<i64>, project_path
/// Get or create the gaol Profile for execution (Windows stub) /// Get or create the gaol Profile for execution (Windows stub)
#[cfg(not(unix))] #[cfg(not(unix))]
pub fn get_gaol_profile(_conn: &Connection, _profile_id: Option<i64>, _project_path: PathBuf) -> Result<()> { pub fn get_gaol_profile(
_conn: &Connection,
_profile_id: Option<i64>,
_project_path: PathBuf,
) -> Result<()> {
warn!("Sandbox profiles are not supported on Windows"); warn!("Sandbox profiles are not supported on Windows");
Ok(()) Ok(())
} }

View File

@@ -29,7 +29,8 @@ pub fn execute_claude_task(
} }
// Always add these flags for testing // Always add these flags for testing
cmd.arg("--output-format").arg("stream-json") cmd.arg("--output-format")
.arg("stream-json")
.arg("--verbose") .arg("--verbose")
.arg("--dangerously-skip-permissions") .arg("--dangerously-skip-permissions")
.current_dir(project_path) .current_dir(project_path)
@@ -63,12 +64,12 @@ pub fn execute_claude_task(
let output = if timeout_cmd.is_empty() { let output = if timeout_cmd.is_empty() {
// Run without timeout wrapper // Run without timeout wrapper
cmd.output() cmd.output().context("Failed to execute Claude command")?
.context("Failed to execute Claude command")?
} else { } else {
// Run with timeout wrapper // Run with timeout wrapper
let mut timeout_cmd = Command::new(timeout_cmd); let mut timeout_cmd = Command::new(timeout_cmd);
timeout_cmd.arg(timeout_secs.to_string()) timeout_cmd
.arg(timeout_secs.to_string())
.arg("claude") .arg("claude")
.args(cmd.get_args()) .args(cmd.get_args())
.current_dir(project_path) .current_dir(project_path)
@@ -119,9 +120,9 @@ impl ClaudeOutput {
let op_lower = operation.to_lowercase(); let op_lower = operation.to_lowercase();
// Check if operation was mentioned along with a block pattern // Check if operation was mentioned along with a block pattern
blocked_patterns.iter().any(|pattern| { blocked_patterns
output.contains(&op_lower) && output.contains(pattern) .iter()
}) .any(|pattern| output.contains(&op_lower) && output.contains(pattern))
} }
/// Check if file read was successful /// Check if file read was successful
@@ -134,7 +135,9 @@ impl ClaudeOutput {
"test content", // Our test files contain this "test content", // Our test files contain this
]; ];
patterns.iter().any(|pattern| self.contains_operation(pattern)) patterns
.iter()
.any(|pattern| self.contains_operation(pattern))
} }
/// Check if network connection was attempted /// Check if network connection was attempted
@@ -146,7 +149,9 @@ impl ClaudeOutput {
host, host,
]; ];
patterns.iter().any(|pattern| self.contains_operation(pattern)) patterns
.iter()
.any(|pattern| self.contains_operation(pattern))
} }
} }
@@ -169,7 +174,10 @@ pub mod tasks {
/// Task to test file write /// Task to test file write
pub fn write_file(filename: &str, content: &str) -> String { pub fn write_file(filename: &str, content: &str) -> String {
format!("Create a file called {} with the content '{}'", filename, content) format!(
"Create a file called {} with the content '{}'",
filename, content
)
} }
/// Task to test process spawning /// Task to test process spawning

View File

@@ -10,9 +10,8 @@ use tempfile::{tempdir, TempDir};
/// Using parking_lot::Mutex which doesn't poison on panic /// Using parking_lot::Mutex which doesn't poison on panic
use parking_lot::Mutex; use parking_lot::Mutex;
pub static TEST_DB: Lazy<Mutex<TestDatabase>> = Lazy::new(|| { pub static TEST_DB: Lazy<Mutex<TestDatabase>> =
Mutex::new(TestDatabase::new().expect("Failed to create test database")) Lazy::new(|| Mutex::new(TestDatabase::new().expect("Failed to create test database")));
});
/// Test database manager /// Test database manager
pub struct TestDatabase { pub struct TestDatabase {
@@ -273,7 +272,10 @@ impl TestFileSystem {
// Create project files // Create project files
std::fs::write(project_path.join("main.rs"), "fn main() {}")?; std::fs::write(project_path.join("main.rs"), "fn main() {}")?;
std::fs::write(project_path.join("Cargo.toml"), "[package]\nname = \"test\"")?; std::fs::write(
project_path.join("Cargo.toml"),
"[package]\nname = \"test\"",
)?;
Ok(Self { Ok(Self {
root, root,
@@ -290,9 +292,7 @@ pub mod profiles {
/// Minimal profile - only project access /// Minimal profile - only project access
pub fn minimal(project_path: &str) -> Vec<TestRule> { pub fn minimal(project_path: &str) -> Vec<TestRule> {
vec![ vec![TestRule::file_read(project_path, true)]
TestRule::file_read(project_path, true),
]
} }
/// Standard profile - project + system libraries /// Standard profile - project + system libraries
@@ -319,14 +319,13 @@ pub mod profiles {
/// Network-only profile /// Network-only profile
pub fn network_only() -> Vec<TestRule> { pub fn network_only() -> Vec<TestRule> {
vec![ vec![TestRule::network_all()]
TestRule::network_all(),
]
} }
/// File-only profile /// File-only profile
pub fn file_only(paths: Vec<&str>) -> Vec<TestRule> { pub fn file_only(paths: Vec<&str>) -> Vec<TestRule> {
paths.into_iter() paths
.into_iter()
.map(|path| TestRule::file_read(path, true)) .map(|path| TestRule::file_read(path, true))
.collect() .collect()
} }

View File

@@ -15,7 +15,10 @@ pub fn is_sandboxing_supported() -> bool {
macro_rules! skip_if_unsupported { macro_rules! skip_if_unsupported {
() => { () => {
if !$crate::sandbox::common::is_sandboxing_supported() { if !$crate::sandbox::common::is_sandboxing_supported() {
eprintln!("Skipping test: sandboxing not supported on {}", std::env::consts::OS); eprintln!(
"Skipping test: sandboxing not supported on {}",
std::env::consts::OS
);
return; return;
} }
}; };
@@ -134,8 +137,7 @@ impl TestCommand {
use std::time::Instant; use std::time::Instant;
let start = Instant::now(); let start = Instant::now();
let mut child = cmd.spawn() let mut child = cmd.spawn().context("Failed to spawn command")?;
.context("Failed to spawn command")?;
loop { loop {
match child.try_wait() { match child.try_wait() {
@@ -162,8 +164,7 @@ impl TestCommand {
#[cfg(not(unix))] #[cfg(not(unix))]
{ {
// Fallback for non-Unix platforms // Fallback for non-Unix platforms
cmd.output() cmd.output().context("Failed to execute command")
.context("Failed to execute command")
} }
} }
@@ -198,11 +199,7 @@ impl TestCommand {
} }
/// Create a simple test binary that attempts an operation /// Create a simple test binary that attempts an operation
pub fn create_test_binary( pub fn create_test_binary(name: &str, code: &str, test_dir: &Path) -> Result<PathBuf> {
name: &str,
code: &str,
test_dir: &Path,
) -> Result<PathBuf> {
create_test_binary_with_deps(name, code, test_dir, &[]) create_test_binary_with_deps(name, code, test_dir, &[])
} }

View File

@@ -1,8 +1,8 @@
//! Common test utilities and helpers for sandbox testing //! Common test utilities and helpers for sandbox testing
pub mod claude_real;
pub mod fixtures; pub mod fixtures;
pub mod helpers; pub mod helpers;
pub mod claude_real;
pub use claude_real::*;
pub use fixtures::*; pub use fixtures::*;
pub use helpers::*; pub use helpers::*;
pub use claude_real::*;

View File

@@ -16,7 +16,8 @@ fn test_agent_with_minimal_profile() {
// Create minimal sandbox profile // Create minimal sandbox profile
let rules = profiles::minimal(&test_fs.project_path.to_string_lossy()); let rules = profiles::minimal(&test_fs.project_path.to_string_lossy());
let profile_id = test_db.create_test_profile("minimal_agent_test", rules) let profile_id = test_db
.create_test_profile("minimal_agent_test", rules)
.expect("Failed to create test profile"); .expect("Failed to create test profile");
// Create test agent // Create test agent
@@ -41,7 +42,8 @@ fn test_agent_with_minimal_profile() {
Some("sonnet"), Some("sonnet"),
Some(profile_id), Some(profile_id),
20, // 20 second timeout 20, // 20 second timeout
).expect("Failed to execute Claude command"); )
.expect("Failed to execute Claude command");
// Debug output // Debug output
eprintln!("=== Claude Output ==="); eprintln!("=== Claude Output ===");
@@ -52,8 +54,11 @@ fn test_agent_with_minimal_profile() {
eprintln!("==================="); eprintln!("===================");
// Basic verification - just check Claude ran // Basic verification - just check Claude ran
assert!(result.exit_code == 0 || result.exit_code == 124, // 0 = success, 124 = timeout assert!(
"Claude should execute (exit code: {})", result.exit_code); result.exit_code == 0 || result.exit_code == 124, // 0 = success, 124 = timeout
"Claude should execute (exit code: {})",
result.exit_code
);
} }
/// Test agent execution with standard sandbox profile /// Test agent execution with standard sandbox profile
@@ -69,7 +74,8 @@ fn test_agent_with_standard_profile() {
// Create standard sandbox profile // Create standard sandbox profile
let rules = profiles::standard(&test_fs.project_path.to_string_lossy()); let rules = profiles::standard(&test_fs.project_path.to_string_lossy());
let profile_id = test_db.create_test_profile("standard_agent_test", rules) let profile_id = test_db
.create_test_profile("standard_agent_test", rules)
.expect("Failed to create test profile"); .expect("Failed to create test profile");
// Create test agent // Create test agent
@@ -94,7 +100,8 @@ fn test_agent_with_standard_profile() {
Some("sonnet"), Some("sonnet"),
Some(profile_id), Some(profile_id),
20, // 20 second timeout 20, // 20 second timeout
).expect("Failed to execute Claude command"); )
.expect("Failed to execute Claude command");
// Debug output // Debug output
eprintln!("=== Claude Output (Standard Profile) ==="); eprintln!("=== Claude Output (Standard Profile) ===");
@@ -104,8 +111,11 @@ fn test_agent_with_standard_profile() {
eprintln!("==================="); eprintln!("===================");
// Basic verification // Basic verification
assert!(result.exit_code == 0 || result.exit_code == 124, assert!(
"Claude should execute with standard profile (exit code: {})", result.exit_code); result.exit_code == 0 || result.exit_code == 124,
"Claude should execute with standard profile (exit code: {})",
result.exit_code
);
} }
/// Test agent execution without sandbox (control test) /// Test agent execution without sandbox (control test)
@@ -120,7 +130,9 @@ fn test_agent_without_sandbox() {
test_db.reset().expect("Failed to reset database"); test_db.reset().expect("Failed to reset database");
// Create agent without sandbox profile // Create agent without sandbox profile
test_db.conn.execute( test_db
.conn
.execute(
"INSERT INTO agents (name, icon, system_prompt, model) VALUES (?1, ?2, ?3, ?4)", "INSERT INTO agents (name, icon, system_prompt, model) VALUES (?1, ?2, ?3, ?4)",
rusqlite::params![ rusqlite::params![
"Unsandboxed Agent", "Unsandboxed Agent",
@@ -128,7 +140,8 @@ fn test_agent_without_sandbox() {
"You are a test agent without sandbox restrictions.", "You are a test agent without sandbox restrictions.",
"sonnet" "sonnet"
], ],
).expect("Failed to create agent"); )
.expect("Failed to create agent");
let _agent_id = test_db.conn.last_insert_rowid(); let _agent_id = test_db.conn.last_insert_rowid();
@@ -140,7 +153,8 @@ fn test_agent_without_sandbox() {
Some("sonnet"), Some("sonnet"),
None, // No sandbox profile None, // No sandbox profile
20, // 20 second timeout 20, // 20 second timeout
).expect("Failed to execute Claude command"); )
.expect("Failed to execute Claude command");
// Debug output // Debug output
eprintln!("=== Claude Output (No Sandbox) ==="); eprintln!("=== Claude Output (No Sandbox) ===");
@@ -150,8 +164,11 @@ fn test_agent_without_sandbox() {
eprintln!("==================="); eprintln!("===================");
// Basic verification // Basic verification
assert!(result.exit_code == 0 || result.exit_code == 124, assert!(
"Claude should execute without sandbox (exit code: {})", result.exit_code); result.exit_code == 0 || result.exit_code == 124,
"Claude should execute without sandbox (exit code: {})",
result.exit_code
);
} }
/// Test agent run violation logging /// Test agent run violation logging
@@ -165,7 +182,8 @@ fn test_agent_run_violation_logging() {
test_db.reset().expect("Failed to reset database"); test_db.reset().expect("Failed to reset database");
// Create a test profile first // Create a test profile first
let profile_id = test_db.create_test_profile("violation_test", vec![]) let profile_id = test_db
.create_test_profile("violation_test", vec![])
.expect("Failed to create test profile"); .expect("Failed to create test profile");
// Create a test agent // Create a test agent
@@ -205,11 +223,14 @@ fn test_agent_run_violation_logging() {
).expect("Failed to insert violation"); ).expect("Failed to insert violation");
// Query violations // Query violations
let count: i64 = test_db.conn.query_row( let count: i64 = test_db
.conn
.query_row(
"SELECT COUNT(*) FROM sandbox_violations WHERE agent_id = ?1", "SELECT COUNT(*) FROM sandbox_violations WHERE agent_id = ?1",
rusqlite::params![agent_id], rusqlite::params![agent_id],
|row| row.get(0), |row| row.get(0),
).expect("Failed to query violations"); )
.expect("Failed to query violations");
assert_eq!(count, 1, "Should have recorded one violation"); assert_eq!(count, 1, "Should have recorded one violation");
} }
@@ -227,11 +248,13 @@ fn test_profile_switching() {
// Create two different profiles // Create two different profiles
let minimal_rules = profiles::minimal(&test_fs.project_path.to_string_lossy()); let minimal_rules = profiles::minimal(&test_fs.project_path.to_string_lossy());
let minimal_id = test_db.create_test_profile("minimal_switch", minimal_rules) let minimal_id = test_db
.create_test_profile("minimal_switch", minimal_rules)
.expect("Failed to create minimal profile"); .expect("Failed to create minimal profile");
let standard_rules = profiles::standard(&test_fs.project_path.to_string_lossy()); let standard_rules = profiles::standard(&test_fs.project_path.to_string_lossy());
let standard_id = test_db.create_test_profile("standard_switch", standard_rules) let standard_id = test_db
.create_test_profile("standard_switch", standard_rules)
.expect("Failed to create standard profile"); .expect("Failed to create standard profile");
// Create agent initially with minimal profile // Create agent initially with minimal profile
@@ -249,17 +272,23 @@ fn test_profile_switching() {
let agent_id = test_db.conn.last_insert_rowid(); let agent_id = test_db.conn.last_insert_rowid();
// Update agent to use standard profile // Update agent to use standard profile
test_db.conn.execute( test_db
.conn
.execute(
"UPDATE agents SET sandbox_profile_id = ?1 WHERE id = ?2", "UPDATE agents SET sandbox_profile_id = ?1 WHERE id = ?2",
rusqlite::params![standard_id, agent_id], rusqlite::params![standard_id, agent_id],
).expect("Failed to update agent profile"); )
.expect("Failed to update agent profile");
// Verify profile was updated // Verify profile was updated
let current_profile: i64 = test_db.conn.query_row( let current_profile: i64 = test_db
.conn
.query_row(
"SELECT sandbox_profile_id FROM agents WHERE id = ?1", "SELECT sandbox_profile_id FROM agents WHERE id = ?1",
rusqlite::params![agent_id], rusqlite::params![agent_id],
|row| row.get(0), |row| row.get(0),
).expect("Failed to query agent profile"); )
.expect("Failed to query agent profile");
assert_eq!(current_profile, standard_id, "Profile should be updated"); assert_eq!(current_profile, standard_id, "Profile should be updated");
} }

View File

@@ -16,14 +16,18 @@ fn test_claude_with_default_sandbox() {
// Create default sandbox profile // Create default sandbox profile
let rules = profiles::standard(&test_fs.project_path.to_string_lossy()); let rules = profiles::standard(&test_fs.project_path.to_string_lossy());
let profile_id = test_db.create_test_profile("claude_default", rules) let profile_id = test_db
.create_test_profile("claude_default", rules)
.expect("Failed to create test profile"); .expect("Failed to create test profile");
// Set as default and active // Set as default and active
test_db.conn.execute( test_db
.conn
.execute(
"UPDATE sandbox_profiles SET is_default = 1, is_active = 1 WHERE id = ?1", "UPDATE sandbox_profiles SET is_default = 1, is_active = 1 WHERE id = ?1",
rusqlite::params![profile_id], rusqlite::params![profile_id],
).expect("Failed to set default profile"); )
.expect("Failed to set default profile");
// Execute real Claude command with default sandbox profile // Execute real Claude command with default sandbox profile
let result = execute_claude_task( let result = execute_claude_task(
@@ -33,7 +37,8 @@ fn test_claude_with_default_sandbox() {
Some("sonnet"), Some("sonnet"),
Some(profile_id), Some(profile_id),
20, // 20 second timeout 20, // 20 second timeout
).expect("Failed to execute Claude command"); )
.expect("Failed to execute Claude command");
// Debug output // Debug output
eprintln!("=== Claude Output (Default Sandbox) ==="); eprintln!("=== Claude Output (Default Sandbox) ===");
@@ -43,8 +48,11 @@ fn test_claude_with_default_sandbox() {
eprintln!("==================="); eprintln!("===================");
// Basic verification // Basic verification
assert!(result.exit_code == 0 || result.exit_code == 124, assert!(
"Claude should execute with default sandbox (exit code: {})", result.exit_code); result.exit_code == 0 || result.exit_code == 124,
"Claude should execute with default sandbox (exit code: {})",
result.exit_code
);
} }
/// Test Claude Code with sandboxing disabled /// Test Claude Code with sandboxing disabled
@@ -60,14 +68,18 @@ fn test_claude_sandbox_disabled() {
// Create profile but mark as inactive // Create profile but mark as inactive
let rules = profiles::standard(&test_fs.project_path.to_string_lossy()); let rules = profiles::standard(&test_fs.project_path.to_string_lossy());
let profile_id = test_db.create_test_profile("claude_inactive", rules) let profile_id = test_db
.create_test_profile("claude_inactive", rules)
.expect("Failed to create test profile"); .expect("Failed to create test profile");
// Set as default but inactive // Set as default but inactive
test_db.conn.execute( test_db
.conn
.execute(
"UPDATE sandbox_profiles SET is_default = 1, is_active = 0 WHERE id = ?1", "UPDATE sandbox_profiles SET is_default = 1, is_active = 0 WHERE id = ?1",
rusqlite::params![profile_id], rusqlite::params![profile_id],
).expect("Failed to set inactive profile"); )
.expect("Failed to set inactive profile");
// Execute real Claude command without active sandbox // Execute real Claude command without active sandbox
let result = execute_claude_task( let result = execute_claude_task(
@@ -77,7 +89,8 @@ fn test_claude_sandbox_disabled() {
Some("sonnet"), Some("sonnet"),
None, // No sandbox since profile is inactive None, // No sandbox since profile is inactive
20, // 20 second timeout 20, // 20 second timeout
).expect("Failed to execute Claude command"); )
.expect("Failed to execute Claude command");
// Debug output // Debug output
eprintln!("=== Claude Output (Inactive Sandbox) ==="); eprintln!("=== Claude Output (Inactive Sandbox) ===");
@@ -87,8 +100,11 @@ fn test_claude_sandbox_disabled() {
eprintln!("==================="); eprintln!("===================");
// Basic verification // Basic verification
assert!(result.exit_code == 0 || result.exit_code == 124, assert!(
"Claude should execute without active sandbox (exit code: {})", result.exit_code); result.exit_code == 0 || result.exit_code == 124,
"Claude should execute without active sandbox (exit code: {})",
result.exit_code
);
} }
/// Test Claude Code session operations /// Test Claude Code session operations
@@ -144,17 +160,22 @@ fn test_claude_settings_sandbox_config() {
"model": "sonnet" "model": "sonnet"
}); });
std::fs::write(&settings_file, serde_json::to_string_pretty(&settings).unwrap()) std::fs::write(
&settings_file,
serde_json::to_string_pretty(&settings).unwrap(),
)
.expect("Failed to write settings"); .expect("Failed to write settings");
// Read and verify settings // Read and verify settings
let content = std::fs::read_to_string(&settings_file) let content = std::fs::read_to_string(&settings_file).expect("Failed to read settings");
.expect("Failed to read settings"); let parsed: serde_json::Value =
let parsed: serde_json::Value = serde_json::from_str(&content) serde_json::from_str(&content).expect("Failed to parse settings");
.expect("Failed to parse settings");
assert_eq!(parsed["sandboxEnabled"], true, "Sandbox should be enabled"); assert_eq!(parsed["sandboxEnabled"], true, "Sandbox should be enabled");
assert_eq!(parsed["defaultSandboxProfile"], "standard", "Default profile should be standard"); assert_eq!(
parsed["defaultSandboxProfile"], "standard",
"Default profile should be standard"
);
} }
/// Test profile-based file access restrictions /// Test profile-based file access restrictions
@@ -175,7 +196,8 @@ fn test_profile_file_access_simulation() {
TestRule::file_read("/etc/hosts", false), // Literal file TestRule::file_read("/etc/hosts", false), // Literal file
]; ];
let profile_id = test_db.create_test_profile("file_access_test", custom_rules) let profile_id = test_db
.create_test_profile("file_access_test", custom_rules)
.expect("Failed to create test profile"); .expect("Failed to create test profile");
// Load the profile rules // Load the profile rules
@@ -191,6 +213,8 @@ fn test_profile_file_access_simulation() {
// Verify rules were created correctly // Verify rules were created correctly
assert_eq!(loaded_rules.len(), 3, "Should have 3 rules"); assert_eq!(loaded_rules.len(), 3, "Should have 3 rules");
assert!(loaded_rules.iter().any(|(op, _, _)| op == "file_read_all"), assert!(
"Should have file_read_all operation"); loaded_rules.iter().any(|(op, _, _)| op == "file_read_all"),
"Should have file_read_all operation"
);
} }

View File

@@ -3,7 +3,7 @@ use crate::sandbox::common::*;
use crate::skip_if_unsupported; use crate::skip_if_unsupported;
use claudia_lib::sandbox::executor::SandboxExecutor; use claudia_lib::sandbox::executor::SandboxExecutor;
use claudia_lib::sandbox::profile::ProfileBuilder; use claudia_lib::sandbox::profile::ProfileBuilder;
use gaol::profile::{Profile, Operation, PathPattern}; use gaol::profile::{Operation, PathPattern, Profile};
use serial_test::serial; use serial_test::serial;
use tempfile::TempDir; use tempfile::TempDir;
@@ -23,9 +23,9 @@ fn test_allowed_file_read() {
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
// Create profile allowing project path access // Create profile allowing project path access
let operations = vec![ let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), test_fs.project_path.clone(),
]; ))];
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
Ok(p) => p, Ok(p) => p,
@@ -74,9 +74,9 @@ fn test_forbidden_file_read() {
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
// Create profile allowing only project path (not forbidden path) // Create profile allowing only project path (not forbidden path)
let operations = vec![ let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), test_fs.project_path.clone(),
]; ))];
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
Ok(p) => p, Ok(p) => p,
@@ -105,7 +105,9 @@ fn test_forbidden_file_read() {
// On some platforms (like macOS), gaol might not block all file reads // On some platforms (like macOS), gaol might not block all file reads
// so we check if the operation failed OR if it's a platform limitation // so we check if the operation failed OR if it's a platform limitation
if status.success() { if status.success() {
eprintln!("WARNING: File read was not blocked - this might be a platform limitation"); eprintln!(
"WARNING: File read was not blocked - this might be a platform limitation"
);
// Check if we're on a platform where this is expected // Check if we're on a platform where this is expected
let platform_config = PlatformConfig::current(); let platform_config = PlatformConfig::current();
if !platform_config.supports_file_read { if !platform_config.supports_file_read {
@@ -129,9 +131,9 @@ fn test_file_write_always_forbidden() {
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
// Create profile with file read permissions (write should still be blocked) // Create profile with file read permissions (write should still be blocked)
let operations = vec![ let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), test_fs.project_path.clone(),
]; ))];
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
Ok(p) => p, Ok(p) => p,
@@ -189,14 +191,14 @@ fn test_file_metadata_operations() {
// Create profile with metadata read permission // Create profile with metadata read permission
let operations = if platform.supports_metadata_read { let operations = if platform.supports_metadata_read {
vec![ vec![Operation::FileReadMetadata(PathPattern::Subpath(
Operation::FileReadMetadata(PathPattern::Subpath(test_fs.project_path.clone())), test_fs.project_path.clone(),
] ))]
} else { } else {
// On Linux, metadata is allowed if file read is allowed // On Linux, metadata is allowed if file read is allowed
vec![ vec![Operation::FileReadAll(PathPattern::Subpath(
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), test_fs.project_path.clone(),
] ))]
}; };
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
@@ -224,7 +226,10 @@ fn test_file_metadata_operations() {
Ok(mut child) => { Ok(mut child) => {
let status = child.wait().expect("Failed to wait for child"); let status = child.wait().expect("Failed to wait for child");
if platform.supports_metadata_read || platform.supports_file_read { if platform.supports_metadata_read || platform.supports_file_read {
assert!(status.success(), "Metadata read should succeed when allowed"); assert!(
status.success(),
"Metadata read should succeed when allowed"
);
} }
} }
Err(e) => { Err(e) => {
@@ -251,11 +256,10 @@ fn test_template_variable_expansion() {
// Create a profile with template variables // Create a profile with template variables
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
let rules = vec![ let rules = vec![TestRule::file_read("{{PROJECT_PATH}}", true)];
TestRule::file_read("{{PROJECT_PATH}}", true),
];
let profile_id = test_db.create_test_profile("template_test", rules) let profile_id = test_db
.create_test_profile("template_test", rules)
.expect("Failed to create test profile"); .expect("Failed to create test profile");
// Load and build the profile // Load and build the profile

View File

@@ -4,8 +4,8 @@ mod file_operations;
#[cfg(test)] #[cfg(test)]
mod network_operations; mod network_operations;
#[cfg(test)] #[cfg(test)]
mod system_info;
#[cfg(test)]
mod process_isolation; mod process_isolation;
#[cfg(test)] #[cfg(test)]
mod system_info;
#[cfg(test)]
mod violations; mod violations;

View File

@@ -2,7 +2,7 @@
use crate::sandbox::common::*; use crate::sandbox::common::*;
use crate::skip_if_unsupported; use crate::skip_if_unsupported;
use claudia_lib::sandbox::executor::SandboxExecutor; use claudia_lib::sandbox::executor::SandboxExecutor;
use gaol::profile::{Profile, Operation, AddressPattern}; use gaol::profile::{AddressPattern, Operation, Profile};
use serial_test::serial; use serial_test::serial;
use std::net::TcpListener; use std::net::TcpListener;
use tempfile::TempDir; use tempfile::TempDir;
@@ -10,7 +10,10 @@ use tempfile::TempDir;
/// Get an available port for testing /// Get an available port for testing
fn get_available_port() -> u16 { fn get_available_port() -> u16 {
let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind to 0"); let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind to 0");
let port = listener.local_addr().expect("Failed to get local addr").port(); let port = listener
.local_addr()
.expect("Failed to get local addr")
.port();
drop(listener); // Release the port drop(listener); // Release the port
port port
} }
@@ -31,9 +34,7 @@ fn test_allowed_network_all() {
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
// Create profile allowing all network access // Create profile allowing all network access
let operations = vec![ let operations = vec![Operation::NetworkOutbound(AddressPattern::All)];
Operation::NetworkOutbound(AddressPattern::All),
];
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
Ok(p) => p, Ok(p) => p,
@@ -51,8 +52,8 @@ fn test_allowed_network_all() {
.expect("Failed to create test binary"); .expect("Failed to create test binary");
// Start a listener on the port // Start a listener on the port
let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) let listener =
.expect("Failed to bind listener"); TcpListener::bind(format!("127.0.0.1:{}", port)).expect("Failed to bind listener");
// Execute in sandbox // Execute in sandbox
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
@@ -68,7 +69,10 @@ fn test_allowed_network_all() {
}); });
let status = child.wait().expect("Failed to wait for child"); let status = child.wait().expect("Failed to wait for child");
assert!(status.success(), "Network connection should succeed when allowed"); assert!(
status.success(),
"Network connection should succeed when allowed"
);
} }
Err(e) => { Err(e) => {
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e); eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
@@ -86,9 +90,9 @@ fn test_forbidden_network() {
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
// Create profile without network permissions // Create profile without network permissions
let operations = vec![ let operations = vec![Operation::FileReadAll(gaol::profile::PathPattern::Subpath(
Operation::FileReadAll(gaol::profile::PathPattern::Subpath(test_fs.project_path.clone())), test_fs.project_path.clone(),
]; ))];
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
Ok(p) => p, Ok(p) => p,
@@ -146,9 +150,9 @@ fn test_network_tcp_port_specific() {
let forbidden_port = get_available_port(); let forbidden_port = get_available_port();
// Create profile allowing only specific port // Create profile allowing only specific port
let operations = vec![ let operations = vec![Operation::NetworkOutbound(AddressPattern::Tcp(
Operation::NetworkOutbound(AddressPattern::Tcp(allowed_port)), allowed_port,
]; ))];
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
Ok(p) => p, Ok(p) => p,
@@ -180,7 +184,10 @@ fn test_network_tcp_port_specific() {
}); });
let status = child.wait().expect("Failed to wait for child"); let status = child.wait().expect("Failed to wait for child");
assert!(status.success(), "Connection to allowed port should succeed"); assert!(
status.success(),
"Connection to allowed port should succeed"
);
} }
Err(e) => { Err(e) => {
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e); eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
@@ -203,7 +210,10 @@ fn test_network_tcp_port_specific() {
) { ) {
Ok(mut child) => { Ok(mut child) => {
let status = child.wait().expect("Failed to wait for child"); let status = child.wait().expect("Failed to wait for child");
assert!(!status.success(), "Connection to forbidden port should fail"); assert!(
!status.success(),
"Connection to forbidden port should fail"
);
} }
Err(e) => { Err(e) => {
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e); eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
@@ -227,14 +237,12 @@ fn test_local_socket_connections() {
// Create appropriate profile based on platform // Create appropriate profile based on platform
let operations = if platform.supports_network_local { let operations = if platform.supports_network_local {
vec![ vec![Operation::NetworkOutbound(AddressPattern::LocalSocket(
Operation::NetworkOutbound(AddressPattern::LocalSocket(socket_path.clone())), socket_path.clone(),
] ))]
} else if platform.supports_network_all { } else if platform.supports_network_all {
// Fallback to allowing all network // Fallback to allowing all network
vec![ vec![Operation::NetworkOutbound(AddressPattern::All)]
Operation::NetworkOutbound(AddressPattern::All),
]
} else { } else {
eprintln!("Skipping test: no network support on this platform"); eprintln!("Skipping test: no network support on this platform");
return; return;
@@ -289,7 +297,10 @@ fn main() {{
}); });
let status = child.wait().expect("Failed to wait for child"); let status = child.wait().expect("Failed to wait for child");
assert!(status.success(), "Local socket connection should succeed when allowed"); assert!(
status.success(),
"Local socket connection should succeed when allowed"
);
} }
Err(e) => { Err(e) => {
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e); eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);

View File

@@ -2,7 +2,7 @@
use crate::sandbox::common::*; use crate::sandbox::common::*;
use crate::skip_if_unsupported; use crate::skip_if_unsupported;
use claudia_lib::sandbox::executor::SandboxExecutor; use claudia_lib::sandbox::executor::SandboxExecutor;
use gaol::profile::{Profile, Operation, PathPattern, AddressPattern}; use gaol::profile::{AddressPattern, Operation, PathPattern, Profile};
use serial_test::serial; use serial_test::serial;
use tempfile::TempDir; use tempfile::TempDir;
@@ -49,7 +49,10 @@ fn test_process_spawn_forbidden() {
eprintln!("WARNING: Process spawning was not blocked"); eprintln!("WARNING: Process spawning was not blocked");
// macOS sandbox might have limitations // macOS sandbox might have limitations
if std::env::consts::OS != "linux" { if std::env::consts::OS != "linux" {
eprintln!("Process spawning might not be fully blocked on {}", std::env::consts::OS); eprintln!(
"Process spawning might not be fully blocked on {}",
std::env::consts::OS
);
} else { } else {
panic!("Process spawning should be blocked on Linux"); panic!("Process spawning should be blocked on Linux");
} }
@@ -72,9 +75,9 @@ fn test_fork_forbidden() {
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
// Create minimal profile // Create minimal profile
let operations = vec![ let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), test_fs.project_path.clone(),
]; ))];
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
Ok(p) => p, Ok(p) => p,
@@ -88,7 +91,12 @@ fn test_fork_forbidden() {
let test_code = test_code::fork_process(); let test_code = test_code::fork_process();
let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_dir = TempDir::new().expect("Failed to create temp dir");
let binary_path = create_test_binary_with_deps("test_fork", test_code, binary_dir.path(), &[("libc", "0.2")]) let binary_path = create_test_binary_with_deps(
"test_fork",
test_code,
binary_dir.path(),
&[("libc", "0.2")],
)
.expect("Failed to create test binary"); .expect("Failed to create test binary");
// Execute in sandbox // Execute in sandbox
@@ -125,9 +133,9 @@ fn test_exec_forbidden() {
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
// Create minimal profile // Create minimal profile
let operations = vec![ let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), test_fs.project_path.clone(),
]; ))];
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
Ok(p) => p, Ok(p) => p,
@@ -141,7 +149,12 @@ fn test_exec_forbidden() {
let test_code = test_code::exec_process(); let test_code = test_code::exec_process();
let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_dir = TempDir::new().expect("Failed to create temp dir");
let binary_path = create_test_binary_with_deps("test_exec", test_code, binary_dir.path(), &[("libc", "0.2")]) let binary_path = create_test_binary_with_deps(
"test_exec",
test_code,
binary_dir.path(),
&[("libc", "0.2")],
)
.expect("Failed to create test binary"); .expect("Failed to create test binary");
// Execute in sandbox // Execute in sandbox
@@ -177,9 +190,9 @@ fn test_thread_creation_allowed() {
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
// Create minimal profile // Create minimal profile
let operations = vec![ let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), test_fs.project_path.clone(),
]; ))];
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
Ok(p) => p, Ok(p) => p,

View File

@@ -2,7 +2,7 @@
use crate::sandbox::common::*; use crate::sandbox::common::*;
use crate::skip_if_unsupported; use crate::skip_if_unsupported;
use claudia_lib::sandbox::executor::SandboxExecutor; use claudia_lib::sandbox::executor::SandboxExecutor;
use gaol::profile::{Profile, Operation}; use gaol::profile::{Operation, Profile};
use serial_test::serial; use serial_test::serial;
use tempfile::TempDir; use tempfile::TempDir;
@@ -22,9 +22,7 @@ fn test_system_info_read() {
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
// Create profile allowing system info read // Create profile allowing system info read
let operations = vec![ let operations = vec![Operation::SystemInfoRead];
Operation::SystemInfoRead,
];
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
Ok(p) => p, Ok(p) => p,
@@ -49,7 +47,10 @@ fn test_system_info_read() {
) { ) {
Ok(mut child) => { Ok(mut child) => {
let status = child.wait().expect("Failed to wait for child"); let status = child.wait().expect("Failed to wait for child");
assert!(status.success(), "System info read should succeed when allowed"); assert!(
status.success(),
"System info read should succeed when allowed"
);
} }
Err(e) => { Err(e) => {
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e); eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
@@ -66,9 +67,9 @@ fn test_forbidden_system_info() {
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
// Create profile without system info permission // Create profile without system info permission
let operations = vec![ let operations = vec![Operation::FileReadAll(gaol::profile::PathPattern::Subpath(
Operation::FileReadAll(gaol::profile::PathPattern::Subpath(test_fs.project_path.clone())), test_fs.project_path.clone(),
]; ))];
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
Ok(p) => p, Ok(p) => p,
@@ -124,18 +125,24 @@ fn test_platform_specific_system_info() {
match std::env::consts::OS { match std::env::consts::OS {
"linux" => { "linux" => {
// On Linux, system info is never allowed // On Linux, system info is never allowed
assert!(!platform.supports_system_info, assert!(
"Linux should not support system info read"); !platform.supports_system_info,
"Linux should not support system info read"
);
} }
"macos" => { "macos" => {
// On macOS, system info can be allowed // On macOS, system info can be allowed
assert!(platform.supports_system_info, assert!(
"macOS should support system info read"); platform.supports_system_info,
"macOS should support system info read"
);
} }
"freebsd" => { "freebsd" => {
// On FreeBSD, system info is always allowed (can't be restricted) // On FreeBSD, system info is always allowed (can't be restricted)
assert!(platform.supports_system_info, assert!(
"FreeBSD always allows system info read"); platform.supports_system_info,
"FreeBSD always allows system info read"
);
} }
_ => { _ => {
eprintln!("Unknown platform behavior for system info"); eprintln!("Unknown platform behavior for system info");

View File

@@ -2,7 +2,7 @@
use crate::sandbox::common::*; use crate::sandbox::common::*;
use crate::skip_if_unsupported; use crate::skip_if_unsupported;
use claudia_lib::sandbox::executor::SandboxExecutor; use claudia_lib::sandbox::executor::SandboxExecutor;
use gaol::profile::{Profile, Operation, PathPattern}; use gaol::profile::{Operation, PathPattern, Profile};
use serial_test::serial; use serial_test::serial;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use tempfile::TempDir; use tempfile::TempDir;
@@ -62,9 +62,9 @@ fn test_violation_detection() {
let collector = ViolationCollector::new(); let collector = ViolationCollector::new();
// Create profile allowing only project path // Create profile allowing only project path
let operations = vec![ let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), test_fs.project_path.clone(),
]; ))];
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
Ok(p) => p, Ok(p) => p,
@@ -76,9 +76,21 @@ fn test_violation_detection() {
// Test various forbidden operations // Test various forbidden operations
let test_cases = vec![ let test_cases = vec![
("file_read", test_code::file_read(&test_fs.forbidden_path.join("secret.txt").to_string_lossy()), "file_read_forbidden"), (
("file_write", test_code::file_write(&test_fs.project_path.join("new.txt").to_string_lossy()), "file_write_forbidden"), "file_read",
("process_spawn", test_code::spawn_process().to_string(), "process_spawn_forbidden"), test_code::file_read(&test_fs.forbidden_path.join("secret.txt").to_string_lossy()),
"file_read_forbidden",
),
(
"file_write",
test_code::file_write(&test_fs.project_path.join("new.txt").to_string_lossy()),
"file_write_forbidden",
),
(
"process_spawn",
test_code::spawn_process().to_string(),
"process_spawn_forbidden",
),
]; ];
for (op_type, test_code, binary_name) in test_cases { for (op_type, test_code, binary_name) in test_cases {
@@ -150,7 +162,11 @@ fn test_violation_patterns() {
}; };
// Test accessing different forbidden paths // Test accessing different forbidden paths
let forbidden_db_path = test_fs.forbidden_path.join("data.db").to_string_lossy().to_string(); let forbidden_db_path = test_fs
.forbidden_path
.join("data.db")
.to_string_lossy()
.to_string();
let forbidden_paths = vec![ let forbidden_paths = vec![
("/etc/passwd", "system_file"), ("/etc/passwd", "system_file"),
("/tmp/test.txt", "temp_file"), ("/tmp/test.txt", "temp_file"),
@@ -173,7 +189,10 @@ fn test_violation_patterns() {
let status = child.wait().expect("Failed to wait for child"); let status = child.wait().expect("Failed to wait for child");
// Some platforms might not block all file access // Some platforms might not block all file access
if status.success() { if status.success() {
eprintln!("WARNING: Access to {} was allowed (possible platform limitation)", path); eprintln!(
"WARNING: Access to {} was allowed (possible platform limitation)",
path
);
if std::env::consts::OS == "linux" && path.starts_with("/etc") { if std::env::consts::OS == "linux" && path.starts_with("/etc") {
panic!("Access to {} should be denied on Linux", path); panic!("Access to {} should be denied on Linux", path);
} }
@@ -196,9 +215,9 @@ fn test_multiple_violations_sequence() {
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
// Create minimal profile // Create minimal profile
let operations = vec![ let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), test_fs.project_path.clone(),
]; ))];
let profile = match Profile::new(operations) { let profile = match Profile::new(operations) {
Ok(p) => p, Ok(p) => p,

View File

@@ -1,6 +1,6 @@
//! Unit tests for SandboxExecutor //! Unit tests for SandboxExecutor
use claudia_lib::sandbox::executor::{SandboxExecutor, should_activate_sandbox}; use claudia_lib::sandbox::executor::{should_activate_sandbox, SandboxExecutor};
use gaol::profile::{Profile, Operation, PathPattern, AddressPattern}; use gaol::profile::{AddressPattern, Operation, PathPattern, Profile};
use std::env; use std::env;
use std::path::PathBuf; use std::path::PathBuf;
@@ -27,15 +27,24 @@ fn test_executor_creation() {
fn test_should_activate_sandbox_env_var() { fn test_should_activate_sandbox_env_var() {
// Test when env var is not set // Test when env var is not set
env::remove_var("GAOL_SANDBOX_ACTIVE"); env::remove_var("GAOL_SANDBOX_ACTIVE");
assert!(!should_activate_sandbox(), "Should not activate when env var is not set"); assert!(
!should_activate_sandbox(),
"Should not activate when env var is not set"
);
// Test when env var is set to "1" // Test when env var is set to "1"
env::set_var("GAOL_SANDBOX_ACTIVE", "1"); env::set_var("GAOL_SANDBOX_ACTIVE", "1");
assert!(should_activate_sandbox(), "Should activate when env var is '1'"); assert!(
should_activate_sandbox(),
"Should activate when env var is '1'"
);
// Test when env var is set to other value // Test when env var is set to other value
env::set_var("GAOL_SANDBOX_ACTIVE", "0"); env::set_var("GAOL_SANDBOX_ACTIVE", "0");
assert!(!should_activate_sandbox(), "Should not activate when env var is not '1'"); assert!(
!should_activate_sandbox(),
"Should not activate when env var is not '1'"
);
// Clean up // Clean up
env::remove_var("GAOL_SANDBOX_ACTIVE"); env::remove_var("GAOL_SANDBOX_ACTIVE");
@@ -78,7 +87,8 @@ fn test_executor_with_complex_profile() {
]; ];
// Only create profile with supported operations // Only create profile with supported operations
let filtered_ops: Vec<_> = operations.into_iter() let filtered_ops: Vec<_> = operations
.into_iter()
.filter(|op| { .filter(|op| {
use gaol::profile::{OperationSupport, OperationSupportLevel}; use gaol::profile::{OperationSupport, OperationSupportLevel};
matches!(op.support(), OperationSupportLevel::CanBeAllowed) matches!(op.support(), OperationSupportLevel::CanBeAllowed)

View File

@@ -1,7 +1,7 @@
//! Unit tests for sandbox components //! Unit tests for sandbox components
#[cfg(test)] #[cfg(test)]
mod profile_builder; mod executor;
#[cfg(test)] #[cfg(test)]
mod platform; mod platform;
#[cfg(test)] #[cfg(test)]
mod executor; mod profile_builder;

View File

@@ -1,7 +1,7 @@
//! Unit tests for platform capabilities //! Unit tests for platform capabilities
use claudia_lib::sandbox::platform::{get_platform_capabilities, is_sandboxing_available}; use claudia_lib::sandbox::platform::{get_platform_capabilities, is_sandboxing_available};
use std::env;
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use std::env;
#[test] #[test]
fn test_sandboxing_availability() { fn test_sandboxing_availability() {
@@ -20,9 +20,14 @@ fn test_platform_capabilities_structure() {
// Verify basic structure // Verify basic structure
assert_eq!(caps.os, env::consts::OS, "OS should match current platform"); assert_eq!(caps.os, env::consts::OS, "OS should match current platform");
assert!(!caps.operations.is_empty() || !caps.sandboxing_supported, assert!(
"Should have operations if sandboxing is supported"); !caps.operations.is_empty() || !caps.sandboxing_supported,
assert!(!caps.notes.is_empty(), "Should have platform-specific notes"); "Should have operations if sandboxing is supported"
);
assert!(
!caps.notes.is_empty(),
"Should have platform-specific notes"
);
} }
#[test] #[test]
@@ -34,27 +39,37 @@ fn test_linux_capabilities() {
assert!(caps.sandboxing_supported); assert!(caps.sandboxing_supported);
// Verify Linux-specific capabilities // Verify Linux-specific capabilities
let file_read = caps.operations.iter() let file_read = caps
.operations
.iter()
.find(|op| op.operation == "file_read_all") .find(|op| op.operation == "file_read_all")
.expect("file_read_all should be present"); .expect("file_read_all should be present");
assert_eq!(file_read.support_level, "can_be_allowed"); assert_eq!(file_read.support_level, "can_be_allowed");
let metadata_read = caps.operations.iter() let metadata_read = caps
.operations
.iter()
.find(|op| op.operation == "file_read_metadata") .find(|op| op.operation == "file_read_metadata")
.expect("file_read_metadata should be present"); .expect("file_read_metadata should be present");
assert_eq!(metadata_read.support_level, "cannot_be_precisely"); assert_eq!(metadata_read.support_level, "cannot_be_precisely");
let network_all = caps.operations.iter() let network_all = caps
.operations
.iter()
.find(|op| op.operation == "network_outbound_all") .find(|op| op.operation == "network_outbound_all")
.expect("network_outbound_all should be present"); .expect("network_outbound_all should be present");
assert_eq!(network_all.support_level, "can_be_allowed"); assert_eq!(network_all.support_level, "can_be_allowed");
let network_tcp = caps.operations.iter() let network_tcp = caps
.operations
.iter()
.find(|op| op.operation == "network_outbound_tcp") .find(|op| op.operation == "network_outbound_tcp")
.expect("network_outbound_tcp should be present"); .expect("network_outbound_tcp should be present");
assert_eq!(network_tcp.support_level, "cannot_be_precisely"); assert_eq!(network_tcp.support_level, "cannot_be_precisely");
let system_info = caps.operations.iter() let system_info = caps
.operations
.iter()
.find(|op| op.operation == "system_info_read") .find(|op| op.operation == "system_info_read")
.expect("system_info_read should be present"); .expect("system_info_read should be present");
assert_eq!(system_info.support_level, "never"); assert_eq!(system_info.support_level, "never");
@@ -69,22 +84,30 @@ fn test_macos_capabilities() {
assert!(caps.sandboxing_supported); assert!(caps.sandboxing_supported);
// Verify macOS-specific capabilities // Verify macOS-specific capabilities
let file_read = caps.operations.iter() let file_read = caps
.operations
.iter()
.find(|op| op.operation == "file_read_all") .find(|op| op.operation == "file_read_all")
.expect("file_read_all should be present"); .expect("file_read_all should be present");
assert_eq!(file_read.support_level, "can_be_allowed"); assert_eq!(file_read.support_level, "can_be_allowed");
let metadata_read = caps.operations.iter() let metadata_read = caps
.operations
.iter()
.find(|op| op.operation == "file_read_metadata") .find(|op| op.operation == "file_read_metadata")
.expect("file_read_metadata should be present"); .expect("file_read_metadata should be present");
assert_eq!(metadata_read.support_level, "can_be_allowed"); assert_eq!(metadata_read.support_level, "can_be_allowed");
let network_tcp = caps.operations.iter() let network_tcp = caps
.operations
.iter()
.find(|op| op.operation == "network_outbound_tcp") .find(|op| op.operation == "network_outbound_tcp")
.expect("network_outbound_tcp should be present"); .expect("network_outbound_tcp should be present");
assert_eq!(network_tcp.support_level, "can_be_allowed"); assert_eq!(network_tcp.support_level, "can_be_allowed");
let system_info = caps.operations.iter() let system_info = caps
.operations
.iter()
.find(|op| op.operation == "system_info_read") .find(|op| op.operation == "system_info_read")
.expect("system_info_read should be present"); .expect("system_info_read should be present");
assert_eq!(system_info.support_level, "can_be_allowed"); assert_eq!(system_info.support_level, "can_be_allowed");
@@ -99,12 +122,16 @@ fn test_freebsd_capabilities() {
assert!(caps.sandboxing_supported); assert!(caps.sandboxing_supported);
// Verify FreeBSD-specific capabilities // Verify FreeBSD-specific capabilities
let file_read = caps.operations.iter() let file_read = caps
.operations
.iter()
.find(|op| op.operation == "file_read_all") .find(|op| op.operation == "file_read_all")
.expect("file_read_all should be present"); .expect("file_read_all should be present");
assert_eq!(file_read.support_level, "never"); assert_eq!(file_read.support_level, "never");
let system_info = caps.operations.iter() let system_info = caps
.operations
.iter()
.find(|op| op.operation == "system_info_read") .find(|op| op.operation == "system_info_read")
.expect("system_info_read should be present"); .expect("system_info_read should be present");
assert_eq!(system_info.support_level, "always"); assert_eq!(system_info.support_level, "always");
@@ -125,10 +152,16 @@ fn test_all_operations_have_descriptions() {
let caps = get_platform_capabilities(); let caps = get_platform_capabilities();
for op in &caps.operations { for op in &caps.operations {
assert!(!op.description.is_empty(), assert!(
"Operation {} should have a description", op.operation); !op.description.is_empty(),
assert!(!op.support_level.is_empty(), "Operation {} should have a description",
"Operation {} should have a support level", op.operation); op.operation
);
assert!(
!op.support_level.is_empty(),
"Operation {} should have a support level",
op.operation
);
} }
} }

View File

@@ -18,8 +18,7 @@ fn make_rule(
pattern_value: pattern_value.to_string(), pattern_value: pattern_value.to_string(),
enabled: true, enabled: true,
platform_support: platforms.map(|p| { platform_support: platforms.map(|p| {
serde_json::to_string(&p.iter().map(|s| s.to_string()).collect::<Vec<_>>()) serde_json::to_string(&p.iter().map(|s| s.to_string()).collect::<Vec<_>>()).unwrap()
.unwrap()
}), }),
created_at: String::new(), created_at: String::new(),
} }
@@ -30,7 +29,10 @@ fn test_profile_builder_creation() {
let project_path = PathBuf::from("/test/project"); let project_path = PathBuf::from("/test/project");
let builder = ProfileBuilder::new(project_path.clone()); let builder = ProfileBuilder::new(project_path.clone());
assert!(builder.is_ok(), "ProfileBuilder should be created successfully"); assert!(
builder.is_ok(),
"ProfileBuilder should be created successfully"
);
} }
#[test] #[test]
@@ -39,7 +41,10 @@ fn test_empty_rules_creates_empty_profile() {
let builder = ProfileBuilder::new(project_path).unwrap(); let builder = ProfileBuilder::new(project_path).unwrap();
let profile = builder.build_profile(vec![]); let profile = builder.build_profile(vec![]);
assert!(profile.is_ok(), "Empty rules should create valid empty profile"); assert!(
profile.is_ok(),
"Empty rules should create valid empty profile"
);
} }
#[test] #[test]
@@ -48,15 +53,28 @@ fn test_file_read_rule_parsing() {
let builder = ProfileBuilder::new(project_path.clone()).unwrap(); let builder = ProfileBuilder::new(project_path.clone()).unwrap();
let rules = vec![ let rules = vec![
make_rule("file_read_all", "literal", "/usr/lib/test.so", Some(&["linux", "macos"])), make_rule(
make_rule("file_read_all", "subpath", "/usr/lib", Some(&["linux", "macos"])), "file_read_all",
"literal",
"/usr/lib/test.so",
Some(&["linux", "macos"]),
),
make_rule(
"file_read_all",
"subpath",
"/usr/lib",
Some(&["linux", "macos"]),
),
]; ];
let _profile = builder.build_profile(rules); let _profile = builder.build_profile(rules);
// Profile creation might fail on unsupported platforms, but parsing should work // Profile creation might fail on unsupported platforms, but parsing should work
if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" { if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" {
assert!(_profile.is_ok(), "File read rules should be parsed on supported platforms"); assert!(
_profile.is_ok(),
"File read rules should be parsed on supported platforms"
);
} }
} }
@@ -68,13 +86,21 @@ fn test_network_rule_parsing() {
let rules = vec![ let rules = vec![
make_rule("network_outbound", "all", "", Some(&["linux", "macos"])), make_rule("network_outbound", "all", "", Some(&["linux", "macos"])),
make_rule("network_outbound", "tcp", "8080", Some(&["macos"])), make_rule("network_outbound", "tcp", "8080", Some(&["macos"])),
make_rule("network_outbound", "local_socket", "/tmp/socket", Some(&["macos"])), make_rule(
"network_outbound",
"local_socket",
"/tmp/socket",
Some(&["macos"]),
),
]; ];
let _profile = builder.build_profile(rules); let _profile = builder.build_profile(rules);
if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" { if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" {
assert!(_profile.is_ok(), "Network rules should be parsed on supported platforms"); assert!(
_profile.is_ok(),
"Network rules should be parsed on supported platforms"
);
} }
} }
@@ -83,14 +109,15 @@ fn test_system_info_rule_parsing() {
let project_path = PathBuf::from("/test/project"); let project_path = PathBuf::from("/test/project");
let builder = ProfileBuilder::new(project_path).unwrap(); let builder = ProfileBuilder::new(project_path).unwrap();
let rules = vec![ let rules = vec![make_rule("system_info_read", "all", "", Some(&["macos"]))];
make_rule("system_info_read", "all", "", Some(&["macos"])),
];
let _profile = builder.build_profile(rules); let _profile = builder.build_profile(rules);
if std::env::consts::OS == "macos" { if std::env::consts::OS == "macos" {
assert!(_profile.is_ok(), "System info rule should be parsed on macOS"); assert!(
_profile.is_ok(),
"System info rule should be parsed on macOS"
);
} }
} }
@@ -100,8 +127,18 @@ fn test_template_variable_replacement() {
let builder = ProfileBuilder::new(project_path.clone()).unwrap(); let builder = ProfileBuilder::new(project_path.clone()).unwrap();
let rules = vec![ let rules = vec![
make_rule("file_read_all", "subpath", "{{PROJECT_PATH}}/src", Some(&["linux", "macos"])), make_rule(
make_rule("file_read_all", "subpath", "{{HOME}}/.config", Some(&["linux", "macos"])), "file_read_all",
"subpath",
"{{PROJECT_PATH}}/src",
Some(&["linux", "macos"]),
),
make_rule(
"file_read_all",
"subpath",
"{{HOME}}/.config",
Some(&["linux", "macos"]),
),
]; ];
let _profile = builder.build_profile(rules); let _profile = builder.build_profile(rules);
@@ -114,7 +151,12 @@ fn test_disabled_rules_are_ignored() {
let project_path = PathBuf::from("/test/project"); let project_path = PathBuf::from("/test/project");
let builder = ProfileBuilder::new(project_path).unwrap(); let builder = ProfileBuilder::new(project_path).unwrap();
let mut rule = make_rule("file_read_all", "subpath", "/usr/lib", Some(&["linux", "macos"])); let mut rule = make_rule(
"file_read_all",
"subpath",
"/usr/lib",
Some(&["linux", "macos"]),
);
rule.enabled = false; rule.enabled = false;
let profile = builder.build_profile(vec![rule]); let profile = builder.build_profile(vec![rule]);
@@ -127,7 +169,11 @@ fn test_platform_filtering() {
let builder = ProfileBuilder::new(project_path).unwrap(); let builder = ProfileBuilder::new(project_path).unwrap();
let current_os = std::env::consts::OS; let current_os = std::env::consts::OS;
let other_os = if current_os == "linux" { "macos" } else { "linux" }; let other_os = if current_os == "linux" {
"macos"
} else {
"linux"
};
let rules = vec![ let rules = vec![
// Rule for current platform // Rule for current platform
@@ -135,7 +181,12 @@ fn test_platform_filtering() {
// Rule for other platform // Rule for other platform
make_rule("file_read_all", "subpath", "/test2", Some(&[other_os])), make_rule("file_read_all", "subpath", "/test2", Some(&[other_os])),
// Rule for both platforms // Rule for both platforms
make_rule("file_read_all", "subpath", "/test3", Some(&["linux", "macos"])), make_rule(
"file_read_all",
"subpath",
"/test3",
Some(&["linux", "macos"]),
),
// Rule with no platform specification (should be included) // Rule with no platform specification (should be included)
make_rule("file_read_all", "subpath", "/test4", None), make_rule("file_read_all", "subpath", "/test4", None),
]; ];
@@ -149,9 +200,12 @@ fn test_invalid_operation_type() {
let project_path = PathBuf::from("/test/project"); let project_path = PathBuf::from("/test/project");
let builder = ProfileBuilder::new(project_path).unwrap(); let builder = ProfileBuilder::new(project_path).unwrap();
let rules = vec![ let rules = vec![make_rule(
make_rule("invalid_operation", "subpath", "/test", Some(&["linux", "macos"])), "invalid_operation",
]; "subpath",
"/test",
Some(&["linux", "macos"]),
)];
let _profile = builder.build_profile(rules); let _profile = builder.build_profile(rules);
assert!(_profile.is_ok(), "Invalid operations should be skipped"); assert!(_profile.is_ok(), "Invalid operations should be skipped");
@@ -162,9 +216,12 @@ fn test_invalid_pattern_type() {
let project_path = PathBuf::from("/test/project"); let project_path = PathBuf::from("/test/project");
let builder = ProfileBuilder::new(project_path).unwrap(); let builder = ProfileBuilder::new(project_path).unwrap();
let rules = vec![ let rules = vec![make_rule(
make_rule("file_read_all", "invalid_pattern", "/test", Some(&["linux", "macos"])), "file_read_all",
]; "invalid_pattern",
"/test",
Some(&["linux", "macos"]),
)];
let _profile = builder.build_profile(rules); let _profile = builder.build_profile(rules);
// Should either skip the rule or fail gracefully // Should either skip the rule or fail gracefully
@@ -175,9 +232,12 @@ fn test_invalid_tcp_port() {
let project_path = PathBuf::from("/test/project"); let project_path = PathBuf::from("/test/project");
let builder = ProfileBuilder::new(project_path).unwrap(); let builder = ProfileBuilder::new(project_path).unwrap();
let rules = vec![ let rules = vec![make_rule(
make_rule("network_outbound", "tcp", "not_a_number", Some(&["macos"])), "network_outbound",
]; "tcp",
"not_a_number",
Some(&["macos"]),
)];
let _profile = builder.build_profile(rules); let _profile = builder.build_profile(rules);
// Should handle invalid port gracefully // Should handle invalid port gracefully
@@ -188,7 +248,6 @@ fn test_invalid_tcp_port() {
#[test_case("network_outbound", "all", "" ; "network all operation")] #[test_case("network_outbound", "all", "" ; "network all operation")]
#[test_case("system_info_read", "all", "" ; "system info operation")] #[test_case("system_info_read", "all", "" ; "system info operation")]
fn test_operation_support_level(operation_type: &str, pattern_type: &str, pattern_value: &str) { fn test_operation_support_level(operation_type: &str, pattern_type: &str, pattern_value: &str) {
let project_path = PathBuf::from("/test/project"); let project_path = PathBuf::from("/test/project");
let builder = ProfileBuilder::new(project_path).unwrap(); let builder = ProfileBuilder::new(project_path).unwrap();
@@ -214,16 +273,29 @@ fn test_complex_profile_with_multiple_rules() {
let rules = vec![ let rules = vec![
// File operations // File operations
make_rule("file_read_all", "subpath", "{{PROJECT_PATH}}", Some(&["linux", "macos"])), make_rule(
make_rule("file_read_all", "subpath", "/usr/lib", Some(&["linux", "macos"])), "file_read_all",
make_rule("file_read_all", "literal", "/etc/hosts", Some(&["linux", "macos"])), "subpath",
"{{PROJECT_PATH}}",
Some(&["linux", "macos"]),
),
make_rule(
"file_read_all",
"subpath",
"/usr/lib",
Some(&["linux", "macos"]),
),
make_rule(
"file_read_all",
"literal",
"/etc/hosts",
Some(&["linux", "macos"]),
),
make_rule("file_read_metadata", "subpath", "/", Some(&["macos"])), make_rule("file_read_metadata", "subpath", "/", Some(&["macos"])),
// Network operations // Network operations
make_rule("network_outbound", "all", "", Some(&["linux", "macos"])), make_rule("network_outbound", "all", "", Some(&["linux", "macos"])),
make_rule("network_outbound", "tcp", "443", Some(&["macos"])), make_rule("network_outbound", "tcp", "443", Some(&["macos"])),
make_rule("network_outbound", "tcp", "80", Some(&["macos"])), make_rule("network_outbound", "tcp", "80", Some(&["macos"])),
// System info // System info
make_rule("system_info_read", "all", "", Some(&["macos"])), make_rule("system_info_read", "all", "", Some(&["macos"])),
]; ];
@@ -231,7 +303,10 @@ fn test_complex_profile_with_multiple_rules() {
let _profile = builder.build_profile(rules); let _profile = builder.build_profile(rules);
if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" { if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" {
assert!(_profile.is_ok(), "Complex profile should be created on supported platforms"); assert!(
_profile.is_ok(),
"Complex profile should be created on supported platforms"
);
} }
} }
@@ -242,9 +317,19 @@ fn test_rule_order_preservation() {
// Create rules with specific order // Create rules with specific order
let rules = vec![ let rules = vec![
make_rule("file_read_all", "subpath", "/first", Some(&["linux", "macos"])), make_rule(
"file_read_all",
"subpath",
"/first",
Some(&["linux", "macos"]),
),
make_rule("network_outbound", "all", "", Some(&["linux", "macos"])), make_rule("network_outbound", "all", "", Some(&["linux", "macos"])),
make_rule("file_read_all", "subpath", "/second", Some(&["linux", "macos"])), make_rule(
"file_read_all",
"subpath",
"/second",
Some(&["linux", "macos"]),
),
]; ];
let _profile = builder.build_profile(rules); let _profile = builder.build_profile(rules);