From bcffce0a08388fe778cee8c3aca2e89b1cd28420 Mon Sep 17 00:00:00 2001 From: Mufeed VH Date: Wed, 25 Jun 2025 03:45:59 +0530 Subject: [PATCH] style: apply cargo fmt across entire Rust codebase - Remove Rust formatting check from CI workflow since formatting is now applied - Standardize import ordering and organization throughout codebase - Fix indentation, spacing, and line breaks for consistency - Clean up trailing whitespace and formatting inconsistencies - Apply rustfmt to all Rust source files including checkpoint, sandbox, commands, and test modules This establishes a consistent code style baseline for the project. --- .github/workflows/build-test.yml | 8 - src-tauri/src/checkpoint/manager.rs | 342 ++++--- src-tauri/src/checkpoint/mod.rs | 34 +- src-tauri/src/checkpoint/state.rs | 102 +- src-tauri/src/checkpoint/storage.rs | 272 +++-- src-tauri/src/claude_binary.rs | 212 ++-- src-tauri/src/commands/agents.rs | 777 +++++++++------ src-tauri/src/commands/claude.rs | 943 ++++++++++-------- src-tauri/src/commands/mcp.rs | 233 +++-- src-tauri/src/commands/mod.rs | 8 +- src-tauri/src/commands/sandbox.rs | 462 +++++---- src-tauri/src/commands/screenshot.rs | 119 +-- src-tauri/src/commands/usage.rs | 324 +++--- src-tauri/src/lib.rs | 6 +- src-tauri/src/main.rs | 97 +- src-tauri/src/process/mod.rs | 2 +- src-tauri/src/process/registry.rs | 133 +-- src-tauri/src/sandbox/defaults.rs | 151 ++- src-tauri/src/sandbox/executor.rs | 258 +++-- src-tauri/src/sandbox/mod.rs | 12 +- src-tauri/src/sandbox/platform.rs | 4 +- src-tauri/src/sandbox/profile.rs | 317 +++--- src-tauri/tests/sandbox/common/claude_real.rs | 82 +- src-tauri/tests/sandbox/common/fixtures.rs | 79 +- src-tauri/tests/sandbox/common/helpers.rs | 83 +- src-tauri/tests/sandbox/common/mod.rs | 4 +- src-tauri/tests/sandbox/e2e/agent_sandbox.rs | 181 ++-- src-tauri/tests/sandbox/e2e/claude_sandbox.rs | 138 +-- src-tauri/tests/sandbox/e2e/mod.rs | 2 +- .../sandbox/integration/file_operations.rs | 114 ++- src-tauri/tests/sandbox/integration/mod.rs | 6 +- .../sandbox/integration/network_operations.rs | 125 +-- .../sandbox/integration/process_isolation.rs | 91 +- .../tests/sandbox/integration/system_info.rs | 61 +- .../tests/sandbox/integration/violations.rs | 95 +- src-tauri/tests/sandbox/mod.rs | 4 +- src-tauri/tests/sandbox/unit/executor.rs | 58 +- src-tauri/tests/sandbox/unit/mod.rs | 4 +- src-tauri/tests/sandbox/unit/platform.rs | 111 ++- .../tests/sandbox/unit/profile_builder.rs | 221 ++-- src-tauri/tests/sandbox_tests.rs | 4 +- 41 files changed, 3617 insertions(+), 2662 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 44430ca..008b215 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -100,14 +100,6 @@ jobs: - name: Build frontend run: bun run build - # Check Rust formatting - - name: Check Rust formatting - if: matrix.platform.os == 'ubuntu-latest' - working-directory: ./src-tauri - run: | - rustup component add rustfmt - cargo fmt -- --check - # Run Rust linter - name: Run Rust linter if: matrix.platform.os == 'ubuntu-latest' diff --git a/src-tauri/src/checkpoint/manager.rs b/src-tauri/src/checkpoint/manager.rs index 09fa40d..d6c0e6f 100644 --- a/src-tauri/src/checkpoint/manager.rs +++ b/src-tauri/src/checkpoint/manager.rs @@ -1,16 +1,16 @@ use anyhow::{Context, Result}; +use chrono::{DateTime, TimeZone, Utc}; +use log; use std::collections::HashMap; use std::fs; use std::path::PathBuf; use std::sync::Arc; -use chrono::{Utc, TimeZone, DateTime}; use tokio::sync::RwLock; -use log; use super::{ - Checkpoint, CheckpointMetadata, FileSnapshot, FileTracker, FileState, - CheckpointResult, SessionTimeline, CheckpointStrategy, CheckpointPaths, - storage::{CheckpointStorage, self}, + storage::{self, CheckpointStorage}, + Checkpoint, CheckpointMetadata, CheckpointPaths, CheckpointResult, CheckpointStrategy, + FileSnapshot, FileState, FileTracker, SessionTimeline, }; /// Manages checkpoint operations for a session @@ -33,10 +33,10 @@ impl CheckpointManager { claude_dir: PathBuf, ) -> Result { let storage = Arc::new(CheckpointStorage::new(claude_dir.clone())); - + // Initialize storage storage.init_storage(&project_id, &session_id)?; - + // Load or create timeline let paths = CheckpointPaths::new(&claude_dir, &project_id, &session_id); let timeline = if paths.timeline_file.exists() { @@ -44,11 +44,11 @@ impl CheckpointManager { } else { SessionTimeline::new(session_id.clone()) }; - + let file_tracker = FileTracker { tracked_files: HashMap::new(), }; - + Ok(Self { project_id, session_id, @@ -59,12 +59,12 @@ impl CheckpointManager { current_messages: Arc::new(RwLock::new(Vec::new())), }) } - + /// Track a new message in the session pub async fn track_message(&self, jsonl_message: String) -> Result<()> { let mut messages = self.current_messages.write().await; messages.push(jsonl_message.clone()); - + // Parse message to check for tool usage if let Ok(msg) = serde_json::from_str::(&jsonl_message) { if let Some(content) = msg.get("message").and_then(|m| m.get("content")) { @@ -81,10 +81,10 @@ impl CheckpointManager { } } } - + Ok(()) } - + /// Track file operations from tool usage async fn track_tool_operation(&self, tool: &str, input: &serde_json::Value) -> Result<()> { match tool.to_lowercase().as_str() { @@ -103,47 +103,51 @@ impl CheckpointManager { } Ok(()) } - + /// Track a file modification pub async fn track_file_modification(&self, file_path: &str) -> Result<()> { let mut tracker = self.file_tracker.write().await; let full_path = self.project_path.join(file_path); - + // Read current file state let (hash, exists, _size, modified) = if full_path.exists() { - let content = fs::read_to_string(&full_path) - .unwrap_or_default(); + let content = fs::read_to_string(&full_path).unwrap_or_default(); let metadata = fs::metadata(&full_path)?; - let modified = metadata.modified() + let modified = metadata + .modified() .ok() .and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok()) - .map(|d| Utc.timestamp_opt(d.as_secs() as i64, d.subsec_nanos()).unwrap()) + .map(|d| { + Utc.timestamp_opt(d.as_secs() as i64, d.subsec_nanos()) + .unwrap() + }) .unwrap_or_else(Utc::now); - + ( storage::CheckpointStorage::calculate_file_hash(&content), true, metadata.len(), - modified + modified, ) } else { (String::new(), false, 0, Utc::now()) }; - + // Check if file has actually changed - let is_modified = if let Some(existing_state) = tracker.tracked_files.get(&PathBuf::from(file_path)) { - // File is modified if: - // 1. Hash has changed - // 2. Existence state has changed - // 3. It was already marked as modified - existing_state.last_hash != hash || - existing_state.exists != exists || - existing_state.is_modified - } else { - // New file is always considered modified - true - }; - + let is_modified = + if let Some(existing_state) = tracker.tracked_files.get(&PathBuf::from(file_path)) { + // File is modified if: + // 1. Hash has changed + // 2. Existence state has changed + // 3. It was already marked as modified + existing_state.last_hash != hash + || existing_state.exists != exists + || existing_state.is_modified + } else { + // New file is always considered modified + true + }; + tracker.tracked_files.insert( PathBuf::from(file_path), FileState { @@ -153,18 +157,18 @@ impl CheckpointManager { exists, }, ); - + Ok(()) } - + /// Track potential file changes from bash commands async fn track_bash_side_effects(&self, command: &str) -> Result<()> { // Common file-modifying commands let file_commands = [ - "echo", "cat", "cp", "mv", "rm", "touch", "sed", "awk", - "npm", "yarn", "pnpm", "bun", "cargo", "make", "gcc", "g++", + "echo", "cat", "cp", "mv", "rm", "touch", "sed", "awk", "npm", "yarn", "pnpm", "bun", + "cargo", "make", "gcc", "g++", ]; - + // Simple heuristic: if command contains file-modifying operations for cmd in &file_commands { if command.contains(cmd) { @@ -176,10 +180,10 @@ impl CheckpointManager { break; } } - + Ok(()) } - + /// Create a checkpoint pub async fn create_checkpoint( &self, @@ -188,13 +192,18 @@ impl CheckpointManager { ) -> Result { let messages = self.current_messages.read().await; let message_index = messages.len().saturating_sub(1); - + // Extract metadata from the last user message - let (user_prompt, model_used, total_tokens) = self.extract_checkpoint_metadata(&messages).await?; - + let (user_prompt, model_used, total_tokens) = + self.extract_checkpoint_metadata(&messages).await?; + // Ensure every file in the project is tracked so new checkpoints include all files // Recursively walk the project directory and track each file - fn collect_files(dir: &std::path::Path, base: &std::path::Path, files: &mut Vec) -> Result<(), std::io::Error> { + fn collect_files( + dir: &std::path::Path, + base: &std::path::Path, + files: &mut Vec, + ) -> Result<(), std::io::Error> { for entry in std::fs::read_dir(dir)? { let entry = entry?; let path = entry.path(); @@ -224,13 +233,13 @@ impl CheckpointManager { let _ = self.track_file_modification(p).await; } } - + // Generate checkpoint ID early so snapshots reference it let checkpoint_id = storage::CheckpointStorage::generate_checkpoint_id(); // Create file snapshots let file_snapshots = self.create_file_snapshots(&checkpoint_id).await?; - + // Generate checkpoint struct let checkpoint = Checkpoint { id: checkpoint_id.clone(), @@ -259,7 +268,7 @@ impl CheckpointManager { ), }, }; - + // Save checkpoint let messages_content = messages.join("\n"); let result = self.storage.save_checkpoint( @@ -269,7 +278,7 @@ impl CheckpointManager { file_snapshots, &messages_content, )?; - + // Reload timeline from disk so in-memory timeline has updated nodes and total_checkpoints let claude_dir = self.storage.claude_dir.clone(); let paths = CheckpointPaths::new(&claude_dir, &self.project_id, &self.session_id); @@ -278,20 +287,20 @@ impl CheckpointManager { let mut timeline_lock = self.timeline.write().await; *timeline_lock = updated_timeline; } - + // Update timeline (current checkpoint only) let mut timeline = self.timeline.write().await; timeline.current_checkpoint_id = Some(checkpoint_id); - + // Reset file tracker let mut tracker = self.file_tracker.write().await; for (_, state) in tracker.tracked_files.iter_mut() { state.is_modified = false; } - + Ok(result) } - + /// Extract metadata from messages for checkpoint async fn extract_checkpoint_metadata( &self, @@ -300,13 +309,14 @@ impl CheckpointManager { let mut user_prompt = String::new(); let mut model_used = String::from("unknown"); let mut total_tokens = 0u64; - + // Iterate through messages in reverse to find the last user prompt for msg_str in messages.iter().rev() { if let Ok(msg) = serde_json::from_str::(msg_str) { // Check for user message if msg.get("type").and_then(|t| t.as_str()) == Some("user") { - if let Some(content) = msg.get("message") + if let Some(content) = msg + .get("message") .and_then(|m| m.get("content")) .and_then(|c| c.as_array()) { @@ -320,19 +330,19 @@ impl CheckpointManager { } } } - + // Extract model info if let Some(model) = msg.get("model").and_then(|m| m.as_str()) { model_used = model.to_string(); } - + // Also check for model in message.model (assistant messages) if let Some(message) = msg.get("message") { if let Some(model) = message.get("model").and_then(|m| m.as_str()) { model_used = model.to_string(); } } - + // Count tokens - check both top-level and nested usage // First check for usage in message.usage (assistant messages) if let Some(message) = msg.get("message") { @@ -344,15 +354,21 @@ impl CheckpointManager { total_tokens += output; } // Also count cache tokens - if let Some(cache_creation) = usage.get("cache_creation_input_tokens").and_then(|t| t.as_u64()) { + if let Some(cache_creation) = usage + .get("cache_creation_input_tokens") + .and_then(|t| t.as_u64()) + { total_tokens += cache_creation; } - if let Some(cache_read) = usage.get("cache_read_input_tokens").and_then(|t| t.as_u64()) { + if let Some(cache_read) = usage + .get("cache_read_input_tokens") + .and_then(|t| t.as_u64()) + { total_tokens += cache_read; } } } - + // Then check for top-level usage (result messages) if let Some(usage) = msg.get("usage") { if let Some(input) = usage.get("input_tokens").and_then(|t| t.as_u64()) { @@ -362,40 +378,45 @@ impl CheckpointManager { total_tokens += output; } // Also count cache tokens - if let Some(cache_creation) = usage.get("cache_creation_input_tokens").and_then(|t| t.as_u64()) { + if let Some(cache_creation) = usage + .get("cache_creation_input_tokens") + .and_then(|t| t.as_u64()) + { total_tokens += cache_creation; } - if let Some(cache_read) = usage.get("cache_read_input_tokens").and_then(|t| t.as_u64()) { + if let Some(cache_read) = usage + .get("cache_read_input_tokens") + .and_then(|t| t.as_u64()) + { total_tokens += cache_read; } } } } - + Ok((user_prompt, model_used, total_tokens)) } - + /// Create file snapshots for all tracked modified files async fn create_file_snapshots(&self, checkpoint_id: &str) -> Result> { let tracker = self.file_tracker.read().await; let mut snapshots = Vec::new(); - + for (rel_path, state) in &tracker.tracked_files { // Skip files that haven't been modified if !state.is_modified { continue; } - + let full_path = self.project_path.join(rel_path); - + let (content, exists, permissions, size, current_hash) = if full_path.exists() { - let content = fs::read_to_string(&full_path) - .unwrap_or_default(); + let content = fs::read_to_string(&full_path).unwrap_or_default(); let current_hash = storage::CheckpointStorage::calculate_file_hash(&content); - + // Don't skip based on hash - if is_modified is true, we should snapshot it // The hash check in track_file_modification already determined if it changed - + let metadata = fs::metadata(&full_path)?; let permissions = { #[cfg(unix)] @@ -412,7 +433,7 @@ impl CheckpointManager { } else { (String::new(), false, None, 0, String::new()) }; - + snapshots.push(FileSnapshot { checkpoint_id: checkpoint_id.to_string(), file_path: rel_path.clone(), @@ -423,21 +444,23 @@ impl CheckpointManager { size, }); } - + Ok(snapshots) } - + /// Restore a checkpoint pub async fn restore_checkpoint(&self, checkpoint_id: &str) -> Result { // Load checkpoint data - let (checkpoint, file_snapshots, messages) = self.storage.load_checkpoint( - &self.project_id, - &self.session_id, - checkpoint_id, - )?; - + let (checkpoint, file_snapshots, messages) = + self.storage + .load_checkpoint(&self.project_id, &self.session_id, checkpoint_id)?; + // First, collect all files currently in the project to handle deletions - fn collect_all_project_files(dir: &std::path::Path, base: &std::path::Path, files: &mut Vec) -> Result<(), std::io::Error> { + fn collect_all_project_files( + dir: &std::path::Path, + base: &std::path::Path, + files: &mut Vec, + ) -> Result<(), std::io::Error> { for entry in std::fs::read_dir(dir)? { let entry = entry?; let path = entry.path(); @@ -458,10 +481,11 @@ impl CheckpointManager { } Ok(()) } - + let mut current_files = Vec::new(); - let _ = collect_all_project_files(&self.project_path, &self.project_path, &mut current_files); - + let _ = + collect_all_project_files(&self.project_path, &self.project_path, &mut current_files); + // Create a set of files that should exist after restore let mut checkpoint_files = std::collections::HashSet::new(); for snapshot in &file_snapshots { @@ -469,11 +493,11 @@ impl CheckpointManager { checkpoint_files.insert(snapshot.file_path.clone()); } } - + // Delete files that exist now but shouldn't exist in the checkpoint let mut warnings = Vec::new(); let mut files_processed = 0; - + for current_file in current_files { if !checkpoint_files.contains(¤t_file) { // This file exists now but not in the checkpoint, so delete it @@ -484,18 +508,25 @@ impl CheckpointManager { log::info!("Deleted file not in checkpoint: {:?}", current_file); } Err(e) => { - warnings.push(format!("Failed to delete {}: {}", current_file.display(), e)); + warnings.push(format!( + "Failed to delete {}: {}", + current_file.display(), + e + )); } } } } - + // Clean up empty directories - fn remove_empty_dirs(dir: &std::path::Path, base: &std::path::Path) -> Result { + fn remove_empty_dirs( + dir: &std::path::Path, + base: &std::path::Path, + ) -> Result { if dir == base { return Ok(false); // Don't remove the base directory } - + let mut is_empty = true; for entry in fs::read_dir(dir)? { let entry = entry?; @@ -508,7 +539,7 @@ impl CheckpointManager { is_empty = false; } } - + if is_empty { fs::remove_dir(dir)?; Ok(true) @@ -516,30 +547,33 @@ impl CheckpointManager { Ok(false) } } - + // Clean up any empty directories left after file deletion let _ = remove_empty_dirs(&self.project_path, &self.project_path); - + // Restore files from checkpoint for snapshot in &file_snapshots { match self.restore_file_snapshot(snapshot).await { Ok(_) => files_processed += 1, - Err(e) => warnings.push(format!("Failed to restore {}: {}", - snapshot.file_path.display(), e)), + Err(e) => warnings.push(format!( + "Failed to restore {}: {}", + snapshot.file_path.display(), + e + )), } } - + // Update current messages let mut current_messages = self.current_messages.write().await; current_messages.clear(); for line in messages.lines() { current_messages.push(line.to_string()); } - + // Update timeline let mut timeline = self.timeline.write().await; timeline.current_checkpoint_id = Some(checkpoint_id.to_string()); - + // Update file tracker let mut tracker = self.file_tracker.write().await; tracker.tracked_files.clear(); @@ -556,35 +590,32 @@ impl CheckpointManager { ); } } - + Ok(CheckpointResult { checkpoint: checkpoint.clone(), files_processed, warnings, }) } - + /// Restore a single file from snapshot async fn restore_file_snapshot(&self, snapshot: &FileSnapshot) -> Result<()> { let full_path = self.project_path.join(&snapshot.file_path); - + if snapshot.is_deleted { // Delete the file if it exists if full_path.exists() { - fs::remove_file(&full_path) - .context("Failed to delete file")?; + fs::remove_file(&full_path).context("Failed to delete file")?; } } else { // Create parent directories if needed if let Some(parent) = full_path.parent() { - fs::create_dir_all(parent) - .context("Failed to create parent directories")?; + fs::create_dir_all(parent).context("Failed to create parent directories")?; } - + // Write file content - fs::write(&full_path, &snapshot.content) - .context("Failed to write file")?; - + fs::write(&full_path, &snapshot.content).context("Failed to write file")?; + // Restore permissions if available #[cfg(unix)] if let Some(mode) = snapshot.permissions { @@ -594,35 +625,38 @@ impl CheckpointManager { .context("Failed to set file permissions")?; } } - + Ok(()) } - + /// Get the current timeline pub async fn get_timeline(&self) -> SessionTimeline { self.timeline.read().await.clone() } - + /// List all checkpoints pub async fn list_checkpoints(&self) -> Vec { let timeline = self.timeline.read().await; let mut checkpoints = Vec::new(); - + if let Some(root) = &timeline.root_node { Self::collect_checkpoints_from_node(root, &mut checkpoints); } - + checkpoints } - + /// Recursively collect checkpoints from timeline tree - fn collect_checkpoints_from_node(node: &super::TimelineNode, checkpoints: &mut Vec) { + fn collect_checkpoints_from_node( + node: &super::TimelineNode, + checkpoints: &mut Vec, + ) { checkpoints.push(node.checkpoint.clone()); for child in &node.children { Self::collect_checkpoints_from_node(child, checkpoints); } } - + /// Fork from a checkpoint pub async fn fork_from_checkpoint( &self, @@ -630,31 +664,29 @@ impl CheckpointManager { description: Option, ) -> Result { // Load the checkpoint to fork from - let (_base_checkpoint, _, _) = self.storage.load_checkpoint( - &self.project_id, - &self.session_id, - checkpoint_id, - )?; - + let (_base_checkpoint, _, _) = + self.storage + .load_checkpoint(&self.project_id, &self.session_id, checkpoint_id)?; + // Restore to that checkpoint first self.restore_checkpoint(checkpoint_id).await?; - + // Create a new checkpoint with the fork - let fork_description = description.unwrap_or_else(|| { - format!("Fork from checkpoint {}", &checkpoint_id[..8]) - }); - - self.create_checkpoint(Some(fork_description), Some(checkpoint_id.to_string())).await + let fork_description = + description.unwrap_or_else(|| format!("Fork from checkpoint {}", &checkpoint_id[..8])); + + self.create_checkpoint(Some(fork_description), Some(checkpoint_id.to_string())) + .await } - + /// Check if auto-checkpoint should be triggered pub async fn should_auto_checkpoint(&self, message: &str) -> bool { let timeline = self.timeline.read().await; - + if !timeline.auto_checkpoint_enabled { return false; } - + match timeline.checkpoint_strategy { CheckpointStrategy::Manual => false, CheckpointStrategy::PerPrompt => { @@ -668,7 +700,11 @@ impl CheckpointManager { CheckpointStrategy::PerToolUse => { // Check if message contains tool use if let Ok(msg) = serde_json::from_str::(message) { - if let Some(content) = msg.get("message").and_then(|m| m.get("content")).and_then(|c| c.as_array()) { + if let Some(content) = msg + .get("message") + .and_then(|m| m.get("content")) + .and_then(|c| c.as_array()) + { content.iter().any(|item| { item.get("type").and_then(|t| t.as_str()) == Some("tool_use") }) @@ -682,12 +718,19 @@ impl CheckpointManager { CheckpointStrategy::Smart => { // Smart strategy: checkpoint after destructive operations if let Ok(msg) = serde_json::from_str::(message) { - if let Some(content) = msg.get("message").and_then(|m| m.get("content")).and_then(|c| c.as_array()) { + if let Some(content) = msg + .get("message") + .and_then(|m| m.get("content")) + .and_then(|c| c.as_array()) + { content.iter().any(|item| { if item.get("type").and_then(|t| t.as_str()) == Some("tool_use") { - let tool_name = item.get("name").and_then(|n| n.as_str()).unwrap_or(""); - matches!(tool_name.to_lowercase().as_str(), - "write" | "edit" | "multiedit" | "bash" | "rm" | "delete") + let tool_name = + item.get("name").and_then(|n| n.as_str()).unwrap_or(""); + matches!( + tool_name.to_lowercase().as_str(), + "write" | "edit" | "multiedit" | "bash" | "rm" | "delete" + ) } else { false } @@ -701,7 +744,7 @@ impl CheckpointManager { } } } - + /// Update checkpoint settings pub async fn update_settings( &self, @@ -711,31 +754,34 @@ impl CheckpointManager { let mut timeline = self.timeline.write().await; timeline.auto_checkpoint_enabled = auto_checkpoint_enabled; timeline.checkpoint_strategy = checkpoint_strategy; - + // Save updated timeline let claude_dir = self.storage.claude_dir.clone(); let paths = CheckpointPaths::new(&claude_dir, &self.project_id, &self.session_id); - self.storage.save_timeline(&paths.timeline_file, &timeline)?; - + self.storage + .save_timeline(&paths.timeline_file, &timeline)?; + Ok(()) } - + /// Get files modified since a given timestamp pub async fn get_files_modified_since(&self, since: DateTime) -> Vec { let tracker = self.file_tracker.read().await; - tracker.tracked_files + tracker + .tracked_files .iter() .filter(|(_, state)| state.last_modified > since && state.is_modified) .map(|(path, _)| path.clone()) .collect() } - + /// Get the last modification time of any tracked file pub async fn get_last_modification_time(&self) -> Option> { let tracker = self.file_tracker.read().await; - tracker.tracked_files + tracker + .tracked_files .values() .map(|state| state.last_modified) .max() } -} \ No newline at end of file +} diff --git a/src-tauri/src/checkpoint/mod.rs b/src-tauri/src/checkpoint/mod.rs index 4c4f9ec..030418b 100644 --- a/src-tauri/src/checkpoint/mod.rs +++ b/src-tauri/src/checkpoint/mod.rs @@ -1,11 +1,11 @@ +use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::PathBuf; -use chrono::{DateTime, Utc}; pub mod manager; -pub mod storage; pub mod state; +pub mod storage; /// Represents a checkpoint in the session timeline #[derive(Debug, Clone, Serialize, Deserialize)] @@ -188,24 +188,25 @@ impl SessionTimeline { total_checkpoints: 0, } } - + /// Find a checkpoint by ID in the timeline tree pub fn find_checkpoint(&self, checkpoint_id: &str) -> Option<&TimelineNode> { - self.root_node.as_ref() + self.root_node + .as_ref() .and_then(|root| Self::find_in_tree(root, checkpoint_id)) } - + fn find_in_tree<'a>(node: &'a TimelineNode, checkpoint_id: &str) -> Option<&'a TimelineNode> { if node.checkpoint.id == checkpoint_id { return Some(node); } - + for child in &node.children { if let Some(found) = Self::find_in_tree(child, checkpoint_id) { return Some(found); } } - + None } } @@ -224,35 +225,38 @@ impl CheckpointPaths { .join(project_id) .join(".timelines") .join(session_id); - + Self { timeline_file: base_dir.join("timeline.json"), checkpoints_dir: base_dir.join("checkpoints"), files_dir: base_dir.join("files"), } } - + pub fn checkpoint_dir(&self, checkpoint_id: &str) -> PathBuf { self.checkpoints_dir.join(checkpoint_id) } - + pub fn checkpoint_metadata_file(&self, checkpoint_id: &str) -> PathBuf { self.checkpoint_dir(checkpoint_id).join("metadata.json") } - + pub fn checkpoint_messages_file(&self, checkpoint_id: &str) -> PathBuf { self.checkpoint_dir(checkpoint_id).join("messages.jsonl") } - + #[allow(dead_code)] pub fn file_snapshot_path(&self, _checkpoint_id: &str, file_hash: &str) -> PathBuf { // In content-addressable storage, files are stored by hash in the content pool self.files_dir.join("content_pool").join(file_hash) } - + #[allow(dead_code)] pub fn file_reference_path(&self, checkpoint_id: &str, safe_filename: &str) -> PathBuf { // References are stored per checkpoint - self.files_dir.join("refs").join(checkpoint_id).join(format!("{}.json", safe_filename)) + self.files_dir + .join("refs") + .join(checkpoint_id) + .join(format!("{}.json", safe_filename)) } -} \ No newline at end of file +} diff --git a/src-tauri/src/checkpoint/state.rs b/src-tauri/src/checkpoint/state.rs index 4f41aa7..a633ebc 100644 --- a/src-tauri/src/checkpoint/state.rs +++ b/src-tauri/src/checkpoint/state.rs @@ -1,13 +1,13 @@ +use anyhow::Result; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::RwLock; -use anyhow::Result; use super::manager::CheckpointManager; /// Manages checkpoint managers for active sessions -/// +/// /// This struct maintains a stateful collection of CheckpointManager instances, /// one per active session, to avoid recreating them on every command invocation. /// It provides thread-safe access to managers and handles their lifecycle. @@ -28,25 +28,25 @@ impl CheckpointState { claude_dir: Arc::new(RwLock::new(None)), } } - + /// Sets the Claude directory path - /// + /// /// This should be called once during application initialization pub async fn set_claude_dir(&self, claude_dir: PathBuf) { let mut dir = self.claude_dir.write().await; *dir = Some(claude_dir); } - + /// Gets or creates a CheckpointManager for a session - /// + /// /// If a manager already exists for the session, it returns the existing one. /// Otherwise, it creates a new manager and stores it for future use. - /// + /// /// # Arguments /// * `session_id` - The session identifier /// * `project_id` - The project identifier /// * `project_path` - The path to the project directory - /// + /// /// # Returns /// An Arc reference to the CheckpointManager for thread-safe sharing pub async fn get_or_create_manager( @@ -56,12 +56,12 @@ impl CheckpointState { project_path: PathBuf, ) -> Result> { let mut managers = self.managers.write().await; - + // Check if manager already exists if let Some(manager) = managers.get(&session_id) { return Ok(Arc::clone(manager)); } - + // Get Claude directory let claude_dir = { let dir = self.claude_dir.read().await; @@ -69,65 +69,62 @@ impl CheckpointState { .ok_or_else(|| anyhow::anyhow!("Claude directory not set"))? .clone() }; - + // Create new manager - let manager = CheckpointManager::new( - project_id, - session_id.clone(), - project_path, - claude_dir, - ).await?; - + let manager = + CheckpointManager::new(project_id, session_id.clone(), project_path, claude_dir) + .await?; + let manager_arc = Arc::new(manager); managers.insert(session_id, Arc::clone(&manager_arc)); - + Ok(manager_arc) } - + /// Gets an existing CheckpointManager for a session - /// + /// /// Returns None if no manager exists for the session #[allow(dead_code)] pub async fn get_manager(&self, session_id: &str) -> Option> { let managers = self.managers.read().await; managers.get(session_id).map(Arc::clone) } - + /// Removes a CheckpointManager for a session - /// + /// /// This should be called when a session ends to free resources pub async fn remove_manager(&self, session_id: &str) -> Option> { let mut managers = self.managers.write().await; managers.remove(session_id) } - + /// Clears all managers - /// + /// /// This is useful for cleanup during application shutdown #[allow(dead_code)] pub async fn clear_all(&self) { let mut managers = self.managers.write().await; managers.clear(); } - + /// Gets the number of active managers pub async fn active_count(&self) -> usize { let managers = self.managers.read().await; managers.len() } - + /// Lists all active session IDs pub async fn list_active_sessions(&self) -> Vec { let managers = self.managers.read().await; managers.keys().cloned().collect() } - + /// Checks if a session has an active manager #[allow(dead_code)] pub async fn has_active_manager(&self, session_id: &str) -> bool { self.get_manager(session_id).await.is_some() } - + /// Clears all managers and returns the count that were cleared #[allow(dead_code)] pub async fn clear_all_and_count(&self) -> usize { @@ -141,50 +138,47 @@ impl CheckpointState { mod tests { use super::*; use tempfile::TempDir; - + #[tokio::test] async fn test_checkpoint_state_lifecycle() { let state = CheckpointState::new(); let temp_dir = TempDir::new().unwrap(); let claude_dir = temp_dir.path().to_path_buf(); - + // Set Claude directory state.set_claude_dir(claude_dir.clone()).await; - + // Create a manager let session_id = "test-session-123".to_string(); let project_id = "test-project".to_string(); let project_path = temp_dir.path().join("project"); std::fs::create_dir_all(&project_path).unwrap(); - - let manager1 = state.get_or_create_manager( - session_id.clone(), - project_id.clone(), - project_path.clone(), - ).await.unwrap(); - + + let manager1 = state + .get_or_create_manager(session_id.clone(), project_id.clone(), project_path.clone()) + .await + .unwrap(); + // Getting the same session should return the same manager - let manager2 = state.get_or_create_manager( - session_id.clone(), - project_id.clone(), - project_path.clone(), - ).await.unwrap(); - + let manager2 = state + .get_or_create_manager(session_id.clone(), project_id.clone(), project_path.clone()) + .await + .unwrap(); + assert!(Arc::ptr_eq(&manager1, &manager2)); assert_eq!(state.active_count().await, 1); - + // Remove the manager let removed = state.remove_manager(&session_id).await; assert!(removed.is_some()); assert_eq!(state.active_count().await, 0); - + // Getting after removal should create a new one - let manager3 = state.get_or_create_manager( - session_id.clone(), - project_id, - project_path, - ).await.unwrap(); - + let manager3 = state + .get_or_create_manager(session_id.clone(), project_id, project_path) + .await + .unwrap(); + assert!(!Arc::ptr_eq(&manager1, &manager3)); } -} \ No newline at end of file +} diff --git a/src-tauri/src/checkpoint/storage.rs b/src-tauri/src/checkpoint/storage.rs index d689b8e..ce82de3 100644 --- a/src-tauri/src/checkpoint/storage.rs +++ b/src-tauri/src/checkpoint/storage.rs @@ -1,13 +1,12 @@ use anyhow::{Context, Result}; +use sha2::{Digest, Sha256}; use std::fs; use std::path::{Path, PathBuf}; -use sha2::{Sha256, Digest}; -use zstd::stream::{encode_all, decode_all}; use uuid::Uuid; +use zstd::stream::{decode_all, encode_all}; use super::{ - Checkpoint, FileSnapshot, SessionTimeline, - TimelineNode, CheckpointPaths, CheckpointResult + Checkpoint, CheckpointPaths, CheckpointResult, FileSnapshot, SessionTimeline, TimelineNode, }; /// Manages checkpoint storage operations @@ -24,26 +23,25 @@ impl CheckpointStorage { compression_level: 3, // Default zstd compression level } } - + /// Initialize checkpoint storage for a session pub fn init_storage(&self, project_id: &str, session_id: &str) -> Result<()> { let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id); - + // Create directory structure fs::create_dir_all(&paths.checkpoints_dir) .context("Failed to create checkpoints directory")?; - fs::create_dir_all(&paths.files_dir) - .context("Failed to create files directory")?; - + fs::create_dir_all(&paths.files_dir).context("Failed to create files directory")?; + // Initialize empty timeline if it doesn't exist if !paths.timeline_file.exists() { let timeline = SessionTimeline::new(session_id.to_string()); self.save_timeline(&paths.timeline_file, &timeline)?; } - + Ok(()) } - + /// Save a checkpoint to disk pub fn save_checkpoint( &self, @@ -55,76 +53,73 @@ impl CheckpointStorage { ) -> Result { let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id); let checkpoint_dir = paths.checkpoint_dir(&checkpoint.id); - + // Create checkpoint directory - fs::create_dir_all(&checkpoint_dir) - .context("Failed to create checkpoint directory")?; - + fs::create_dir_all(&checkpoint_dir).context("Failed to create checkpoint directory")?; + // Save checkpoint metadata let metadata_path = paths.checkpoint_metadata_file(&checkpoint.id); let metadata_json = serde_json::to_string_pretty(checkpoint) .context("Failed to serialize checkpoint metadata")?; - fs::write(&metadata_path, metadata_json) - .context("Failed to write checkpoint metadata")?; - + fs::write(&metadata_path, metadata_json).context("Failed to write checkpoint metadata")?; + // Save messages (compressed) let messages_path = paths.checkpoint_messages_file(&checkpoint.id); let compressed_messages = encode_all(messages.as_bytes(), self.compression_level) .context("Failed to compress messages")?; fs::write(&messages_path, compressed_messages) .context("Failed to write compressed messages")?; - + // Save file snapshots let mut warnings = Vec::new(); let mut files_processed = 0; - + for snapshot in &file_snapshots { match self.save_file_snapshot(&paths, snapshot) { Ok(_) => files_processed += 1, - Err(e) => warnings.push(format!("Failed to save {}: {}", - snapshot.file_path.display(), e)), + Err(e) => warnings.push(format!( + "Failed to save {}: {}", + snapshot.file_path.display(), + e + )), } } - + // Update timeline - self.update_timeline_with_checkpoint( - &paths.timeline_file, - checkpoint, - &file_snapshots - )?; - + self.update_timeline_with_checkpoint(&paths.timeline_file, checkpoint, &file_snapshots)?; + Ok(CheckpointResult { checkpoint: checkpoint.clone(), files_processed, warnings, }) } - + /// Save a single file snapshot fn save_file_snapshot(&self, paths: &CheckpointPaths, snapshot: &FileSnapshot) -> Result<()> { // Use content-addressable storage: store files by their hash // This prevents duplication of identical file content across checkpoints let content_pool_dir = paths.files_dir.join("content_pool"); - fs::create_dir_all(&content_pool_dir) - .context("Failed to create content pool directory")?; - + fs::create_dir_all(&content_pool_dir).context("Failed to create content pool directory")?; + // Store the actual content in the content pool let content_file = content_pool_dir.join(&snapshot.hash); - + // Only write the content if it doesn't already exist if !content_file.exists() { // Compress and save file content - let compressed_content = encode_all(snapshot.content.as_bytes(), self.compression_level) - .context("Failed to compress file content")?; + let compressed_content = + encode_all(snapshot.content.as_bytes(), self.compression_level) + .context("Failed to compress file content")?; fs::write(&content_file, compressed_content) .context("Failed to write file content to pool")?; } - + // Create a reference in the checkpoint-specific directory let checkpoint_refs_dir = paths.files_dir.join("refs").join(&snapshot.checkpoint_id); fs::create_dir_all(&checkpoint_refs_dir) .context("Failed to create checkpoint refs directory")?; - + // Save file metadata with reference to content let ref_metadata = serde_json::json!({ "path": snapshot.file_path, @@ -133,20 +128,21 @@ impl CheckpointStorage { "permissions": snapshot.permissions, "size": snapshot.size, }); - + // Use a sanitized filename for the reference - let safe_filename = snapshot.file_path + let safe_filename = snapshot + .file_path .to_string_lossy() .replace('/', "_") .replace('\\', "_"); let ref_path = checkpoint_refs_dir.join(format!("{}.json", safe_filename)); - + fs::write(&ref_path, serde_json::to_string_pretty(&ref_metadata)?) .context("Failed to write file reference")?; - + Ok(()) } - + /// Load a checkpoint from disk pub fn load_checkpoint( &self, @@ -155,75 +151,78 @@ impl CheckpointStorage { checkpoint_id: &str, ) -> Result<(Checkpoint, Vec, String)> { let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id); - + // Load checkpoint metadata let metadata_path = paths.checkpoint_metadata_file(checkpoint_id); - let metadata_json = fs::read_to_string(&metadata_path) - .context("Failed to read checkpoint metadata")?; - let checkpoint: Checkpoint = serde_json::from_str(&metadata_json) - .context("Failed to parse checkpoint metadata")?; - + let metadata_json = + fs::read_to_string(&metadata_path).context("Failed to read checkpoint metadata")?; + let checkpoint: Checkpoint = + serde_json::from_str(&metadata_json).context("Failed to parse checkpoint metadata")?; + // Load messages let messages_path = paths.checkpoint_messages_file(checkpoint_id); - let compressed_messages = fs::read(&messages_path) - .context("Failed to read compressed messages")?; - let messages = String::from_utf8(decode_all(&compressed_messages[..]) - .context("Failed to decompress messages")?) - .context("Invalid UTF-8 in messages")?; - + let compressed_messages = + fs::read(&messages_path).context("Failed to read compressed messages")?; + let messages = String::from_utf8( + decode_all(&compressed_messages[..]).context("Failed to decompress messages")?, + ) + .context("Invalid UTF-8 in messages")?; + // Load file snapshots let file_snapshots = self.load_file_snapshots(&paths, checkpoint_id)?; - + Ok((checkpoint, file_snapshots, messages)) } - + /// Load all file snapshots for a checkpoint fn load_file_snapshots( - &self, - paths: &CheckpointPaths, - checkpoint_id: &str + &self, + paths: &CheckpointPaths, + checkpoint_id: &str, ) -> Result> { let refs_dir = paths.files_dir.join("refs").join(checkpoint_id); if !refs_dir.exists() { return Ok(Vec::new()); } - + let content_pool_dir = paths.files_dir.join("content_pool"); let mut snapshots = Vec::new(); - + // Read all reference files for entry in fs::read_dir(&refs_dir)? { let entry = entry?; let path = entry.path(); - + // Skip non-JSON files if path.extension().and_then(|e| e.to_str()) != Some("json") { continue; } - + // Load reference metadata - let ref_json = fs::read_to_string(&path) - .context("Failed to read file reference")?; - let ref_metadata: serde_json::Value = serde_json::from_str(&ref_json) - .context("Failed to parse file reference")?; - - let hash = ref_metadata["hash"].as_str() + let ref_json = fs::read_to_string(&path).context("Failed to read file reference")?; + let ref_metadata: serde_json::Value = + serde_json::from_str(&ref_json).context("Failed to parse file reference")?; + + let hash = ref_metadata["hash"] + .as_str() .ok_or_else(|| anyhow::anyhow!("Missing hash in reference"))?; - + // Load content from pool let content_file = content_pool_dir.join(hash); let content = if content_file.exists() { - let compressed_content = fs::read(&content_file) - .context("Failed to read file content from pool")?; - String::from_utf8(decode_all(&compressed_content[..]) - .context("Failed to decompress file content")?) - .context("Invalid UTF-8 in file content")? + let compressed_content = + fs::read(&content_file).context("Failed to read file content from pool")?; + String::from_utf8( + decode_all(&compressed_content[..]) + .context("Failed to decompress file content")?, + ) + .context("Invalid UTF-8 in file content")? } else { // Handle missing content gracefully log::warn!("Content file missing for hash: {}", hash); String::new() }; - + snapshots.push(FileSnapshot { checkpoint_id: checkpoint_id.to_string(), file_path: PathBuf::from(ref_metadata["path"].as_str().unwrap_or("")), @@ -234,28 +233,26 @@ impl CheckpointStorage { size: ref_metadata["size"].as_u64().unwrap_or(0), }); } - + Ok(snapshots) } - + /// Save timeline to disk pub fn save_timeline(&self, timeline_path: &Path, timeline: &SessionTimeline) -> Result<()> { - let timeline_json = serde_json::to_string_pretty(timeline) - .context("Failed to serialize timeline")?; - fs::write(timeline_path, timeline_json) - .context("Failed to write timeline")?; + let timeline_json = + serde_json::to_string_pretty(timeline).context("Failed to serialize timeline")?; + fs::write(timeline_path, timeline_json).context("Failed to write timeline")?; Ok(()) } - + /// Load timeline from disk pub fn load_timeline(&self, timeline_path: &Path) -> Result { - let timeline_json = fs::read_to_string(timeline_path) - .context("Failed to read timeline")?; - let timeline: SessionTimeline = serde_json::from_str(&timeline_json) - .context("Failed to parse timeline")?; + let timeline_json = fs::read_to_string(timeline_path).context("Failed to read timeline")?; + let timeline: SessionTimeline = + serde_json::from_str(&timeline_json).context("Failed to parse timeline")?; Ok(timeline) } - + /// Update timeline with a new checkpoint fn update_timeline_with_checkpoint( &self, @@ -264,15 +261,13 @@ impl CheckpointStorage { file_snapshots: &[FileSnapshot], ) -> Result<()> { let mut timeline = self.load_timeline(timeline_path)?; - + let new_node = TimelineNode { checkpoint: checkpoint.clone(), children: Vec::new(), - file_snapshot_ids: file_snapshots.iter() - .map(|s| s.hash.clone()) - .collect(), + file_snapshot_ids: file_snapshots.iter().map(|s| s.hash.clone()).collect(), }; - + // If this is the first checkpoint if timeline.root_node.is_none() { timeline.root_node = Some(new_node); @@ -280,7 +275,7 @@ impl CheckpointStorage { } else if let Some(parent_id) = &checkpoint.parent_checkpoint_id { // Check if parent exists before modifying let parent_exists = timeline.find_checkpoint(parent_id).is_some(); - + if parent_exists { if let Some(root) = &mut timeline.root_node { Self::add_child_to_node(root, parent_id, new_node)?; @@ -290,59 +285,54 @@ impl CheckpointStorage { anyhow::bail!("Parent checkpoint not found: {}", parent_id); } } - + timeline.total_checkpoints += 1; self.save_timeline(timeline_path, &timeline)?; - + Ok(()) } - + /// Recursively add a child node to the timeline tree fn add_child_to_node( - node: &mut TimelineNode, - parent_id: &str, - child: TimelineNode + node: &mut TimelineNode, + parent_id: &str, + child: TimelineNode, ) -> Result<()> { if node.checkpoint.id == parent_id { node.children.push(child); return Ok(()); } - + for child_node in &mut node.children { if Self::add_child_to_node(child_node, parent_id, child.clone()).is_ok() { return Ok(()); } } - + anyhow::bail!("Parent checkpoint not found: {}", parent_id) } - + /// Calculate hash of file content pub fn calculate_file_hash(content: &str) -> String { let mut hasher = Sha256::new(); hasher.update(content.as_bytes()); format!("{:x}", hasher.finalize()) } - + /// Generate a new checkpoint ID pub fn generate_checkpoint_id() -> String { Uuid::new_v4().to_string() } - + /// Estimate storage size for a checkpoint - pub fn estimate_checkpoint_size( - messages: &str, - file_snapshots: &[FileSnapshot], - ) -> u64 { + pub fn estimate_checkpoint_size(messages: &str, file_snapshots: &[FileSnapshot]) -> u64 { let messages_size = messages.len() as u64; - let files_size: u64 = file_snapshots.iter() - .map(|s| s.content.len() as u64) - .sum(); - + let files_size: u64 = file_snapshots.iter().map(|s| s.content.len() as u64).sum(); + // Estimate compressed size (typically 20-30% of original for text) (messages_size + files_size) / 4 } - + /// Clean up old checkpoints based on retention policy pub fn cleanup_old_checkpoints( &self, @@ -352,26 +342,26 @@ impl CheckpointStorage { ) -> Result { let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id); let timeline = self.load_timeline(&paths.timeline_file)?; - + // Collect all checkpoint IDs in chronological order let mut all_checkpoints = Vec::new(); if let Some(root) = &timeline.root_node { Self::collect_checkpoints(root, &mut all_checkpoints); } - + // Sort by timestamp (oldest first) all_checkpoints.sort_by(|a, b| a.timestamp.cmp(&b.timestamp)); - + // Keep only the most recent checkpoints let to_remove = all_checkpoints.len().saturating_sub(keep_count); let mut removed_count = 0; - + for checkpoint in all_checkpoints.into_iter().take(to_remove) { if self.remove_checkpoint(&paths, &checkpoint.id).is_ok() { removed_count += 1; } } - + // Run garbage collection to clean up orphaned content if removed_count > 0 { match self.garbage_collect_content(project_id, session_id) { @@ -383,10 +373,10 @@ impl CheckpointStorage { } } } - + Ok(removed_count) } - + /// Collect all checkpoints from the tree in order fn collect_checkpoints(node: &TimelineNode, checkpoints: &mut Vec) { checkpoints.push(node.checkpoint.clone()); @@ -394,46 +384,40 @@ impl CheckpointStorage { Self::collect_checkpoints(child, checkpoints); } } - + /// Remove a checkpoint and its associated files fn remove_checkpoint(&self, paths: &CheckpointPaths, checkpoint_id: &str) -> Result<()> { // Remove checkpoint metadata directory let checkpoint_dir = paths.checkpoint_dir(checkpoint_id); if checkpoint_dir.exists() { - fs::remove_dir_all(&checkpoint_dir) - .context("Failed to remove checkpoint directory")?; + fs::remove_dir_all(&checkpoint_dir).context("Failed to remove checkpoint directory")?; } - + // Remove file references for this checkpoint let refs_dir = paths.files_dir.join("refs").join(checkpoint_id); if refs_dir.exists() { - fs::remove_dir_all(&refs_dir) - .context("Failed to remove file references")?; + fs::remove_dir_all(&refs_dir).context("Failed to remove file references")?; } - + // Note: We don't remove content from the pool here as it might be // referenced by other checkpoints. Use garbage_collect_content() for that. - + Ok(()) } - + /// Garbage collect unreferenced content from the content pool - pub fn garbage_collect_content( - &self, - project_id: &str, - session_id: &str, - ) -> Result { + pub fn garbage_collect_content(&self, project_id: &str, session_id: &str) -> Result { let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id); let content_pool_dir = paths.files_dir.join("content_pool"); let refs_dir = paths.files_dir.join("refs"); - + if !content_pool_dir.exists() { return Ok(0); } - + // Collect all referenced hashes let mut referenced_hashes = std::collections::HashSet::new(); - + if refs_dir.exists() { for checkpoint_entry in fs::read_dir(&refs_dir)? { let checkpoint_dir = checkpoint_entry?.path(); @@ -442,7 +426,9 @@ impl CheckpointStorage { let ref_path = ref_entry?.path(); if ref_path.extension().and_then(|e| e.to_str()) == Some("json") { if let Ok(ref_json) = fs::read_to_string(&ref_path) { - if let Ok(ref_metadata) = serde_json::from_str::(&ref_json) { + if let Ok(ref_metadata) = + serde_json::from_str::(&ref_json) + { if let Some(hash) = ref_metadata["hash"].as_str() { referenced_hashes.insert(hash.to_string()); } @@ -453,7 +439,7 @@ impl CheckpointStorage { } } } - + // Remove unreferenced content let mut removed_count = 0; for entry in fs::read_dir(&content_pool_dir)? { @@ -468,7 +454,7 @@ impl CheckpointStorage { } } } - + Ok(removed_count) } -} \ No newline at end of file +} diff --git a/src-tauri/src/claude_binary.rs b/src-tauri/src/claude_binary.rs index af46ef3..adf3365 100644 --- a/src-tauri/src/claude_binary.rs +++ b/src-tauri/src/claude_binary.rs @@ -1,12 +1,12 @@ +use anyhow::Result; +use log::{debug, error, info, warn}; +use serde::{Deserialize, Serialize}; +use std::cmp::Ordering; /// Shared module for detecting Claude Code binary installations /// Supports NVM installations, aliased paths, and version-based selection use std::path::PathBuf; use std::process::Command; -use log::{info, warn, debug, error}; -use anyhow::Result; -use std::cmp::Ordering; use tauri::Manager; -use serde::{Serialize, Deserialize}; /// Represents a Claude installation with metadata #[derive(Debug, Clone, Serialize, Deserialize)] @@ -23,7 +23,7 @@ pub struct ClaudeInstallation { /// Checks database first, then discovers all installations and selects the best one pub fn find_claude_binary(app_handle: &tauri::AppHandle) -> Result { info!("Searching for claude binary..."); - + // First check if we have a stored path in the database if let Ok(app_data_dir) = app_handle.path().app_data_dir() { let db_path = app_data_dir.join("agents.db"); @@ -45,24 +45,26 @@ pub fn find_claude_binary(app_handle: &tauri::AppHandle) -> Result Result Vec { info!("Discovering all Claude installations..."); - + let installations = discover_all_installations(); - + // Sort by version (highest first), then by source preference let mut sorted = installations; sorted.sort_by(|a, b| { @@ -87,15 +89,15 @@ pub fn discover_claude_installations() -> Vec { // If versions are equal, prefer by source source_preference(a).cmp(&source_preference(b)) } - other => other + other => other, } } (Some(_), None) => Ordering::Less, // Version comes before no version (None, Some(_)) => Ordering::Greater, - (None, None) => source_preference(a).cmp(&source_preference(b)) + (None, None) => source_preference(a).cmp(&source_preference(b)), } }); - + sorted } @@ -121,57 +123,58 @@ fn source_preference(installation: &ClaudeInstallation) -> u8 { /// Discovers all Claude installations on the system fn discover_all_installations() -> Vec { let mut installations = Vec::new(); - + // 1. Try 'which' command first (now works in production) if let Some(installation) = try_which_command() { installations.push(installation); } - + // 2. Check NVM paths installations.extend(find_nvm_installations()); - + // 3. Check standard paths installations.extend(find_standard_installations()); - + // Remove duplicates by path let mut unique_paths = std::collections::HashSet::new(); installations.retain(|install| unique_paths.insert(install.path.clone())); - + installations } /// Try using the 'which' command to find Claude fn try_which_command() -> Option { debug!("Trying 'which claude' to find binary..."); - + match Command::new("which").arg("claude").output() { Ok(output) if output.status.success() => { let output_str = String::from_utf8_lossy(&output.stdout).trim().to_string(); - + if output_str.is_empty() { return None; } - + // Parse aliased output: "claude: aliased to /path/to/claude" let path = if output_str.starts_with("claude:") && output_str.contains("aliased to") { - output_str.split("aliased to") + output_str + .split("aliased to") .nth(1) .map(|s| s.trim().to_string()) } else { Some(output_str) }?; - + debug!("'which' found claude at: {}", path); - + // Verify the path exists if !PathBuf::from(&path).exists() { warn!("Path from 'which' does not exist: {}", path); return None; } - + // Get version let version = get_claude_version(&path).ok().flatten(); - + Some(ClaudeInstallation { path, version, @@ -185,26 +188,29 @@ fn try_which_command() -> Option { /// Find Claude installations in NVM directories fn find_nvm_installations() -> Vec { let mut installations = Vec::new(); - + if let Ok(home) = std::env::var("HOME") { - let nvm_dir = PathBuf::from(&home).join(".nvm").join("versions").join("node"); - + let nvm_dir = PathBuf::from(&home) + .join(".nvm") + .join("versions") + .join("node"); + debug!("Checking NVM directory: {:?}", nvm_dir); - + if let Ok(entries) = std::fs::read_dir(&nvm_dir) { for entry in entries.flatten() { if entry.file_type().map(|t| t.is_dir()).unwrap_or(false) { let claude_path = entry.path().join("bin").join("claude"); - + if claude_path.exists() && claude_path.is_file() { let path_str = claude_path.to_string_lossy().to_string(); let node_version = entry.file_name().to_string_lossy().to_string(); - + debug!("Found Claude in NVM node {}: {}", node_version, path_str); - + // Get Claude version let version = get_claude_version(&path_str).ok().flatten(); - + installations.push(ClaudeInstallation { path: path_str, version, @@ -215,46 +221,64 @@ fn find_nvm_installations() -> Vec { } } } - + installations } /// Check standard installation paths fn find_standard_installations() -> Vec { let mut installations = Vec::new(); - + // Common installation paths for claude let mut paths_to_check: Vec<(String, String)> = vec![ ("/usr/local/bin/claude".to_string(), "system".to_string()), - ("/opt/homebrew/bin/claude".to_string(), "homebrew".to_string()), + ( + "/opt/homebrew/bin/claude".to_string(), + "homebrew".to_string(), + ), ("/usr/bin/claude".to_string(), "system".to_string()), ("/bin/claude".to_string(), "system".to_string()), ]; - + // Also check user-specific paths if let Ok(home) = std::env::var("HOME") { paths_to_check.extend(vec![ - (format!("{}/.claude/local/claude", home), "claude-local".to_string()), - (format!("{}/.local/bin/claude", home), "local-bin".to_string()), - (format!("{}/.npm-global/bin/claude", home), "npm-global".to_string()), + ( + format!("{}/.claude/local/claude", home), + "claude-local".to_string(), + ), + ( + format!("{}/.local/bin/claude", home), + "local-bin".to_string(), + ), + ( + format!("{}/.npm-global/bin/claude", home), + "npm-global".to_string(), + ), (format!("{}/.yarn/bin/claude", home), "yarn".to_string()), (format!("{}/.bun/bin/claude", home), "bun".to_string()), (format!("{}/bin/claude", home), "home-bin".to_string()), // Check common node_modules locations - (format!("{}/node_modules/.bin/claude", home), "node-modules".to_string()), - (format!("{}/.config/yarn/global/node_modules/.bin/claude", home), "yarn-global".to_string()), + ( + format!("{}/node_modules/.bin/claude", home), + "node-modules".to_string(), + ), + ( + format!("{}/.config/yarn/global/node_modules/.bin/claude", home), + "yarn-global".to_string(), + ), ]); } - + // Check each path for (path, source) in paths_to_check { let path_buf = PathBuf::from(&path); if path_buf.exists() && path_buf.is_file() { debug!("Found claude at standard path: {} ({})", path, source); - + // Get version let version = get_claude_version(&path).ok().flatten(); - + installations.push(ClaudeInstallation { path, version, @@ -262,13 +286,13 @@ fn find_standard_installations() -> Vec { }); } } - + // Also check if claude is available in PATH (without full path) if let Ok(output) = Command::new("claude").arg("--version").output() { if output.status.success() { debug!("claude is available in PATH"); let version = extract_version_from_output(&output.stdout); - + installations.push(ClaudeInstallation { path: "claude".to_string(), version, @@ -276,7 +300,7 @@ fn find_standard_installations() -> Vec { }); } } - + installations } @@ -300,13 +324,13 @@ fn get_claude_version(path: &str) -> Result, String> { /// Extract version string from command output fn extract_version_from_output(stdout: &[u8]) -> Option { let output_str = String::from_utf8_lossy(stdout); - + // Extract version: first token before whitespace that looks like a version - output_str.split_whitespace() + output_str + .split_whitespace() .find(|token| { // Version usually contains dots and numbers - token.chars().any(|c| c == '.') && - token.chars().any(|c| c.is_numeric()) + token.chars().any(|c| c == '.') && token.chars().any(|c| c.is_numeric()) }) .map(|s| s.to_string()) } @@ -320,34 +344,34 @@ fn select_best_installation(installations: Vec) -> Option compare_versions(v1, v2), - // Prefer the entry that actually has version information. - (Some(_), None) => Ordering::Greater, - (None, Some(_)) => Ordering::Less, - // Neither have version info: prefer the one that is not just - // the bare "claude" lookup from PATH, because that may fail - // at runtime if PATH is sandbox-stripped. - (None, None) => { - if a.path == "claude" && b.path != "claude" { - Ordering::Less - } else if a.path != "claude" && b.path == "claude" { - Ordering::Greater - } else { - Ordering::Equal - } + installations.into_iter().max_by(|a, b| { + match (&a.version, &b.version) { + // If both have versions, compare them semantically. + (Some(v1), Some(v2)) => compare_versions(v1, v2), + // Prefer the entry that actually has version information. + (Some(_), None) => Ordering::Greater, + (None, Some(_)) => Ordering::Less, + // Neither have version info: prefer the one that is not just + // the bare "claude" lookup from PATH, because that may fail + // at runtime if PATH is sandbox-stripped. + (None, None) => { + if a.path == "claude" && b.path != "claude" { + Ordering::Less + } else if a.path != "claude" && b.path == "claude" { + Ordering::Greater + } else { + Ordering::Equal } } - }) + } + }) } /// Compare two version strings fn compare_versions(a: &str, b: &str) -> Ordering { // Simple semantic version comparison - let a_parts: Vec = a.split('.') + let a_parts: Vec = a + .split('.') .filter_map(|s| { // Handle versions like "1.0.17-beta" by taking only numeric part s.chars() @@ -357,8 +381,9 @@ fn compare_versions(a: &str, b: &str) -> Ordering { .ok() }) .collect(); - - let b_parts: Vec = b.split('.') + + let b_parts: Vec = b + .split('.') .filter_map(|s| { s.chars() .take_while(|c| c.is_numeric()) @@ -367,7 +392,7 @@ fn compare_versions(a: &str, b: &str) -> Ordering { .ok() }) .collect(); - + // Compare each part for i in 0..std::cmp::max(a_parts.len(), b_parts.len()) { let a_val = a_parts.get(i).unwrap_or(&0); @@ -377,7 +402,7 @@ fn compare_versions(a: &str, b: &str) -> Ordering { other => return other, } } - + Ordering::Equal } @@ -385,19 +410,28 @@ fn compare_versions(a: &str, b: &str) -> Ordering { /// This ensures commands like Claude can find Node.js and other dependencies pub fn create_command_with_env(program: &str) -> Command { let mut cmd = Command::new(program); - + // Inherit essential environment variables from parent process for (key, value) in std::env::vars() { // Pass through PATH and other essential environment variables - if key == "PATH" || key == "HOME" || key == "USER" - || key == "SHELL" || key == "LANG" || key == "LC_ALL" || key.starts_with("LC_") - || key == "NODE_PATH" || key == "NVM_DIR" || key == "NVM_BIN" - || key == "HOMEBREW_PREFIX" || key == "HOMEBREW_CELLAR" { + if key == "PATH" + || key == "HOME" + || key == "USER" + || key == "SHELL" + || key == "LANG" + || key == "LC_ALL" + || key.starts_with("LC_") + || key == "NODE_PATH" + || key == "NVM_DIR" + || key == "NVM_BIN" + || key == "HOMEBREW_PREFIX" + || key == "HOMEBREW_CELLAR" + { debug!("Inheriting env var: {}={}", key, value); cmd.env(&key, &value); } } - + // Add NVM support if the program is in an NVM directory if program.contains("/.nvm/versions/node/") { if let Some(node_bin_dir) = std::path::Path::new(program).parent() { @@ -411,6 +445,6 @@ pub fn create_command_with_env(program: &str) -> Command { } } } - + cmd -} \ No newline at end of file +} diff --git a/src-tauri/src/commands/agents.rs b/src-tauri/src/commands/agents.rs index 7f61f1d..b903f40 100644 --- a/src-tauri/src/commands/agents.rs +++ b/src-tauri/src/commands/agents.rs @@ -2,16 +2,16 @@ use crate::sandbox::profile::ProfileBuilder; use anyhow::Result; use chrono; use log::{debug, error, info, warn}; +use reqwest; use rusqlite::{params, Connection, Result as SqliteResult}; use serde::{Deserialize, Serialize}; use serde_json::Value as JsonValue; use std::path::PathBuf; use std::process::Stdio; use std::sync::Mutex; -use tauri::{AppHandle, Manager, State, Emitter}; +use tauri::{AppHandle, Emitter, Manager, State}; use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; -use reqwest; /// Finds the full path to the claude binary /// This is necessary because macOS apps have a limited PATH environment @@ -47,7 +47,7 @@ pub struct AgentRun { pub model: String, pub project_path: String, pub session_id: String, // UUID session ID from Claude Code - pub status: String, // 'pending', 'running', 'completed', 'failed', 'cancelled' + pub status: String, // 'pending', 'running', 'completed', 'failed', 'cancelled' pub pid: Option, pub process_started_at: Option, pub created_at: String, @@ -125,14 +125,16 @@ impl AgentRunMetrics { } // Extract token usage - check both top-level and nested message.usage - let usage = json.get("usage") + let usage = json + .get("usage") .or_else(|| json.get("message").and_then(|m| m.get("usage"))); - + if let Some(usage) = usage { if let Some(input_tokens) = usage.get("input_tokens").and_then(|t| t.as_i64()) { total_tokens += input_tokens; } - if let Some(output_tokens) = usage.get("output_tokens").and_then(|t| t.as_i64()) { + if let Some(output_tokens) = usage.get("output_tokens").and_then(|t| t.as_i64()) + { total_tokens += output_tokens; } } @@ -151,9 +153,17 @@ impl AgentRunMetrics { Self { duration_ms, - total_tokens: if total_tokens > 0 { Some(total_tokens) } else { None }, + total_tokens: if total_tokens > 0 { + Some(total_tokens) + } else { + None + }, cost_usd: if cost_usd > 0.0 { Some(cost_usd) } else { None }, - message_count: if message_count > 0 { Some(message_count) } else { None }, + message_count: if message_count > 0 { + Some(message_count) + } else { + None + }, } } } @@ -171,7 +181,10 @@ pub async fn read_session_jsonl(session_id: &str, project_path: &str) -> Result< let session_file = project_dir.join(format!("{}.jsonl", session_id)); if !session_file.exists() { - return Err(format!("Session file not found: {}", session_file.display())); + return Err(format!( + "Session file not found: {}", + session_file.display() + )); } match tokio::fs::read_to_string(&session_file).await { @@ -204,12 +217,15 @@ pub async fn get_agent_run_with_metrics(run: AgentRun) -> AgentRunWithMetrics { /// Initialize the agents database pub fn init_database(app: &AppHandle) -> SqliteResult { - let app_dir = app.path().app_data_dir().expect("Failed to get app data dir"); + let app_dir = app + .path() + .app_data_dir() + .expect("Failed to get app data dir"); std::fs::create_dir_all(&app_dir).expect("Failed to create app data dir"); - + let db_path = app_dir.join("agents.db"); let conn = Connection::open(db_path)?; - + // Create agents table conn.execute( "CREATE TABLE IF NOT EXISTS agents ( @@ -228,16 +244,34 @@ pub fn init_database(app: &AppHandle) -> SqliteResult { )", [], )?; - + // Add columns to existing table if they don't exist let _ = conn.execute("ALTER TABLE agents ADD COLUMN default_task TEXT", []); - let _ = conn.execute("ALTER TABLE agents ADD COLUMN model TEXT DEFAULT 'sonnet'", []); - let _ = conn.execute("ALTER TABLE agents ADD COLUMN sandbox_profile_id INTEGER REFERENCES sandbox_profiles(id)", []); - let _ = conn.execute("ALTER TABLE agents ADD COLUMN sandbox_enabled BOOLEAN DEFAULT 1", []); - let _ = conn.execute("ALTER TABLE agents ADD COLUMN enable_file_read BOOLEAN DEFAULT 1", []); - let _ = conn.execute("ALTER TABLE agents ADD COLUMN enable_file_write BOOLEAN DEFAULT 1", []); - let _ = conn.execute("ALTER TABLE agents ADD COLUMN enable_network BOOLEAN DEFAULT 0", []); - + let _ = conn.execute( + "ALTER TABLE agents ADD COLUMN model TEXT DEFAULT 'sonnet'", + [], + ); + let _ = conn.execute( + "ALTER TABLE agents ADD COLUMN sandbox_profile_id INTEGER REFERENCES sandbox_profiles(id)", + [], + ); + let _ = conn.execute( + "ALTER TABLE agents ADD COLUMN sandbox_enabled BOOLEAN DEFAULT 1", + [], + ); + let _ = conn.execute( + "ALTER TABLE agents ADD COLUMN enable_file_read BOOLEAN DEFAULT 1", + [], + ); + let _ = conn.execute( + "ALTER TABLE agents ADD COLUMN enable_file_write BOOLEAN DEFAULT 1", + [], + ); + let _ = conn.execute( + "ALTER TABLE agents ADD COLUMN enable_network BOOLEAN DEFAULT 0", + [], + ); + // Create agent_runs table conn.execute( "CREATE TABLE IF NOT EXISTS agent_runs ( @@ -261,17 +295,29 @@ pub fn init_database(app: &AppHandle) -> SqliteResult { // Migrate existing agent_runs table if needed let _ = conn.execute("ALTER TABLE agent_runs ADD COLUMN session_id TEXT", []); - let _ = conn.execute("ALTER TABLE agent_runs ADD COLUMN status TEXT DEFAULT 'pending'", []); + let _ = conn.execute( + "ALTER TABLE agent_runs ADD COLUMN status TEXT DEFAULT 'pending'", + [], + ); let _ = conn.execute("ALTER TABLE agent_runs ADD COLUMN pid INTEGER", []); - let _ = conn.execute("ALTER TABLE agent_runs ADD COLUMN process_started_at TEXT", []); - + let _ = conn.execute( + "ALTER TABLE agent_runs ADD COLUMN process_started_at TEXT", + [], + ); + // Drop old columns that are no longer needed (data is now read from JSONL files) // Note: SQLite doesn't support DROP COLUMN, so we'll ignore errors for existing columns - let _ = conn.execute("UPDATE agent_runs SET session_id = '' WHERE session_id IS NULL", []); + let _ = conn.execute( + "UPDATE agent_runs SET session_id = '' WHERE session_id IS NULL", + [], + ); let _ = conn.execute("UPDATE agent_runs SET status = 'completed' WHERE status IS NULL AND completed_at IS NOT NULL", []); let _ = conn.execute("UPDATE agent_runs SET status = 'failed' WHERE status IS NULL AND completed_at IS NOT NULL AND session_id = ''", []); - let _ = conn.execute("UPDATE agent_runs SET status = 'pending' WHERE status IS NULL", []); - + let _ = conn.execute( + "UPDATE agent_runs SET status = 'pending' WHERE status IS NULL", + [], + ); + // Create trigger to update the updated_at timestamp conn.execute( "CREATE TRIGGER IF NOT EXISTS update_agent_timestamp @@ -282,7 +328,7 @@ pub fn init_database(app: &AppHandle) -> SqliteResult { END", [], )?; - + // Create sandbox profiles table conn.execute( "CREATE TABLE IF NOT EXISTS sandbox_profiles ( @@ -296,7 +342,7 @@ pub fn init_database(app: &AppHandle) -> SqliteResult { )", [], )?; - + // Create sandbox rules table conn.execute( "CREATE TABLE IF NOT EXISTS sandbox_rules ( @@ -312,7 +358,7 @@ pub fn init_database(app: &AppHandle) -> SqliteResult { )", [], )?; - + // Create trigger to update sandbox profile timestamp conn.execute( "CREATE TRIGGER IF NOT EXISTS update_sandbox_profile_timestamp @@ -323,7 +369,7 @@ pub fn init_database(app: &AppHandle) -> SqliteResult { END", [], )?; - + // Create sandbox violations table conn.execute( "CREATE TABLE IF NOT EXISTS sandbox_violations ( @@ -342,17 +388,17 @@ pub fn init_database(app: &AppHandle) -> SqliteResult { )", [], )?; - + // Create index for efficient querying conn.execute( "CREATE INDEX IF NOT EXISTS idx_sandbox_violations_denied_at ON sandbox_violations(denied_at DESC)", [], )?; - + // Create default sandbox profiles if they don't exist crate::sandbox::defaults::create_default_profiles(&conn)?; - + // Create settings table for app-wide settings conn.execute( "CREATE TABLE IF NOT EXISTS app_settings ( @@ -363,7 +409,7 @@ pub fn init_database(app: &AppHandle) -> SqliteResult { )", [], )?; - + // Create trigger to update the updated_at timestamp conn.execute( "CREATE TRIGGER IF NOT EXISTS update_app_settings_timestamp @@ -374,7 +420,7 @@ pub fn init_database(app: &AppHandle) -> SqliteResult { END", [], )?; - + Ok(conn) } @@ -382,11 +428,11 @@ pub fn init_database(app: &AppHandle) -> SqliteResult { #[tauri::command] pub async fn list_agents(db: State<'_, AgentDb>) -> Result, String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + let mut stmt = conn .prepare("SELECT id, name, icon, system_prompt, default_task, model, sandbox_enabled, enable_file_read, enable_file_write, enable_network, created_at, updated_at FROM agents ORDER BY created_at DESC") .map_err(|e| e.to_string())?; - + let agents = stmt .query_map([], |row| { Ok(Agent { @@ -395,7 +441,9 @@ pub async fn list_agents(db: State<'_, AgentDb>) -> Result, String> { icon: row.get(2)?, system_prompt: row.get(3)?, default_task: row.get(4)?, - model: row.get::<_, String>(5).unwrap_or_else(|_| "sonnet".to_string()), + model: row + .get::<_, String>(5) + .unwrap_or_else(|_| "sonnet".to_string()), sandbox_enabled: row.get::<_, bool>(6).unwrap_or(true), enable_file_read: row.get::<_, bool>(7).unwrap_or(true), enable_file_write: row.get::<_, bool>(8).unwrap_or(true), @@ -407,7 +455,7 @@ pub async fn list_agents(db: State<'_, AgentDb>) -> Result, String> { .map_err(|e| e.to_string())? .collect::, _>>() .map_err(|e| e.to_string())?; - + Ok(agents) } @@ -431,15 +479,15 @@ pub async fn create_agent( let enable_file_read = enable_file_read.unwrap_or(true); let enable_file_write = enable_file_write.unwrap_or(true); let enable_network = enable_network.unwrap_or(false); - + conn.execute( "INSERT INTO agents (name, icon, system_prompt, default_task, model, sandbox_enabled, enable_file_read, enable_file_write, enable_network) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", params![name, icon, system_prompt, default_task, model, sandbox_enabled, enable_file_read, enable_file_write, enable_network], ) .map_err(|e| e.to_string())?; - + let id = conn.last_insert_rowid(); - + // Fetch the created agent let agent = conn .query_row( @@ -463,7 +511,7 @@ pub async fn create_agent( }, ) .map_err(|e| e.to_string())?; - + Ok(agent) } @@ -484,9 +532,11 @@ pub async fn update_agent( ) -> Result { let conn = db.0.lock().map_err(|e| e.to_string())?; let model = model.unwrap_or_else(|| "sonnet".to_string()); - + // Build dynamic query based on provided parameters - let mut query = "UPDATE agents SET name = ?1, icon = ?2, system_prompt = ?3, default_task = ?4, model = ?5".to_string(); + let mut query = + "UPDATE agents SET name = ?1, icon = ?2, system_prompt = ?3, default_task = ?4, model = ?5" + .to_string(); let mut params_vec: Vec> = vec![ Box::new(name), Box::new(icon), @@ -495,7 +545,7 @@ pub async fn update_agent( Box::new(model), ]; let mut param_count = 5; - + if let Some(se) = sandbox_enabled { param_count += 1; query.push_str(&format!(", sandbox_enabled = ?{}", param_count)); @@ -516,14 +566,17 @@ pub async fn update_agent( query.push_str(&format!(", enable_network = ?{}", param_count)); params_vec.push(Box::new(en)); } - + param_count += 1; query.push_str(&format!(" WHERE id = ?{}", param_count)); params_vec.push(Box::new(id)); - - conn.execute(&query, rusqlite::params_from_iter(params_vec.iter().map(|p| p.as_ref()))) - .map_err(|e| e.to_string())?; - + + conn.execute( + &query, + rusqlite::params_from_iter(params_vec.iter().map(|p| p.as_ref())), + ) + .map_err(|e| e.to_string())?; + // Fetch the updated agent let agent = conn .query_row( @@ -547,7 +600,7 @@ pub async fn update_agent( }, ) .map_err(|e| e.to_string())?; - + Ok(agent) } @@ -555,10 +608,10 @@ pub async fn update_agent( #[tauri::command] pub async fn delete_agent(db: State<'_, AgentDb>, id: i64) -> Result<(), String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + conn.execute("DELETE FROM agents WHERE id = ?1", params![id]) .map_err(|e| e.to_string())?; - + Ok(()) } @@ -566,7 +619,7 @@ pub async fn delete_agent(db: State<'_, AgentDb>, id: i64) -> Result<(), String> #[tauri::command] pub async fn get_agent(db: State<'_, AgentDb>, id: i64) -> Result { let conn = db.0.lock().map_err(|e| e.to_string())?; - + let agent = conn .query_row( "SELECT id, name, icon, system_prompt, default_task, model, sandbox_enabled, enable_file_read, enable_file_write, enable_network, created_at, updated_at FROM agents WHERE id = ?1", @@ -589,7 +642,7 @@ pub async fn get_agent(db: State<'_, AgentDb>, id: i64) -> Result }, ) .map_err(|e| e.to_string())?; - + Ok(agent) } @@ -600,7 +653,7 @@ pub async fn list_agent_runs( agent_id: Option, ) -> Result, String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + let query = if agent_id.is_some() { "SELECT id, agent_id, agent_name, agent_icon, task, model, project_path, session_id, status, pid, process_started_at, created_at, completed_at FROM agent_runs WHERE agent_id = ?1 ORDER BY created_at DESC" @@ -608,9 +661,9 @@ pub async fn list_agent_runs( "SELECT id, agent_id, agent_name, agent_icon, task, model, project_path, session_id, status, pid, process_started_at, created_at, completed_at FROM agent_runs ORDER BY created_at DESC" }; - + let mut stmt = conn.prepare(query).map_err(|e| e.to_string())?; - + let run_mapper = |row: &rusqlite::Row| -> rusqlite::Result { Ok(AgentRun { id: Some(row.get(0)?), @@ -621,14 +674,20 @@ pub async fn list_agent_runs( model: row.get(5)?, project_path: row.get(6)?, session_id: row.get(7)?, - status: row.get::<_, String>(8).unwrap_or_else(|_| "pending".to_string()), - pid: row.get::<_, Option>(9).ok().flatten().map(|p| p as u32), + status: row + .get::<_, String>(8) + .unwrap_or_else(|_| "pending".to_string()), + pid: row + .get::<_, Option>(9) + .ok() + .flatten() + .map(|p| p as u32), process_started_at: row.get(10)?, created_at: row.get(11)?, completed_at: row.get(12)?, }) }; - + let runs = if let Some(aid) = agent_id { stmt.query_map(params![aid], run_mapper) } else { @@ -637,7 +696,7 @@ pub async fn list_agent_runs( .map_err(|e| e.to_string())? .collect::, _>>() .map_err(|e| e.to_string())?; - + Ok(runs) } @@ -645,7 +704,7 @@ pub async fn list_agent_runs( #[tauri::command] pub async fn get_agent_run(db: State<'_, AgentDb>, id: i64) -> Result { let conn = db.0.lock().map_err(|e| e.to_string())?; - + let run = conn .query_row( "SELECT id, agent_id, agent_name, agent_icon, task, model, project_path, session_id, status, pid, process_started_at, created_at, completed_at @@ -670,13 +729,16 @@ pub async fn get_agent_run(db: State<'_, AgentDb>, id: i64) -> Result, id: i64) -> Result { +pub async fn get_agent_run_with_real_time_metrics( + db: State<'_, AgentDb>, + id: i64, +) -> Result { let run = get_agent_run(db, id).await?; Ok(get_agent_run_with_metrics(run).await) } @@ -689,12 +751,12 @@ pub async fn list_agent_runs_with_metrics( ) -> Result, String> { let runs = list_agent_runs(db, agent_id).await?; let mut runs_with_metrics = Vec::new(); - + for run in runs { let run_with_metrics = get_agent_run_with_metrics(run).await; runs_with_metrics.push(run_with_metrics); } - + Ok(runs_with_metrics) } @@ -710,11 +772,11 @@ pub async fn execute_agent( registry: State<'_, crate::process::ProcessRegistryState>, ) -> Result { info!("Executing agent {} with task: {}", agent_id, task); - + // Get the agent from database let agent = get_agent(db.clone(), agent_id).await?; let execution_model = model.unwrap_or(agent.model.clone()); - + // Create a new run record let run_id = { let conn = db.0.lock().map_err(|e| e.to_string())?; @@ -725,18 +787,20 @@ pub async fn execute_agent( .map_err(|e| e.to_string())?; conn.last_insert_rowid() }; - + // Create sandbox rules based on agent-specific permissions (no database dependency) let sandbox_profile = if !agent.sandbox_enabled { info!("🔓 Agent '{}': Sandbox DISABLED", agent.name); None } else { - info!("🔒 Agent '{}': Sandbox enabled | File Read: {} | File Write: {} | Network: {}", - agent.name, agent.enable_file_read, agent.enable_file_write, agent.enable_network); - + info!( + "🔒 Agent '{}': Sandbox enabled | File Read: {} | File Write: {} | Network: {}", + agent.name, agent.enable_file_read, agent.enable_file_write, agent.enable_network + ); + // Create rules dynamically based on agent permissions let mut rules = Vec::new(); - + // Add file read rules if enabled if agent.enable_file_read { // Project directory access @@ -750,7 +814,7 @@ pub async fn execute_agent( platform_support: Some(r#"["linux", "macos", "windows"]"#.to_string()), created_at: String::new(), }); - + // System libraries (for language runtimes, etc.) rules.push(crate::sandbox::profile::SandboxRule { id: Some(2), @@ -762,7 +826,7 @@ pub async fn execute_agent( platform_support: Some(r#"["linux", "macos"]"#.to_string()), created_at: String::new(), }); - + rules.push(crate::sandbox::profile::SandboxRule { id: Some(3), profile_id: 0, @@ -773,7 +837,7 @@ pub async fn execute_agent( platform_support: Some(r#"["linux", "macos"]"#.to_string()), created_at: String::new(), }); - + rules.push(crate::sandbox::profile::SandboxRule { id: Some(4), profile_id: 0, @@ -784,7 +848,7 @@ pub async fn execute_agent( platform_support: Some(r#"["macos"]"#.to_string()), created_at: String::new(), }); - + rules.push(crate::sandbox::profile::SandboxRule { id: Some(5), profile_id: 0, @@ -796,7 +860,7 @@ pub async fn execute_agent( created_at: String::new(), }); } - + // Add network rules if enabled if agent.enable_network { rules.push(crate::sandbox::profile::SandboxRule { @@ -810,7 +874,7 @@ pub async fn execute_agent( created_at: String::new(), }); } - + // Always add essential system paths (needed for executables to run) rules.push(crate::sandbox::profile::SandboxRule { id: Some(7), @@ -822,7 +886,7 @@ pub async fn execute_agent( platform_support: Some(r#"["linux", "macos"]"#.to_string()), created_at: String::new(), }); - + rules.push(crate::sandbox::profile::SandboxRule { id: Some(8), profile_id: 0, @@ -833,7 +897,7 @@ pub async fn execute_agent( platform_support: Some(r#"["macos"]"#.to_string()), created_at: String::new(), }); - + rules.push(crate::sandbox::profile::SandboxRule { id: Some(9), profile_id: 0, @@ -844,7 +908,7 @@ pub async fn execute_agent( platform_support: Some(r#"["linux", "macos"]"#.to_string()), created_at: String::new(), }); - + rules.push(crate::sandbox::profile::SandboxRule { id: Some(10), profile_id: 0, @@ -855,7 +919,7 @@ pub async fn execute_agent( platform_support: Some(r#"["linux", "macos"]"#.to_string()), created_at: String::new(), }); - + // System libraries (needed for executables to link) rules.push(crate::sandbox::profile::SandboxRule { id: Some(11), @@ -867,7 +931,7 @@ pub async fn execute_agent( platform_support: Some(r#"["linux", "macos"]"#.to_string()), created_at: String::new(), }); - + rules.push(crate::sandbox::profile::SandboxRule { id: Some(12), profile_id: 0, @@ -878,7 +942,7 @@ pub async fn execute_agent( platform_support: Some(r#"["macos"]"#.to_string()), created_at: String::new(), }); - + // Always add system info reading (minimal requirement) rules.push(crate::sandbox::profile::SandboxRule { id: Some(13), @@ -890,10 +954,10 @@ pub async fn execute_agent( platform_support: Some(r#"["linux", "macos"]"#.to_string()), created_at: String::new(), }); - + Some(("Agent-specific".to_string(), rules)) }; - + // Build the command let mut cmd = if let Some((_profile_name, rules)) = sandbox_profile { info!("🧪 DEBUG: Testing Claude command first without sandbox..."); @@ -905,10 +969,16 @@ pub async fn execute_agent( return Err(e); } }; - match std::process::Command::new(&claude_path).arg("--version").output() { + match std::process::Command::new(&claude_path) + .arg("--version") + .output() + { Ok(output) => { if output.status.success() { - info!("✅ Claude command works: {}", String::from_utf8_lossy(&output.stdout).trim()); + info!( + "✅ Claude command works: {}", + String::from_utf8_lossy(&output.stdout).trim() + ); } else { warn!("⚠️ Claude command failed with status: {}", output.status); warn!(" stdout: {}", String::from_utf8_lossy(&output.stdout)); @@ -920,11 +990,12 @@ pub async fn execute_agent( error!(" This could be why the agent is failing to start"); } } - + // Test if Claude can actually start a session (this might reveal auth issues) info!("🧪 Testing Claude with exact same arguments as agent (without sandbox env vars)..."); let mut test_cmd = std::process::Command::new(&claude_path); - test_cmd.arg("-p") + test_cmd + .arg("-p") .arg(&task) .arg("--system-prompt") .arg(&agent.system_prompt) @@ -935,17 +1006,17 @@ pub async fn execute_agent( .arg("--verbose") .arg("--dangerously-skip-permissions") .current_dir(&project_path); - + info!("🧪 Testing command: claude -p \"{}\" --system-prompt \"{}\" --model {} --output-format stream-json --verbose --dangerously-skip-permissions", task, agent.system_prompt, execution_model); - + // Start the test process and give it 5 seconds to produce output match test_cmd.spawn() { Ok(mut child) => { // Wait for 5 seconds to see if it produces output let start = std::time::Instant::now(); let mut output_received = false; - + while start.elapsed() < std::time::Duration::from_secs(5) { match child.try_wait() { Ok(Some(status)) => { @@ -963,7 +1034,7 @@ pub async fn execute_agent( } } } - + if !output_received { warn!("🧪 Test process is still running after 5 seconds - this suggests Claude might be waiting for input"); // Kill the test process @@ -977,49 +1048,54 @@ pub async fn execute_agent( error!("❌ Failed to spawn test Claude process: {}", e); } } - + info!("🧪 End of Claude test, proceeding with sandbox..."); - + // Build the gaol profile using agent-specific permissions let project_path_buf = PathBuf::from(&project_path); - + match ProfileBuilder::new(project_path_buf.clone()) { Ok(builder) => { // Build agent-specific profile with permission filtering match builder.build_agent_profile( - rules, - agent.sandbox_enabled, - agent.enable_file_read, - agent.enable_file_write, - agent.enable_network + rules, + agent.sandbox_enabled, + agent.enable_file_read, + agent.enable_file_write, + agent.enable_network, ) { Ok(build_result) => { - // Create the enhanced sandbox executor #[cfg(unix)] - let executor = crate::sandbox::executor::SandboxExecutor::new_with_serialization( - build_result.profile, - project_path_buf.clone(), - build_result.serialized - ); - + let executor = + crate::sandbox::executor::SandboxExecutor::new_with_serialization( + build_result.profile, + project_path_buf.clone(), + build_result.serialized, + ); + #[cfg(not(unix))] - let executor = crate::sandbox::executor::SandboxExecutor::new_with_serialization( - (), - project_path_buf.clone(), - build_result.serialized - ); - + let executor = + crate::sandbox::executor::SandboxExecutor::new_with_serialization( + (), + project_path_buf.clone(), + build_result.serialized, + ); + // Prepare the sandboxed command let args = vec![ - "-p", &task, - "--system-prompt", &agent.system_prompt, - "--model", &execution_model, - "--output-format", "stream-json", + "-p", + &task, + "--system-prompt", + &agent.system_prompt, + "--model", + &execution_model, + "--output-format", + "stream-json", "--verbose", - "--dangerously-skip-permissions" + "--dangerously-skip-permissions", ]; - + let claude_path = match find_claude_binary(&app) { Ok(path) => path, Err(e) => { @@ -1057,8 +1133,11 @@ pub async fn execute_agent( } } Err(e) => { - error!("Failed to create ProfileBuilder: {}, falling back to non-sandboxed", e); - + error!( + "Failed to create ProfileBuilder: {}, falling back to non-sandboxed", + e + ); + // Fall back to non-sandboxed command let claude_path = match find_claude_binary(&app) { Ok(path) => path, @@ -1086,7 +1165,10 @@ pub async fn execute_agent( } } else { // No sandbox or sandbox disabled, use regular command - warn!("🚨 Running agent '{}' WITHOUT SANDBOX - full system access!", agent.name); + warn!( + "🚨 Running agent '{}' WITHOUT SANDBOX - full system access!", + agent.name + ); let claude_path = match find_claude_binary(&app) { Ok(path) => path, Err(e) => { @@ -1106,26 +1188,26 @@ pub async fn execute_agent( .arg("--verbose") .arg("--dangerously-skip-permissions") .current_dir(&project_path) - .stdin(Stdio::null()) // Don't pipe stdin - we have no input to send + .stdin(Stdio::null()) // Don't pipe stdin - we have no input to send .stdout(Stdio::piped()) .stderr(Stdio::piped()); cmd }; - + // Spawn the process info!("🚀 Spawning Claude process..."); let mut child = cmd.spawn().map_err(|e| { error!("❌ Failed to spawn Claude process: {}", e); format!("Failed to spawn Claude: {}", e) })?; - + info!("🔌 Using Stdio::null() for stdin - no input expected"); - + // Get the PID and register the process let pid = child.id().unwrap_or(0); let now = chrono::Utc::now().to_rfc3339(); info!("✅ Claude process spawned successfully with PID: {}", pid); - + // Update the database with PID and status { let conn = db.0.lock().map_err(|e| e.to_string())?; @@ -1135,21 +1217,21 @@ pub async fn execute_agent( ).map_err(|e| e.to_string())?; info!("📝 Updated database with running status and PID"); } - + // Get stdout and stderr let stdout = child.stdout.take().ok_or("Failed to get stdout")?; let stderr = child.stderr.take().ok_or("Failed to get stderr")?; info!("📡 Set up stdout/stderr readers"); - + // Create readers let stdout_reader = BufReader::new(stdout); let stderr_reader = BufReader::new(stderr); - + // Shared state for collecting session ID and live output let session_id = std::sync::Arc::new(Mutex::new(String::new())); let live_output = std::sync::Arc::new(Mutex::new(String::new())); let start_time = std::time::Instant::now(); - + // Spawn tasks to read stdout and stderr let app_handle = app.clone(); let session_id_clone = session_id.clone(); @@ -1157,36 +1239,39 @@ pub async fn execute_agent( let registry_clone = registry.0.clone(); let first_output = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); let first_output_clone = first_output.clone(); - + let stdout_task = tokio::spawn(async move { info!("📖 Starting to read Claude stdout..."); let mut lines = stdout_reader.lines(); let mut line_count = 0; - + while let Ok(Some(line)) = lines.next_line().await { line_count += 1; - + // Log first output if !first_output_clone.load(std::sync::atomic::Ordering::Relaxed) { - info!("🎉 First output received from Claude process! Line: {}", line); + info!( + "🎉 First output received from Claude process! Line: {}", + line + ); first_output_clone.store(true, std::sync::atomic::Ordering::Relaxed); } - + if line_count <= 5 { info!("stdout[{}]: {}", line_count, line); } else { debug!("stdout[{}]: {}", line_count, line); } - + // Store live output in both local buffer and registry if let Ok(mut output) = live_output_clone.lock() { output.push_str(&line); output.push('\n'); } - + // Also store in process registry for cross-session access let _ = registry_clone.append_live_output(run_id, &line); - + // Extract session ID from JSONL output if let Ok(json) = serde_json::from_str::(&line) { if let Some(sid) = json.get("sessionId").and_then(|s| s.as_str()) { @@ -1198,84 +1283,103 @@ pub async fn execute_agent( } } } - + // Emit the line to the frontend with run_id for isolation let _ = app_handle.emit(&format!("agent-output:{}", run_id), &line); // Also emit to the generic event for backward compatibility let _ = app_handle.emit("agent-output", &line); } - - info!("📖 Finished reading Claude stdout. Total lines: {}", line_count); + + info!( + "📖 Finished reading Claude stdout. Total lines: {}", + line_count + ); }); - + let app_handle_stderr = app.clone(); let first_error = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); let first_error_clone = first_error.clone(); - + let stderr_task = tokio::spawn(async move { info!("📖 Starting to read Claude stderr..."); let mut lines = stderr_reader.lines(); let mut error_count = 0; - + while let Ok(Some(line)) = lines.next_line().await { error_count += 1; - + // Log first error if !first_error_clone.load(std::sync::atomic::Ordering::Relaxed) { warn!("⚠️ First error output from Claude process! Line: {}", line); first_error_clone.store(true, std::sync::atomic::Ordering::Relaxed); } - + error!("stderr[{}]: {}", error_count, line); // Emit error lines to the frontend with run_id for isolation let _ = app_handle_stderr.emit(&format!("agent-error:{}", run_id), &line); // Also emit to the generic event for backward compatibility let _ = app_handle_stderr.emit("agent-error", &line); } - + if error_count > 0 { - warn!("📖 Finished reading Claude stderr. Total error lines: {}", error_count); + warn!( + "📖 Finished reading Claude stderr. Total error lines: {}", + error_count + ); } else { info!("📖 Finished reading Claude stderr. No errors."); } }); - + // Register the process in the registry for live output tracking (after stdout/stderr setup) - registry.0.register_process( - run_id, - agent_id, - agent.name.clone(), - pid, - project_path.clone(), - task.clone(), - execution_model.clone(), - child, - ).map_err(|e| format!("Failed to register process: {}", e))?; + registry + .0 + .register_process( + run_id, + agent_id, + agent.name.clone(), + pid, + project_path.clone(), + task.clone(), + execution_model.clone(), + child, + ) + .map_err(|e| format!("Failed to register process: {}", e))?; info!("📋 Registered process in registry"); - + // Create variables we need for the spawned task - let app_dir = app.path().app_data_dir().expect("Failed to get app data dir"); + let app_dir = app + .path() + .app_data_dir() + .expect("Failed to get app data dir"); let db_path = app_dir.join("agents.db"); - + // Monitor process status and wait for completion tokio::spawn(async move { info!("🕐 Starting process monitoring..."); - + // Wait for first output with timeout - for i in 0..300 { // 30 seconds (300 * 100ms) + for i in 0..300 { + // 30 seconds (300 * 100ms) if first_output.load(std::sync::atomic::Ordering::Relaxed) { - info!("✅ Output detected after {}ms, continuing normal execution", i * 100); + info!( + "✅ Output detected after {}ms, continuing normal execution", + i * 100 + ); break; } - + tokio::time::sleep(std::time::Duration::from_millis(100)).await; - + // Log progress every 5 seconds if i > 0 && i % 50 == 0 { - info!("⏳ Still waiting for Claude output... ({}s elapsed)", i / 10); + info!( + "⏳ Still waiting for Claude output... ({}s elapsed)", + i / 10 + ); } } - + // Check if we timed out if !first_output.load(std::sync::atomic::Ordering::Relaxed) { warn!("⏰ TIMEOUT: No output from Claude process after 30 seconds"); @@ -1285,14 +1389,17 @@ pub async fn execute_agent( warn!(" 3. Claude failed to initialize but didn't report an error"); warn!(" 4. Network connectivity issues"); warn!(" 5. Authentication issues (API key not found/invalid)"); - + // Process timed out - kill it via PID - warn!("🔍 Process likely stuck waiting for input, attempting to kill PID: {}", pid); + warn!( + "🔍 Process likely stuck waiting for input, attempting to kill PID: {}", + pid + ); let kill_result = std::process::Command::new("kill") .arg("-TERM") .arg(pid.to_string()) .output(); - + match kill_result { Ok(output) if output.status.success() => { warn!("🔍 Successfully sent TERM signal to process"); @@ -1308,7 +1415,7 @@ pub async fn execute_agent( warn!("🔍 Error killing process: {}", e); } } - + // Update database if let Ok(conn) = Connection::open(&db_path) { let _ = conn.execute( @@ -1316,30 +1423,30 @@ pub async fn execute_agent( params![run_id], ); } - + let _ = app.emit("agent-complete", false); let _ = app.emit(&format!("agent-complete:{}", run_id), false); return; } - + // Wait for reading tasks to complete info!("⏳ Waiting for stdout/stderr reading to complete..."); let _ = stdout_task.await; let _ = stderr_task.await; - + let duration_ms = start_time.elapsed().as_millis() as i64; info!("⏱️ Process execution took {} ms", duration_ms); - + // Get the session ID that was extracted let extracted_session_id = if let Ok(sid) = session_id.lock() { sid.clone() } else { String::new() }; - + // Wait for process completion and update status info!("✅ Claude process execution monitoring complete"); - + // Update the run record with session ID and mark as completed - open a new connection if let Ok(conn) = Connection::open(&db_path) { let _ = conn.execute( @@ -1347,49 +1454,54 @@ pub async fn execute_agent( params![extracted_session_id, run_id], ); } - + // Cleanup will be handled by the cleanup_finished_processes function - + let _ = app.emit("agent-complete", true); let _ = app.emit(&format!("agent-complete:{}", run_id), true); }); - + Ok(run_id) } /// List all currently running agent sessions #[tauri::command] -pub async fn list_running_sessions( - db: State<'_, AgentDb>, -) -> Result, String> { +pub async fn list_running_sessions(db: State<'_, AgentDb>) -> Result, String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + let mut stmt = conn.prepare( "SELECT id, agent_id, agent_name, agent_icon, task, model, project_path, session_id, status, pid, process_started_at, created_at, completed_at FROM agent_runs WHERE status = 'running' ORDER BY process_started_at DESC" ).map_err(|e| e.to_string())?; - - let runs = stmt.query_map([], |row| { - Ok(AgentRun { - id: Some(row.get(0)?), - agent_id: row.get(1)?, - agent_name: row.get(2)?, - agent_icon: row.get(3)?, - task: row.get(4)?, - model: row.get(5)?, - project_path: row.get(6)?, - session_id: row.get(7)?, - status: row.get::<_, String>(8).unwrap_or_else(|_| "pending".to_string()), - pid: row.get::<_, Option>(9).ok().flatten().map(|p| p as u32), - process_started_at: row.get(10)?, - created_at: row.get(11)?, - completed_at: row.get(12)?, + + let runs = stmt + .query_map([], |row| { + Ok(AgentRun { + id: Some(row.get(0)?), + agent_id: row.get(1)?, + agent_name: row.get(2)?, + agent_icon: row.get(3)?, + task: row.get(4)?, + model: row.get(5)?, + project_path: row.get(6)?, + session_id: row.get(7)?, + status: row + .get::<_, String>(8) + .unwrap_or_else(|_| "pending".to_string()), + pid: row + .get::<_, Option>(9) + .ok() + .flatten() + .map(|p| p as u32), + process_started_at: row.get(10)?, + created_at: row.get(11)?, + completed_at: row.get(12)?, + }) }) - }) - .map_err(|e| e.to_string())? - .collect::, _>>() - .map_err(|e| e.to_string())?; - + .map_err(|e| e.to_string())? + .collect::, _>>() + .map_err(|e| e.to_string())?; + Ok(runs) } @@ -1402,7 +1514,7 @@ pub async fn kill_agent_session( run_id: i64, ) -> Result { info!("Attempting to kill agent session {}", run_id); - + // First try to kill using the process registry let killed_via_registry = match registry.0.kill_process(run_id).await { Ok(success) => { @@ -1419,7 +1531,7 @@ pub async fn kill_agent_session( false } }; - + // If registry kill didn't work, try fallback with PID from database if !killed_via_registry { let pid_result = { @@ -1427,27 +1539,27 @@ pub async fn kill_agent_session( conn.query_row( "SELECT pid FROM agent_runs WHERE id = ?1 AND status = 'running'", params![run_id], - |row| row.get::<_, Option>(0) + |row| row.get::<_, Option>(0), ) .map_err(|e| e.to_string())? }; - + if let Some(pid) = pid_result { info!("Attempting fallback kill for PID {} from database", pid); let _ = registry.0.kill_process_by_pid(run_id, pid as u32)?; } } - + // Update the database to mark as cancelled let conn = db.0.lock().map_err(|e| e.to_string())?; let updated = conn.execute( "UPDATE agent_runs SET status = 'cancelled', completed_at = CURRENT_TIMESTAMP WHERE id = ?1 AND status = 'running'", params![run_id], ).map_err(|e| e.to_string())?; - + // Emit cancellation event with run_id for proper isolation let _ = app.emit(&format!("agent-cancelled:{}", run_id), true); - + Ok(updated > 0 || killed_via_registry) } @@ -1458,11 +1570,11 @@ pub async fn get_session_status( run_id: i64, ) -> Result, String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + match conn.query_row( "SELECT status FROM agent_runs WHERE id = ?1", params![run_id], - |row| row.get::<_, String>(0) + |row| row.get::<_, String>(0), ) { Ok(status) => Ok(Some(status)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), @@ -1472,27 +1584,24 @@ pub async fn get_session_status( /// Cleanup finished processes and update their status #[tauri::command] -pub async fn cleanup_finished_processes( - db: State<'_, AgentDb>, -) -> Result, String> { +pub async fn cleanup_finished_processes(db: State<'_, AgentDb>) -> Result, String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + // Get all running processes - let mut stmt = conn.prepare( - "SELECT id, pid FROM agent_runs WHERE status = 'running' AND pid IS NOT NULL" - ).map_err(|e| e.to_string())?; - - let running_processes = stmt.query_map([], |row| { - Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)) - }) - .map_err(|e| e.to_string())? - .collect::, _>>() - .map_err(|e| e.to_string())?; - + let mut stmt = conn + .prepare("SELECT id, pid FROM agent_runs WHERE status = 'running' AND pid IS NOT NULL") + .map_err(|e| e.to_string())?; + + let running_processes = stmt + .query_map([], |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?))) + .map_err(|e| e.to_string())? + .collect::, _>>() + .map_err(|e| e.to_string())?; + drop(stmt); - + let mut cleaned_up = Vec::new(); - + for (run_id, pid) in running_processes { // Check if the process is still running let is_running = if cfg!(target_os = "windows") { @@ -1518,21 +1627,24 @@ pub async fn cleanup_finished_processes( Err(_) => false, } }; - + if !is_running { // Process has finished, update status let updated = conn.execute( "UPDATE agent_runs SET status = 'completed', completed_at = CURRENT_TIMESTAMP WHERE id = ?1", params![run_id], ).map_err(|e| e.to_string())?; - + if updated > 0 { cleaned_up.push(run_id); - info!("Marked agent run {} as completed (PID {} no longer running)", run_id, pid); + info!( + "Marked agent run {} as completed (PID {} no longer running)", + run_id, pid + ); } } } - + Ok(cleaned_up) } @@ -1554,7 +1666,7 @@ pub async fn get_session_output( ) -> Result { // Get the session information let run = get_agent_run(db, run_id).await?; - + // If no session ID yet, try to get live output from registry if run.session_id.is_empty() { let live_output = registry.0.get_live_output(run_id)?; @@ -1563,7 +1675,7 @@ pub async fn get_session_output( } return Ok(String::new()); } - + // Read the JSONL content match read_session_jsonl(&run.session_id, &run.project_path).await { Ok(content) => Ok(content), @@ -1584,38 +1696,39 @@ pub async fn stream_session_output( ) -> Result<(), String> { // Get the session information let run = get_agent_run(db, run_id).await?; - + // If no session ID yet, can't stream if run.session_id.is_empty() { return Err("Session not started yet".to_string()); } - + let session_id = run.session_id.clone(); let project_path = run.project_path.clone(); - + // Spawn a task to monitor the file tokio::spawn(async move { let claude_dir = match dirs::home_dir() { Some(home) => home.join(".claude").join("projects"), None => return, }; - + let encoded_project = project_path.replace('/', "-"); let project_dir = claude_dir.join(&encoded_project); let session_file = project_dir.join(format!("{}.jsonl", session_id)); - + let mut last_size = 0u64; - + // Monitor file changes continuously while session is running loop { if session_file.exists() { if let Ok(metadata) = tokio::fs::metadata(&session_file).await { let current_size = metadata.len(); - + if current_size > last_size { // File has grown, read new content if let Ok(content) = tokio::fs::read_to_string(&session_file).await { - let _ = app.emit("session-output-update", &format!("{}:{}", run_id, content)); + let _ = app + .emit("session-output-update", &format!("{}:{}", run_id, content)); } last_size = current_size; } @@ -1625,16 +1738,19 @@ pub async fn stream_session_output( tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; continue; } - + // Check if the session is still running by querying the database // If the session is no longer running, stop streaming if let Ok(conn) = rusqlite::Connection::open( - app.path().app_data_dir().expect("Failed to get app data dir").join("agents.db") + app.path() + .app_data_dir() + .expect("Failed to get app data dir") + .join("agents.db"), ) { if let Ok(status) = conn.query_row( "SELECT status FROM agent_runs WHERE id = ?1", rusqlite::params![run_id], - |row| row.get::<_, String>(0) + |row| row.get::<_, String>(0), ) { if status != "running" { debug!("Session {} is no longer running, stopping stream", run_id); @@ -1642,16 +1758,19 @@ pub async fn stream_session_output( } } else { // If we can't query the status, assume it's still running - debug!("Could not query session status for {}, continuing stream", run_id); + debug!( + "Could not query session status for {}, continuing stream", + run_id + ); } } - + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; } - + debug!("Stopped streaming for session {}", run_id); }); - + Ok(()) } @@ -1659,7 +1778,7 @@ pub async fn stream_session_output( #[tauri::command] pub async fn export_agent(db: State<'_, AgentDb>, id: i64) -> Result { let conn = db.0.lock().map_err(|e| e.to_string())?; - + // Fetch the agent let agent = conn .query_row( @@ -1680,14 +1799,14 @@ pub async fn export_agent(db: State<'_, AgentDb>, id: i64) -> Result, id: i64) -> Result, id: i64, file_path: String) -> Result<(), String> { +pub async fn export_agent_to_file( + db: State<'_, AgentDb>, + id: i64, + file_path: String, +) -> Result<(), String> { // Get the JSON data let json_data = export_agent(db, id).await?; - + // Write to file - std::fs::write(&file_path, json_data) - .map_err(|e| format!("Failed to write file: {}", e))?; - + std::fs::write(&file_path, json_data).map_err(|e| format!("Failed to write file: {}", e))?; + Ok(()) } @@ -1710,7 +1832,7 @@ pub async fn export_agent_to_file(db: State<'_, AgentDb>, id: i64, file_path: St #[tauri::command] pub async fn get_claude_binary_path(db: State<'_, AgentDb>) -> Result, String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + match conn.query_row( "SELECT value FROM app_settings WHERE key = 'claude_binary_path'", [], @@ -1726,13 +1848,13 @@ pub async fn get_claude_binary_path(db: State<'_, AgentDb>) -> Result, path: String) -> Result<(), String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + // Validate that the path exists and is executable let path_buf = std::path::PathBuf::from(&path); if !path_buf.exists() { return Err(format!("File does not exist: {}", path)); } - + // Check if it's executable (on Unix systems) #[cfg(unix)] { @@ -1744,26 +1866,28 @@ pub async fn set_claude_binary_path(db: State<'_, AgentDb>, path: String) -> Res return Err(format!("File is not executable: {}", path)); } } - + // Insert or update the setting conn.execute( "INSERT INTO app_settings (key, value) VALUES ('claude_binary_path', ?1) ON CONFLICT(key) DO UPDATE SET value = ?1", params![path], - ).map_err(|e| format!("Failed to save Claude binary path: {}", e))?; - + ) + .map_err(|e| format!("Failed to save Claude binary path: {}", e))?; + Ok(()) } /// List all available Claude installations on the system #[tauri::command] -pub async fn list_claude_installations() -> Result, String> { +pub async fn list_claude_installations( +) -> Result, String> { let installations = crate::claude_binary::discover_claude_installations(); - + if installations.is_empty() { return Err("No Claude Code installations found on the system".to_string()); } - + Ok(installations) } @@ -1772,21 +1896,30 @@ pub async fn list_claude_installations() -> Result Command { // Convert std::process::Command to tokio::process::Command let _std_cmd = crate::claude_binary::create_command_with_env(program); - + // Create a new tokio Command from the program path let mut tokio_cmd = Command::new(program); - + // Copy over all environment variables from the std::process::Command // This is a workaround since we can't directly convert between the two types for (key, value) in std::env::vars() { - if key == "PATH" || key == "HOME" || key == "USER" - || key == "SHELL" || key == "LANG" || key == "LC_ALL" || key.starts_with("LC_") - || key == "NODE_PATH" || key == "NVM_DIR" || key == "NVM_BIN" - || key == "HOMEBREW_PREFIX" || key == "HOMEBREW_CELLAR" { + if key == "PATH" + || key == "HOME" + || key == "USER" + || key == "SHELL" + || key == "LANG" + || key == "LC_ALL" + || key.starts_with("LC_") + || key == "NODE_PATH" + || key == "NVM_DIR" + || key == "NVM_BIN" + || key == "HOMEBREW_PREFIX" + || key == "HOMEBREW_CELLAR" + { tokio_cmd.env(&key, &value); } } - + // Add NVM support if the program is in an NVM directory if program.contains("/.nvm/versions/node/") { if let Some(node_bin_dir) = std::path::Path::new(program).parent() { @@ -1820,17 +1953,20 @@ fn create_command_with_env(program: &str) -> Command { #[tauri::command] pub async fn import_agent(db: State<'_, AgentDb>, json_data: String) -> Result { // Parse the JSON data - let export_data: AgentExport = serde_json::from_str(&json_data) - .map_err(|e| format!("Invalid JSON format: {}", e))?; - + let export_data: AgentExport = + serde_json::from_str(&json_data).map_err(|e| format!("Invalid JSON format: {}", e))?; + // Validate version if export_data.version != 1 { - return Err(format!("Unsupported export version: {}. This version of the app only supports version 1.", export_data.version)); + return Err(format!( + "Unsupported export version: {}. This version of the app only supports version 1.", + export_data.version + )); } - + let agent_data = export_data.agent; let conn = db.0.lock().map_err(|e| e.to_string())?; - + // Check if an agent with the same name already exists let existing_count: i64 = conn .query_row( @@ -1839,14 +1975,14 @@ pub async fn import_agent(db: State<'_, AgentDb>, json_data: String) -> Result 0 { format!("{} (Imported)", agent_data.name) } else { agent_data.name }; - + // Create the agent conn.execute( "INSERT INTO agents (name, icon, system_prompt, default_task, model, sandbox_enabled, enable_file_read, enable_file_write, enable_network) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", @@ -1863,9 +1999,9 @@ pub async fn import_agent(db: State<'_, AgentDb>, json_data: String) -> Result, json_data: String) -> Result, file_path: String) -> Result { +pub async fn import_agent_from_file( + db: State<'_, AgentDb>, + file_path: String, +) -> Result { // Read the file - let json_data = std::fs::read_to_string(&file_path) - .map_err(|e| format!("Failed to read file: {}", e))?; - + let json_data = + std::fs::read_to_string(&file_path).map_err(|e| format!("Failed to read file: {}", e))?; + // Import the agent import_agent(db, json_data).await } @@ -1932,10 +2071,10 @@ struct GitHubApiResponse { #[tauri::command] pub async fn fetch_github_agents() -> Result, String> { info!("Fetching agents from GitHub repository..."); - + let client = reqwest::Client::new(); let url = "https://api.github.com/repos/getAsterisk/claudia/contents/cc_agents"; - + let response = client .get(url) .header("Accept", "application/vnd.github+json") @@ -1943,18 +2082,18 @@ pub async fn fetch_github_agents() -> Result, String> { .send() .await .map_err(|e| format!("Failed to fetch from GitHub: {}", e))?; - + if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_default(); return Err(format!("GitHub API error ({}): {}", status, error_text)); } - + let api_files: Vec = response .json() .await .map_err(|e| format!("Failed to parse GitHub response: {}", e))?; - + // Filter only .claudia.json files let agent_files: Vec = api_files .into_iter() @@ -1969,7 +2108,7 @@ pub async fn fetch_github_agents() -> Result, String> { }) }) .collect(); - + info!("Found {} agents on GitHub", agent_files.len()); Ok(agent_files) } @@ -1978,7 +2117,7 @@ pub async fn fetch_github_agents() -> Result, String> { #[tauri::command] pub async fn fetch_github_agent_content(download_url: String) -> Result { info!("Fetching agent content from: {}", download_url); - + let client = reqwest::Client::new(); let response = client .get(&download_url) @@ -1987,25 +2126,31 @@ pub async fn fetch_github_agent_content(download_url: String) -> Result Result { info!("Importing agent from GitHub: {}", download_url); - + // First, fetch the agent content let export_data = fetch_github_agent_content(download_url).await?; - + // Convert to JSON string and use existing import logic let json_data = serde_json::to_string(&export_data) .map_err(|e| format!("Failed to serialize agent data: {}", e))?; - + // Import using existing function import_agent(db, json_data).await } diff --git a/src-tauri/src/commands/claude.rs b/src-tauri/src/commands/claude.rs index 5793c10..754a4d5 100644 --- a/src-tauri/src/commands/claude.rs +++ b/src-tauri/src/commands/claude.rs @@ -1,14 +1,14 @@ use anyhow::{Context, Result}; use serde::{Deserialize, Serialize}; use std::fs; -use std::path::PathBuf; -use std::time::SystemTime; use std::io::{BufRead, BufReader}; +use std::path::PathBuf; use std::process::Stdio; -use tauri::{AppHandle, Emitter, Manager}; -use tokio::process::{Command, Child}; -use tokio::sync::Mutex; use std::sync::Arc; +use std::time::SystemTime; +use tauri::{AppHandle, Emitter, Manager}; +use tokio::process::{Child, Command}; +use tokio::sync::Mutex; use uuid; /// Global state to track current Claude process @@ -147,7 +147,7 @@ fn get_project_path_from_sessions(project_dir: &PathBuf) -> Result Result (Option, Option file, Err(_) => return (None, None), }; - + let reader = BufReader::new(file); - + for line in reader.lines() { if let Ok(line) = line { if let Ok(entry) = serde_json::from_str::(&line) { @@ -200,12 +200,14 @@ fn extract_first_user_message(jsonl_path: &PathBuf) -> (Option, Option") || content.starts_with("") { + if content.starts_with("") + || content.starts_with("") + { continue; } - + // Found a valid user message return (Some(content), entry.timestamp); } @@ -214,7 +216,7 @@ fn extract_first_user_message(jsonl_path: &PathBuf) -> (Option, Option (Option, Option Command { // Convert std::process::Command to tokio::process::Command let _std_cmd = crate::claude_binary::create_command_with_env(program); - + // Create a new tokio Command from the program path let mut tokio_cmd = Command::new(program); - + // Copy over all environment variables for (key, value) in std::env::vars() { - if key == "PATH" || key == "HOME" || key == "USER" - || key == "SHELL" || key == "LANG" || key == "LC_ALL" || key.starts_with("LC_") - || key == "NODE_PATH" || key == "NVM_DIR" || key == "NVM_BIN" - || key == "HOMEBREW_PREFIX" || key == "HOMEBREW_CELLAR" { + if key == "PATH" + || key == "HOME" + || key == "USER" + || key == "SHELL" + || key == "LANG" + || key == "LC_ALL" + || key.starts_with("LC_") + || key == "NODE_PATH" + || key == "NVM_DIR" + || key == "NVM_BIN" + || key == "HOMEBREW_PREFIX" + || key == "HOMEBREW_CELLAR" + { log::debug!("Inheriting env var: {}={}", key, value); tokio_cmd.env(&key, &value); } } - + // Add NVM support if the program is in an NVM directory if program.contains("/.nvm/versions/node/") { if let Some(node_bin_dir) = std::path::Path::new(program).parent() { @@ -249,7 +260,7 @@ fn create_command_with_env(program: &str) -> Command { } } } - + tokio_cmd } @@ -257,35 +268,35 @@ fn create_command_with_env(program: &str) -> Command { #[tauri::command] pub async fn list_projects() -> Result, String> { log::info!("Listing projects from ~/.claude/projects"); - + let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; let projects_dir = claude_dir.join("projects"); - + if !projects_dir.exists() { log::warn!("Projects directory does not exist: {:?}", projects_dir); return Ok(Vec::new()); } - + let mut projects = Vec::new(); - + // Read all directories in the projects folder let entries = fs::read_dir(&projects_dir) .map_err(|e| format!("Failed to read projects directory: {}", e))?; - + for entry in entries { let entry = entry.map_err(|e| format!("Failed to read directory entry: {}", e))?; let path = entry.path(); - + if path.is_dir() { let dir_name = path .file_name() .and_then(|n| n.to_str()) .ok_or_else(|| "Invalid directory name".to_string())?; - + // Get directory creation time let metadata = fs::metadata(&path) .map_err(|e| format!("Failed to read directory metadata: {}", e))?; - + let created_at = metadata .created() .or_else(|_| metadata.modified()) @@ -293,7 +304,7 @@ pub async fn list_projects() -> Result, String> { .duration_since(SystemTime::UNIX_EPOCH) .unwrap_or_default() .as_secs(); - + // Get the actual project path from JSONL files let project_path = match get_project_path_from_sessions(&path) { Ok(path) => path, @@ -302,20 +313,23 @@ pub async fn list_projects() -> Result, String> { decode_project_path(dir_name) } }; - + // List all JSONL files (sessions) in this project directory let mut sessions = Vec::new(); if let Ok(session_entries) = fs::read_dir(&path) { for session_entry in session_entries.flatten() { let session_path = session_entry.path(); - if session_path.is_file() && session_path.extension().and_then(|s| s.to_str()) == Some("jsonl") { - if let Some(session_id) = session_path.file_stem().and_then(|s| s.to_str()) { + if session_path.is_file() + && session_path.extension().and_then(|s| s.to_str()) == Some("jsonl") + { + if let Some(session_id) = session_path.file_stem().and_then(|s| s.to_str()) + { sessions.push(session_id.to_string()); } } } } - + projects.push(Project { id: dir_name.to_string(), path: project_path, @@ -324,10 +338,10 @@ pub async fn list_projects() -> Result, String> { }); } } - + // Sort projects by creation time (newest first) projects.sort_by(|a, b| b.created_at.cmp(&a.created_at)); - + log::info!("Found {} projects", projects.len()); Ok(projects) } @@ -336,40 +350,44 @@ pub async fn list_projects() -> Result, String> { #[tauri::command] pub async fn get_project_sessions(project_id: String) -> Result, String> { log::info!("Getting sessions for project: {}", project_id); - + let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; let project_dir = claude_dir.join("projects").join(&project_id); let todos_dir = claude_dir.join("todos"); - + if !project_dir.exists() { return Err(format!("Project directory not found: {}", project_id)); } - + // Get the actual project path from JSONL files let project_path = match get_project_path_from_sessions(&project_dir) { Ok(path) => path, Err(e) => { - log::warn!("Failed to get project path from sessions for {}: {}, falling back to decode", project_id, e); + log::warn!( + "Failed to get project path from sessions for {}: {}, falling back to decode", + project_id, + e + ); decode_project_path(&project_id) } }; - + let mut sessions = Vec::new(); - + // Read all JSONL files in the project directory let entries = fs::read_dir(&project_dir) .map_err(|e| format!("Failed to read project directory: {}", e))?; - + for entry in entries { let entry = entry.map_err(|e| format!("Failed to read directory entry: {}", e))?; let path = entry.path(); - + if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("jsonl") { if let Some(session_id) = path.file_stem().and_then(|s| s.to_str()) { // Get file creation time let metadata = fs::metadata(&path) .map_err(|e| format!("Failed to read file metadata: {}", e))?; - + let created_at = metadata .created() .or_else(|_| metadata.modified()) @@ -377,10 +395,10 @@ pub async fn get_project_sessions(project_id: String) -> Result, St .duration_since(SystemTime::UNIX_EPOCH) .unwrap_or_default() .as_secs(); - + // Extract first user message and timestamp let (first_message, message_timestamp) = extract_first_user_message(&path); - + // Try to load associated todo data let todo_path = todos_dir.join(format!("{}.json", session_id)); let todo_data = if todo_path.exists() { @@ -390,7 +408,7 @@ pub async fn get_project_sessions(project_id: String) -> Result, St } else { None }; - + sessions.push(Session { id: session_id.to_string(), project_id: project_id.clone(), @@ -403,11 +421,15 @@ pub async fn get_project_sessions(project_id: String) -> Result, St } } } - + // Sort sessions by creation time (newest first) sessions.sort_by(|a, b| b.created_at.cmp(&a.created_at)); - - log::info!("Found {} sessions for project {}", sessions.len(), project_id); + + log::info!( + "Found {} sessions for project {}", + sessions.len(), + project_id + ); Ok(sessions) } @@ -415,23 +437,23 @@ pub async fn get_project_sessions(project_id: String) -> Result, St #[tauri::command] pub async fn get_claude_settings() -> Result { log::info!("Reading Claude settings"); - + let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; let settings_path = claude_dir.join("settings.json"); - + if !settings_path.exists() { log::warn!("Settings file not found, returning empty settings"); return Ok(ClaudeSettings { data: serde_json::json!({}), }); } - + let content = fs::read_to_string(&settings_path) .map_err(|e| format!("Failed to read settings file: {}", e))?; - + let data: serde_json::Value = serde_json::from_str(&content) .map_err(|e| format!("Failed to parse settings JSON: {}", e))?; - + Ok(ClaudeSettings { data }) } @@ -439,13 +461,13 @@ pub async fn get_claude_settings() -> Result { #[tauri::command] pub async fn open_new_session(app: AppHandle, path: Option) -> Result { log::info!("Opening new Claude Code session at path: {:?}", path); - + #[cfg(not(debug_assertions))] let _claude_path = find_claude_binary(&app)?; - + #[cfg(debug_assertions)] let claude_path = find_claude_binary(&app)?; - + // In production, we can't use std::process::Command directly // The user should launch Claude Code through other means or use the execute_claude_code command #[cfg(not(debug_assertions))] @@ -453,16 +475,16 @@ pub async fn open_new_session(app: AppHandle, path: Option) -> Result { @@ -481,24 +503,23 @@ pub async fn open_new_session(app: AppHandle, path: Option) -> Result Result { log::info!("Reading CLAUDE.md system prompt"); - + let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; let claude_md_path = claude_dir.join("CLAUDE.md"); - + if !claude_md_path.exists() { log::warn!("CLAUDE.md not found"); return Ok(String::new()); } - - fs::read_to_string(&claude_md_path) - .map_err(|e| format!("Failed to read CLAUDE.md: {}", e)) + + fs::read_to_string(&claude_md_path).map_err(|e| format!("Failed to read CLAUDE.md: {}", e)) } /// Checks if Claude Code is installed and gets its version #[tauri::command] pub async fn check_claude_version(app: AppHandle) -> Result { log::info!("Checking Claude Code version"); - + let claude_path = match find_claude_binary(&app) { Ok(path) => path, Err(e) => { @@ -509,7 +530,7 @@ pub async fn check_claude_version(app: AppHandle) -> Result Result { let stdout = String::from_utf8_lossy(&output.stdout).to_string(); let stderr = String::from_utf8_lossy(&output.stderr).to_string(); - let full_output = if stderr.is_empty() { stdout.clone() } else { format!("{}\n{}", stdout, stderr) }; - + let full_output = if stderr.is_empty() { + stdout.clone() + } else { + format!("{}\n{}", stdout, stderr) + }; + // Check if the output matches the expected format // Expected format: "1.0.17 (Claude Code)" or similar let is_valid = stdout.contains("(Claude Code)") || stdout.contains("Claude Code"); - + // Extract version number if valid let version = if is_valid { // Try to extract just the version number - stdout.split_whitespace() - .next() - .map(|s| s.to_string()) + stdout.split_whitespace().next().map(|s| s.to_string()) } else { None }; - + Ok(ClaudeVersionStatus { is_installed: is_valid && output.status.success(), version, @@ -578,13 +601,12 @@ pub async fn check_claude_version(app: AppHandle) -> Result Result { log::info!("Saving CLAUDE.md system prompt"); - + let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; let claude_md_path = claude_dir.join("CLAUDE.md"); - - fs::write(&claude_md_path, content) - .map_err(|e| format!("Failed to write CLAUDE.md: {}", e))?; - + + fs::write(&claude_md_path, content).map_err(|e| format!("Failed to write CLAUDE.md: {}", e))?; + Ok("System prompt saved successfully".to_string()) } @@ -592,17 +614,17 @@ pub async fn save_system_prompt(content: String) -> Result { #[tauri::command] pub async fn save_claude_settings(settings: serde_json::Value) -> Result { log::info!("Saving Claude settings"); - + let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; let settings_path = claude_dir.join("settings.json"); - + // Pretty print the JSON with 2-space indentation let json_string = serde_json::to_string_pretty(&settings) .map_err(|e| format!("Failed to serialize settings: {}", e))?; - + fs::write(&settings_path, json_string) .map_err(|e| format!("Failed to write settings file: {}", e))?; - + Ok("Settings saved successfully".to_string()) } @@ -610,18 +632,18 @@ pub async fn save_claude_settings(settings: serde_json::Value) -> Result Result, String> { log::info!("Finding CLAUDE.md files in project: {}", project_path); - + let path = PathBuf::from(&project_path); if !path.exists() { return Err(format!("Project path does not exist: {}", project_path)); } - + let mut claude_files = Vec::new(); find_claude_md_recursive(&path, &path, &mut claude_files)?; - + // Sort by relative path claude_files.sort_by(|a, b| a.relative_path.cmp(&b.relative_path)); - + log::info!("Found {} CLAUDE.md files", claude_files.len()); Ok(claude_files) } @@ -634,26 +656,29 @@ fn find_claude_md_recursive( ) -> Result<(), String> { let entries = fs::read_dir(current_path) .map_err(|e| format!("Failed to read directory {:?}: {}", current_path, e))?; - + for entry in entries { let entry = entry.map_err(|e| format!("Failed to read directory entry: {}", e))?; let path = entry.path(); - + // Skip hidden directories and files if let Some(name) = path.file_name().and_then(|n| n.to_str()) { if name.starts_with('.') && name != ".claude" { continue; } } - + if path.is_dir() { // Skip common directories that shouldn't be scanned if let Some(dir_name) = path.file_name().and_then(|n| n.to_str()) { - if matches!(dir_name, "node_modules" | "target" | ".git" | "dist" | "build" | ".next" | "__pycache__") { + if matches!( + dir_name, + "node_modules" | "target" | ".git" | "dist" | "build" | ".next" | "__pycache__" + ) { continue; } } - + // Recurse into subdirectory find_claude_md_recursive(&path, project_root, claude_files)?; } else if path.is_file() { @@ -662,19 +687,20 @@ fn find_claude_md_recursive( if file_name.eq_ignore_ascii_case("CLAUDE.md") { let metadata = fs::metadata(&path) .map_err(|e| format!("Failed to read file metadata: {}", e))?; - - let relative_path = path.strip_prefix(project_root) + + let relative_path = path + .strip_prefix(project_root) .map_err(|e| format!("Failed to get relative path: {}", e))? .to_string_lossy() .to_string(); - + let modified = metadata .modified() .unwrap_or(SystemTime::UNIX_EPOCH) .duration_since(SystemTime::UNIX_EPOCH) .unwrap_or_default() .as_secs(); - + claude_files.push(ClaudeMdFile { relative_path, absolute_path: path.to_string_lossy().to_string(), @@ -685,7 +711,7 @@ fn find_claude_md_recursive( } } } - + Ok(()) } @@ -693,53 +719,61 @@ fn find_claude_md_recursive( #[tauri::command] pub async fn read_claude_md_file(file_path: String) -> Result { log::info!("Reading CLAUDE.md file: {}", file_path); - + let path = PathBuf::from(&file_path); if !path.exists() { return Err(format!("File does not exist: {}", file_path)); } - - fs::read_to_string(&path) - .map_err(|e| format!("Failed to read file: {}", e)) + + fs::read_to_string(&path).map_err(|e| format!("Failed to read file: {}", e)) } /// Saves a specific CLAUDE.md file by its absolute path #[tauri::command] pub async fn save_claude_md_file(file_path: String, content: String) -> Result { log::info!("Saving CLAUDE.md file: {}", file_path); - + let path = PathBuf::from(&file_path); - + // Ensure the parent directory exists if let Some(parent) = path.parent() { fs::create_dir_all(parent) .map_err(|e| format!("Failed to create parent directory: {}", e))?; } - - fs::write(&path, content) - .map_err(|e| format!("Failed to write file: {}", e))?; - + + fs::write(&path, content).map_err(|e| format!("Failed to write file: {}", e))?; + Ok("File saved successfully".to_string()) } /// Loads the JSONL history for a specific session #[tauri::command] -pub async fn load_session_history(session_id: String, project_id: String) -> Result, String> { - log::info!("Loading session history for session: {} in project: {}", session_id, project_id); - +pub async fn load_session_history( + session_id: String, + project_id: String, +) -> Result, String> { + log::info!( + "Loading session history for session: {} in project: {}", + session_id, + project_id + ); + let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; - let session_path = claude_dir.join("projects").join(&project_id).join(format!("{}.jsonl", session_id)); - + let session_path = claude_dir + .join("projects") + .join(&project_id) + .join(format!("{}.jsonl", session_id)); + if !session_path.exists() { return Err(format!("Session file not found: {}", session_id)); } - - let file = fs::File::open(&session_path) - .map_err(|e| format!("Failed to open session file: {}", e))?; - + + let file = + fs::File::open(&session_path).map_err(|e| format!("Failed to open session file: {}", e))?; + let reader = BufReader::new(file); let mut messages = Vec::new(); - + for line in reader.lines() { if let Ok(line) = line { if let Ok(json) = serde_json::from_str::(&line) { @@ -747,7 +781,7 @@ pub async fn load_session_history(session_id: String, project_id: String) -> Res } } } - + Ok(messages) } @@ -759,18 +793,22 @@ pub async fn execute_claude_code( prompt: String, model: String, ) -> Result<(), String> { - log::info!("Starting new Claude Code session in: {} with model: {}", project_path, model); - + log::info!( + "Starting new Claude Code session in: {} with model: {}", + project_path, + model + ); + // Check if sandboxing should be used let use_sandbox = should_use_sandbox(&app)?; - + let mut cmd = if use_sandbox { create_sandboxed_claude_command(&app, &project_path)? } else { let claude_path = find_claude_binary(&app)?; create_command_with_env(&claude_path) }; - + cmd.arg("-p") .arg(&prompt) .arg("--model") @@ -782,7 +820,7 @@ pub async fn execute_claude_code( .current_dir(&project_path) .stdout(Stdio::piped()) .stderr(Stdio::piped()); - + spawn_claude_process(app, cmd).await } @@ -794,19 +832,23 @@ pub async fn continue_claude_code( prompt: String, model: String, ) -> Result<(), String> { - log::info!("Continuing Claude Code conversation in: {} with model: {}", project_path, model); - + log::info!( + "Continuing Claude Code conversation in: {} with model: {}", + project_path, + model + ); + // Check if sandboxing should be used let use_sandbox = should_use_sandbox(&app)?; - + let mut cmd = if use_sandbox { create_sandboxed_claude_command(&app, &project_path)? } else { let claude_path = find_claude_binary(&app)?; create_command_with_env(&claude_path) }; - - cmd.arg("-c") // Continue flag + + cmd.arg("-c") // Continue flag .arg("-p") .arg(&prompt) .arg("--model") @@ -818,7 +860,7 @@ pub async fn continue_claude_code( .current_dir(&project_path) .stdout(Stdio::piped()) .stderr(Stdio::piped()); - + spawn_claude_process(app, cmd).await } @@ -831,18 +873,23 @@ pub async fn resume_claude_code( prompt: String, model: String, ) -> Result<(), String> { - log::info!("Resuming Claude Code session: {} in: {} with model: {}", session_id, project_path, model); - + log::info!( + "Resuming Claude Code session: {} in: {} with model: {}", + session_id, + project_path, + model + ); + // Check if sandboxing should be used let use_sandbox = should_use_sandbox(&app)?; - + let mut cmd = if use_sandbox { create_sandboxed_claude_command(&app, &project_path)? } else { let claude_path = find_claude_binary(&app)?; create_command_with_env(&claude_path) }; - + cmd.arg("--resume") .arg(&session_id) .arg("-p") @@ -856,35 +903,41 @@ pub async fn resume_claude_code( .current_dir(&project_path) .stdout(Stdio::piped()) .stderr(Stdio::piped()); - + spawn_claude_process(app, cmd).await } /// Cancel the currently running Claude Code execution #[tauri::command] -pub async fn cancel_claude_execution(app: AppHandle, session_id: Option) -> Result<(), String> { - log::info!("Cancelling Claude Code execution for session: {:?}", session_id); - +pub async fn cancel_claude_execution( + app: AppHandle, + session_id: Option, +) -> Result<(), String> { + log::info!( + "Cancelling Claude Code execution for session: {:?}", + session_id + ); + let claude_state = app.state::(); let mut current_process = claude_state.current_process.lock().await; - + if let Some(mut child) = current_process.take() { // Try to get the PID before killing let pid = child.id(); log::info!("Attempting to kill Claude process with PID: {:?}", pid); - + // Kill the process match child.kill().await { Ok(_) => { log::info!("Successfully killed Claude process"); - + // If we have a session ID, emit session-specific events if let Some(sid) = session_id { let _ = app.emit(&format!("claude-cancelled:{}", sid), true); tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; let _ = app.emit(&format!("claude-complete:{}", sid), false); } - + // Also emit generic events for backward compatibility let _ = app.emit("claude-cancelled", true); tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; @@ -909,27 +962,32 @@ fn should_use_sandbox(app: &AppHandle) -> Result { log::info!("Sandboxing not available on this platform"); return Ok(false); } - + // Check if a setting exists to enable/disable sandboxing let settings = get_claude_settings_sync(app)?; - + // Check for a sandboxing setting in the settings - if let Some(sandbox_enabled) = settings.data.get("sandboxEnabled").and_then(|v| v.as_bool()) { + if let Some(sandbox_enabled) = settings + .data + .get("sandboxEnabled") + .and_then(|v| v.as_bool()) + { return Ok(sandbox_enabled); } - + // Default to true (sandboxing enabled) on supported platforms Ok(true) } /// Helper function to create a sandboxed Claude command fn create_sandboxed_claude_command(app: &AppHandle, project_path: &str) -> Result { - use crate::sandbox::{profile::ProfileBuilder, executor::create_sandboxed_command}; + use crate::sandbox::{executor::create_sandboxed_command, profile::ProfileBuilder}; use std::path::PathBuf; - + // Get the database connection let conn = { - let app_data_dir = app.path() + let app_data_dir = app + .path() .app_data_dir() .map_err(|e| format!("Failed to get app data dir: {}", e))?; let db_path = app_data_dir.join("agents.db"); @@ -945,44 +1003,55 @@ fn create_sandboxed_claude_command(app: &AppHandle, project_path: &str) -> Resul |row| row.get(0), ) .ok(); - + match profile_id { Some(profile_id) => { - log::info!("Using default sandbox profile: {} (id: {})", profile_id, profile_id); - + log::info!( + "Using default sandbox profile: {} (id: {})", + profile_id, + profile_id + ); + // Get all rules for this profile - let mut stmt = conn.prepare( - "SELECT operation_type, pattern_type, pattern_value, enabled, platform_support - FROM sandbox_rules WHERE profile_id = ?1 AND enabled = 1" - ).map_err(|e| e.to_string())?; - - let rules = stmt.query_map(rusqlite::params![profile_id], |row| { - Ok(( - row.get::<_, String>(0)?, - row.get::<_, String>(1)?, - row.get::<_, String>(2)?, - row.get::<_, bool>(3)?, - row.get::<_, Option>(4)? - )) - }) - .map_err(|e| e.to_string())? - .collect::, _>>() - .map_err(|e| e.to_string())?; - + let mut stmt = conn + .prepare( + "SELECT operation_type, pattern_type, pattern_value, enabled, platform_support + FROM sandbox_rules WHERE profile_id = ?1 AND enabled = 1", + ) + .map_err(|e| e.to_string())?; + + let rules = stmt + .query_map(rusqlite::params![profile_id], |row| { + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, String>(2)?, + row.get::<_, bool>(3)?, + row.get::<_, Option>(4)?, + )) + }) + .map_err(|e| e.to_string())? + .collect::, _>>() + .map_err(|e| e.to_string())?; + log::info!("Building sandbox profile with {} rules", rules.len()); - + // Build the gaol profile let project_path_buf = PathBuf::from(project_path); - + match ProfileBuilder::new(project_path_buf.clone()) { Ok(builder) => { // Convert database rules to SandboxRule structs let mut sandbox_rules = Vec::new(); - - for (idx, (op_type, pattern_type, pattern_value, enabled, platform_support)) in rules.into_iter().enumerate() { + + for (idx, (op_type, pattern_type, pattern_value, enabled, platform_support)) in + rules.into_iter().enumerate() + { // Check if this rule applies to the current platform if let Some(platforms_json) = &platform_support { - if let Ok(platforms) = serde_json::from_str::>(platforms_json) { + if let Ok(platforms) = + serde_json::from_str::>(platforms_json) + { let current_platform = if cfg!(target_os = "linux") { "linux" } else if cfg!(target_os = "macos") { @@ -992,13 +1061,13 @@ fn create_sandboxed_claude_command(app: &AppHandle, project_path: &str) -> Resul } else { "unsupported" }; - + if !platforms.contains(¤t_platform.to_string()) { continue; } } } - + // Create SandboxRule struct let rule = crate::sandbox::profile::SandboxRule { id: Some(idx as i64), @@ -1010,23 +1079,31 @@ fn create_sandboxed_claude_command(app: &AppHandle, project_path: &str) -> Resul platform_support, created_at: String::new(), }; - + sandbox_rules.push(rule); } - + // Try to build the profile match builder.build_profile(sandbox_rules) { Ok(profile) => { log::info!("Successfully built sandbox profile '{}'", profile_id); - + // Use the helper function to create sandboxed command let claude_path = find_claude_binary(app)?; #[cfg(unix)] - return Ok(create_sandboxed_command(&claude_path, &[], &project_path_buf, profile, project_path_buf.clone())); - + return Ok(create_sandboxed_command( + &claude_path, + &[], + &project_path_buf, + profile, + project_path_buf.clone(), + )); + #[cfg(not(unix))] { - log::warn!("Sandboxing not supported on Windows, using regular command"); + log::warn!( + "Sandboxing not supported on Windows, using regular command" + ); Ok(create_command_with_env(&claude_path)) } } @@ -1038,7 +1115,10 @@ fn create_sandboxed_claude_command(app: &AppHandle, project_path: &str) -> Resul } } Err(e) => { - log::error!("Failed to create ProfileBuilder: {}, falling back to non-sandboxed", e); + log::error!( + "Failed to create ProfileBuilder: {}, falling back to non-sandboxed", + e + ); let claude_path = find_claude_binary(app)?; Ok(create_command_with_env(&claude_path)) } @@ -1060,44 +1140,51 @@ fn get_claude_settings_sync(_app: &AppHandle) -> Result if !settings_path.exists() { return Ok(ClaudeSettings::default()); } - + let content = std::fs::read_to_string(&settings_path) .map_err(|e| format!("Failed to read settings file: {}", e))?; - + let data: serde_json::Value = serde_json::from_str(&content) .map_err(|e| format!("Failed to parse settings JSON: {}", e))?; - + Ok(ClaudeSettings { data }) } /// Helper function to spawn Claude process and handle streaming async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), String> { use tokio::io::{AsyncBufReadExt, BufReader}; - + // Generate a unique session ID for this Claude Code session - let session_id = format!("claude-{}-{}", + let session_id = format!( + "claude-{}-{}", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_millis(), uuid::Uuid::new_v4().to_string() ); - + // Spawn the process - let mut child = cmd.spawn().map_err(|e| format!("Failed to spawn Claude: {}", e))?; - + let mut child = cmd + .spawn() + .map_err(|e| format!("Failed to spawn Claude: {}", e))?; + // Get stdout and stderr let stdout = child.stdout.take().ok_or("Failed to get stdout")?; let stderr = child.stderr.take().ok_or("Failed to get stderr")?; - + // Get the child PID for logging let pid = child.id(); - log::info!("Spawned Claude process with PID: {:?} and session ID: {}", pid, session_id); - + log::info!( + "Spawned Claude process with PID: {:?} and session ID: {}", + pid, + session_id + ); + // Create readers let stdout_reader = BufReader::new(stdout); let stderr_reader = BufReader::new(stderr); - + // Store the child process in the global state (for backward compatibility) let claude_state = app.state::(); { @@ -1109,7 +1196,7 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St } *current_process = Some(child); } - + // Spawn tasks to read stdout and stderr let app_handle = app.clone(); let session_id_clone = session_id.clone(); @@ -1123,7 +1210,7 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St let _ = app_handle.emit("claude-output", &line); } }); - + let app_handle_stderr = app.clone(); let session_id_clone2 = session_id.clone(); let stderr_task = tokio::spawn(async move { @@ -1136,7 +1223,7 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St let _ = app_handle_stderr.emit("claude-error", &line); } }); - + // Wait for the process to complete let app_handle_wait = app.clone(); let claude_state_wait = claude_state.current_process.clone(); @@ -1144,7 +1231,7 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St tokio::spawn(async move { let _ = stdout_task.await; let _ = stderr_task.await; - + // Get the child from the state to wait on it let mut current_process = claude_state_wait.lock().await; if let Some(mut child) = current_process.take() { @@ -1153,7 +1240,10 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St log::info!("Claude process exited with status: {}", status); // Add a small delay to ensure all messages are processed tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let _ = app_handle_wait.emit(&format!("claude-complete:{}", session_id_clone3), status.success()); + let _ = app_handle_wait.emit( + &format!("claude-complete:{}", session_id_clone3), + status.success(), + ); // Also emit to the generic event for backward compatibility let _ = app_handle_wait.emit("claude-complete", status.success()); } @@ -1161,20 +1251,24 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St log::error!("Failed to wait for Claude process: {}", e); // Add a small delay to ensure all messages are processed tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let _ = app_handle_wait.emit(&format!("claude-complete:{}", session_id_clone3), false); + let _ = app_handle_wait + .emit(&format!("claude-complete:{}", session_id_clone3), false); // Also emit to the generic event for backward compatibility let _ = app_handle_wait.emit("claude-complete", false); } } } - + // Clear the process from state *current_process = None; }); - + // Return the session ID to the frontend - let _ = app.emit(&format!("claude-session-started:{}", session_id), session_id.clone()); - + let _ = app.emit( + &format!("claude-session-started:{}", session_id), + session_id.clone(), + ); + Ok(()) } @@ -1182,58 +1276,60 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St #[tauri::command] pub async fn list_directory_contents(directory_path: String) -> Result, String> { log::info!("Listing directory contents: '{}'", directory_path); - + // Check if path is empty if directory_path.trim().is_empty() { log::error!("Directory path is empty or whitespace"); return Err("Directory path cannot be empty".to_string()); } - + let path = PathBuf::from(&directory_path); log::debug!("Resolved path: {:?}", path); - + if !path.exists() { log::error!("Path does not exist: {:?}", path); return Err(format!("Path does not exist: {}", directory_path)); } - + if !path.is_dir() { log::error!("Path is not a directory: {:?}", path); return Err(format!("Path is not a directory: {}", directory_path)); } - + let mut entries = Vec::new(); - - let dir_entries = fs::read_dir(&path) - .map_err(|e| format!("Failed to read directory: {}", e))?; - + + let dir_entries = + fs::read_dir(&path).map_err(|e| format!("Failed to read directory: {}", e))?; + for entry in dir_entries { let entry = entry.map_err(|e| format!("Failed to read entry: {}", e))?; let entry_path = entry.path(); - let metadata = entry.metadata() + let metadata = entry + .metadata() .map_err(|e| format!("Failed to read metadata: {}", e))?; - + // Skip hidden files/directories unless they are .claude directories if let Some(name) = entry_path.file_name().and_then(|n| n.to_str()) { if name.starts_with('.') && name != ".claude" { continue; } } - + let name = entry_path .file_name() .and_then(|n| n.to_str()) .unwrap_or("") .to_string(); - + let extension = if metadata.is_file() { - entry_path.extension() + entry_path + .extension() .and_then(|e| e.to_str()) .map(|e| e.to_string()) } else { None }; - + entries.push(FileEntry { name, path: entry_path.to_string_lossy().to_string(), @@ -1242,16 +1338,14 @@ pub async fn list_directory_contents(directory_path: String) -> Result std::cmp::Ordering::Less, - (false, true) => std::cmp::Ordering::Greater, - _ => a.name.to_lowercase().cmp(&b.name.to_lowercase()), - } + entries.sort_by(|a, b| match (a.is_directory, b.is_directory) { + (true, false) => std::cmp::Ordering::Less, + (false, true) => std::cmp::Ordering::Greater, + _ => a.name.to_lowercase().cmp(&b.name.to_lowercase()), }); - + Ok(entries) } @@ -1259,47 +1353,47 @@ pub async fn list_directory_contents(directory_path: String) -> Result Result, String> { log::info!("Searching files in '{}' for: '{}'", base_path, query); - + // Check if path is empty if base_path.trim().is_empty() { log::error!("Base path is empty or whitespace"); return Err("Base path cannot be empty".to_string()); } - + // Check if query is empty if query.trim().is_empty() { log::warn!("Search query is empty, returning empty results"); return Ok(Vec::new()); } - + let path = PathBuf::from(&base_path); log::debug!("Resolved search base path: {:?}", path); - + if !path.exists() { log::error!("Base path does not exist: {:?}", path); return Err(format!("Path does not exist: {}", base_path)); } - + let query_lower = query.to_lowercase(); let mut results = Vec::new(); - + search_files_recursive(&path, &path, &query_lower, &mut results, 0)?; - + // Sort by relevance: exact matches first, then by name results.sort_by(|a, b| { let a_exact = a.name.to_lowercase() == query_lower; let b_exact = b.name.to_lowercase() == query_lower; - + match (a_exact, b_exact) { (true, false) => std::cmp::Ordering::Less, (false, true) => std::cmp::Ordering::Greater, _ => a.name.to_lowercase().cmp(&b.name.to_lowercase()), } }); - + // Limit results to prevent overwhelming the UI results.truncate(50); - + Ok(results) } @@ -1314,33 +1408,35 @@ fn search_files_recursive( if depth > 5 || results.len() >= 50 { return Ok(()); } - + let entries = fs::read_dir(current_path) .map_err(|e| format!("Failed to read directory {:?}: {}", current_path, e))?; - + for entry in entries { let entry = entry.map_err(|e| format!("Failed to read entry: {}", e))?; let entry_path = entry.path(); - + // Skip hidden files/directories if let Some(name) = entry_path.file_name().and_then(|n| n.to_str()) { if name.starts_with('.') { continue; } - + // Check if name matches query if name.to_lowercase().contains(query) { - let metadata = entry.metadata() + let metadata = entry + .metadata() .map_err(|e| format!("Failed to read metadata: {}", e))?; - + let extension = if metadata.is_file() { - entry_path.extension() + entry_path + .extension() .and_then(|e| e.to_str()) .map(|e| e.to_string()) } else { None }; - + results.push(FileEntry { name: name.to_string(), path: entry_path.to_string_lossy().to_string(), @@ -1350,20 +1446,23 @@ fn search_files_recursive( }); } } - + // Recurse into directories if entry_path.is_dir() { // Skip common directories that shouldn't be searched if let Some(dir_name) = entry_path.file_name().and_then(|n| n.to_str()) { - if matches!(dir_name, "node_modules" | "target" | ".git" | "dist" | "build" | ".next" | "__pycache__") { + if matches!( + dir_name, + "node_modules" | "target" | ".git" | "dist" | "build" | ".next" | "__pycache__" + ) { continue; } } - + search_files_recursive(&entry_path, base_path, query, results, depth + 1)?; } } - + Ok(()) } @@ -1377,26 +1476,33 @@ pub async fn create_checkpoint( message_index: Option, description: Option, ) -> Result { - log::info!("Creating checkpoint for session: {} in project: {}", session_id, project_id); - - let manager = app.get_or_create_manager( - session_id.clone(), - project_id.clone(), - PathBuf::from(&project_path), - ).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; - + log::info!( + "Creating checkpoint for session: {} in project: {}", + session_id, + project_id + ); + + let manager = app + .get_or_create_manager( + session_id.clone(), + project_id.clone(), + PathBuf::from(&project_path), + ) + .await + .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; + // Always load current session messages from the JSONL file let session_path = get_claude_dir() .map_err(|e| e.to_string())? .join("projects") .join(&project_id) .join(format!("{}.jsonl", session_id)); - + if session_path.exists() { let file = fs::File::open(&session_path) .map_err(|e| format!("Failed to open session file: {}", e))?; let reader = BufReader::new(file); - + let mut line_count = 0; for line in reader.lines() { if let Some(index) = message_index { @@ -1405,14 +1511,18 @@ pub async fn create_checkpoint( } } if let Ok(line) = line { - manager.track_message(line).await + manager + .track_message(line) + .await .map_err(|e| format!("Failed to track message: {}", e))?; } line_count += 1; } } - - manager.create_checkpoint(description, None).await + + manager + .create_checkpoint(description, None) + .await .map_err(|e| format!("Failed to create checkpoint: {}", e)) } @@ -1425,35 +1535,43 @@ pub async fn restore_checkpoint( project_id: String, project_path: String, ) -> Result { - log::info!("Restoring checkpoint: {} for session: {}", checkpoint_id, session_id); - - let manager = app.get_or_create_manager( - session_id.clone(), - project_id.clone(), - PathBuf::from(&project_path), - ).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; - - let result = manager.restore_checkpoint(&checkpoint_id).await + log::info!( + "Restoring checkpoint: {} for session: {}", + checkpoint_id, + session_id + ); + + let manager = app + .get_or_create_manager( + session_id.clone(), + project_id.clone(), + PathBuf::from(&project_path), + ) + .await + .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; + + let result = manager + .restore_checkpoint(&checkpoint_id) + .await .map_err(|e| format!("Failed to restore checkpoint: {}", e))?; - + // Update the session JSONL file with restored messages let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; let session_path = claude_dir .join("projects") .join(&result.checkpoint.project_id) .join(format!("{}.jsonl", session_id)); - + // The manager has already restored the messages internally, // but we need to update the actual session file - let (_, _, messages) = manager.storage.load_checkpoint( - &result.checkpoint.project_id, - &session_id, - &checkpoint_id, - ).map_err(|e| format!("Failed to load checkpoint data: {}", e))?; - + let (_, _, messages) = manager + .storage + .load_checkpoint(&result.checkpoint.project_id, &session_id, &checkpoint_id) + .map_err(|e| format!("Failed to load checkpoint data: {}", e))?; + fs::write(&session_path, messages) .map_err(|e| format!("Failed to update session file: {}", e))?; - + Ok(result) } @@ -1465,14 +1583,17 @@ pub async fn list_checkpoints( project_id: String, project_path: String, ) -> Result, String> { - log::info!("Listing checkpoints for session: {} in project: {}", session_id, project_id); - - let manager = app.get_or_create_manager( + log::info!( + "Listing checkpoints for session: {} in project: {}", session_id, - project_id, - PathBuf::from(&project_path), - ).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; - + project_id + ); + + let manager = app + .get_or_create_manager(session_id, project_id, PathBuf::from(&project_path)) + .await + .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; + Ok(manager.list_checkpoints().await) } @@ -1487,10 +1608,14 @@ pub async fn fork_from_checkpoint( new_session_id: String, description: Option, ) -> Result { - log::info!("Forking from checkpoint: {} to new session: {}", checkpoint_id, new_session_id); - + log::info!( + "Forking from checkpoint: {} to new session: {}", + checkpoint_id, + new_session_id + ); + let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; - + // First, copy the session file to the new session let source_session_path = claude_dir .join("projects") @@ -1500,20 +1625,25 @@ pub async fn fork_from_checkpoint( .join("projects") .join(&project_id) .join(format!("{}.jsonl", new_session_id)); - + if source_session_path.exists() { fs::copy(&source_session_path, &new_session_path) .map_err(|e| format!("Failed to copy session file: {}", e))?; } - + // Create manager for the new session - let manager = app.get_or_create_manager( - new_session_id.clone(), - project_id, - PathBuf::from(&project_path), - ).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; - - manager.fork_from_checkpoint(&checkpoint_id, description).await + let manager = app + .get_or_create_manager( + new_session_id.clone(), + project_id, + PathBuf::from(&project_path), + ) + .await + .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; + + manager + .fork_from_checkpoint(&checkpoint_id, description) + .await .map_err(|e| format!("Failed to fork checkpoint: {}", e)) } @@ -1525,14 +1655,17 @@ pub async fn get_session_timeline( project_id: String, project_path: String, ) -> Result { - log::info!("Getting timeline for session: {} in project: {}", session_id, project_id); - - let manager = app.get_or_create_manager( + log::info!( + "Getting timeline for session: {} in project: {}", session_id, - project_id, - PathBuf::from(&project_path), - ).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; - + project_id + ); + + let manager = app + .get_or_create_manager(session_id, project_id, PathBuf::from(&project_path)) + .await + .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; + Ok(manager.get_timeline().await) } @@ -1547,24 +1680,30 @@ pub async fn update_checkpoint_settings( checkpoint_strategy: String, ) -> Result<(), String> { use crate::checkpoint::CheckpointStrategy; - + log::info!("Updating checkpoint settings for session: {}", session_id); - + let strategy = match checkpoint_strategy.as_str() { "manual" => CheckpointStrategy::Manual, "per_prompt" => CheckpointStrategy::PerPrompt, "per_tool_use" => CheckpointStrategy::PerToolUse, "smart" => CheckpointStrategy::Smart, - _ => return Err(format!("Invalid checkpoint strategy: {}", checkpoint_strategy)), + _ => { + return Err(format!( + "Invalid checkpoint strategy: {}", + checkpoint_strategy + )) + } }; - - let manager = app.get_or_create_manager( - session_id, - project_id, - PathBuf::from(&project_path), - ).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; - - manager.update_settings(auto_checkpoint_enabled, strategy).await + + let manager = app + .get_or_create_manager(session_id, project_id, PathBuf::from(&project_path)) + .await + .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; + + manager + .update_settings(auto_checkpoint_enabled, strategy) + .await .map_err(|e| format!("Failed to update settings: {}", e)) } @@ -1577,34 +1716,42 @@ pub async fn get_checkpoint_diff( project_id: String, ) -> Result { use crate::checkpoint::storage::CheckpointStorage; - - log::info!("Getting diff between checkpoints: {} -> {}", from_checkpoint_id, to_checkpoint_id); - + + log::info!( + "Getting diff between checkpoints: {} -> {}", + from_checkpoint_id, + to_checkpoint_id + ); + let claude_dir = get_claude_dir().map_err(|e| e.to_string())?; let storage = CheckpointStorage::new(claude_dir); - + // Load both checkpoints - let (from_checkpoint, from_files, _) = storage.load_checkpoint(&project_id, &session_id, &from_checkpoint_id) + let (from_checkpoint, from_files, _) = storage + .load_checkpoint(&project_id, &session_id, &from_checkpoint_id) .map_err(|e| format!("Failed to load source checkpoint: {}", e))?; - let (to_checkpoint, to_files, _) = storage.load_checkpoint(&project_id, &session_id, &to_checkpoint_id) + let (to_checkpoint, to_files, _) = storage + .load_checkpoint(&project_id, &session_id, &to_checkpoint_id) .map_err(|e| format!("Failed to load target checkpoint: {}", e))?; - + // Build file maps - let mut from_map: std::collections::HashMap = std::collections::HashMap::new(); + let mut from_map: std::collections::HashMap = + std::collections::HashMap::new(); for file in &from_files { from_map.insert(file.file_path.clone(), file); } - - let mut to_map: std::collections::HashMap = std::collections::HashMap::new(); + + let mut to_map: std::collections::HashMap = + std::collections::HashMap::new(); for file in &to_files { to_map.insert(file.file_path.clone(), file); } - + // Calculate differences let mut modified_files = Vec::new(); let mut added_files = Vec::new(); let mut deleted_files = Vec::new(); - + // Check for modified and deleted files for (path, from_file) in &from_map { if let Some(to_file) = to_map.get(path) { @@ -1612,7 +1759,7 @@ pub async fn get_checkpoint_diff( // File was modified let additions = to_file.content.lines().count(); let deletions = from_file.content.lines().count(); - + modified_files.push(crate::checkpoint::FileDiff { path: path.clone(), additions, @@ -1625,17 +1772,18 @@ pub async fn get_checkpoint_diff( deleted_files.push(path.clone()); } } - + // Check for added files for (path, _) in &to_map { if !from_map.contains_key(path) { added_files.push(path.clone()); } } - + // Calculate token delta - let token_delta = (to_checkpoint.metadata.total_tokens as i64) - (from_checkpoint.metadata.total_tokens as i64); - + let token_delta = (to_checkpoint.metadata.total_tokens as i64) + - (from_checkpoint.metadata.total_tokens as i64); + Ok(crate::checkpoint::CheckpointDiff { from_checkpoint_id, to_checkpoint_id, @@ -1656,14 +1804,15 @@ pub async fn track_checkpoint_message( message: String, ) -> Result<(), String> { log::info!("Tracking message for session: {}", session_id); - - let manager = app.get_or_create_manager( - session_id, - project_id, - PathBuf::from(project_path), - ).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; - - manager.track_message(message).await + + let manager = app + .get_or_create_manager(session_id, project_id, PathBuf::from(project_path)) + .await + .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; + + manager + .track_message(message) + .await .map_err(|e| format!("Failed to track message: {}", e)) } @@ -1677,13 +1826,12 @@ pub async fn check_auto_checkpoint( message: String, ) -> Result { log::info!("Checking auto-checkpoint for session: {}", session_id); - - let manager = app.get_or_create_manager( - session_id.clone(), - project_id, - PathBuf::from(project_path), - ).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; - + + let manager = app + .get_or_create_manager(session_id.clone(), project_id, PathBuf::from(project_path)) + .await + .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; + Ok(manager.should_auto_checkpoint(&message).await) } @@ -1696,15 +1844,24 @@ pub async fn cleanup_old_checkpoints( project_path: String, keep_count: usize, ) -> Result { - log::info!("Cleaning up old checkpoints for session: {}, keeping {}", session_id, keep_count); - - let manager = app.get_or_create_manager( - session_id.clone(), - project_id.clone(), - PathBuf::from(project_path), - ).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; - - manager.storage.cleanup_old_checkpoints(&project_id, &session_id, keep_count) + log::info!( + "Cleaning up old checkpoints for session: {}, keeping {}", + session_id, + keep_count + ); + + let manager = app + .get_or_create_manager( + session_id.clone(), + project_id.clone(), + PathBuf::from(project_path), + ) + .await + .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; + + manager + .storage + .cleanup_old_checkpoints(&project_id, &session_id, keep_count) .map_err(|e| format!("Failed to cleanup checkpoints: {}", e)) } @@ -1717,15 +1874,14 @@ pub async fn get_checkpoint_settings( project_path: String, ) -> Result { log::info!("Getting checkpoint settings for session: {}", session_id); - - let manager = app.get_or_create_manager( - session_id, - project_id, - PathBuf::from(project_path), - ).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; - + + let manager = app + .get_or_create_manager(session_id, project_id, PathBuf::from(project_path)) + .await + .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; + let timeline = manager.get_timeline().await; - + Ok(serde_json::json!({ "auto_checkpoint_enabled": timeline.auto_checkpoint_enabled, "checkpoint_strategy": timeline.checkpoint_strategy, @@ -1741,7 +1897,7 @@ pub async fn clear_checkpoint_manager( session_id: String, ) -> Result<(), String> { log::info!("Clearing checkpoint manager for session: {}", session_id); - + app.remove_manager(&session_id).await; Ok(()) } @@ -1753,7 +1909,7 @@ pub async fn get_checkpoint_state_stats( ) -> Result { let active_count = app.active_count().await; let active_sessions = app.list_active_sessions().await; - + Ok(serde_json::json!({ "active_managers": active_count, "active_sessions": active_sessions, @@ -1770,24 +1926,28 @@ pub async fn get_recently_modified_files( minutes: i64, ) -> Result, String> { use chrono::{Duration, Utc}; - - log::info!("Getting files modified in the last {} minutes for session: {}", minutes, session_id); - - let manager = app.get_or_create_manager( - session_id, - project_id, - PathBuf::from(project_path), - ).await.map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; - + + log::info!( + "Getting files modified in the last {} minutes for session: {}", + minutes, + session_id + ); + + let manager = app + .get_or_create_manager(session_id, project_id, PathBuf::from(project_path)) + .await + .map_err(|e| format!("Failed to get checkpoint manager: {}", e))?; + let since = Utc::now() - Duration::minutes(minutes); let modified_files = manager.get_files_modified_since(since).await; - + // Also log the last modification time if let Some(last_mod) = manager.get_last_modification_time().await { log::info!("Last file modification was at: {}", last_mod); } - - Ok(modified_files.into_iter() + + Ok(modified_files + .into_iter() .map(|p| p.to_string_lossy().to_string()) .collect()) } @@ -1801,12 +1961,17 @@ pub async fn track_session_messages( project_path: String, messages: Vec, ) -> Result<(), String> { - let mgr = state.get_or_create_manager( - session_id, project_id, std::path::PathBuf::from(project_path) - ).await.map_err(|e| e.to_string())?; + let mgr = state + .get_or_create_manager( + session_id, + project_id, + std::path::PathBuf::from(project_path), + ) + .await + .map_err(|e| e.to_string())?; for m in messages { mgr.track_message(m).await.map_err(|e| e.to_string())?; } Ok(()) -} \ No newline at end of file +} diff --git a/src-tauri/src/commands/mcp.rs b/src-tauri/src/commands/mcp.rs index 3f0d400..2db974f 100644 --- a/src-tauri/src/commands/mcp.rs +++ b/src-tauri/src/commands/mcp.rs @@ -1,12 +1,12 @@ -use tauri::AppHandle; use anyhow::{Context, Result}; +use dirs; +use log::{error, info}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs; use std::path::PathBuf; use std::process::Command; -use log::{info, error}; -use dirs; +use tauri::AppHandle; /// Helper function to create a std::process::Command with proper environment variables /// This ensures commands like Claude can find Node.js and other dependencies @@ -17,8 +17,7 @@ fn create_command_with_env(program: &str) -> Command { /// Finds the full path to the claude binary /// This is necessary because macOS apps have a limited PATH environment fn find_claude_binary(app_handle: &AppHandle) -> Result { - crate::claude_binary::find_claude_binary(app_handle) - .map_err(|e| anyhow::anyhow!(e)) + crate::claude_binary::find_claude_binary(app_handle).map_err(|e| anyhow::anyhow!(e)) } /// Represents an MCP server configuration @@ -99,17 +98,16 @@ pub struct ImportServerResult { /// Executes a claude mcp command fn execute_claude_mcp_command(app_handle: &AppHandle, args: Vec<&str>) -> Result { info!("Executing claude mcp command with args: {:?}", args); - + let claude_path = find_claude_binary(app_handle)?; let mut cmd = create_command_with_env(&claude_path); cmd.arg("mcp"); for arg in args { cmd.arg(arg); } - - let output = cmd.output() - .context("Failed to execute claude command")?; - + + let output = cmd.output().context("Failed to execute claude command")?; + if output.status.success() { Ok(String::from_utf8_lossy(&output.stdout).to_string()) } else { @@ -131,33 +129,34 @@ pub async fn mcp_add( scope: String, ) -> Result { info!("Adding MCP server: {} with transport: {}", name, transport); - + // Prepare owned strings for environment variables - let env_args: Vec = env.iter() + let env_args: Vec = env + .iter() .map(|(key, value)| format!("{}={}", key, value)) .collect(); - + let mut cmd_args = vec!["add"]; - + // Add scope flag cmd_args.push("-s"); cmd_args.push(&scope); - + // Add transport flag for SSE if transport == "sse" { cmd_args.push("--transport"); cmd_args.push("sse"); } - + // Add environment variables for (i, _) in env.iter().enumerate() { cmd_args.push("-e"); cmd_args.push(&env_args[i]); } - + // Add name cmd_args.push(&name); - + // Add command/URL based on transport if transport == "stdio" { if let Some(cmd) = &command { @@ -188,7 +187,7 @@ pub async fn mcp_add( }); } } - + match execute_claude_mcp_command(&app, cmd_args) { Ok(output) => { info!("Successfully added MCP server: {}", name); @@ -213,19 +212,19 @@ pub async fn mcp_add( #[tauri::command] pub async fn mcp_list(app: AppHandle) -> Result, String> { info!("Listing MCP servers"); - + match execute_claude_mcp_command(&app, vec!["list"]) { Ok(output) => { info!("Raw output from 'claude mcp list': {:?}", output); let trimmed = output.trim(); info!("Trimmed output: {:?}", trimmed); - + // Check if no servers are configured if trimmed.contains("No MCP servers configured") || trimmed.is_empty() { info!("No servers found - empty or 'No MCP servers' message"); return Ok(vec![]); } - + // Parse the text output, handling multi-line commands let mut servers = Vec::new(); let lines: Vec<&str> = trimmed.lines().collect(); @@ -233,13 +232,13 @@ pub async fn mcp_list(app: AppHandle) -> Result, String> { for (idx, line) in lines.iter().enumerate() { info!("Line {}: {:?}", idx, line); } - + let mut i = 0; - + while i < lines.len() { let line = lines[i]; info!("Processing line {}: {:?}", i, line); - + // Check if this line starts a new server entry if let Some(colon_pos) = line.find(':') { info!("Found colon at position {} in line: {:?}", colon_pos, line); @@ -247,26 +246,31 @@ pub async fn mcp_list(app: AppHandle) -> Result, String> { // Server names typically don't contain '/' or '\' let potential_name = line[..colon_pos].trim(); info!("Potential server name: {:?}", potential_name); - + if !potential_name.contains('/') && !potential_name.contains('\\') { info!("Valid server name detected: {:?}", potential_name); let name = potential_name.to_string(); let mut command_parts = vec![line[colon_pos + 1..].trim().to_string()]; info!("Initial command part: {:?}", command_parts[0]); - + // Check if command continues on next lines i += 1; while i < lines.len() { let next_line = lines[i]; info!("Checking next line {} for continuation: {:?}", i, next_line); - + // If the next line starts with a server name pattern, break if next_line.contains(':') { - let potential_next_name = next_line.split(':').next().unwrap_or("").trim(); - info!("Found colon in next line, potential name: {:?}", potential_next_name); - if !potential_next_name.is_empty() && - !potential_next_name.contains('/') && - !potential_next_name.contains('\\') { + let potential_next_name = + next_line.split(':').next().unwrap_or("").trim(); + info!( + "Found colon in next line, potential name: {:?}", + potential_next_name + ); + if !potential_next_name.is_empty() + && !potential_next_name.contains('/') + && !potential_next_name.contains('\\') + { info!("Next line is a new server, breaking"); break; } @@ -276,11 +280,11 @@ pub async fn mcp_list(app: AppHandle) -> Result, String> { command_parts.push(next_line.trim().to_string()); i += 1; } - + // Join all command parts let full_command = command_parts.join(" "); info!("Full command for server '{}': {:?}", name, full_command); - + // For now, we'll create a basic server entry servers.push(MCPServer { name: name.clone(), @@ -298,7 +302,7 @@ pub async fn mcp_list(app: AppHandle) -> Result, String> { }, }); info!("Added server: {:?}", name); - + continue; } else { info!("Skipping line - name contains path separators"); @@ -306,13 +310,16 @@ pub async fn mcp_list(app: AppHandle) -> Result, String> { } else { info!("No colon found in line {}", i); } - + i += 1; } - + info!("Found {} MCP servers total", servers.len()); for (idx, server) in servers.iter().enumerate() { - info!("Server {}: name='{}', command={:?}", idx, server.name, server.command); + info!( + "Server {}: name='{}', command={:?}", + idx, server.name, server.command + ); } Ok(servers) } @@ -327,7 +334,7 @@ pub async fn mcp_list(app: AppHandle) -> Result, String> { #[tauri::command] pub async fn mcp_get(app: AppHandle, name: String) -> Result { info!("Getting MCP server details for: {}", name); - + match execute_claude_mcp_command(&app, vec!["get", &name]) { Ok(output) => { // Parse the structured text output @@ -337,17 +344,19 @@ pub async fn mcp_get(app: AppHandle, name: String) -> Result let mut args = vec![]; let env = HashMap::new(); let mut url = None; - + for line in output.lines() { let line = line.trim(); - + if line.starts_with("Scope:") { let scope_part = line.replace("Scope:", "").trim().to_string(); if scope_part.to_lowercase().contains("local") { scope = "local".to_string(); } else if scope_part.to_lowercase().contains("project") { scope = "project".to_string(); - } else if scope_part.to_lowercase().contains("user") || scope_part.to_lowercase().contains("global") { + } else if scope_part.to_lowercase().contains("user") + || scope_part.to_lowercase().contains("global") + { scope = "user".to_string(); } } else if line.starts_with("Type:") { @@ -366,7 +375,7 @@ pub async fn mcp_get(app: AppHandle, name: String) -> Result // For now, we'll leave it empty } } - + Ok(MCPServer { name, transport, @@ -394,7 +403,7 @@ pub async fn mcp_get(app: AppHandle, name: String) -> Result #[tauri::command] pub async fn mcp_remove(app: AppHandle, name: String) -> Result { info!("Removing MCP server: {}", name); - + match execute_claude_mcp_command(&app, vec!["remove", &name]) { Ok(output) => { info!("Successfully removed MCP server: {}", name); @@ -409,17 +418,25 @@ pub async fn mcp_remove(app: AppHandle, name: String) -> Result /// Adds an MCP server from JSON configuration #[tauri::command] -pub async fn mcp_add_json(app: AppHandle, name: String, json_config: String, scope: String) -> Result { - info!("Adding MCP server from JSON: {} with scope: {}", name, scope); - +pub async fn mcp_add_json( + app: AppHandle, + name: String, + json_config: String, + scope: String, +) -> Result { + info!( + "Adding MCP server from JSON: {} with scope: {}", + name, scope + ); + // Build command args let mut cmd_args = vec!["add-json", &name, &json_config]; - + // Add scope flag let scope_flag = "-s"; cmd_args.push(scope_flag); cmd_args.push(&scope); - + match execute_claude_mcp_command(&app, cmd_args) { Ok(output) => { info!("Successfully added MCP server from JSON: {}", name); @@ -442,9 +459,15 @@ pub async fn mcp_add_json(app: AppHandle, name: String, json_config: String, sco /// Imports MCP servers from Claude Desktop #[tauri::command] -pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Result { - info!("Importing MCP servers from Claude Desktop with scope: {}", scope); - +pub async fn mcp_add_from_claude_desktop( + app: AppHandle, + scope: String, +) -> Result { + info!( + "Importing MCP servers from Claude Desktop with scope: {}", + scope + ); + // Get Claude Desktop config path based on platform let config_path = if cfg!(target_os = "macos") { dirs::home_dir() @@ -460,43 +483,55 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul .join("Claude") .join("claude_desktop_config.json") } else { - return Err("Import from Claude Desktop is only supported on macOS and Linux/WSL".to_string()); + return Err( + "Import from Claude Desktop is only supported on macOS and Linux/WSL".to_string(), + ); }; - + // Check if config file exists if !config_path.exists() { - return Err("Claude Desktop configuration not found. Make sure Claude Desktop is installed.".to_string()); + return Err( + "Claude Desktop configuration not found. Make sure Claude Desktop is installed." + .to_string(), + ); } - + // Read and parse the config file let config_content = fs::read_to_string(&config_path) .map_err(|e| format!("Failed to read Claude Desktop config: {}", e))?; - + let config: serde_json::Value = serde_json::from_str(&config_content) .map_err(|e| format!("Failed to parse Claude Desktop config: {}", e))?; - + // Extract MCP servers - let mcp_servers = config.get("mcpServers") + let mcp_servers = config + .get("mcpServers") .and_then(|v| v.as_object()) .ok_or_else(|| "No MCP servers found in Claude Desktop config".to_string())?; - + let mut imported_count = 0; let mut failed_count = 0; let mut server_results = Vec::new(); - + // Import each server using add-json for (name, server_config) in mcp_servers { info!("Importing server: {}", name); - + // Convert Claude Desktop format to add-json format let mut json_config = serde_json::Map::new(); - + // All Claude Desktop servers are stdio type - json_config.insert("type".to_string(), serde_json::Value::String("stdio".to_string())); - + json_config.insert( + "type".to_string(), + serde_json::Value::String("stdio".to_string()), + ); + // Add command if let Some(command) = server_config.get("command").and_then(|v| v.as_str()) { - json_config.insert("command".to_string(), serde_json::Value::String(command.to_string())); + json_config.insert( + "command".to_string(), + serde_json::Value::String(command.to_string()), + ); } else { failed_count += 1; server_results.push(ImportServerResult { @@ -506,25 +541,28 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul }); continue; } - + // Add args if present if let Some(args) = server_config.get("args").and_then(|v| v.as_array()) { json_config.insert("args".to_string(), args.clone().into()); } else { json_config.insert("args".to_string(), serde_json::Value::Array(vec![])); } - + // Add env if present if let Some(env) = server_config.get("env").and_then(|v| v.as_object()) { json_config.insert("env".to_string(), env.clone().into()); } else { - json_config.insert("env".to_string(), serde_json::Value::Object(serde_json::Map::new())); + json_config.insert( + "env".to_string(), + serde_json::Value::Object(serde_json::Map::new()), + ); } - + // Convert to JSON string let json_str = serde_json::to_string(&json_config) .map_err(|e| format!("Failed to serialize config for {}: {}", name, e))?; - + // Call add-json command match mcp_add_json(app.clone(), name.clone(), json_str, scope.clone()).await { Ok(result) => { @@ -559,9 +597,12 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul } } } - - info!("Import complete: {} imported, {} failed", imported_count, failed_count); - + + info!( + "Import complete: {} imported, {} failed", + imported_count, failed_count + ); + Ok(ImportResult { imported_count, failed_count, @@ -573,7 +614,7 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul #[tauri::command] pub async fn mcp_serve(app: AppHandle) -> Result { info!("Starting Claude Code as MCP server"); - + // Start the server in a separate process let claude_path = match find_claude_binary(&app) { Ok(path) => path, @@ -582,10 +623,10 @@ pub async fn mcp_serve(app: AppHandle) -> Result { return Err(e.to_string()); } }; - + let mut cmd = create_command_with_env(&claude_path); cmd.arg("mcp").arg("serve"); - + match cmd.spawn() { Ok(_) => { info!("Successfully started Claude Code MCP server"); @@ -602,7 +643,7 @@ pub async fn mcp_serve(app: AppHandle) -> Result { #[tauri::command] pub async fn mcp_test_connection(app: AppHandle, name: String) -> Result { info!("Testing connection to MCP server: {}", name); - + // For now, we'll use the get command to test if the server exists match execute_claude_mcp_command(&app, vec!["get", &name]) { Ok(_) => Ok(format!("Connection to {} successful", name)), @@ -614,7 +655,7 @@ pub async fn mcp_test_connection(app: AppHandle, name: String) -> Result Result { info!("Resetting MCP project choices"); - + match execute_claude_mcp_command(&app, vec!["reset-project-choices"]) { Ok(output) => { info!("Successfully reset MCP project choices"); @@ -631,7 +672,7 @@ pub async fn mcp_reset_project_choices(app: AppHandle) -> Result #[tauri::command] pub async fn mcp_get_server_status() -> Result, String> { info!("Getting MCP server status"); - + // TODO: Implement actual status checking // For now, return empty status Ok(HashMap::new()) @@ -641,25 +682,23 @@ pub async fn mcp_get_server_status() -> Result, St #[tauri::command] pub async fn mcp_read_project_config(project_path: String) -> Result { info!("Reading .mcp.json from project: {}", project_path); - + let mcp_json_path = PathBuf::from(&project_path).join(".mcp.json"); - + if !mcp_json_path.exists() { return Ok(MCPProjectConfig { mcp_servers: HashMap::new(), }); } - + match fs::read_to_string(&mcp_json_path) { - Ok(content) => { - match serde_json::from_str::(&content) { - Ok(config) => Ok(config), - Err(e) => { - error!("Failed to parse .mcp.json: {}", e); - Err(format!("Failed to parse .mcp.json: {}", e)) - } + Ok(content) => match serde_json::from_str::(&content) { + Ok(config) => Ok(config), + Err(e) => { + error!("Failed to parse .mcp.json: {}", e); + Err(format!("Failed to parse .mcp.json: {}", e)) } - } + }, Err(e) => { error!("Failed to read .mcp.json: {}", e); Err(format!("Failed to read .mcp.json: {}", e)) @@ -674,14 +713,14 @@ pub async fn mcp_save_project_config( config: MCPProjectConfig, ) -> Result { info!("Saving .mcp.json to project: {}", project_path); - + let mcp_json_path = PathBuf::from(&project_path).join(".mcp.json"); - + let json_content = serde_json::to_string_pretty(&config) .map_err(|e| format!("Failed to serialize config: {}", e))?; - + fs::write(&mcp_json_path, json_content) .map_err(|e| format!("Failed to write .mcp.json: {}", e))?; - + Ok("Project MCP configuration saved".to_string()) -} \ No newline at end of file +} diff --git a/src-tauri/src/commands/mod.rs b/src-tauri/src/commands/mod.rs index 95384bd..fd44c79 100644 --- a/src-tauri/src/commands/mod.rs +++ b/src-tauri/src/commands/mod.rs @@ -1,6 +1,6 @@ -pub mod claude; pub mod agents; -pub mod sandbox; -pub mod usage; +pub mod claude; pub mod mcp; -pub mod screenshot; \ No newline at end of file +pub mod sandbox; +pub mod screenshot; +pub mod usage; diff --git a/src-tauri/src/commands/sandbox.rs b/src-tauri/src/commands/sandbox.rs index 1413cee..63224ab 100644 --- a/src-tauri/src/commands/sandbox.rs +++ b/src-tauri/src/commands/sandbox.rs @@ -52,11 +52,11 @@ pub struct ImportResult { #[tauri::command] pub async fn list_sandbox_profiles(db: State<'_, AgentDb>) -> Result, String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + let mut stmt = conn .prepare("SELECT id, name, description, is_active, is_default, created_at, updated_at FROM sandbox_profiles ORDER BY name") .map_err(|e| e.to_string())?; - + let profiles = stmt .query_map([], |row| { Ok(SandboxProfile { @@ -72,7 +72,7 @@ pub async fn list_sandbox_profiles(db: State<'_, AgentDb>) -> Result, _>>() .map_err(|e| e.to_string())?; - + Ok(profiles) } @@ -84,15 +84,15 @@ pub async fn create_sandbox_profile( description: Option, ) -> Result { let conn = db.0.lock().map_err(|e| e.to_string())?; - + conn.execute( "INSERT INTO sandbox_profiles (name, description) VALUES (?1, ?2)", params![name, description], ) .map_err(|e| e.to_string())?; - + let id = conn.last_insert_rowid(); - + // Fetch the created profile let profile = conn .query_row( @@ -111,7 +111,7 @@ pub async fn create_sandbox_profile( }, ) .map_err(|e| e.to_string())?; - + Ok(profile) } @@ -126,7 +126,7 @@ pub async fn update_sandbox_profile( is_default: bool, ) -> Result { let conn = db.0.lock().map_err(|e| e.to_string())?; - + // If setting as default, unset other defaults if is_default { conn.execute( @@ -135,13 +135,13 @@ pub async fn update_sandbox_profile( ) .map_err(|e| e.to_string())?; } - + conn.execute( "UPDATE sandbox_profiles SET name = ?1, description = ?2, is_active = ?3, is_default = ?4 WHERE id = ?5", params![name, description, is_active, is_default, id], ) .map_err(|e| e.to_string())?; - + // Fetch the updated profile let profile = conn .query_row( @@ -160,7 +160,7 @@ pub async fn update_sandbox_profile( }, ) .map_err(|e| e.to_string())?; - + Ok(profile) } @@ -168,7 +168,7 @@ pub async fn update_sandbox_profile( #[tauri::command] pub async fn delete_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result<(), String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + // Check if it's the default profile let is_default: bool = conn .query_row( @@ -177,22 +177,25 @@ pub async fn delete_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result<( |row| row.get(0), ) .map_err(|e| e.to_string())?; - + if is_default { return Err("Cannot delete the default profile".to_string()); } - + conn.execute("DELETE FROM sandbox_profiles WHERE id = ?1", params![id]) .map_err(|e| e.to_string())?; - + Ok(()) } /// Get a single sandbox profile by ID #[tauri::command] -pub async fn get_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result { +pub async fn get_sandbox_profile( + db: State<'_, AgentDb>, + id: i64, +) -> Result { let conn = db.0.lock().map_err(|e| e.to_string())?; - + let profile = conn .query_row( "SELECT id, name, description, is_active, is_default, created_at, updated_at FROM sandbox_profiles WHERE id = ?1", @@ -210,7 +213,7 @@ pub async fn get_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result Result, String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + let mut stmt = conn .prepare("SELECT id, profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support, created_at FROM sandbox_rules WHERE profile_id = ?1 ORDER BY operation_type, pattern_value") .map_err(|e| e.to_string())?; - + let rules = stmt .query_map(params![profile_id], |row| { Ok(SandboxRule { @@ -242,7 +245,7 @@ pub async fn list_sandbox_rules( .map_err(|e| e.to_string())? .collect::, _>>() .map_err(|e| e.to_string())?; - + Ok(rules) } @@ -258,18 +261,18 @@ pub async fn create_sandbox_rule( platform_support: Option, ) -> Result { let conn = db.0.lock().map_err(|e| e.to_string())?; - + // Validate rule doesn't conflict // TODO: Add more validation logic here - + conn.execute( "INSERT INTO sandbox_rules (profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", params![profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support], ) .map_err(|e| e.to_string())?; - + let id = conn.last_insert_rowid(); - + // Fetch the created rule let rule = conn .query_row( @@ -289,7 +292,7 @@ pub async fn create_sandbox_rule( }, ) .map_err(|e| e.to_string())?; - + Ok(rule) } @@ -305,13 +308,13 @@ pub async fn update_sandbox_rule( platform_support: Option, ) -> Result { let conn = db.0.lock().map_err(|e| e.to_string())?; - + conn.execute( "UPDATE sandbox_rules SET operation_type = ?1, pattern_type = ?2, pattern_value = ?3, enabled = ?4, platform_support = ?5 WHERE id = ?6", params![operation_type, pattern_type, pattern_value, enabled, platform_support, id], ) .map_err(|e| e.to_string())?; - + // Fetch the updated rule let rule = conn .query_row( @@ -331,7 +334,7 @@ pub async fn update_sandbox_rule( }, ) .map_err(|e| e.to_string())?; - + Ok(rule) } @@ -339,10 +342,10 @@ pub async fn update_sandbox_rule( #[tauri::command] pub async fn delete_sandbox_rule(db: State<'_, AgentDb>, id: i64) -> Result<(), String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + conn.execute("DELETE FROM sandbox_rules WHERE id = ?1", params![id]) .map_err(|e| e.to_string())?; - + Ok(()) } @@ -359,38 +362,38 @@ pub async fn test_sandbox_profile( profile_id: i64, ) -> Result { let conn = db.0.lock().map_err(|e| e.to_string())?; - + // Load the profile and rules let profile = crate::sandbox::profile::load_profile(&conn, profile_id) .map_err(|e| format!("Failed to load profile: {}", e))?; - + if !profile.is_active { return Ok(format!( "Profile '{}' is currently inactive. Activate it to use with agents.", profile.name )); } - + let rules = crate::sandbox::profile::load_profile_rules(&conn, profile_id) .map_err(|e| format!("Failed to load profile rules: {}", e))?; - + if rules.is_empty() { return Ok(format!( "Profile '{}' has no rules configured. Add rules to define sandbox permissions.", profile.name )); } - + // Try to build the gaol profile - let test_path = std::env::current_dir() - .unwrap_or_else(|_| std::path::PathBuf::from("/tmp")); - + let test_path = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("/tmp")); + let builder = crate::sandbox::profile::ProfileBuilder::new(test_path.clone()) .map_err(|e| format!("Failed to create profile builder: {}", e))?; - - let build_result = builder.build_profile_with_serialization(rules.clone()) + + let build_result = builder + .build_profile_with_serialization(rules.clone()) .map_err(|e| format!("Failed to build sandbox profile: {}", e))?; - + // Check platform support let platform_caps = crate::sandbox::platform::get_platform_capabilities(); if !platform_caps.sandboxing_supported { @@ -401,27 +404,23 @@ pub async fn test_sandbox_profile( platform_caps.os )); } - + // Try to execute a simple command in the sandbox let executor = crate::sandbox::executor::SandboxExecutor::new_with_serialization( - build_result.profile, + build_result.profile, test_path.clone(), - build_result.serialized + build_result.serialized, ); - + // Use a simple echo command for testing - let test_command = if cfg!(windows) { - "cmd" - } else { - "echo" - }; - + let test_command = if cfg!(windows) { "cmd" } else { "echo" }; + let test_args = if cfg!(windows) { vec!["/C", "echo", "sandbox test successful"] } else { vec!["sandbox test successful"] }; - + match executor.execute_sandboxed_spawn(test_command, &test_args, &test_path) { Ok(mut child) => { // Wait for the process to complete with a timeout @@ -452,19 +451,17 @@ pub async fn test_sandbox_profile( )) } } - Err(e) => { - Ok(format!( - "⚠️ Profile '{}' validated with warnings.\n\n\ + Err(e) => Ok(format!( + "⚠️ Profile '{}' validated with warnings.\n\n\ • {} rules loaded and validated\n\ • Sandbox activation: Partial\n\ • Test process: Could not get exit status ({})\n\ • Platform: {}", - profile.name, - rules.len(), - e, - platform_caps.os - )) - } + profile.name, + rules.len(), + e, + platform_caps.os + )), } } Err(e) => { @@ -509,176 +506,200 @@ pub async fn list_sandbox_violations( limit: Option, ) -> Result, String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + // Build dynamic query let mut query = String::from( "SELECT id, profile_id, agent_id, agent_run_id, operation_type, pattern_value, process_name, pid, denied_at FROM sandbox_violations WHERE 1=1" ); - + let mut param_idx = 1; - + if profile_id.is_some() { query.push_str(&format!(" AND profile_id = ?{}", param_idx)); param_idx += 1; } - + if agent_id.is_some() { query.push_str(&format!(" AND agent_id = ?{}", param_idx)); param_idx += 1; } - + query.push_str(" ORDER BY denied_at DESC"); - + if limit.is_some() { query.push_str(&format!(" LIMIT ?{}", param_idx)); } - + // Execute query based on parameters let violations: Vec = if let Some(pid) = profile_id { if let Some(aid) = agent_id { if let Some(lim) = limit { // All three parameters let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; - let rows = stmt.query_map(params![pid, aid, lim], |row| { - Ok(SandboxViolation { - id: Some(row.get(0)?), - profile_id: row.get(1)?, - agent_id: row.get(2)?, - agent_run_id: row.get(3)?, - operation_type: row.get(4)?, - pattern_value: row.get(5)?, - process_name: row.get(6)?, - pid: row.get(7)?, - denied_at: row.get(8)?, + let rows = stmt + .query_map(params![pid, aid, lim], |row| { + Ok(SandboxViolation { + id: Some(row.get(0)?), + profile_id: row.get(1)?, + agent_id: row.get(2)?, + agent_run_id: row.get(3)?, + operation_type: row.get(4)?, + pattern_value: row.get(5)?, + process_name: row.get(6)?, + pid: row.get(7)?, + denied_at: row.get(8)?, + }) }) - }).map_err(|e| e.to_string())?; - rows.collect::, _>>().map_err(|e| e.to_string())? + .map_err(|e| e.to_string())?; + rows.collect::, _>>() + .map_err(|e| e.to_string())? } else { // profile_id and agent_id only let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; - let rows = stmt.query_map(params![pid, aid], |row| { - Ok(SandboxViolation { - id: Some(row.get(0)?), - profile_id: row.get(1)?, - agent_id: row.get(2)?, - agent_run_id: row.get(3)?, - operation_type: row.get(4)?, - pattern_value: row.get(5)?, - process_name: row.get(6)?, - pid: row.get(7)?, - denied_at: row.get(8)?, + let rows = stmt + .query_map(params![pid, aid], |row| { + Ok(SandboxViolation { + id: Some(row.get(0)?), + profile_id: row.get(1)?, + agent_id: row.get(2)?, + agent_run_id: row.get(3)?, + operation_type: row.get(4)?, + pattern_value: row.get(5)?, + process_name: row.get(6)?, + pid: row.get(7)?, + denied_at: row.get(8)?, + }) }) - }).map_err(|e| e.to_string())?; - rows.collect::, _>>().map_err(|e| e.to_string())? + .map_err(|e| e.to_string())?; + rows.collect::, _>>() + .map_err(|e| e.to_string())? } } else if let Some(lim) = limit { // profile_id and limit only let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; - let rows = stmt.query_map(params![pid, lim], |row| { - Ok(SandboxViolation { - id: Some(row.get(0)?), - profile_id: row.get(1)?, - agent_id: row.get(2)?, - agent_run_id: row.get(3)?, - operation_type: row.get(4)?, - pattern_value: row.get(5)?, - process_name: row.get(6)?, - pid: row.get(7)?, - denied_at: row.get(8)?, + let rows = stmt + .query_map(params![pid, lim], |row| { + Ok(SandboxViolation { + id: Some(row.get(0)?), + profile_id: row.get(1)?, + agent_id: row.get(2)?, + agent_run_id: row.get(3)?, + operation_type: row.get(4)?, + pattern_value: row.get(5)?, + process_name: row.get(6)?, + pid: row.get(7)?, + denied_at: row.get(8)?, + }) }) - }).map_err(|e| e.to_string())?; - rows.collect::, _>>().map_err(|e| e.to_string())? + .map_err(|e| e.to_string())?; + rows.collect::, _>>() + .map_err(|e| e.to_string())? } else { // profile_id only let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; - let rows = stmt.query_map(params![pid], |row| { - Ok(SandboxViolation { - id: Some(row.get(0)?), - profile_id: row.get(1)?, - agent_id: row.get(2)?, - agent_run_id: row.get(3)?, - operation_type: row.get(4)?, - pattern_value: row.get(5)?, - process_name: row.get(6)?, - pid: row.get(7)?, - denied_at: row.get(8)?, + let rows = stmt + .query_map(params![pid], |row| { + Ok(SandboxViolation { + id: Some(row.get(0)?), + profile_id: row.get(1)?, + agent_id: row.get(2)?, + agent_run_id: row.get(3)?, + operation_type: row.get(4)?, + pattern_value: row.get(5)?, + process_name: row.get(6)?, + pid: row.get(7)?, + denied_at: row.get(8)?, + }) }) - }).map_err(|e| e.to_string())?; - rows.collect::, _>>().map_err(|e| e.to_string())? + .map_err(|e| e.to_string())?; + rows.collect::, _>>() + .map_err(|e| e.to_string())? } } else if let Some(aid) = agent_id { if let Some(lim) = limit { // agent_id and limit only let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; - let rows = stmt.query_map(params![aid, lim], |row| { - Ok(SandboxViolation { - id: Some(row.get(0)?), - profile_id: row.get(1)?, - agent_id: row.get(2)?, - agent_run_id: row.get(3)?, - operation_type: row.get(4)?, - pattern_value: row.get(5)?, - process_name: row.get(6)?, - pid: row.get(7)?, - denied_at: row.get(8)?, + let rows = stmt + .query_map(params![aid, lim], |row| { + Ok(SandboxViolation { + id: Some(row.get(0)?), + profile_id: row.get(1)?, + agent_id: row.get(2)?, + agent_run_id: row.get(3)?, + operation_type: row.get(4)?, + pattern_value: row.get(5)?, + process_name: row.get(6)?, + pid: row.get(7)?, + denied_at: row.get(8)?, + }) }) - }).map_err(|e| e.to_string())?; - rows.collect::, _>>().map_err(|e| e.to_string())? + .map_err(|e| e.to_string())?; + rows.collect::, _>>() + .map_err(|e| e.to_string())? } else { // agent_id only let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; - let rows = stmt.query_map(params![aid], |row| { - Ok(SandboxViolation { - id: Some(row.get(0)?), - profile_id: row.get(1)?, - agent_id: row.get(2)?, - agent_run_id: row.get(3)?, - operation_type: row.get(4)?, - pattern_value: row.get(5)?, - process_name: row.get(6)?, - pid: row.get(7)?, - denied_at: row.get(8)?, + let rows = stmt + .query_map(params![aid], |row| { + Ok(SandboxViolation { + id: Some(row.get(0)?), + profile_id: row.get(1)?, + agent_id: row.get(2)?, + agent_run_id: row.get(3)?, + operation_type: row.get(4)?, + pattern_value: row.get(5)?, + process_name: row.get(6)?, + pid: row.get(7)?, + denied_at: row.get(8)?, + }) }) - }).map_err(|e| e.to_string())?; - rows.collect::, _>>().map_err(|e| e.to_string())? + .map_err(|e| e.to_string())?; + rows.collect::, _>>() + .map_err(|e| e.to_string())? } } else if let Some(lim) = limit { // limit only let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; - let rows = stmt.query_map(params![lim], |row| { - Ok(SandboxViolation { - id: Some(row.get(0)?), - profile_id: row.get(1)?, - agent_id: row.get(2)?, - agent_run_id: row.get(3)?, - operation_type: row.get(4)?, - pattern_value: row.get(5)?, - process_name: row.get(6)?, - pid: row.get(7)?, - denied_at: row.get(8)?, + let rows = stmt + .query_map(params![lim], |row| { + Ok(SandboxViolation { + id: Some(row.get(0)?), + profile_id: row.get(1)?, + agent_id: row.get(2)?, + agent_run_id: row.get(3)?, + operation_type: row.get(4)?, + pattern_value: row.get(5)?, + process_name: row.get(6)?, + pid: row.get(7)?, + denied_at: row.get(8)?, + }) }) - }).map_err(|e| e.to_string())?; - rows.collect::, _>>().map_err(|e| e.to_string())? + .map_err(|e| e.to_string())?; + rows.collect::, _>>() + .map_err(|e| e.to_string())? } else { // No parameters let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?; - let rows = stmt.query_map([], |row| { - Ok(SandboxViolation { - id: Some(row.get(0)?), - profile_id: row.get(1)?, - agent_id: row.get(2)?, - agent_run_id: row.get(3)?, - operation_type: row.get(4)?, - pattern_value: row.get(5)?, - process_name: row.get(6)?, - pid: row.get(7)?, - denied_at: row.get(8)?, + let rows = stmt + .query_map([], |row| { + Ok(SandboxViolation { + id: Some(row.get(0)?), + profile_id: row.get(1)?, + agent_id: row.get(2)?, + agent_run_id: row.get(3)?, + operation_type: row.get(4)?, + pattern_value: row.get(5)?, + process_name: row.get(6)?, + pid: row.get(7)?, + denied_at: row.get(8)?, + }) }) - }).map_err(|e| e.to_string())?; - rows.collect::, _>>().map_err(|e| e.to_string())? + .map_err(|e| e.to_string())?; + rows.collect::, _>>() + .map_err(|e| e.to_string())? }; - + Ok(violations) } @@ -695,14 +716,14 @@ pub async fn log_sandbox_violation( pid: Option, ) -> Result<(), String> { let conn = db.0.lock().map_err(|e| e.to_string())?; - + conn.execute( "INSERT INTO sandbox_violations (profile_id, agent_id, agent_run_id, operation_type, pattern_value, process_name, pid) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", params![profile_id, agent_id, agent_run_id, operation_type, pattern_value, process_name, pid], ) .map_err(|e| e.to_string())?; - + Ok(()) } @@ -713,7 +734,7 @@ pub async fn clear_sandbox_violations( older_than_days: Option, ) -> Result { let conn = db.0.lock().map_err(|e| e.to_string())?; - + let query = if let Some(days) = older_than_days { format!( "DELETE FROM sandbox_violations WHERE denied_at < datetime('now', '-{} days')", @@ -722,10 +743,9 @@ pub async fn clear_sandbox_violations( } else { "DELETE FROM sandbox_violations".to_string() }; - - let deleted = conn.execute(&query, []) - .map_err(|e| e.to_string())?; - + + let deleted = conn.execute(&query, []).map_err(|e| e.to_string())?; + Ok(deleted as i64) } @@ -735,28 +755,30 @@ pub async fn get_sandbox_violation_stats( db: State<'_, AgentDb>, ) -> Result { let conn = db.0.lock().map_err(|e| e.to_string())?; - + // Get total violations let total: i64 = conn - .query_row("SELECT COUNT(*) FROM sandbox_violations", [], |row| row.get(0)) + .query_row("SELECT COUNT(*) FROM sandbox_violations", [], |row| { + row.get(0) + }) .map_err(|e| e.to_string())?; - + // Get violations by operation type let mut stmt = conn .prepare( "SELECT operation_type, COUNT(*) as count FROM sandbox_violations GROUP BY operation_type - ORDER BY count DESC" + ORDER BY count DESC", ) .map_err(|e| e.to_string())?; - + let by_operation: Vec<(String, i64)> = stmt .query_map([], |row| Ok((row.get(0)?, row.get(1)?))) .map_err(|e| e.to_string())? .collect::, _>>() .map_err(|e| e.to_string())?; - + // Get recent violations count (last 24 hours) let recent: i64 = conn .query_row( @@ -765,7 +787,7 @@ pub async fn get_sandbox_violation_stats( |row| row.get(0), ) .map_err(|e| e.to_string())?; - + Ok(serde_json::json!({ "total": total, "recent_24h": recent, @@ -789,10 +811,10 @@ pub async fn export_sandbox_profile( let conn = db.0.lock().map_err(|e| e.to_string())?; crate::sandbox::profile::load_profile(&conn, profile_id).map_err(|e| e.to_string())? }; - + // Get the rules let rules = list_sandbox_rules(db.clone(), profile_id).await?; - + Ok(SandboxProfileExport { version: 1, exported_at: chrono::Utc::now().to_rfc3339(), @@ -808,17 +830,14 @@ pub async fn export_all_sandbox_profiles( ) -> Result { let profiles = list_sandbox_profiles(db.clone()).await?; let mut profile_exports = Vec::new(); - + for profile in profiles { if let Some(id) = profile.id { let rules = list_sandbox_rules(db.clone(), id).await?; - profile_exports.push(SandboxProfileWithRules { - profile, - rules, - }); + profile_exports.push(SandboxProfileWithRules { profile, rules }); } } - + Ok(SandboxProfileExport { version: 1, exported_at: chrono::Utc::now().to_rfc3339(), @@ -834,16 +853,19 @@ pub async fn import_sandbox_profiles( export_data: SandboxProfileExport, ) -> Result, String> { let mut results = Vec::new(); - + // Validate version if export_data.version != 1 { - return Err(format!("Unsupported export version: {}", export_data.version)); + return Err(format!( + "Unsupported export version: {}", + export_data.version + )); } - + for profile_export in export_data.profiles { let mut profile = profile_export.profile; let original_name = profile.name.clone(); - + // Check for name conflicts let existing: Result = { let conn = db.0.lock().map_err(|e| e.to_string())?; @@ -853,29 +875,31 @@ pub async fn import_sandbox_profiles( |row| row.get(0), ) }; - + let (imported, new_name) = match existing { Ok(_) => { // Name conflict - append timestamp - let new_name = format!("{} (imported {})", profile.name, chrono::Utc::now().format("%Y-%m-%d %H:%M")); + let new_name = format!( + "{} (imported {})", + profile.name, + chrono::Utc::now().format("%Y-%m-%d %H:%M") + ); profile.name = new_name.clone(); (true, Some(new_name)) } Err(_) => (true, None), }; - + if imported { // Reset profile fields for new insert profile.id = None; profile.is_default = false; // Never import as default - + // Create the profile - let created_profile = create_sandbox_profile( - db.clone(), - profile.name.clone(), - profile.description, - ).await?; - + let created_profile = + create_sandbox_profile(db.clone(), profile.name.clone(), profile.description) + .await?; + if let Some(new_id) = created_profile.id { // Import rules for rule in profile_export.rules { @@ -889,10 +913,11 @@ pub async fn import_sandbox_profiles( rule.pattern_value, rule.enabled, rule.platform_support, - ).await; + ) + .await; } } - + // Update profile status if needed if profile.is_active { let _ = update_sandbox_profile( @@ -902,18 +927,21 @@ pub async fn import_sandbox_profiles( created_profile.description, profile.is_active, false, // Never set as default on import - ).await; + ) + .await; } } - + results.push(ImportResult { profile_name: original_name, imported: true, - reason: new_name.as_ref().map(|_| "Name conflict resolved".to_string()), + reason: new_name + .as_ref() + .map(|_| "Name conflict resolved".to_string()), new_name, }); } } - + Ok(results) -} \ No newline at end of file +} diff --git a/src-tauri/src/commands/screenshot.rs b/src-tauri/src/commands/screenshot.rs index 3fdb5c6..8fb6c53 100644 --- a/src-tauri/src/commands/screenshot.rs +++ b/src-tauri/src/commands/screenshot.rs @@ -1,20 +1,20 @@ -use headless_chrome::{Browser, LaunchOptions}; use headless_chrome::protocol::cdp::Page; +use headless_chrome::{Browser, LaunchOptions}; use std::fs; use std::time::Duration; use tauri::AppHandle; /// Captures a screenshot of a URL using headless Chrome -/// +/// /// This function launches a headless Chrome browser, navigates to the specified URL, /// and captures a screenshot of either the entire page or a specific element. -/// +/// /// # Arguments /// * `app` - The Tauri application handle /// * `url` - The URL to capture /// * `selector` - Optional CSS selector for a specific element to capture /// * `full_page` - Whether to capture the entire page or just the viewport -/// +/// /// # Returns /// * `Result` - The path to the saved screenshot file, or an error message #[tauri::command] @@ -32,11 +32,10 @@ pub async fn capture_url_screenshot( ); // Run the browser operations in a blocking task since headless_chrome is not async - let result = tokio::task::spawn_blocking(move || { - capture_screenshot_sync(url, selector, full_page) - }) - .await - .map_err(|e| format!("Failed to spawn blocking task: {}", e))?; + let result = + tokio::task::spawn_blocking(move || capture_screenshot_sync(url, selector, full_page)) + .await + .map_err(|e| format!("Failed to spawn blocking task: {}", e))?; // Log the result of the headless Chrome capture before returning match &result { @@ -61,8 +60,8 @@ fn capture_screenshot_sync( }; // Launch the browser - let browser = Browser::new(launch_options) - .map_err(|e| format!("Failed to launch browser: {}", e))?; + let browser = + Browser::new(launch_options).map_err(|e| format!("Failed to launch browser: {}", e))?; // Create a new tab let tab = browser @@ -86,14 +85,17 @@ fn capture_screenshot_sync( // Wait explicitly for the element to exist – this often prevents // "Unable to capture screenshot" CDP errors on some pages if let Err(e) = tab.wait_for_element("body") { - log::warn!("Timed out waiting for element: {} – continuing anyway", e); + log::warn!( + "Timed out waiting for element: {} – continuing anyway", + e + ); } // Capture the screenshot let screenshot_data = if let Some(selector) = selector { // Wait for the element and capture it log::info!("Waiting for element with selector: {}", selector); - + let element = tab .wait_for_element(&selector) .map_err(|e| format!("Failed to find element '{}': {}", selector, e))?; @@ -103,8 +105,11 @@ fn capture_screenshot_sync( .map_err(|e| format!("Failed to capture element screenshot: {}", e))? } else { // Capture the entire page or viewport - log::info!("Capturing {} screenshot", if full_page { "full page" } else { "viewport" }); - + log::info!( + "Capturing {} screenshot", + if full_page { "full page" } else { "viewport" } + ); + // Get the page dimensions for full page screenshot let clip = if full_page { // Execute JavaScript to get the full page dimensions @@ -132,30 +137,30 @@ fn capture_screenshot_sync( ) .map_err(|e| format!("Failed to get page dimensions: {}", e))?; - // Extract dimensions from the result - let width = dimensions - .value - .as_ref() - .and_then(|v| v.as_object()) - .and_then(|obj| obj.get("width")) - .and_then(|v| v.as_f64()) - .unwrap_or(1920.0); + // Extract dimensions from the result + let width = dimensions + .value + .as_ref() + .and_then(|v| v.as_object()) + .and_then(|obj| obj.get("width")) + .and_then(|v| v.as_f64()) + .unwrap_or(1920.0); - let height = dimensions - .value - .as_ref() - .and_then(|v| v.as_object()) - .and_then(|obj| obj.get("height")) - .and_then(|v| v.as_f64()) - .unwrap_or(1080.0); + let height = dimensions + .value + .as_ref() + .and_then(|v| v.as_object()) + .and_then(|obj| obj.get("height")) + .and_then(|v| v.as_f64()) + .unwrap_or(1080.0); - Some(Page::Viewport { - x: 0.0, - y: 0.0, - width, - height, - scale: 1.0, - }) + Some(Page::Viewport { + x: 0.0, + y: 0.0, + width, + height, + scale: 1.0, + }) } else { None }; @@ -176,13 +181,8 @@ fn capture_screenshot_sync( err ); - tab.capture_screenshot( - Page::CaptureScreenshotFormatOption::Png, - None, - clip, - true, - ) - .map_err(|e| format!("Failed to capture screenshot after retry: {}", e))? + tab.capture_screenshot(Page::CaptureScreenshotFormatOption::Png, None, clip, true) + .map_err(|e| format!("Failed to capture screenshot after retry: {}", e))? } } }; @@ -208,13 +208,13 @@ fn capture_screenshot_sync( } /// Cleans up old screenshot files from the temporary directory -/// +/// /// This function removes screenshot files older than the specified number of minutes /// to prevent accumulation of temporary files. -/// +/// /// # Arguments /// * `older_than_minutes` - Remove files older than this many minutes (default: 60) -/// +/// /// # Returns /// * `Result` - The number of files deleted, or an error message #[tauri::command] @@ -222,24 +222,29 @@ pub async fn cleanup_screenshot_temp_files( older_than_minutes: Option, ) -> Result { let minutes = older_than_minutes.unwrap_or(60); - log::info!("Cleaning up screenshot files older than {} minutes", minutes); - + log::info!( + "Cleaning up screenshot files older than {} minutes", + minutes + ); + let temp_dir = std::env::temp_dir(); let cutoff_time = chrono::Utc::now() - chrono::Duration::minutes(minutes as i64); let mut deleted_count = 0; - + // Read directory entries - let entries = fs::read_dir(&temp_dir) - .map_err(|e| format!("Failed to read temp directory: {}", e))?; - + let entries = + fs::read_dir(&temp_dir).map_err(|e| format!("Failed to read temp directory: {}", e))?; + for entry in entries { if let Ok(entry) = entry { let path = entry.path(); - + // Check if it's a claudia screenshot file if let Some(filename) = path.file_name() { if let Some(filename_str) = filename.to_str() { - if filename_str.starts_with("claudia_screenshot_") && filename_str.ends_with(".png") { + if filename_str.starts_with("claudia_screenshot_") + && filename_str.ends_with(".png") + { // Check file age if let Ok(metadata) = fs::metadata(&path) { if let Ok(modified) = metadata.modified() { @@ -258,7 +263,7 @@ pub async fn cleanup_screenshot_temp_files( } } } - + log::info!("Cleaned up {} old screenshot files", deleted_count); Ok(deleted_count) -} \ No newline at end of file +} diff --git a/src-tauri/src/commands/usage.rs b/src-tauri/src/commands/usage.rs index e40ae4d..b459c15 100644 --- a/src-tauri/src/commands/usage.rs +++ b/src-tauri/src/commands/usage.rs @@ -1,9 +1,9 @@ -use std::collections::{HashMap, HashSet}; -use std::fs; -use std::path::PathBuf; use chrono::{DateTime, Local, NaiveDate}; use serde::{Deserialize, Serialize}; use serde_json; +use std::collections::{HashMap, HashSet}; +use std::fs; +use std::path::PathBuf; use tauri::command; #[derive(Debug, Serialize, Deserialize, Clone)] @@ -108,11 +108,21 @@ fn calculate_cost(model: &str, usage: &UsageData) -> f64 { let cache_read_tokens = usage.cache_read_input_tokens.unwrap_or(0) as f64; // Calculate cost based on model - let (input_price, output_price, cache_write_price, cache_read_price) = + let (input_price, output_price, cache_write_price, cache_read_price) = if model.contains("opus-4") || model.contains("claude-opus-4") { - (OPUS_4_INPUT_PRICE, OPUS_4_OUTPUT_PRICE, OPUS_4_CACHE_WRITE_PRICE, OPUS_4_CACHE_READ_PRICE) + ( + OPUS_4_INPUT_PRICE, + OPUS_4_OUTPUT_PRICE, + OPUS_4_CACHE_WRITE_PRICE, + OPUS_4_CACHE_READ_PRICE, + ) } else if model.contains("sonnet-4") || model.contains("claude-sonnet-4") { - (SONNET_4_INPUT_PRICE, SONNET_4_OUTPUT_PRICE, SONNET_4_CACHE_WRITE_PRICE, SONNET_4_CACHE_READ_PRICE) + ( + SONNET_4_INPUT_PRICE, + SONNET_4_OUTPUT_PRICE, + SONNET_4_CACHE_WRITE_PRICE, + SONNET_4_CACHE_READ_PRICE, + ) } else { // Return 0 for unknown models to avoid incorrect cost estimations. (0.0, 0.0, 0.0, 0.0) @@ -134,10 +144,11 @@ fn parse_jsonl_file( ) -> Vec { let mut entries = Vec::new(); let mut actual_project_path: Option = None; - + if let Ok(content) = fs::read_to_string(path) { // Extract session ID from the file path - let session_id = path.parent() + let session_id = path + .parent() .and_then(|p| p.file_name()) .and_then(|n| n.to_str()) .unwrap_or("unknown") @@ -155,7 +166,7 @@ fn parse_jsonl_file( actual_project_path = Some(cwd.to_string()); } } - + // Try to parse as JsonlEntry for usage data if let Ok(entry) = serde_json::from_value::(json_value) { if let Some(message) = &entry.message { @@ -170,10 +181,11 @@ fn parse_jsonl_file( if let Some(usage) = &message.usage { // Skip entries without meaningful token usage - if usage.input_tokens.unwrap_or(0) == 0 && - usage.output_tokens.unwrap_or(0) == 0 && - usage.cache_creation_input_tokens.unwrap_or(0) == 0 && - usage.cache_read_input_tokens.unwrap_or(0) == 0 { + if usage.input_tokens.unwrap_or(0) == 0 + && usage.output_tokens.unwrap_or(0) == 0 + && usage.cache_creation_input_tokens.unwrap_or(0) == 0 + && usage.cache_read_input_tokens.unwrap_or(0) == 0 + { continue; } @@ -184,17 +196,23 @@ fn parse_jsonl_file( 0.0 } }); - + // Use actual project path if found, otherwise use encoded name - let project_path = actual_project_path.clone() + let project_path = actual_project_path + .clone() .unwrap_or_else(|| encoded_project_name.to_string()); - + entries.push(UsageEntry { timestamp: entry.timestamp, - model: message.model.clone().unwrap_or_else(|| "unknown".to_string()), + model: message + .model + .clone() + .unwrap_or_else(|| "unknown".to_string()), input_tokens: usage.input_tokens.unwrap_or(0), output_tokens: usage.output_tokens.unwrap_or(0), - cache_creation_tokens: usage.cache_creation_input_tokens.unwrap_or(0), + cache_creation_tokens: usage + .cache_creation_input_tokens + .unwrap_or(0), cache_read_tokens: usage.cache_read_input_tokens.unwrap_or(0), cost, session_id: entry.session_id.unwrap_or_else(|| session_id.clone()), @@ -263,10 +281,10 @@ fn get_all_usage_entries(claude_path: &PathBuf) -> Vec { let entries = parse_jsonl_file(&path, &project_name, &mut processed_hashes); all_entries.extend(entries); } - + // Sort by timestamp all_entries.sort_by(|a, b| a.timestamp.cmp(&b.timestamp)); - + all_entries } @@ -275,9 +293,9 @@ pub fn get_usage_stats(days: Option) -> Result { let claude_path = dirs::home_dir() .ok_or("Failed to get home directory")? .join(".claude"); - + let all_entries = get_all_usage_entries(&claude_path); - + if all_entries.is_empty() { return Ok(UsageStats { total_cost: 0.0, @@ -292,11 +310,12 @@ pub fn get_usage_stats(days: Option) -> Result { by_project: vec![], }); } - + // Filter by days if specified let filtered_entries = if let Some(days) = days { let cutoff = Local::now().naive_local().date() - chrono::Duration::days(days as i64); - all_entries.into_iter() + all_entries + .into_iter() .filter(|e| { if let Ok(dt) = DateTime::parse_from_rfc3339(&e.timestamp) { dt.naive_local().date() >= cutoff @@ -308,18 +327,18 @@ pub fn get_usage_stats(days: Option) -> Result { } else { all_entries }; - + // Calculate aggregated stats let mut total_cost = 0.0; let mut total_input_tokens = 0u64; let mut total_output_tokens = 0u64; let mut total_cache_creation_tokens = 0u64; let mut total_cache_read_tokens = 0u64; - + let mut model_stats: HashMap = HashMap::new(); let mut daily_stats: HashMap = HashMap::new(); let mut project_stats: HashMap = HashMap::new(); - + for entry in &filtered_entries { // Update totals total_cost += entry.cost; @@ -327,18 +346,20 @@ pub fn get_usage_stats(days: Option) -> Result { total_output_tokens += entry.output_tokens; total_cache_creation_tokens += entry.cache_creation_tokens; total_cache_read_tokens += entry.cache_read_tokens; - + // Update model stats - let model_stat = model_stats.entry(entry.model.clone()).or_insert(ModelUsage { - model: entry.model.clone(), - total_cost: 0.0, - total_tokens: 0, - input_tokens: 0, - output_tokens: 0, - cache_creation_tokens: 0, - cache_read_tokens: 0, - session_count: 0, - }); + let model_stat = model_stats + .entry(entry.model.clone()) + .or_insert(ModelUsage { + model: entry.model.clone(), + total_cost: 0.0, + total_tokens: 0, + input_tokens: 0, + output_tokens: 0, + cache_creation_tokens: 0, + cache_read_tokens: 0, + session_count: 0, + }); model_stat.total_cost += entry.cost; model_stat.input_tokens += entry.input_tokens; model_stat.output_tokens += entry.output_tokens; @@ -346,9 +367,14 @@ pub fn get_usage_stats(days: Option) -> Result { model_stat.cache_read_tokens += entry.cache_read_tokens; model_stat.total_tokens = model_stat.input_tokens + model_stat.output_tokens; model_stat.session_count += 1; - + // Update daily stats - let date = entry.timestamp.split('T').next().unwrap_or(&entry.timestamp).to_string(); + let date = entry + .timestamp + .split('T') + .next() + .unwrap_or(&entry.timestamp) + .to_string(); let daily_stat = daily_stats.entry(date.clone()).or_insert(DailyUsage { date, total_cost: 0.0, @@ -356,43 +382,58 @@ pub fn get_usage_stats(days: Option) -> Result { models_used: vec![], }); daily_stat.total_cost += entry.cost; - daily_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens; + daily_stat.total_tokens += entry.input_tokens + + entry.output_tokens + + entry.cache_creation_tokens + + entry.cache_read_tokens; if !daily_stat.models_used.contains(&entry.model) { daily_stat.models_used.push(entry.model.clone()); } - + // Update project stats - let project_stat = project_stats.entry(entry.project_path.clone()).or_insert(ProjectUsage { - project_path: entry.project_path.clone(), - project_name: entry.project_path.split('/').last() - .unwrap_or(&entry.project_path) - .to_string(), - total_cost: 0.0, - total_tokens: 0, - session_count: 0, - last_used: entry.timestamp.clone(), - }); + let project_stat = + project_stats + .entry(entry.project_path.clone()) + .or_insert(ProjectUsage { + project_path: entry.project_path.clone(), + project_name: entry + .project_path + .split('/') + .last() + .unwrap_or(&entry.project_path) + .to_string(), + total_cost: 0.0, + total_tokens: 0, + session_count: 0, + last_used: entry.timestamp.clone(), + }); project_stat.total_cost += entry.cost; - project_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens; + project_stat.total_tokens += entry.input_tokens + + entry.output_tokens + + entry.cache_creation_tokens + + entry.cache_read_tokens; project_stat.session_count += 1; if entry.timestamp > project_stat.last_used { project_stat.last_used = entry.timestamp.clone(); } } - - let total_tokens = total_input_tokens + total_output_tokens + total_cache_creation_tokens + total_cache_read_tokens; + + let total_tokens = total_input_tokens + + total_output_tokens + + total_cache_creation_tokens + + total_cache_read_tokens; let total_sessions = filtered_entries.len() as u64; - + // Convert hashmaps to sorted vectors let mut by_model: Vec = model_stats.into_values().collect(); by_model.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap()); - + let mut by_date: Vec = daily_stats.into_values().collect(); by_date.sort_by(|a, b| b.date.cmp(&a.date)); - + let mut by_project: Vec = project_stats.into_values().collect(); by_project.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap()); - + Ok(UsageStats { total_cost, total_tokens, @@ -412,27 +453,26 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result = all_entries.into_iter() + let filtered_entries: Vec<_> = all_entries + .into_iter() .filter(|e| { if let Ok(dt) = DateTime::parse_from_rfc3339(&e.timestamp) { let date = dt.naive_local().date(); @@ -442,7 +482,7 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result Result = HashMap::new(); let mut daily_stats: HashMap = HashMap::new(); let mut project_stats: HashMap = HashMap::new(); - + for entry in &filtered_entries { // Update totals total_cost += entry.cost; @@ -476,18 +516,20 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result Result Result project_stat.last_used { project_stat.last_used = entry.timestamp.clone(); } } - - let total_tokens = total_input_tokens + total_output_tokens + total_cache_creation_tokens + total_cache_read_tokens; + + let total_tokens = total_input_tokens + + total_output_tokens + + total_cache_creation_tokens + + total_cache_read_tokens; let total_sessions = filtered_entries.len() as u64; - + // Convert hashmaps to sorted vectors let mut by_model: Vec = model_stats.into_values().collect(); by_model.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap()); - + let mut by_date: Vec = daily_stats.into_values().collect(); by_date.sort_by(|a, b| b.date.cmp(&a.date)); - + let mut by_project: Vec = project_stats.into_values().collect(); by_project.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap()); - + Ok(UsageStats { total_cost, total_tokens, @@ -557,23 +619,26 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result, date: Option) -> Result, String> { +pub fn get_usage_details( + project_path: Option, + date: Option, +) -> Result, String> { let claude_path = dirs::home_dir() .ok_or("Failed to get home directory")? .join(".claude"); - + let mut all_entries = get_all_usage_entries(&claude_path); - + // Filter by project if specified if let Some(project) = project_path { all_entries.retain(|e| e.project_path == project); } - + // Filter by date if specified if let Some(date) = date { all_entries.retain(|e| e.timestamp.starts_with(&date)); } - + Ok(all_entries) } @@ -586,7 +651,7 @@ pub fn get_session_stats( let claude_path = dirs::home_dir() .ok_or("Failed to get home directory")? .join(".claude"); - + let all_entries = get_all_usage_entries(&claude_path); let since_date = since.and_then(|s| NaiveDate::parse_from_str(&s, "%Y%m%d").ok()); @@ -609,14 +674,16 @@ pub fn get_session_stats( let mut session_stats: HashMap = HashMap::new(); for entry in &filtered_entries { let session_key = format!("{}/{}", entry.project_path, entry.session_id); - let project_stat = session_stats.entry(session_key).or_insert_with(|| ProjectUsage { - project_path: entry.project_path.clone(), - project_name: entry.session_id.clone(), // Using session_id as project_name for session view - total_cost: 0.0, - total_tokens: 0, - session_count: 0, // In this context, this will count entries per session - last_used: " ".to_string(), - }); + let project_stat = session_stats + .entry(session_key) + .or_insert_with(|| ProjectUsage { + project_path: entry.project_path.clone(), + project_name: entry.session_id.clone(), // Using session_id as project_name for session view + total_cost: 0.0, + total_tokens: 0, + session_count: 0, // In this context, this will count entries per session + last_used: " ".to_string(), + }); project_stat.total_cost += entry.cost; project_stat.total_tokens += entry.input_tokens @@ -643,6 +710,5 @@ pub fn get_session_stats( by_session.sort_by(|a, b| b.last_used.cmp(&a.last_used)); } - Ok(by_session) -} \ No newline at end of file +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 047dfbd..cf1f90a 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -1,11 +1,11 @@ // Learn more about Tauri commands at https://tauri.app/develop/calling-rust/ // Declare modules -pub mod commands; -pub mod sandbox; pub mod checkpoint; -pub mod process; pub mod claude_binary; +pub mod commands; +pub mod process; +pub mod sandbox; #[cfg_attr(mobile, tauri::mobile_entry_point)] pub fn run() { diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index a5dad57..af26492 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -1,57 +1,52 @@ // Prevents additional console window on Windows in release, DO NOT REMOVE!! #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] -mod commands; -mod sandbox; mod checkpoint; -mod process; mod claude_binary; +mod commands; +mod process; +mod sandbox; -use tauri::Manager; -use commands::claude::{ - get_claude_settings, get_project_sessions, get_system_prompt, list_projects, open_new_session, - check_claude_version, save_system_prompt, save_claude_settings, - find_claude_md_files, read_claude_md_file, save_claude_md_file, - load_session_history, execute_claude_code, continue_claude_code, resume_claude_code, - list_directory_contents, search_files, - create_checkpoint, restore_checkpoint, list_checkpoints, fork_from_checkpoint, - get_session_timeline, update_checkpoint_settings, get_checkpoint_diff, - track_checkpoint_message, track_session_messages, check_auto_checkpoint, cleanup_old_checkpoints, - get_checkpoint_settings, clear_checkpoint_manager, get_checkpoint_state_stats, - get_recently_modified_files, cancel_claude_execution, ClaudeProcessState, -}; +use checkpoint::state::CheckpointState; use commands::agents::{ - init_database, list_agents, create_agent, update_agent, delete_agent, - get_agent, execute_agent, list_agent_runs, get_agent_run, - get_agent_run_with_real_time_metrics, list_agent_runs_with_metrics, - list_running_sessions, kill_agent_session, - get_session_status, cleanup_finished_processes, get_session_output, - get_live_session_output, stream_session_output, get_claude_binary_path, - set_claude_binary_path, export_agent, export_agent_to_file, import_agent, - import_agent_from_file, fetch_github_agents, fetch_github_agent_content, - import_agent_from_github, list_claude_installations, AgentDb + cleanup_finished_processes, create_agent, delete_agent, execute_agent, export_agent, + export_agent_to_file, fetch_github_agent_content, fetch_github_agents, get_agent, + get_agent_run, get_agent_run_with_real_time_metrics, get_claude_binary_path, + get_live_session_output, get_session_output, get_session_status, import_agent, + import_agent_from_file, import_agent_from_github, init_database, kill_agent_session, + list_agent_runs, list_agent_runs_with_metrics, list_agents, list_claude_installations, + list_running_sessions, set_claude_binary_path, stream_session_output, update_agent, AgentDb, }; -use commands::sandbox::{ - list_sandbox_profiles, create_sandbox_profile, update_sandbox_profile, delete_sandbox_profile, - get_sandbox_profile, list_sandbox_rules, create_sandbox_rule, update_sandbox_rule, - delete_sandbox_rule, get_platform_capabilities, test_sandbox_profile, - list_sandbox_violations, log_sandbox_violation, clear_sandbox_violations, get_sandbox_violation_stats, - export_sandbox_profile, export_all_sandbox_profiles, import_sandbox_profiles, -}; -use commands::screenshot::{ - capture_url_screenshot, cleanup_screenshot_temp_files, -}; -use commands::usage::{ - get_usage_stats, get_usage_by_date_range, get_usage_details, get_session_stats, +use commands::claude::{ + cancel_claude_execution, check_auto_checkpoint, check_claude_version, cleanup_old_checkpoints, + clear_checkpoint_manager, continue_claude_code, create_checkpoint, execute_claude_code, + find_claude_md_files, fork_from_checkpoint, get_checkpoint_diff, get_checkpoint_settings, + get_checkpoint_state_stats, get_claude_settings, get_project_sessions, + get_recently_modified_files, get_session_timeline, get_system_prompt, list_checkpoints, + list_directory_contents, list_projects, load_session_history, open_new_session, + read_claude_md_file, restore_checkpoint, resume_claude_code, save_claude_md_file, + save_claude_settings, save_system_prompt, search_files, track_checkpoint_message, + track_session_messages, update_checkpoint_settings, ClaudeProcessState, }; use commands::mcp::{ - mcp_add, mcp_list, mcp_get, mcp_remove, mcp_add_json, mcp_add_from_claude_desktop, - mcp_serve, mcp_test_connection, mcp_reset_project_choices, mcp_get_server_status, - mcp_read_project_config, mcp_save_project_config, + mcp_add, mcp_add_from_claude_desktop, mcp_add_json, mcp_get, mcp_get_server_status, mcp_list, + mcp_read_project_config, mcp_remove, mcp_reset_project_choices, mcp_save_project_config, + mcp_serve, mcp_test_connection, +}; +use commands::sandbox::{ + clear_sandbox_violations, create_sandbox_profile, create_sandbox_rule, delete_sandbox_profile, + delete_sandbox_rule, export_all_sandbox_profiles, export_sandbox_profile, + get_platform_capabilities, get_sandbox_profile, get_sandbox_violation_stats, + import_sandbox_profiles, list_sandbox_profiles, list_sandbox_rules, list_sandbox_violations, + log_sandbox_violation, test_sandbox_profile, update_sandbox_profile, update_sandbox_rule, +}; +use commands::screenshot::{capture_url_screenshot, cleanup_screenshot_temp_files}; +use commands::usage::{ + get_session_stats, get_usage_by_date_range, get_usage_details, get_usage_stats, }; -use std::sync::Mutex; -use checkpoint::state::CheckpointState; use process::ProcessRegistryState; +use std::sync::Mutex; +use tauri::Manager; fn main() { // Initialize logger @@ -72,32 +67,34 @@ fn main() { // Initialize agents database let conn = init_database(&app.handle()).expect("Failed to initialize agents database"); app.manage(AgentDb(Mutex::new(conn))); - + // Initialize checkpoint state let checkpoint_state = CheckpointState::new(); - + // Set the Claude directory path if let Ok(claude_dir) = dirs::home_dir() .ok_or_else(|| "Could not find home directory") .and_then(|home| { let claude_path = home.join(".claude"); - claude_path.canonicalize() + claude_path + .canonicalize() .map_err(|_| "Could not find ~/.claude directory") - }) { + }) + { let state_clone = checkpoint_state.clone(); tauri::async_runtime::spawn(async move { state_clone.set_claude_dir(claude_dir).await; }); } - + app.manage(checkpoint_state); - + // Initialize process registry app.manage(ProcessRegistryState::default()); - + // Initialize Claude process state app.manage(ClaudeProcessState::default()); - + Ok(()) }) .invoke_handler(tauri::generate_handler![ diff --git a/src-tauri/src/process/mod.rs b/src-tauri/src/process/mod.rs index 7f8af66..b1b2d8e 100644 --- a/src-tauri/src/process/mod.rs +++ b/src-tauri/src/process/mod.rs @@ -1,3 +1,3 @@ pub mod registry; -pub use registry::*; \ No newline at end of file +pub use registry::*; diff --git a/src-tauri/src/process/registry.rs b/src-tauri/src/process/registry.rs index 7021989..3fd58fb 100644 --- a/src-tauri/src/process/registry.rs +++ b/src-tauri/src/process/registry.rs @@ -1,8 +1,8 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; -use serde::{Deserialize, Serialize}; use tokio::process::Child; -use chrono::{DateTime, Utc}; /// Information about a running agent process #[derive(Debug, Clone, Serialize, Deserialize)] @@ -50,7 +50,7 @@ impl ProcessRegistry { child: Child, ) -> Result<(), String> { let mut processes = self.processes.lock().map_err(|e| e.to_string())?; - + let process_info = ProcessInfo { run_id, agent_id, @@ -84,7 +84,10 @@ impl ProcessRegistry { #[allow(dead_code)] pub fn get_running_processes(&self) -> Result, String> { let processes = self.processes.lock().map_err(|e| e.to_string())?; - Ok(processes.values().map(|handle| handle.info.clone()).collect()) + Ok(processes + .values() + .map(|handle| handle.info.clone()) + .collect()) } /// Get a specific running process @@ -96,8 +99,8 @@ impl ProcessRegistry { /// Kill a running process with proper cleanup pub async fn kill_process(&self, run_id: i64) -> Result { - use log::{info, warn, error}; - + use log::{error, info, warn}; + // First check if the process exists and get its PID let (pid, child_arc) = { let processes = self.processes.lock().map_err(|e| e.to_string())?; @@ -107,9 +110,12 @@ impl ProcessRegistry { return Ok(false); // Process not found } }; - - info!("Attempting graceful shutdown of process {} (PID: {})", run_id, pid); - + + info!( + "Attempting graceful shutdown of process {} (PID: {})", + run_id, pid + ); + // Send kill signal to the process let kill_sent = { let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?; @@ -128,52 +134,50 @@ impl ProcessRegistry { false // Process already killed } }; - + if !kill_sent { return Ok(false); } - + // Wait for the process to exit (with timeout) - let wait_result = tokio::time::timeout( - tokio::time::Duration::from_secs(5), - async { - loop { - // Check if process has exited - let status = { - let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?; - if let Some(child) = child_guard.as_mut() { - match child.try_wait() { - Ok(Some(status)) => { - info!("Process {} exited with status: {:?}", run_id, status); - *child_guard = None; // Clear the child handle - Some(Ok::<(), String>(())) - } - Ok(None) => { - // Still running - None - } - Err(e) => { - error!("Error checking process status: {}", e); - Some(Err(e.to_string())) - } + let wait_result = tokio::time::timeout(tokio::time::Duration::from_secs(5), async { + loop { + // Check if process has exited + let status = { + let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?; + if let Some(child) = child_guard.as_mut() { + match child.try_wait() { + Ok(Some(status)) => { + info!("Process {} exited with status: {:?}", run_id, status); + *child_guard = None; // Clear the child handle + Some(Ok::<(), String>(())) + } + Ok(None) => { + // Still running + None + } + Err(e) => { + error!("Error checking process status: {}", e); + Some(Err(e.to_string())) } - } else { - // Process already gone - Some(Ok(())) - } - }; - - match status { - Some(result) => return result, - None => { - // Still running, wait a bit - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; } + } else { + // Process already gone + Some(Ok(())) + } + }; + + match status { + Some(result) => return result, + None => { + // Still running, wait a bit + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; } } } - ).await; - + }) + .await; + match wait_result { Ok(Ok(_)) => { info!("Process {} exited gracefully", run_id); @@ -189,19 +193,19 @@ impl ProcessRegistry { } } } - + // Remove from registry after killing self.unregister_process(run_id)?; - + Ok(true) } /// Kill a process by PID using system commands (fallback method) pub fn kill_process_by_pid(&self, run_id: i64, pid: u32) -> Result { - use log::{info, warn, error}; - + use log::{error, info, warn}; + info!("Attempting to kill process {} by PID {}", run_id, pid); - + let kill_result = if cfg!(target_os = "windows") { std::process::Command::new("taskkill") .args(["/F", "/PID", &pid.to_string()]) @@ -211,22 +215,25 @@ impl ProcessRegistry { let term_result = std::process::Command::new("kill") .args(["-TERM", &pid.to_string()]) .output(); - + match &term_result { Ok(output) if output.status.success() => { info!("Sent SIGTERM to PID {}", pid); // Give it 2 seconds to exit gracefully std::thread::sleep(std::time::Duration::from_secs(2)); - + // Check if still running let check_result = std::process::Command::new("kill") .args(["-0", &pid.to_string()]) .output(); - + if let Ok(output) = check_result { if output.status.success() { // Still running, send SIGKILL - warn!("Process {} still running after SIGTERM, sending SIGKILL", pid); + warn!( + "Process {} still running after SIGTERM, sending SIGKILL", + pid + ); std::process::Command::new("kill") .args(["-KILL", &pid.to_string()]) .output() @@ -246,7 +253,7 @@ impl ProcessRegistry { } } }; - + match kill_result { Ok(output) => { if output.status.success() { @@ -271,11 +278,11 @@ impl ProcessRegistry { #[allow(dead_code)] pub async fn is_process_running(&self, run_id: i64) -> Result { let processes = self.processes.lock().map_err(|e| e.to_string())?; - + if let Some(handle) = processes.get(&run_id) { let child_arc = handle.child.clone(); drop(processes); // Release the lock before async operation - + let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?; if let Some(ref mut child) = child_guard.as_mut() { match child.try_wait() { @@ -329,20 +336,20 @@ impl ProcessRegistry { pub async fn cleanup_finished_processes(&self) -> Result, String> { let mut finished_runs = Vec::new(); let processes_lock = self.processes.clone(); - + // First, identify finished processes { let processes = processes_lock.lock().map_err(|e| e.to_string())?; let run_ids: Vec = processes.keys().cloned().collect(); drop(processes); - + for run_id in run_ids { if !self.is_process_running(run_id).await? { finished_runs.push(run_id); } } } - + // Then remove them from the registry { let mut processes = processes_lock.lock().map_err(|e| e.to_string())?; @@ -350,7 +357,7 @@ impl ProcessRegistry { processes.remove(run_id); } } - + Ok(finished_runs) } } @@ -368,4 +375,4 @@ impl Default for ProcessRegistryState { fn default() -> Self { Self(Arc::new(ProcessRegistry::new())) } -} \ No newline at end of file +} diff --git a/src-tauri/src/sandbox/defaults.rs b/src-tauri/src/sandbox/defaults.rs index 7285ac1..e9f1828 100644 --- a/src-tauri/src/sandbox/defaults.rs +++ b/src-tauri/src/sandbox/defaults.rs @@ -4,26 +4,24 @@ use rusqlite::{params, Connection, Result}; /// Create default sandbox profiles for initial setup pub fn create_default_profiles(conn: &Connection) -> Result<()> { // Check if we already have profiles - let count: i64 = conn.query_row( - "SELECT COUNT(*) FROM sandbox_profiles", - [], - |row| row.get(0), - )?; - + let count: i64 = conn.query_row("SELECT COUNT(*) FROM sandbox_profiles", [], |row| { + row.get(0) + })?; + if count > 0 { // Already have profiles, don't create defaults return Ok(()); } - + // Create Standard Profile create_standard_profile(conn)?; - - // Create Minimal Profile + + // Create Minimal Profile create_minimal_profile(conn)?; - + // Create Development Profile create_development_profile(conn)?; - + Ok(()) } @@ -38,22 +36,57 @@ fn create_standard_profile(conn: &Connection) -> Result<()> { true // Set as default ], )?; - + let profile_id = conn.last_insert_rowid(); - + // Add rules let rules = vec![ // File access - ("file_read_all", "subpath", "{{PROJECT_PATH}}", true, Some(r#"["linux", "macos"]"#)), - ("file_read_all", "subpath", "/usr/lib", true, Some(r#"["linux", "macos"]"#)), - ("file_read_all", "subpath", "/usr/local/lib", true, Some(r#"["linux", "macos"]"#)), - ("file_read_all", "subpath", "/System/Library", true, Some(r#"["macos"]"#)), - ("file_read_metadata", "subpath", "/", true, Some(r#"["macos"]"#)), - + ( + "file_read_all", + "subpath", + "{{PROJECT_PATH}}", + true, + Some(r#"["linux", "macos"]"#), + ), + ( + "file_read_all", + "subpath", + "/usr/lib", + true, + Some(r#"["linux", "macos"]"#), + ), + ( + "file_read_all", + "subpath", + "/usr/local/lib", + true, + Some(r#"["linux", "macos"]"#), + ), + ( + "file_read_all", + "subpath", + "/System/Library", + true, + Some(r#"["macos"]"#), + ), + ( + "file_read_metadata", + "subpath", + "/", + true, + Some(r#"["macos"]"#), + ), // Network access - ("network_outbound", "all", "", true, Some(r#"["linux", "macos"]"#)), + ( + "network_outbound", + "all", + "", + true, + Some(r#"["linux", "macos"]"#), + ), ]; - + for (op_type, pattern_type, pattern_value, enabled, platforms) in rules { conn.execute( "INSERT INTO sandbox_rules (profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support) @@ -61,7 +94,7 @@ fn create_standard_profile(conn: &Connection) -> Result<()> { params![profile_id, op_type, pattern_type, pattern_value, enabled, platforms], )?; } - + Ok(()) } @@ -76,9 +109,9 @@ fn create_minimal_profile(conn: &Connection) -> Result<()> { false ], )?; - + let profile_id = conn.last_insert_rowid(); - + // Add minimal rules - only project access conn.execute( "INSERT INTO sandbox_rules (profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support) @@ -92,7 +125,7 @@ fn create_minimal_profile(conn: &Connection) -> Result<()> { Some(r#"["linux", "macos", "windows"]"#) ], )?; - + Ok(()) } @@ -107,26 +140,66 @@ fn create_development_profile(conn: &Connection) -> Result<()> { false ], )?; - + let profile_id = conn.last_insert_rowid(); - + // Add development rules let rules = vec![ // Broad file access - ("file_read_all", "subpath", "{{PROJECT_PATH}}", true, Some(r#"["linux", "macos"]"#)), - ("file_read_all", "subpath", "{{HOME}}", true, Some(r#"["linux", "macos"]"#)), - ("file_read_all", "subpath", "/usr", true, Some(r#"["linux", "macos"]"#)), - ("file_read_all", "subpath", "/opt", true, Some(r#"["linux", "macos"]"#)), - ("file_read_all", "subpath", "/Applications", true, Some(r#"["macos"]"#)), - ("file_read_metadata", "subpath", "/", true, Some(r#"["macos"]"#)), - + ( + "file_read_all", + "subpath", + "{{PROJECT_PATH}}", + true, + Some(r#"["linux", "macos"]"#), + ), + ( + "file_read_all", + "subpath", + "{{HOME}}", + true, + Some(r#"["linux", "macos"]"#), + ), + ( + "file_read_all", + "subpath", + "/usr", + true, + Some(r#"["linux", "macos"]"#), + ), + ( + "file_read_all", + "subpath", + "/opt", + true, + Some(r#"["linux", "macos"]"#), + ), + ( + "file_read_all", + "subpath", + "/Applications", + true, + Some(r#"["macos"]"#), + ), + ( + "file_read_metadata", + "subpath", + "/", + true, + Some(r#"["macos"]"#), + ), // Network access - ("network_outbound", "all", "", true, Some(r#"["linux", "macos"]"#)), - + ( + "network_outbound", + "all", + "", + true, + Some(r#"["linux", "macos"]"#), + ), // System info (macOS only) ("system_info_read", "all", "", true, Some(r#"["macos"]"#)), ]; - + for (op_type, pattern_type, pattern_value, enabled, platforms) in rules { conn.execute( "INSERT INTO sandbox_rules (profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support) @@ -134,6 +207,6 @@ fn create_development_profile(conn: &Connection) -> Result<()> { params![profile_id, op_type, pattern_type, pattern_value, enabled, platforms], )?; } - + Ok(()) -} \ No newline at end of file +} diff --git a/src-tauri/src/sandbox/executor.rs b/src-tauri/src/sandbox/executor.rs index 19aec89..3ee8df6 100644 --- a/src-tauri/src/sandbox/executor.rs +++ b/src-tauri/src/sandbox/executor.rs @@ -1,7 +1,9 @@ use anyhow::{Context, Result}; #[cfg(unix)] -use gaol::sandbox::{ChildSandbox, ChildSandboxMethods, Command as GaolCommand, Sandbox, SandboxMethods}; -use log::{info, warn, error, debug}; +use gaol::sandbox::{ + ChildSandbox, ChildSandboxMethods, Command as GaolCommand, Sandbox, SandboxMethods, +}; +use log::{debug, error, info, warn}; use std::env; use std::path::{Path, PathBuf}; use std::process::Stdio; @@ -25,12 +27,12 @@ impl SandboxExecutor { serialized_profile: None, } } - + /// Create a new sandbox executor with serialized profile for child process communication pub fn new_with_serialization( - profile: gaol::profile::Profile, + profile: gaol::profile::Profile, project_path: PathBuf, - serialized_profile: SerializedProfile + serialized_profile: SerializedProfile, ) -> Self { Self { profile, @@ -41,15 +43,23 @@ impl SandboxExecutor { /// Execute a command in the sandbox (for the parent process) /// This is used when we need to spawn a child process with sandbox - pub fn execute_sandboxed_spawn(&self, command: &str, args: &[&str], cwd: &Path) -> Result { + pub fn execute_sandboxed_spawn( + &self, + command: &str, + args: &[&str], + cwd: &Path, + ) -> Result { info!("Executing sandboxed command: {} {:?}", command, args); - + // On macOS, we need to check if the command is allowed by the system #[cfg(target_os = "macos")] { // For testing purposes, we'll skip actual sandboxing for simple commands like echo if command == "echo" || command == "/bin/echo" { - debug!("Using direct execution for simple test command: {}", command); + debug!( + "Using direct execution for simple test command: {}", + command + ); return std::process::Command::new(command) .args(args) .current_dir(cwd) @@ -60,44 +70,55 @@ impl SandboxExecutor { .context("Failed to spawn test command"); } } - + // Create the sandbox let sandbox = Sandbox::new(self.profile.clone()); - + // Create the command let mut gaol_command = GaolCommand::new(command); for arg in args { gaol_command.arg(arg); } - + // Set environment variables gaol_command.env("GAOL_CHILD_PROCESS", "1"); gaol_command.env("GAOL_SANDBOX_ACTIVE", "1"); - gaol_command.env("GAOL_PROJECT_PATH", self.project_path.to_string_lossy().as_ref()); - + gaol_command.env( + "GAOL_PROJECT_PATH", + self.project_path.to_string_lossy().as_ref(), + ); + // Inherit specific parent environment variables that are safe for (key, value) in env::vars() { // Only pass through safe environment variables - if key.starts_with("PATH") || key.starts_with("HOME") || key.starts_with("USER") - || key == "SHELL" || key == "LANG" || key == "LC_ALL" || key.starts_with("LC_") { + if key.starts_with("PATH") + || key.starts_with("HOME") + || key.starts_with("USER") + || key == "SHELL" + || key == "LANG" + || key == "LC_ALL" + || key.starts_with("LC_") + { gaol_command.env(&key, &value); } } - + // Try to start the sandboxed process using gaol match sandbox.start(&mut gaol_command) { Ok(process) => { debug!("Successfully started sandboxed process using gaol"); // Unfortunately, gaol doesn't expose the underlying Child process // So we need to use a different approach for now - + // This is a limitation of the gaol library - we can't get the Child back // For now, we'll have to use the fallback approach - warn!("Gaol started the process but we can't get the Child handle - using fallback"); - + warn!( + "Gaol started the process but we can't get the Child handle - using fallback" + ); + // Drop the process to avoid zombie drop(process); - + // Fall through to fallback } Err(e) => { @@ -105,10 +126,10 @@ impl SandboxExecutor { debug!("Gaol error details: {:?}", e); } } - + // Fallback: Use regular process spawn with sandbox activation in child info!("Using child-side sandbox activation as fallback"); - + // Serialize the sandbox rules for the child process let rules_json = if let Some(ref serialized) = self.serialized_profile { serde_json::to_string(serialized)? @@ -116,50 +137,70 @@ impl SandboxExecutor { let serialized_rules = self.extract_sandbox_rules()?; serde_json::to_string(&serialized_rules)? }; - + let mut std_command = std::process::Command::new(command); - std_command.args(args) + std_command + .args(args) .current_dir(cwd) .env("GAOL_SANDBOX_ACTIVE", "1") - .env("GAOL_PROJECT_PATH", self.project_path.to_string_lossy().as_ref()) + .env( + "GAOL_PROJECT_PATH", + self.project_path.to_string_lossy().as_ref(), + ) .env("GAOL_SANDBOX_RULES", rules_json) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()); - - std_command.spawn() + + std_command + .spawn() .context("Failed to spawn process with sandbox environment") } - + /// Prepare a tokio Command for sandboxed execution /// The sandbox will be activated in the child process pub fn prepare_sandboxed_command(&self, command: &str, args: &[&str], cwd: &Path) -> Command { info!("Preparing sandboxed command: {} {:?}", command, args); - + let mut cmd = Command::new(command); - cmd.args(args) - .current_dir(cwd); - + cmd.args(args).current_dir(cwd); + // Inherit essential environment variables from parent process // This is crucial for commands like Claude that need to find Node.js for (key, value) in env::vars() { // Pass through PATH and other essential environment variables - if key == "PATH" || key == "HOME" || key == "USER" - || key == "SHELL" || key == "LANG" || key == "LC_ALL" || key.starts_with("LC_") - || key == "NODE_PATH" || key == "NVM_DIR" || key == "NVM_BIN" { + if key == "PATH" + || key == "HOME" + || key == "USER" + || key == "SHELL" + || key == "LANG" + || key == "LC_ALL" + || key.starts_with("LC_") + || key == "NODE_PATH" + || key == "NVM_DIR" + || key == "NVM_BIN" + { debug!("Inheriting env var: {}={}", key, value); cmd.env(&key, &value); } } - + // Serialize the sandbox rules for the child process let rules_json = if let Some(ref serialized) = self.serialized_profile { let json = serde_json::to_string(serialized).ok(); - info!("🔧 Using serialized sandbox profile with {} operations", serialized.operations.len()); + info!( + "🔧 Using serialized sandbox profile with {} operations", + serialized.operations.len() + ); for (i, op) in serialized.operations.iter().enumerate() { match op { SerializedOperation::FileReadAll { path, is_subpath } => { - info!(" Rule {}: FileReadAll {} (subpath: {})", i, path.display(), is_subpath); + info!( + " Rule {}: FileReadAll {} (subpath: {})", + i, + path.display(), + is_subpath + ); } SerializedOperation::NetworkOutbound { pattern } => { info!(" Rule {}: NetworkOutbound {}", i, pattern); @@ -179,7 +220,7 @@ impl SandboxExecutor { .ok() .and_then(|r| serde_json::to_string(&r).ok()) }; - + if let Some(json) = rules_json { // TEMPORARILY DISABLED: Claude Code might not understand these env vars and could hang // cmd.env("GAOL_SANDBOX_ACTIVE", "1"); @@ -188,19 +229,22 @@ impl SandboxExecutor { warn!("🚨 TEMPORARILY DISABLED sandbox environment variables for debugging"); info!("🔧 Would have set sandbox environment variables for child process"); info!(" GAOL_SANDBOX_ACTIVE=1 (disabled)"); - info!(" GAOL_PROJECT_PATH={} (disabled)", self.project_path.display()); + info!( + " GAOL_PROJECT_PATH={} (disabled)", + self.project_path.display() + ); info!(" GAOL_SANDBOX_RULES={} chars (disabled)", json.len()); } else { warn!("🚨 Failed to serialize sandbox rules - running without sandbox!"); } - - cmd.stdin(Stdio::null()) // Don't pipe stdin - we have no input to send + + cmd.stdin(Stdio::null()) // Don't pipe stdin - we have no input to send .stdout(Stdio::piped()) .stderr(Stdio::piped()); - + cmd } - + /// Extract sandbox rules from the profile /// This is a workaround since gaol doesn't expose the operations fn extract_sandbox_rules(&self) -> Result { @@ -208,18 +252,18 @@ impl SandboxExecutor { // For now, return a default set based on what we know // This should be improved by tracking rules during profile creation let operations = vec![ - SerializedOperation::FileReadAll { - path: self.project_path.clone(), - is_subpath: true + SerializedOperation::FileReadAll { + path: self.project_path.clone(), + is_subpath: true, }, - SerializedOperation::NetworkOutbound { - pattern: "all".to_string() + SerializedOperation::NetworkOutbound { + pattern: "all".to_string(), }, ]; - + Ok(SerializedProfile { operations }) } - + /// Activate sandbox in the current process (for child processes) /// This should be called early in the child process pub fn activate_sandbox_in_child() -> Result<()> { @@ -227,21 +271,23 @@ impl SandboxExecutor { if !should_activate_sandbox() { return Ok(()); } - + info!("Activating sandbox in child process"); - + // Get project path - let project_path = env::var("GAOL_PROJECT_PATH") - .context("GAOL_PROJECT_PATH not set")?; + let project_path = env::var("GAOL_PROJECT_PATH").context("GAOL_PROJECT_PATH not set")?; let project_path = PathBuf::from(project_path); - + // Try to deserialize the sandbox rules from environment let profile = if let Ok(rules_json) = env::var("GAOL_SANDBOX_RULES") { match serde_json::from_str::(&rules_json) { Ok(serialized) => { - debug!("Deserializing {} sandbox rules", serialized.operations.len()); + debug!( + "Deserializing {} sandbox rules", + serialized.operations.len() + ); deserialize_profile(serialized, &project_path)? - }, + } Err(e) => { warn!("Failed to deserialize sandbox rules: {}", e); // Fallback to minimal profile @@ -253,10 +299,10 @@ impl SandboxExecutor { // Fallback to minimal profile create_minimal_profile(project_path)? }; - + // Create and activate the child sandbox let sandbox = ChildSandbox::new(profile); - + match sandbox.activate() { Ok(_) => { info!("Sandbox activated successfully"); @@ -280,12 +326,12 @@ impl SandboxExecutor { serialized_profile: None, } } - + /// Create a new sandbox executor with serialized profile (no-op on Windows) pub fn new_with_serialization( - _profile: (), + _profile: (), project_path: PathBuf, - serialized_profile: SerializedProfile + serialized_profile: SerializedProfile, ) -> Self { Self { project_path, @@ -294,9 +340,17 @@ impl SandboxExecutor { } /// Execute a command in the sandbox (Windows - no sandboxing) - pub fn execute_sandboxed_spawn(&self, command: &str, args: &[&str], cwd: &Path) -> Result { - info!("Executing command without sandbox on Windows: {} {:?}", command, args); - + pub fn execute_sandboxed_spawn( + &self, + command: &str, + args: &[&str], + cwd: &Path, + ) -> Result { + info!( + "Executing command without sandbox on Windows: {} {:?}", + command, args + ); + std::process::Command::new(command) .args(args) .current_dir(cwd) @@ -309,23 +363,26 @@ impl SandboxExecutor { /// Prepare a sandboxed tokio Command (Windows - no sandboxing) pub fn prepare_sandboxed_command(&self, command: &str, args: &[&str], cwd: &Path) -> Command { - info!("Preparing command without sandbox on Windows: {} {:?}", command, args); - + info!( + "Preparing command without sandbox on Windows: {} {:?}", + command, args + ); + let mut cmd = Command::new(command); cmd.args(args) .current_dir(cwd) .stdin(Stdio::null()) .stdout(Stdio::piped()) .stderr(Stdio::piped()); - + cmd } - + /// Extract sandbox rules (no-op on Windows) fn extract_sandbox_rules(&self) -> Result { Ok(SerializedProfile { operations: vec![] }) } - + /// Activate sandbox in child process (no-op on Windows) pub fn activate_sandbox_in_child() -> Result<()> { debug!("Sandbox activation skipped on Windows"); @@ -341,11 +398,11 @@ pub fn should_activate_sandbox() -> bool { /// Helper to create a sandboxed tokio Command #[cfg(unix)] pub fn create_sandboxed_command( - command: &str, - args: &[&str], + command: &str, + args: &[&str], cwd: &Path, profile: gaol::profile::Profile, - project_path: PathBuf + project_path: PathBuf, ) -> Command { let executor = SandboxExecutor::new(profile, project_path); executor.prepare_sandboxed_command(command, args, cwd) @@ -368,9 +425,12 @@ pub enum SerializedOperation { } #[cfg(unix)] -fn deserialize_profile(serialized: SerializedProfile, project_path: &Path) -> Result { +fn deserialize_profile( + serialized: SerializedProfile, + project_path: &Path, +) -> Result { let mut operations = Vec::new(); - + for op in serialized.operations { match op { SerializedOperation::FileReadAll { path, is_subpath } => { @@ -401,12 +461,12 @@ fn deserialize_profile(serialized: SerializedProfile, project_path: &Path) -> Re } SerializedOperation::NetworkTcp { port } => { operations.push(gaol::profile::Operation::NetworkOutbound( - gaol::profile::AddressPattern::Tcp(port) + gaol::profile::AddressPattern::Tcp(port), )); } SerializedOperation::NetworkLocalSocket { path } => { operations.push(gaol::profile::Operation::NetworkOutbound( - gaol::profile::AddressPattern::LocalSocket(path) + gaol::profile::AddressPattern::LocalSocket(path), )); } SerializedOperation::SystemInfoRead => { @@ -414,40 +474,38 @@ fn deserialize_profile(serialized: SerializedProfile, project_path: &Path) -> Re } } } - + // Always ensure project path access let has_project_access = operations.iter().any(|op| { matches!(op, gaol::profile::Operation::FileReadAll(gaol::profile::PathPattern::Subpath(p)) if p == project_path) }); - + if !has_project_access { operations.push(gaol::profile::Operation::FileReadAll( - gaol::profile::PathPattern::Subpath(project_path.to_path_buf()) + gaol::profile::PathPattern::Subpath(project_path.to_path_buf()), )); } - + let op_count = operations.len(); - gaol::profile::Profile::new(operations) - .map_err(|e| { - error!("Failed to create profile: {:?}", e); - anyhow::anyhow!("Failed to create profile from {} operations: {:?}", op_count, e) - }) + gaol::profile::Profile::new(operations).map_err(|e| { + error!("Failed to create profile: {:?}", e); + anyhow::anyhow!( + "Failed to create profile from {} operations: {:?}", + op_count, + e + ) + }) } #[cfg(unix)] fn create_minimal_profile(project_path: PathBuf) -> Result { let operations = vec![ - gaol::profile::Operation::FileReadAll( - gaol::profile::PathPattern::Subpath(project_path) - ), - gaol::profile::Operation::NetworkOutbound( - gaol::profile::AddressPattern::All - ), + gaol::profile::Operation::FileReadAll(gaol::profile::PathPattern::Subpath(project_path)), + gaol::profile::Operation::NetworkOutbound(gaol::profile::AddressPattern::All), ]; - - gaol::profile::Profile::new(operations) - .map_err(|e| { - error!("Failed to create minimal profile: {:?}", e); - anyhow::anyhow!("Failed to create minimal sandbox profile: {:?}", e) - }) -} \ No newline at end of file + + gaol::profile::Profile::new(operations).map_err(|e| { + error!("Failed to create minimal profile: {:?}", e); + anyhow::anyhow!("Failed to create minimal sandbox profile: {:?}", e) + }) +} diff --git a/src-tauri/src/sandbox/mod.rs b/src-tauri/src/sandbox/mod.rs index 6e0ce10..289ee63 100644 --- a/src-tauri/src/sandbox/mod.rs +++ b/src-tauri/src/sandbox/mod.rs @@ -1,21 +1,21 @@ #[allow(unused)] -pub mod profile; +pub mod defaults; #[allow(unused)] pub mod executor; #[allow(unused)] pub mod platform; #[allow(unused)] -pub mod defaults; +pub mod profile; // These are used in agents.rs and claude.rs via direct module paths #[allow(unused)] -pub use profile::{SandboxProfile, SandboxRule, ProfileBuilder}; +pub use profile::{ProfileBuilder, SandboxProfile, SandboxRule}; // These are used in main.rs and sandbox.rs #[allow(unused)] -pub use executor::{SandboxExecutor, should_activate_sandbox}; +pub use executor::{should_activate_sandbox, SandboxExecutor}; // These are used in sandbox.rs #[allow(unused)] -pub use platform::{PlatformCapabilities, get_platform_capabilities}; +pub use platform::{get_platform_capabilities, PlatformCapabilities}; // Used for initial setup #[allow(unused)] -pub use defaults::create_default_profiles; \ No newline at end of file +pub use defaults::create_default_profiles; diff --git a/src-tauri/src/sandbox/platform.rs b/src-tauri/src/sandbox/platform.rs index bb54a80..2aac002 100644 --- a/src-tauri/src/sandbox/platform.rs +++ b/src-tauri/src/sandbox/platform.rs @@ -28,7 +28,7 @@ pub struct OperationSupport { /// Get the platform capabilities for sandboxing pub fn get_platform_capabilities() -> PlatformCapabilities { let os = env::consts::OS; - + match os { "linux" => get_linux_capabilities(), "macos" => get_macos_capabilities(), @@ -176,4 +176,4 @@ fn get_unsupported_capabilities(os: &str) -> PlatformCapabilities { /// Check if sandboxing is available on the current platform pub fn is_sandboxing_available() -> bool { matches!(env::consts::OS, "linux" | "macos" | "freebsd") -} \ No newline at end of file +} diff --git a/src-tauri/src/sandbox/profile.rs b/src-tauri/src/sandbox/profile.rs index 9a17a5f..6e37d23 100644 --- a/src-tauri/src/sandbox/profile.rs +++ b/src-tauri/src/sandbox/profile.rs @@ -1,3 +1,4 @@ +use crate::sandbox::executor::{SerializedOperation, SerializedProfile}; use anyhow::{Context, Result}; #[cfg(unix)] use gaol::profile::{AddressPattern, Operation, OperationSupport, PathPattern, Profile}; @@ -5,7 +6,6 @@ use log::{debug, info, warn}; use rusqlite::{params, Connection}; use serde::{Deserialize, Serialize}; use std::path::PathBuf; -use crate::sandbox::executor::{SerializedOperation, SerializedProfile}; /// Represents a sandbox profile from the database #[derive(Debug, Clone, Serialize, Deserialize)] @@ -37,7 +37,7 @@ pub struct ProfileBuildResult { #[cfg(unix)] pub profile: Profile, #[cfg(not(unix))] - pub profile: (), // Placeholder for Windows + pub profile: (), // Placeholder for Windows pub serialized: SerializedProfile, } @@ -50,56 +50,63 @@ pub struct ProfileBuilder { impl ProfileBuilder { /// Create a new profile builder pub fn new(project_path: PathBuf) -> Result { - let home_dir = dirs::home_dir() - .context("Could not determine home directory")?; - + let home_dir = dirs::home_dir().context("Could not determine home directory")?; + Ok(Self { project_path, home_dir, }) } - + /// Build a gaol Profile from database rules filtered by agent permissions - pub fn build_agent_profile(&self, rules: Vec, sandbox_enabled: bool, enable_file_read: bool, enable_file_write: bool, enable_network: bool) -> Result { + pub fn build_agent_profile( + &self, + rules: Vec, + sandbox_enabled: bool, + enable_file_read: bool, + enable_file_write: bool, + enable_network: bool, + ) -> Result { // If sandbox is completely disabled, return an empty profile if !sandbox_enabled { return Ok(ProfileBuildResult { #[cfg(unix)] - profile: Profile::new(vec![]).map_err(|_| anyhow::anyhow!("Failed to create empty profile"))?, + profile: Profile::new(vec![]) + .map_err(|_| anyhow::anyhow!("Failed to create empty profile"))?, #[cfg(not(unix))] profile: (), serialized: SerializedProfile { operations: vec![] }, }); } - + let mut filtered_rules = Vec::new(); - + for rule in rules { if !rule.enabled { continue; } - + // Filter rules based on agent permissions let include_rule = match rule.operation_type.as_str() { "file_read_all" | "file_read_metadata" => enable_file_read, "network_outbound" => enable_network, "system_info_read" => true, // Always allow system info reading - _ => true // Include unknown rule types by default + _ => true, // Include unknown rule types by default }; - + if include_rule { filtered_rules.push(rule); } } - + // Always ensure project path access if file reading is enabled if enable_file_read { let has_project_access = filtered_rules.iter().any(|rule| { - rule.operation_type == "file_read_all" && - rule.pattern_type == "subpath" && - rule.pattern_value.contains("{{PROJECT_PATH}}") + rule.operation_type == "file_read_all" + && rule.pattern_type == "subpath" + && rule.pattern_value.contains("{{PROJECT_PATH}}") }); - + if !has_project_access { // Add a default project access rule filtered_rules.push(SandboxRule { @@ -114,78 +121,99 @@ impl ProfileBuilder { }); } } - + self.build_profile_with_serialization(filtered_rules) } - + /// Build a gaol Profile from database rules #[cfg(unix)] pub fn build_profile(&self, rules: Vec) -> Result { let result = self.build_profile_with_serialization(rules)?; Ok(result.profile) } - + /// Build a gaol Profile from database rules (Windows stub) #[cfg(not(unix))] pub fn build_profile(&self, _rules: Vec) -> Result<()> { warn!("Sandbox profiles are not supported on Windows"); Ok(()) } - + /// Build a gaol Profile from database rules and return serialized operations - pub fn build_profile_with_serialization(&self, rules: Vec) -> Result { + pub fn build_profile_with_serialization( + &self, + rules: Vec, + ) -> Result { #[cfg(unix)] { let mut operations = Vec::new(); let mut serialized_operations = Vec::new(); - + for rule in rules { if !rule.enabled { continue; } - + // Check platform support if !self.is_rule_supported_on_platform(&rule) { - debug!("Skipping rule {} - not supported on current platform", rule.operation_type); + debug!( + "Skipping rule {} - not supported on current platform", + rule.operation_type + ); continue; } - + match self.build_operation_with_serialization(&rule) { Ok(Some((op, serialized))) => { // Check if operation is supported on current platform - if matches!(op.support(), gaol::profile::OperationSupportLevel::CanBeAllowed) { + if matches!( + op.support(), + gaol::profile::OperationSupportLevel::CanBeAllowed + ) { operations.push(op); serialized_operations.push(serialized); } else { - warn!("Operation {:?} not supported at desired level on current platform", rule.operation_type); + warn!( + "Operation {:?} not supported at desired level on current platform", + rule.operation_type + ); } - }, + } Ok(None) => { - debug!("Skipping unsupported operation type: {}", rule.operation_type); - }, + debug!( + "Skipping unsupported operation type: {}", + rule.operation_type + ); + } Err(e) => { - warn!("Failed to build operation for rule {}: {}", rule.id.unwrap_or(0), e); + warn!( + "Failed to build operation for rule {}: {}", + rule.id.unwrap_or(0), + e + ); } } } - + // Ensure project path access is included let has_project_access = serialized_operations.iter().any(|op| { matches!(op, SerializedOperation::FileReadAll { path, is_subpath: true } if path == &self.project_path) }); - + if !has_project_access { - operations.push(Operation::FileReadAll(PathPattern::Subpath(self.project_path.clone()))); + operations.push(Operation::FileReadAll(PathPattern::Subpath( + self.project_path.clone(), + ))); serialized_operations.push(SerializedOperation::FileReadAll { path: self.project_path.clone(), is_subpath: true, }); } - + // Create the profile let profile = Profile::new(operations) .map_err(|_| anyhow::anyhow!("Failed to create sandbox profile - some operations may not be supported on this platform"))?; - + Ok(ProfileBuildResult { profile, serialized: SerializedProfile { @@ -193,22 +221,22 @@ impl ProfileBuilder { }, }) } - + #[cfg(not(unix))] { // On Windows, we just create a serialized profile without actual sandboxing let mut serialized_operations = Vec::new(); - + for rule in rules { if !rule.enabled { continue; } - + if let Ok(Some(serialized)) = self.build_serialized_operation(&rule) { serialized_operations.push(serialized); } } - + Ok(ProfileBuildResult { profile: (), serialized: SerializedProfile { @@ -217,7 +245,7 @@ impl ProfileBuilder { }) } } - + /// Build a gaol Operation from a database rule #[cfg(unix)] fn build_operation(&self, rule: &SandboxRule) -> Result> { @@ -227,97 +255,125 @@ impl ProfileBuilder { Err(e) => Err(e), } } - + /// Build a gaol Operation and its serialized form from a database rule #[cfg(unix)] - fn build_operation_with_serialization(&self, rule: &SandboxRule) -> Result> { + fn build_operation_with_serialization( + &self, + rule: &SandboxRule, + ) -> Result> { match rule.operation_type.as_str() { "file_read_all" => { - let (pattern, path, is_subpath) = self.build_path_pattern_with_info(&rule.pattern_type, &rule.pattern_value)?; + let (pattern, path, is_subpath) = + self.build_path_pattern_with_info(&rule.pattern_type, &rule.pattern_value)?; Ok(Some(( Operation::FileReadAll(pattern), - SerializedOperation::FileReadAll { path, is_subpath } + SerializedOperation::FileReadAll { path, is_subpath }, ))) - }, + } "file_read_metadata" => { - let (pattern, path, is_subpath) = self.build_path_pattern_with_info(&rule.pattern_type, &rule.pattern_value)?; + let (pattern, path, is_subpath) = + self.build_path_pattern_with_info(&rule.pattern_type, &rule.pattern_value)?; Ok(Some(( Operation::FileReadMetadata(pattern), - SerializedOperation::FileReadMetadata { path, is_subpath } + SerializedOperation::FileReadMetadata { path, is_subpath }, ))) - }, + } "network_outbound" => { - let (pattern, serialized) = self.build_address_pattern_with_serialization(&rule.pattern_type, &rule.pattern_value)?; + let (pattern, serialized) = self.build_address_pattern_with_serialization( + &rule.pattern_type, + &rule.pattern_value, + )?; Ok(Some((Operation::NetworkOutbound(pattern), serialized))) - }, - "system_info_read" => { - Ok(Some(( - Operation::SystemInfoRead, - SerializedOperation::SystemInfoRead - ))) - }, - _ => Ok(None) + } + "system_info_read" => Ok(Some(( + Operation::SystemInfoRead, + SerializedOperation::SystemInfoRead, + ))), + _ => Ok(None), } } - + /// Build a PathPattern from pattern type and value #[cfg(unix)] fn build_path_pattern(&self, pattern_type: &str, pattern_value: &str) -> Result { let (pattern, _, _) = self.build_path_pattern_with_info(pattern_type, pattern_value)?; Ok(pattern) } - + /// Build a PathPattern and return additional info for serialization #[cfg(unix)] - fn build_path_pattern_with_info(&self, pattern_type: &str, pattern_value: &str) -> Result<(PathPattern, PathBuf, bool)> { + fn build_path_pattern_with_info( + &self, + pattern_type: &str, + pattern_value: &str, + ) -> Result<(PathPattern, PathBuf, bool)> { // Replace template variables let expanded_value = pattern_value .replace("{{PROJECT_PATH}}", &self.project_path.to_string_lossy()) .replace("{{HOME}}", &self.home_dir.to_string_lossy()); - + let path = PathBuf::from(expanded_value); - + match pattern_type { "literal" => Ok((PathPattern::Literal(path.clone()), path, false)), "subpath" => Ok((PathPattern::Subpath(path.clone()), path, true)), - _ => Err(anyhow::anyhow!("Unknown path pattern type: {}", pattern_type)) + _ => Err(anyhow::anyhow!( + "Unknown path pattern type: {}", + pattern_type + )), } } - + /// Build an AddressPattern from pattern type and value #[cfg(unix)] - fn build_address_pattern(&self, pattern_type: &str, pattern_value: &str) -> Result { - let (pattern, _) = self.build_address_pattern_with_serialization(pattern_type, pattern_value)?; + fn build_address_pattern( + &self, + pattern_type: &str, + pattern_value: &str, + ) -> Result { + let (pattern, _) = + self.build_address_pattern_with_serialization(pattern_type, pattern_value)?; Ok(pattern) } - + /// Build an AddressPattern and its serialized form #[cfg(unix)] - fn build_address_pattern_with_serialization(&self, pattern_type: &str, pattern_value: &str) -> Result<(AddressPattern, SerializedOperation)> { + fn build_address_pattern_with_serialization( + &self, + pattern_type: &str, + pattern_value: &str, + ) -> Result<(AddressPattern, SerializedOperation)> { match pattern_type { "all" => Ok(( AddressPattern::All, - SerializedOperation::NetworkOutbound { pattern: "all".to_string() } + SerializedOperation::NetworkOutbound { + pattern: "all".to_string(), + }, )), "tcp" => { - let port = pattern_value.parse::() + let port = pattern_value + .parse::() .context("Invalid TCP port number")?; Ok(( AddressPattern::Tcp(port), - SerializedOperation::NetworkTcp { port } + SerializedOperation::NetworkTcp { port }, )) - }, + } "local_socket" => { let path = PathBuf::from(pattern_value); Ok(( AddressPattern::LocalSocket(path.clone()), - SerializedOperation::NetworkLocalSocket { path } + SerializedOperation::NetworkLocalSocket { path }, )) - }, - _ => Err(anyhow::anyhow!("Unknown address pattern type: {}", pattern_type)) + } + _ => Err(anyhow::anyhow!( + "Unknown address pattern type: {}", + pattern_type + )), } } - + /// Check if a rule is supported on the current platform fn is_rule_supported_on_platform(&self, rule: &SandboxRule) -> bool { if let Some(platforms_json) = &rule.platform_support { @@ -332,37 +388,42 @@ impl ProfileBuilder { /// Build only the serialized operation (for Windows) #[cfg(not(unix))] - fn build_serialized_operation(&self, rule: &SandboxRule) -> Result> { + fn build_serialized_operation( + &self, + rule: &SandboxRule, + ) -> Result> { let pattern_value = self.expand_pattern_value(&rule.pattern_value); - + match rule.operation_type.as_str() { "file_read_all" => { - let (path, is_subpath) = self.parse_path_pattern(&rule.pattern_type, &pattern_value)?; + let (path, is_subpath) = + self.parse_path_pattern(&rule.pattern_type, &pattern_value)?; Ok(Some(SerializedOperation::FileReadAll { path, is_subpath })) } "file_read_metadata" => { - let (path, is_subpath) = self.parse_path_pattern(&rule.pattern_type, &pattern_value)?; - Ok(Some(SerializedOperation::FileReadMetadata { path, is_subpath })) - } - "network_outbound" => { - Ok(Some(SerializedOperation::NetworkOutbound { pattern: pattern_value })) + let (path, is_subpath) = + self.parse_path_pattern(&rule.pattern_type, &pattern_value)?; + Ok(Some(SerializedOperation::FileReadMetadata { + path, + is_subpath, + })) } + "network_outbound" => Ok(Some(SerializedOperation::NetworkOutbound { + pattern: pattern_value, + })), "network_tcp" => { - let port = pattern_value.parse::() - .context("Invalid TCP port")?; + let port = pattern_value.parse::().context("Invalid TCP port")?; Ok(Some(SerializedOperation::NetworkTcp { port })) } "network_local_socket" => { let path = PathBuf::from(pattern_value); Ok(Some(SerializedOperation::NetworkLocalSocket { path })) } - "system_info_read" => { - Ok(Some(SerializedOperation::SystemInfoRead)) - } + "system_info_read" => Ok(Some(SerializedOperation::SystemInfoRead)), _ => Ok(None), } } - + /// Helper method to expand pattern values (Windows version) #[cfg(not(unix))] fn expand_pattern_value(&self, pattern_value: &str) -> String { @@ -370,16 +431,23 @@ impl ProfileBuilder { .replace("{{PROJECT_PATH}}", &self.project_path.to_string_lossy()) .replace("{{HOME}}", &self.home_dir.to_string_lossy()) } - + /// Helper method to parse path patterns (Windows version) #[cfg(not(unix))] - fn parse_path_pattern(&self, pattern_type: &str, pattern_value: &str) -> Result<(PathBuf, bool)> { + fn parse_path_pattern( + &self, + pattern_type: &str, + pattern_value: &str, + ) -> Result<(PathBuf, bool)> { let path = PathBuf::from(pattern_value); - + match pattern_type { "literal" => Ok((path, false)), "subpath" => Ok((path, true)), - _ => Err(anyhow::anyhow!("Unknown path pattern type: {}", pattern_type)) + _ => Err(anyhow::anyhow!( + "Unknown path pattern type: {}", + pattern_type + )), } } } @@ -400,7 +468,7 @@ pub fn load_profile(conn: &Connection, profile_id: i64) -> Result Result { created_at: row.get(5)?, updated_at: row.get(6)?, }) - } + }, ) .context("Failed to load default sandbox profile") } @@ -432,40 +500,45 @@ pub fn load_profile_rules(conn: &Connection, profile_id: i64) -> Result, _>>()?; - + + let rules = stmt + .query_map(params![profile_id], |row| { + Ok(SandboxRule { + id: Some(row.get(0)?), + profile_id: row.get(1)?, + operation_type: row.get(2)?, + pattern_type: row.get(3)?, + pattern_value: row.get(4)?, + enabled: row.get(5)?, + platform_support: row.get(6)?, + created_at: row.get(7)?, + }) + })? + .collect::, _>>()?; + Ok(rules) } /// Get or create the gaol Profile for execution #[cfg(unix)] -pub fn get_gaol_profile(conn: &Connection, profile_id: Option, project_path: PathBuf) -> Result { +pub fn get_gaol_profile( + conn: &Connection, + profile_id: Option, + project_path: PathBuf, +) -> Result { // Load the profile let profile = if let Some(id) = profile_id { load_profile(conn, id)? } else { load_default_profile(conn)? }; - + info!("Using sandbox profile: {}", profile.name); - + // Load the rules let rules = load_profile_rules(conn, profile.id.unwrap())?; info!("Loaded {} sandbox rules", rules.len()); - + // Build the gaol profile let builder = ProfileBuilder::new(project_path)?; builder.build_profile(rules) @@ -473,7 +546,11 @@ pub fn get_gaol_profile(conn: &Connection, profile_id: Option, project_path /// Get or create the gaol Profile for execution (Windows stub) #[cfg(not(unix))] -pub fn get_gaol_profile(_conn: &Connection, _profile_id: Option, _project_path: PathBuf) -> Result<()> { +pub fn get_gaol_profile( + _conn: &Connection, + _profile_id: Option, + _project_path: PathBuf, +) -> Result<()> { warn!("Sandbox profiles are not supported on Windows"); Ok(()) -} \ No newline at end of file +} diff --git a/src-tauri/tests/sandbox/common/claude_real.rs b/src-tauri/tests/sandbox/common/claude_real.rs index ec52ca0..d84f482 100644 --- a/src-tauri/tests/sandbox/common/claude_real.rs +++ b/src-tauri/tests/sandbox/common/claude_real.rs @@ -14,36 +14,37 @@ pub fn execute_claude_task( timeout_secs: u64, ) -> Result { let mut cmd = Command::new("claude"); - + // Add task cmd.arg("-p").arg(task); - + // Add system prompt if provided if let Some(prompt) = system_prompt { cmd.arg("--system-prompt").arg(prompt); } - + // Add model if provided if let Some(m) = model { cmd.arg("--model").arg(m); } - + // Always add these flags for testing - cmd.arg("--output-format").arg("stream-json") - .arg("--verbose") - .arg("--dangerously-skip-permissions") - .current_dir(project_path) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()); - + cmd.arg("--output-format") + .arg("stream-json") + .arg("--verbose") + .arg("--dangerously-skip-permissions") + .current_dir(project_path) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + // Add sandbox profile ID if provided if let Some(profile_id) = sandbox_profile_id { cmd.env("CLAUDIA_SANDBOX_PROFILE_ID", profile_id.to_string()); } - + // Execute with timeout (use gtimeout on macOS, timeout on Linux) let start = std::time::Instant::now(); - + let timeout_cmd = if cfg!(target_os = "macos") { // On macOS, try gtimeout (from GNU coreutils) first, fallback to direct execution if std::process::Command::new("which") @@ -60,15 +61,15 @@ pub fn execute_claude_task( } else { "timeout" }; - + let output = if timeout_cmd.is_empty() { // Run without timeout wrapper - cmd.output() - .context("Failed to execute Claude command")? + cmd.output().context("Failed to execute Claude command")? } else { // Run with timeout wrapper let mut timeout_cmd = Command::new(timeout_cmd); - timeout_cmd.arg(timeout_secs.to_string()) + timeout_cmd + .arg(timeout_secs.to_string()) .arg("claude") .args(cmd.get_args()) .current_dir(project_path) @@ -78,9 +79,9 @@ pub fn execute_claude_task( .output() .context("Failed to execute Claude command with timeout")? }; - + let duration = start.elapsed(); - + Ok(ClaudeOutput { stdout: String::from_utf8_lossy(&output.stdout).to_string(), stderr: String::from_utf8_lossy(&output.stderr).to_string(), @@ -103,7 +104,7 @@ impl ClaudeOutput { pub fn contains_operation(&self, operation: &str) -> bool { self.stdout.contains(operation) || self.stderr.contains(operation) } - + /// Check if operation was blocked (look for permission denied, sandbox violation, etc) pub fn operation_was_blocked(&self, operation: &str) -> bool { let blocked_patterns = [ @@ -114,16 +115,16 @@ impl ClaudeOutput { "access denied", "sandbox violation", ]; - + let output = format!("{}\n{}", self.stdout, self.stderr).to_lowercase(); let op_lower = operation.to_lowercase(); - + // Check if operation was mentioned along with a block pattern - blocked_patterns.iter().any(|pattern| { - output.contains(&op_lower) && output.contains(pattern) - }) + blocked_patterns + .iter() + .any(|pattern| output.contains(&op_lower) && output.contains(pattern)) } - + /// Check if file read was successful pub fn file_read_succeeded(&self, filename: &str) -> bool { // Look for patterns indicating successful file read @@ -133,10 +134,12 @@ impl ClaudeOutput { &format!("Contents of {}", filename), "test content", // Our test files contain this ]; - - patterns.iter().any(|pattern| self.contains_operation(pattern)) + + patterns + .iter() + .any(|pattern| self.contains_operation(pattern)) } - + /// Check if network connection was attempted pub fn network_attempted(&self, host: &str) -> bool { let patterns = [ @@ -145,8 +148,10 @@ impl ClaudeOutput { &format!("connect to {}", host), host, ]; - - patterns.iter().any(|pattern| self.contains_operation(pattern)) + + patterns + .iter() + .any(|pattern| self.contains_operation(pattern)) } } @@ -156,24 +161,27 @@ pub mod tasks { pub fn read_file(filename: &str) -> String { format!("Read the file {} and show me its contents", filename) } - + /// Task to attempt network connection pub fn connect_network(host: &str) -> String { format!("Try to connect to {} and tell me if it works", host) } - + /// Task to do multiple operations pub fn multi_operation() -> String { "Read the file ./test.txt in the current directory and show its contents".to_string() } - + /// Task to test file write pub fn write_file(filename: &str, content: &str) -> String { - format!("Create a file called {} with the content '{}'", filename, content) + format!( + "Create a file called {} with the content '{}'", + filename, content + ) } - + /// Task to test process spawning pub fn spawn_process(command: &str) -> String { format!("Run the command '{}' and show me the output", command) } -} \ No newline at end of file +} diff --git a/src-tauri/tests/sandbox/common/fixtures.rs b/src-tauri/tests/sandbox/common/fixtures.rs index 0f6c73c..c3feb89 100644 --- a/src-tauri/tests/sandbox/common/fixtures.rs +++ b/src-tauri/tests/sandbox/common/fixtures.rs @@ -10,9 +10,8 @@ use tempfile::{tempdir, TempDir}; /// Using parking_lot::Mutex which doesn't poison on panic use parking_lot::Mutex; -pub static TEST_DB: Lazy> = Lazy::new(|| { - Mutex::new(TestDatabase::new().expect("Failed to create test database")) -}); +pub static TEST_DB: Lazy> = + Lazy::new(|| Mutex::new(TestDatabase::new().expect("Failed to create test database"))); /// Test database manager pub struct TestDatabase { @@ -26,13 +25,13 @@ impl TestDatabase { let temp_dir = tempdir()?; let db_path = temp_dir.path().join("test_sandbox.db"); let conn = Connection::open(&db_path)?; - + // Initialize schema Self::init_schema(&conn)?; - + Ok(Self { conn, temp_dir }) } - + /// Initialize database schema fn init_schema(conn: &Connection) -> Result<()> { // Create sandbox profiles table @@ -48,7 +47,7 @@ impl TestDatabase { )", [], )?; - + // Create sandbox rules table conn.execute( "CREATE TABLE IF NOT EXISTS sandbox_rules ( @@ -64,7 +63,7 @@ impl TestDatabase { )", [], )?; - + // Create agents table conn.execute( "CREATE TABLE IF NOT EXISTS agents ( @@ -80,7 +79,7 @@ impl TestDatabase { )", [], )?; - + // Create agent_runs table conn.execute( "CREATE TABLE IF NOT EXISTS agent_runs ( @@ -101,7 +100,7 @@ impl TestDatabase { )", [], )?; - + // Create sandbox violations table conn.execute( "CREATE TABLE IF NOT EXISTS sandbox_violations ( @@ -120,7 +119,7 @@ impl TestDatabase { )", [], )?; - + // Create trigger to update the updated_at timestamp for agents conn.execute( "CREATE TRIGGER IF NOT EXISTS update_agent_timestamp @@ -131,7 +130,7 @@ impl TestDatabase { END", [], )?; - + // Create trigger to update sandbox profile timestamp conn.execute( "CREATE TRIGGER IF NOT EXISTS update_sandbox_profile_timestamp @@ -142,10 +141,10 @@ impl TestDatabase { END", [], )?; - + Ok(()) } - + /// Create a test profile with rules pub fn create_test_profile(&self, name: &str, rules: Vec) -> Result { // Insert profile @@ -153,9 +152,9 @@ impl TestDatabase { "INSERT INTO sandbox_profiles (name, description, is_active, is_default) VALUES (?1, ?2, ?3, ?4)", params![name, format!("Test profile: {name}"), true, false], )?; - + let profile_id = self.conn.last_insert_rowid(); - + // Insert rules for rule in rules { self.conn.execute( @@ -171,10 +170,10 @@ impl TestDatabase { ], )?; } - + Ok(profile_id) } - + /// Reset database to clean state pub fn reset(&self) -> Result<()> { // Delete in the correct order to respect foreign key constraints @@ -208,7 +207,7 @@ impl TestRule { platform_support: Some(r#"["linux", "macos"]"#.to_string()), } } - + /// Create a network rule pub fn network_all() -> Self { Self { @@ -219,7 +218,7 @@ impl TestRule { platform_support: Some(r#"["linux", "macos"]"#.to_string()), } } - + /// Create a network TCP rule pub fn network_tcp(port: u16) -> Self { Self { @@ -230,7 +229,7 @@ impl TestRule { platform_support: Some(r#"["macos"]"#.to_string()), } } - + /// Create a system info read rule pub fn system_info_read() -> Self { Self { @@ -256,25 +255,28 @@ impl TestFileSystem { pub fn new() -> Result { let root = tempdir()?; let root_path = root.path(); - + // Create project directory let project_path = root_path.join("test_project"); std::fs::create_dir_all(&project_path)?; - + // Create allowed directory let allowed_path = root_path.join("allowed"); std::fs::create_dir_all(&allowed_path)?; std::fs::write(allowed_path.join("test.txt"), "allowed content")?; - + // Create forbidden directory let forbidden_path = root_path.join("forbidden"); std::fs::create_dir_all(&forbidden_path)?; std::fs::write(forbidden_path.join("secret.txt"), "forbidden content")?; - + // Create project files std::fs::write(project_path.join("main.rs"), "fn main() {}")?; - std::fs::write(project_path.join("Cargo.toml"), "[package]\nname = \"test\"")?; - + std::fs::write( + project_path.join("Cargo.toml"), + "[package]\nname = \"test\"", + )?; + Ok(Self { root, project_path, @@ -287,14 +289,12 @@ impl TestFileSystem { /// Standard test profiles pub mod profiles { use super::*; - + /// Minimal profile - only project access pub fn minimal(project_path: &str) -> Vec { - vec![ - TestRule::file_read(project_path, true), - ] + vec![TestRule::file_read(project_path, true)] } - + /// Standard profile - project + system libraries pub fn standard(project_path: &str) -> Vec { vec![ @@ -304,7 +304,7 @@ pub mod profiles { TestRule::network_all(), ] } - + /// Development profile - more permissive pub fn development(project_path: &str, home_dir: &str) -> Vec { vec![ @@ -316,18 +316,17 @@ pub mod profiles { TestRule::system_info_read(), ] } - + /// Network-only profile pub fn network_only() -> Vec { - vec![ - TestRule::network_all(), - ] + vec![TestRule::network_all()] } - + /// File-only profile pub fn file_only(paths: Vec<&str>) -> Vec { - paths.into_iter() + paths + .into_iter() .map(|path| TestRule::file_read(path, true)) .collect() } -} \ No newline at end of file +} diff --git a/src-tauri/tests/sandbox/common/helpers.rs b/src-tauri/tests/sandbox/common/helpers.rs index b1035c3..493a2d6 100644 --- a/src-tauri/tests/sandbox/common/helpers.rs +++ b/src-tauri/tests/sandbox/common/helpers.rs @@ -15,7 +15,10 @@ pub fn is_sandboxing_supported() -> bool { macro_rules! skip_if_unsupported { () => { if !$crate::sandbox::common::is_sandboxing_supported() { - eprintln!("Skipping test: sandboxing not supported on {}", std::env::consts::OS); + eprintln!( + "Skipping test: sandboxing not supported on {}", + std::env::consts::OS + ); return; } }; @@ -39,7 +42,7 @@ impl PlatformConfig { supports_file_read: true, supports_metadata_read: false, // Cannot be precisely controlled supports_network_all: true, - supports_network_tcp: false, // Cannot filter by port + supports_network_tcp: false, // Cannot filter by port supports_network_local: false, // Cannot filter by path supports_system_info: false, }, @@ -89,54 +92,53 @@ impl TestCommand { working_dir: None, } } - + /// Add an argument pub fn arg(mut self, arg: &str) -> Self { self.args.push(arg.to_string()); self } - + /// Add multiple arguments pub fn args(mut self, args: &[&str]) -> Self { self.args.extend(args.iter().map(|s| s.to_string())); self } - + /// Set an environment variable pub fn env(mut self, key: &str, value: &str) -> Self { self.env_vars.push((key.to_string(), value.to_string())); self } - + /// Set working directory pub fn current_dir(mut self, dir: &Path) -> Self { self.working_dir = Some(dir.to_path_buf()); self } - + /// Execute the command with timeout pub fn execute_with_timeout(&self, timeout: Duration) -> Result { let mut cmd = Command::new(&self.command); - + cmd.args(&self.args); - + for (key, value) in &self.env_vars { cmd.env(key, value); } - + if let Some(dir) = &self.working_dir { cmd.current_dir(dir); } - + // On Unix, we can use a timeout mechanism #[cfg(unix)] { use std::time::Instant; - + let start = Instant::now(); - let mut child = cmd.spawn() - .context("Failed to spawn command")?; - + let mut child = cmd.spawn().context("Failed to spawn command")?; + loop { match child.try_wait() { Ok(Some(status)) => { @@ -158,19 +160,18 @@ impl TestCommand { } } } - + #[cfg(not(unix))] { // Fallback for non-Unix platforms - cmd.output() - .context("Failed to execute command") + cmd.output().context("Failed to execute command") } } - + /// Execute and expect success pub fn execute_expect_success(&self) -> Result { let output = self.execute_with_timeout(Duration::from_secs(10))?; - + if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); return Err(anyhow::anyhow!( @@ -178,31 +179,27 @@ impl TestCommand { output.status.code() )); } - + Ok(String::from_utf8_lossy(&output.stdout).to_string()) } - + /// Execute and expect failure pub fn execute_expect_failure(&self) -> Result { let output = self.execute_with_timeout(Duration::from_secs(10))?; - + if output.status.success() { let stdout = String::from_utf8_lossy(&output.stdout); return Err(anyhow::anyhow!( "Command unexpectedly succeeded. Stdout: {stdout}" )); } - + Ok(String::from_utf8_lossy(&output.stderr).to_string()) } } /// Create a simple test binary that attempts an operation -pub fn create_test_binary( - name: &str, - code: &str, - test_dir: &Path, -) -> Result { +pub fn create_test_binary(name: &str, code: &str, test_dir: &Path) -> Result { create_test_binary_with_deps(name, code, test_dir, &[]) } @@ -215,7 +212,7 @@ pub fn create_test_binary_with_deps( ) -> Result { let src_dir = test_dir.join("src"); std::fs::create_dir_all(&src_dir)?; - + // Build dependencies section let deps_section = if dependencies.is_empty() { String::new() @@ -226,7 +223,7 @@ pub fn create_test_binary_with_deps( } deps }; - + // Create Cargo.toml let cargo_toml = format!( r#"[package] @@ -240,10 +237,10 @@ path = "src/main.rs" {deps_section}"# ); std::fs::write(test_dir.join("Cargo.toml"), cargo_toml)?; - + // Create main.rs std::fs::write(src_dir.join("main.rs"), code)?; - + // Build the binary let output = Command::new("cargo") .arg("build") @@ -251,12 +248,12 @@ path = "src/main.rs" .current_dir(test_dir) .output() .context("Failed to build test binary")?; - + if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); return Err(anyhow::anyhow!("Failed to build test binary: {stderr}")); } - + let binary_path = test_dir.join("target/release").join(name); Ok(binary_path) } @@ -281,7 +278,7 @@ fn main() {{ "# ) } - + /// Code that reads file metadata pub fn file_metadata(path: &str) -> String { format!( @@ -300,7 +297,7 @@ fn main() {{ "# ) } - + /// Code that makes a network connection pub fn network_connect(addr: &str) -> String { format!( @@ -321,7 +318,7 @@ fn main() {{ "# ) } - + /// Code that reads system information pub fn system_info() -> &'static str { r#" @@ -368,7 +365,7 @@ fn main() { } "# } - + /// Code that tries to spawn a process pub fn spawn_process() -> &'static str { r#" @@ -387,7 +384,7 @@ fn main() { } "# } - + /// Code that uses fork (requires libc) pub fn fork_process() -> &'static str { r#" @@ -418,7 +415,7 @@ fn main() { } "# } - + /// Code that uses exec (requires libc) pub fn exec_process() -> &'static str { r#" @@ -446,7 +443,7 @@ fn main() { } "# } - + /// Code that tries to write a file pub fn file_write(path: &str) -> String { format!( @@ -483,4 +480,4 @@ pub fn assert_sandbox_success(output: &str) { /// Assert that a command output indicates failure pub fn assert_sandbox_failure(output: &str) { assert_output_contains(output, "FAILURE:"); -} \ No newline at end of file +} diff --git a/src-tauri/tests/sandbox/common/mod.rs b/src-tauri/tests/sandbox/common/mod.rs index 2ced385..34aeafe 100644 --- a/src-tauri/tests/sandbox/common/mod.rs +++ b/src-tauri/tests/sandbox/common/mod.rs @@ -1,8 +1,8 @@ //! Common test utilities and helpers for sandbox testing +pub mod claude_real; pub mod fixtures; pub mod helpers; -pub mod claude_real; +pub use claude_real::*; pub use fixtures::*; pub use helpers::*; -pub use claude_real::*; \ No newline at end of file diff --git a/src-tauri/tests/sandbox/e2e/agent_sandbox.rs b/src-tauri/tests/sandbox/e2e/agent_sandbox.rs index 1766ec8..c5c487e 100644 --- a/src-tauri/tests/sandbox/e2e/agent_sandbox.rs +++ b/src-tauri/tests/sandbox/e2e/agent_sandbox.rs @@ -8,17 +8,18 @@ use serial_test::serial; #[serial] fn test_agent_with_minimal_profile() { skip_if_unsupported!(); - + // Create test environment let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_db = TEST_DB.lock(); test_db.reset().expect("Failed to reset database"); - + // Create minimal sandbox profile let rules = profiles::minimal(&test_fs.project_path.to_string_lossy()); - let profile_id = test_db.create_test_profile("minimal_agent_test", rules) + let profile_id = test_db + .create_test_profile("minimal_agent_test", rules) .expect("Failed to create test profile"); - + // Create test agent test_db.conn.execute( "INSERT INTO agents (name, icon, system_prompt, model, sandbox_profile_id) VALUES (?1, ?2, ?3, ?4, ?5)", @@ -30,9 +31,9 @@ fn test_agent_with_minimal_profile() { profile_id ], ).expect("Failed to create agent"); - + let _agent_id = test_db.conn.last_insert_rowid(); - + // Execute real Claude command with minimal profile let result = execute_claude_task( &test_fs.project_path, @@ -41,8 +42,9 @@ fn test_agent_with_minimal_profile() { Some("sonnet"), Some(profile_id), 20, // 20 second timeout - ).expect("Failed to execute Claude command"); - + ) + .expect("Failed to execute Claude command"); + // Debug output eprintln!("=== Claude Output ==="); eprintln!("Exit code: {}", result.exit_code); @@ -50,10 +52,13 @@ fn test_agent_with_minimal_profile() { eprintln!("STDERR:\n{}", result.stderr); eprintln!("Duration: {:?}", result.duration); eprintln!("==================="); - + // Basic verification - just check Claude ran - assert!(result.exit_code == 0 || result.exit_code == 124, // 0 = success, 124 = timeout - "Claude should execute (exit code: {})", result.exit_code); + assert!( + result.exit_code == 0 || result.exit_code == 124, // 0 = success, 124 = timeout + "Claude should execute (exit code: {})", + result.exit_code + ); } /// Test agent execution with standard sandbox profile @@ -61,17 +66,18 @@ fn test_agent_with_minimal_profile() { #[serial] fn test_agent_with_standard_profile() { skip_if_unsupported!(); - + // Create test environment let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_db = TEST_DB.lock(); test_db.reset().expect("Failed to reset database"); - + // Create standard sandbox profile let rules = profiles::standard(&test_fs.project_path.to_string_lossy()); - let profile_id = test_db.create_test_profile("standard_agent_test", rules) + let profile_id = test_db + .create_test_profile("standard_agent_test", rules) .expect("Failed to create test profile"); - + // Create test agent test_db.conn.execute( "INSERT INTO agents (name, icon, system_prompt, model, sandbox_profile_id) VALUES (?1, ?2, ?3, ?4, ?5)", @@ -83,9 +89,9 @@ fn test_agent_with_standard_profile() { profile_id ], ).expect("Failed to create agent"); - + let _agent_id = test_db.conn.last_insert_rowid(); - + // Execute real Claude command with standard profile let result = execute_claude_task( &test_fs.project_path, @@ -94,18 +100,22 @@ fn test_agent_with_standard_profile() { Some("sonnet"), Some(profile_id), 20, // 20 second timeout - ).expect("Failed to execute Claude command"); - + ) + .expect("Failed to execute Claude command"); + // Debug output eprintln!("=== Claude Output (Standard Profile) ==="); eprintln!("Exit code: {}", result.exit_code); eprintln!("STDOUT:\n{}", result.stdout); eprintln!("STDERR:\n{}", result.stderr); eprintln!("==================="); - + // Basic verification - assert!(result.exit_code == 0 || result.exit_code == 124, - "Claude should execute with standard profile (exit code: {})", result.exit_code); + assert!( + result.exit_code == 0 || result.exit_code == 124, + "Claude should execute with standard profile (exit code: {})", + result.exit_code + ); } /// Test agent execution without sandbox (control test) @@ -113,25 +123,28 @@ fn test_agent_with_standard_profile() { #[serial] fn test_agent_without_sandbox() { skip_if_unsupported!(); - + // Create test environment let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_db = TEST_DB.lock(); test_db.reset().expect("Failed to reset database"); - + // Create agent without sandbox profile - test_db.conn.execute( - "INSERT INTO agents (name, icon, system_prompt, model) VALUES (?1, ?2, ?3, ?4)", - rusqlite::params![ - "Unsandboxed Agent", - "⚠️", - "You are a test agent without sandbox restrictions.", - "sonnet" - ], - ).expect("Failed to create agent"); - + test_db + .conn + .execute( + "INSERT INTO agents (name, icon, system_prompt, model) VALUES (?1, ?2, ?3, ?4)", + rusqlite::params![ + "Unsandboxed Agent", + "⚠️", + "You are a test agent without sandbox restrictions.", + "sonnet" + ], + ) + .expect("Failed to create agent"); + let _agent_id = test_db.conn.last_insert_rowid(); - + // Execute real Claude command without sandbox profile let result = execute_claude_task( &test_fs.project_path, @@ -139,19 +152,23 @@ fn test_agent_without_sandbox() { Some("You are a test agent without sandbox restrictions."), Some("sonnet"), None, // No sandbox profile - 20, // 20 second timeout - ).expect("Failed to execute Claude command"); - + 20, // 20 second timeout + ) + .expect("Failed to execute Claude command"); + // Debug output eprintln!("=== Claude Output (No Sandbox) ==="); eprintln!("Exit code: {}", result.exit_code); eprintln!("STDOUT:\n{}", result.stdout); eprintln!("STDERR:\n{}", result.stderr); eprintln!("==================="); - + // Basic verification - assert!(result.exit_code == 0 || result.exit_code == 124, - "Claude should execute without sandbox (exit code: {})", result.exit_code); + assert!( + result.exit_code == 0 || result.exit_code == 124, + "Claude should execute without sandbox (exit code: {})", + result.exit_code + ); } /// Test agent run violation logging @@ -159,15 +176,16 @@ fn test_agent_without_sandbox() { #[serial] fn test_agent_run_violation_logging() { skip_if_unsupported!(); - + // Create test environment let test_db = TEST_DB.lock(); test_db.reset().expect("Failed to reset database"); - + // Create a test profile first - let profile_id = test_db.create_test_profile("violation_test", vec![]) + let profile_id = test_db + .create_test_profile("violation_test", vec![]) .expect("Failed to create test profile"); - + // Create a test agent test_db.conn.execute( "INSERT INTO agents (name, icon, system_prompt, model, sandbox_profile_id) VALUES (?1, ?2, ?3, ?4, ?5)", @@ -179,9 +197,9 @@ fn test_agent_run_violation_logging() { profile_id ], ).expect("Failed to create agent"); - + let agent_id = test_db.conn.last_insert_rowid(); - + // Create a test agent run test_db.conn.execute( "INSERT INTO agent_runs (agent_id, agent_name, agent_icon, task, model, project_path) VALUES (?1, ?2, ?3, ?4, ?5, ?6)", @@ -194,23 +212,26 @@ fn test_agent_run_violation_logging() { "/test/path" ], ).expect("Failed to create agent run"); - + let agent_run_id = test_db.conn.last_insert_rowid(); - + // Insert test violations test_db.conn.execute( "INSERT INTO sandbox_violations (profile_id, agent_id, agent_run_id, operation_type, pattern_value) VALUES (?1, ?2, ?3, ?4, ?5)", rusqlite::params![profile_id, agent_id, agent_run_id, "file_read_all", "/etc/passwd"], ).expect("Failed to insert violation"); - + // Query violations - let count: i64 = test_db.conn.query_row( - "SELECT COUNT(*) FROM sandbox_violations WHERE agent_id = ?1", - rusqlite::params![agent_id], - |row| row.get(0), - ).expect("Failed to query violations"); - + let count: i64 = test_db + .conn + .query_row( + "SELECT COUNT(*) FROM sandbox_violations WHERE agent_id = ?1", + rusqlite::params![agent_id], + |row| row.get(0), + ) + .expect("Failed to query violations"); + assert_eq!(count, 1, "Should have recorded one violation"); } @@ -219,21 +240,23 @@ fn test_agent_run_violation_logging() { #[serial] fn test_profile_switching() { skip_if_unsupported!(); - + // Create test environment let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_db = TEST_DB.lock(); test_db.reset().expect("Failed to reset database"); - + // Create two different profiles let minimal_rules = profiles::minimal(&test_fs.project_path.to_string_lossy()); - let minimal_id = test_db.create_test_profile("minimal_switch", minimal_rules) + let minimal_id = test_db + .create_test_profile("minimal_switch", minimal_rules) .expect("Failed to create minimal profile"); - + let standard_rules = profiles::standard(&test_fs.project_path.to_string_lossy()); - let standard_id = test_db.create_test_profile("standard_switch", standard_rules) + let standard_id = test_db + .create_test_profile("standard_switch", standard_rules) .expect("Failed to create standard profile"); - + // Create agent initially with minimal profile test_db.conn.execute( "INSERT INTO agents (name, icon, system_prompt, model, sandbox_profile_id) VALUES (?1, ?2, ?3, ?4, ?5)", @@ -245,21 +268,27 @@ fn test_profile_switching() { minimal_id ], ).expect("Failed to create agent"); - + let agent_id = test_db.conn.last_insert_rowid(); - + // Update agent to use standard profile - test_db.conn.execute( - "UPDATE agents SET sandbox_profile_id = ?1 WHERE id = ?2", - rusqlite::params![standard_id, agent_id], - ).expect("Failed to update agent profile"); - + test_db + .conn + .execute( + "UPDATE agents SET sandbox_profile_id = ?1 WHERE id = ?2", + rusqlite::params![standard_id, agent_id], + ) + .expect("Failed to update agent profile"); + // Verify profile was updated - let current_profile: i64 = test_db.conn.query_row( - "SELECT sandbox_profile_id FROM agents WHERE id = ?1", - rusqlite::params![agent_id], - |row| row.get(0), - ).expect("Failed to query agent profile"); - + let current_profile: i64 = test_db + .conn + .query_row( + "SELECT sandbox_profile_id FROM agents WHERE id = ?1", + rusqlite::params![agent_id], + |row| row.get(0), + ) + .expect("Failed to query agent profile"); + assert_eq!(current_profile, standard_id, "Profile should be updated"); -} \ No newline at end of file +} diff --git a/src-tauri/tests/sandbox/e2e/claude_sandbox.rs b/src-tauri/tests/sandbox/e2e/claude_sandbox.rs index 2d6e3e2..f446318 100644 --- a/src-tauri/tests/sandbox/e2e/claude_sandbox.rs +++ b/src-tauri/tests/sandbox/e2e/claude_sandbox.rs @@ -8,23 +8,27 @@ use serial_test::serial; #[serial] fn test_claude_with_default_sandbox() { skip_if_unsupported!(); - + // Create test environment let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_db = TEST_DB.lock(); test_db.reset().expect("Failed to reset database"); - + // Create default sandbox profile let rules = profiles::standard(&test_fs.project_path.to_string_lossy()); - let profile_id = test_db.create_test_profile("claude_default", rules) + let profile_id = test_db + .create_test_profile("claude_default", rules) .expect("Failed to create test profile"); - + // Set as default and active - test_db.conn.execute( - "UPDATE sandbox_profiles SET is_default = 1, is_active = 1 WHERE id = ?1", - rusqlite::params![profile_id], - ).expect("Failed to set default profile"); - + test_db + .conn + .execute( + "UPDATE sandbox_profiles SET is_default = 1, is_active = 1 WHERE id = ?1", + rusqlite::params![profile_id], + ) + .expect("Failed to set default profile"); + // Execute real Claude command with default sandbox profile let result = execute_claude_task( &test_fs.project_path, @@ -33,18 +37,22 @@ fn test_claude_with_default_sandbox() { Some("sonnet"), Some(profile_id), 20, // 20 second timeout - ).expect("Failed to execute Claude command"); - + ) + .expect("Failed to execute Claude command"); + // Debug output eprintln!("=== Claude Output (Default Sandbox) ==="); eprintln!("Exit code: {}", result.exit_code); eprintln!("STDOUT:\n{}", result.stdout); eprintln!("STDERR:\n{}", result.stderr); eprintln!("==================="); - + // Basic verification - assert!(result.exit_code == 0 || result.exit_code == 124, - "Claude should execute with default sandbox (exit code: {})", result.exit_code); + assert!( + result.exit_code == 0 || result.exit_code == 124, + "Claude should execute with default sandbox (exit code: {})", + result.exit_code + ); } /// Test Claude Code with sandboxing disabled @@ -52,23 +60,27 @@ fn test_claude_with_default_sandbox() { #[serial] fn test_claude_sandbox_disabled() { skip_if_unsupported!(); - + // Create test environment let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_db = TEST_DB.lock(); test_db.reset().expect("Failed to reset database"); - + // Create profile but mark as inactive let rules = profiles::standard(&test_fs.project_path.to_string_lossy()); - let profile_id = test_db.create_test_profile("claude_inactive", rules) + let profile_id = test_db + .create_test_profile("claude_inactive", rules) .expect("Failed to create test profile"); - + // Set as default but inactive - test_db.conn.execute( - "UPDATE sandbox_profiles SET is_default = 1, is_active = 0 WHERE id = ?1", - rusqlite::params![profile_id], - ).expect("Failed to set inactive profile"); - + test_db + .conn + .execute( + "UPDATE sandbox_profiles SET is_default = 1, is_active = 0 WHERE id = ?1", + rusqlite::params![profile_id], + ) + .expect("Failed to set inactive profile"); + // Execute real Claude command without active sandbox let result = execute_claude_task( &test_fs.project_path, @@ -76,19 +88,23 @@ fn test_claude_sandbox_disabled() { Some("You are Claude. Only perform the requested task."), Some("sonnet"), None, // No sandbox since profile is inactive - 20, // 20 second timeout - ).expect("Failed to execute Claude command"); - + 20, // 20 second timeout + ) + .expect("Failed to execute Claude command"); + // Debug output eprintln!("=== Claude Output (Inactive Sandbox) ==="); eprintln!("Exit code: {}", result.exit_code); eprintln!("STDOUT:\n{}", result.stdout); eprintln!("STDERR:\n{}", result.stderr); eprintln!("==================="); - + // Basic verification - assert!(result.exit_code == 0 || result.exit_code == 124, - "Claude should execute without active sandbox (exit code: {})", result.exit_code); + assert!( + result.exit_code == 0 || result.exit_code == 124, + "Claude should execute without active sandbox (exit code: {})", + result.exit_code + ); } /// Test Claude Code session operations @@ -96,31 +112,31 @@ fn test_claude_sandbox_disabled() { #[serial] fn test_claude_session_operations() { // This test doesn't require actual Claude execution - + // Create test environment let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create mock session structure let claude_dir = test_fs.root.path().join(".claude"); let projects_dir = claude_dir.join("projects"); let project_id = test_fs.project_path.to_string_lossy().replace('/', "-"); let session_dir = projects_dir.join(&project_id); - + std::fs::create_dir_all(&session_dir).expect("Failed to create session dir"); - + // Create mock session file let session_id = "test-session-123"; let session_file = session_dir.join(format!("{}.jsonl", session_id)); - + let session_data = serde_json::json!({ "type": "session_start", "cwd": test_fs.project_path.to_string_lossy(), "timestamp": "2024-01-01T00:00:00Z" }); - + std::fs::write(&session_file, format!("{}\n", session_data)) .expect("Failed to write session file"); - + // Verify session file exists assert!(session_file.exists(), "Session file should exist"); } @@ -131,11 +147,11 @@ fn test_claude_session_operations() { fn test_claude_settings_sandbox_config() { // Create test environment let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create mock settings let claude_dir = test_fs.root.path().join(".claude"); std::fs::create_dir_all(&claude_dir).expect("Failed to create claude dir"); - + let settings_file = claude_dir.join("settings.json"); let settings = serde_json::json!({ "sandboxEnabled": true, @@ -143,18 +159,23 @@ fn test_claude_settings_sandbox_config() { "theme": "dark", "model": "sonnet" }); - - std::fs::write(&settings_file, serde_json::to_string_pretty(&settings).unwrap()) - .expect("Failed to write settings"); - + + std::fs::write( + &settings_file, + serde_json::to_string_pretty(&settings).unwrap(), + ) + .expect("Failed to write settings"); + // Read and verify settings - let content = std::fs::read_to_string(&settings_file) - .expect("Failed to read settings"); - let parsed: serde_json::Value = serde_json::from_str(&content) - .expect("Failed to parse settings"); - + let content = std::fs::read_to_string(&settings_file).expect("Failed to read settings"); + let parsed: serde_json::Value = + serde_json::from_str(&content).expect("Failed to parse settings"); + assert_eq!(parsed["sandboxEnabled"], true, "Sandbox should be enabled"); - assert_eq!(parsed["defaultSandboxProfile"], "standard", "Default profile should be standard"); + assert_eq!( + parsed["defaultSandboxProfile"], "standard", + "Default profile should be standard" + ); } /// Test profile-based file access restrictions @@ -162,22 +183,23 @@ fn test_claude_settings_sandbox_config() { #[serial] fn test_profile_file_access_simulation() { skip_if_unsupported!(); - + // Create test environment let _test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let test_db = TEST_DB.lock(); test_db.reset().expect("Failed to reset database"); - + // Create a custom profile with specific file access let custom_rules = vec![ TestRule::file_read("{{PROJECT_PATH}}", true), TestRule::file_read("/usr/local/bin", true), TestRule::file_read("/etc/hosts", false), // Literal file ]; - - let profile_id = test_db.create_test_profile("file_access_test", custom_rules) + + let profile_id = test_db + .create_test_profile("file_access_test", custom_rules) .expect("Failed to create test profile"); - + // Load the profile rules let loaded_rules: Vec<(String, String, String)> = test_db.conn .prepare("SELECT operation_type, pattern_type, pattern_value FROM sandbox_rules WHERE profile_id = ?1") @@ -188,9 +210,11 @@ fn test_profile_file_access_simulation() { .expect("Failed to query rules") .collect::, _>>() .expect("Failed to collect rules"); - + // Verify rules were created correctly assert_eq!(loaded_rules.len(), 3, "Should have 3 rules"); - assert!(loaded_rules.iter().any(|(op, _, _)| op == "file_read_all"), - "Should have file_read_all operation"); -} \ No newline at end of file + assert!( + loaded_rules.iter().any(|(op, _, _)| op == "file_read_all"), + "Should have file_read_all operation" + ); +} diff --git a/src-tauri/tests/sandbox/e2e/mod.rs b/src-tauri/tests/sandbox/e2e/mod.rs index 755d3c0..ffb8c33 100644 --- a/src-tauri/tests/sandbox/e2e/mod.rs +++ b/src-tauri/tests/sandbox/e2e/mod.rs @@ -2,4 +2,4 @@ #[cfg(test)] mod agent_sandbox; #[cfg(test)] -mod claude_sandbox; \ No newline at end of file +mod claude_sandbox; diff --git a/src-tauri/tests/sandbox/integration/file_operations.rs b/src-tauri/tests/sandbox/integration/file_operations.rs index dcc01bf..9d5b802 100644 --- a/src-tauri/tests/sandbox/integration/file_operations.rs +++ b/src-tauri/tests/sandbox/integration/file_operations.rs @@ -3,7 +3,7 @@ use crate::sandbox::common::*; use crate::skip_if_unsupported; use claudia_lib::sandbox::executor::SandboxExecutor; use claudia_lib::sandbox::profile::ProfileBuilder; -use gaol::profile::{Profile, Operation, PathPattern}; +use gaol::profile::{Operation, PathPattern, Profile}; use serial_test::serial; use tempfile::TempDir; @@ -12,21 +12,21 @@ use tempfile::TempDir; #[serial] fn test_allowed_file_read() { skip_if_unsupported!(); - + let platform = PlatformConfig::current(); if !platform.supports_file_read { eprintln!("Skipping test: file read not supported on this platform"); return; } - + // Create test file system let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create profile allowing project path access - let operations = vec![ - Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), - ]; - + let operations = vec![Operation::FileReadAll(PathPattern::Subpath( + test_fs.project_path.clone(), + ))]; + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -34,13 +34,13 @@ fn test_allowed_file_read() { return; } }; - + // Create test binary that reads from allowed path let test_code = test_code::file_read(&test_fs.project_path.join("main.rs").to_string_lossy()); let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_file_read", &test_code, binary_dir.path()) .expect("Failed to create test binary"); - + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -63,21 +63,21 @@ fn test_allowed_file_read() { #[serial] fn test_forbidden_file_read() { skip_if_unsupported!(); - + let platform = PlatformConfig::current(); if !platform.supports_file_read { eprintln!("Skipping test: file read not supported on this platform"); return; } - + // Create test file system let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create profile allowing only project path (not forbidden path) - let operations = vec![ - Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), - ]; - + let operations = vec![Operation::FileReadAll(PathPattern::Subpath( + test_fs.project_path.clone(), + ))]; + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -85,14 +85,14 @@ fn test_forbidden_file_read() { return; } }; - + // Create test binary that reads from forbidden path let forbidden_file = test_fs.forbidden_path.join("secret.txt"); let test_code = test_code::file_read(&forbidden_file.to_string_lossy()); let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_forbidden_read", &test_code, binary_dir.path()) .expect("Failed to create test binary"); - + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -105,7 +105,9 @@ fn test_forbidden_file_read() { // On some platforms (like macOS), gaol might not block all file reads // so we check if the operation failed OR if it's a platform limitation if status.success() { - eprintln!("WARNING: File read was not blocked - this might be a platform limitation"); + eprintln!( + "WARNING: File read was not blocked - this might be a platform limitation" + ); // Check if we're on a platform where this is expected let platform_config = PlatformConfig::current(); if !platform_config.supports_file_read { @@ -124,15 +126,15 @@ fn test_forbidden_file_read() { #[serial] fn test_file_write_always_forbidden() { skip_if_unsupported!(); - + // Create test file system let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create profile with file read permissions (write should still be blocked) - let operations = vec![ - Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), - ]; - + let operations = vec![Operation::FileReadAll(PathPattern::Subpath( + test_fs.project_path.clone(), + ))]; + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -140,14 +142,14 @@ fn test_file_write_always_forbidden() { return; } }; - + // Create test binary that tries to write a file let write_path = test_fs.project_path.join("test_write.txt"); let test_code = test_code::file_write(&write_path.to_string_lossy()); let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_file_write", &test_code, binary_dir.path()) .expect("Failed to create test binary"); - + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -177,28 +179,28 @@ fn test_file_write_always_forbidden() { #[serial] fn test_file_metadata_operations() { skip_if_unsupported!(); - + let platform = PlatformConfig::current(); if !platform.supports_metadata_read && !platform.supports_file_read { eprintln!("Skipping test: metadata read not supported on this platform"); return; } - + // Create test file system let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create profile with metadata read permission let operations = if platform.supports_metadata_read { - vec![ - Operation::FileReadMetadata(PathPattern::Subpath(test_fs.project_path.clone())), - ] + vec![Operation::FileReadMetadata(PathPattern::Subpath( + test_fs.project_path.clone(), + ))] } else { // On Linux, metadata is allowed if file read is allowed - vec![ - Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), - ] + vec![Operation::FileReadAll(PathPattern::Subpath( + test_fs.project_path.clone(), + ))] }; - + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -206,14 +208,14 @@ fn test_file_metadata_operations() { return; } }; - + // Create test binary that reads file metadata let test_file = test_fs.project_path.join("main.rs"); let test_code = test_code::file_metadata(&test_file.to_string_lossy()); let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_metadata", &test_code, binary_dir.path()) .expect("Failed to create test binary"); - + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -224,7 +226,10 @@ fn test_file_metadata_operations() { Ok(mut child) => { let status = child.wait().expect("Failed to wait for child"); if platform.supports_metadata_read || platform.supports_file_read { - assert!(status.success(), "Metadata read should succeed when allowed"); + assert!( + status.success(), + "Metadata read should succeed when allowed" + ); } } Err(e) => { @@ -238,33 +243,32 @@ fn test_file_metadata_operations() { #[serial] fn test_template_variable_expansion() { skip_if_unsupported!(); - + let platform = PlatformConfig::current(); if !platform.supports_file_read { eprintln!("Skipping test: file read not supported on this platform"); return; } - + // Create test database and profile let test_db = TEST_DB.lock(); test_db.reset().expect("Failed to reset database"); - + // Create a profile with template variables let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - let rules = vec![ - TestRule::file_read("{{PROJECT_PATH}}", true), - ]; - - let profile_id = test_db.create_test_profile("template_test", rules) + let rules = vec![TestRule::file_read("{{PROJECT_PATH}}", true)]; + + let profile_id = test_db + .create_test_profile("template_test", rules) .expect("Failed to create test profile"); - + // Load and build the profile let db_rules = claudia_lib::sandbox::profile::load_profile_rules(&test_db.conn, profile_id) .expect("Failed to load profile rules"); - + let builder = ProfileBuilder::new(test_fs.project_path.clone()) .expect("Failed to create profile builder"); - + let profile = match builder.build_profile(db_rules) { Ok(p) => p, Err(_) => { @@ -272,13 +276,13 @@ fn test_template_variable_expansion() { return; } }; - + // Create test binary that reads from project path let test_code = test_code::file_read(&test_fs.project_path.join("main.rs").to_string_lossy()); let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_template", &test_code, binary_dir.path()) .expect("Failed to create test binary"); - + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -294,4 +298,4 @@ fn test_template_variable_expansion() { eprintln!("Sandbox execution failed: {} (may be expected in CI)", e); } } -} \ No newline at end of file +} diff --git a/src-tauri/tests/sandbox/integration/mod.rs b/src-tauri/tests/sandbox/integration/mod.rs index 1a41814..0f5f8be 100644 --- a/src-tauri/tests/sandbox/integration/mod.rs +++ b/src-tauri/tests/sandbox/integration/mod.rs @@ -4,8 +4,8 @@ mod file_operations; #[cfg(test)] mod network_operations; #[cfg(test)] -mod system_info; -#[cfg(test)] mod process_isolation; #[cfg(test)] -mod violations; \ No newline at end of file +mod system_info; +#[cfg(test)] +mod violations; diff --git a/src-tauri/tests/sandbox/integration/network_operations.rs b/src-tauri/tests/sandbox/integration/network_operations.rs index 95b62a9..171938d 100644 --- a/src-tauri/tests/sandbox/integration/network_operations.rs +++ b/src-tauri/tests/sandbox/integration/network_operations.rs @@ -2,7 +2,7 @@ use crate::sandbox::common::*; use crate::skip_if_unsupported; use claudia_lib::sandbox::executor::SandboxExecutor; -use gaol::profile::{Profile, Operation, AddressPattern}; +use gaol::profile::{AddressPattern, Operation, Profile}; use serial_test::serial; use std::net::TcpListener; use tempfile::TempDir; @@ -10,7 +10,10 @@ use tempfile::TempDir; /// Get an available port for testing fn get_available_port() -> u16 { let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind to 0"); - let port = listener.local_addr().expect("Failed to get local addr").port(); + let port = listener + .local_addr() + .expect("Failed to get local addr") + .port(); drop(listener); // Release the port port } @@ -20,21 +23,19 @@ fn get_available_port() -> u16 { #[serial] fn test_allowed_network_all() { skip_if_unsupported!(); - + let platform = PlatformConfig::current(); if !platform.supports_network_all { eprintln!("Skipping test: network all not supported on this platform"); return; } - + // Create test project let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create profile allowing all network access - let operations = vec![ - Operation::NetworkOutbound(AddressPattern::All), - ]; - + let operations = vec![Operation::NetworkOutbound(AddressPattern::All)]; + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -42,18 +43,18 @@ fn test_allowed_network_all() { return; } }; - + // Create test binary that connects to localhost let port = get_available_port(); let test_code = test_code::network_connect(&format!("127.0.0.1:{}", port)); let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_network", &test_code, binary_dir.path()) .expect("Failed to create test binary"); - + // Start a listener on the port - let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) - .expect("Failed to bind listener"); - + let listener = + TcpListener::bind(format!("127.0.0.1:{}", port)).expect("Failed to bind listener"); + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -66,9 +67,12 @@ fn test_allowed_network_all() { std::thread::spawn(move || { let _ = listener.accept(); }); - + let status = child.wait().expect("Failed to wait for child"); - assert!(status.success(), "Network connection should succeed when allowed"); + assert!( + status.success(), + "Network connection should succeed when allowed" + ); } Err(e) => { eprintln!("Sandbox execution failed: {} (may be expected in CI)", e); @@ -81,15 +85,15 @@ fn test_allowed_network_all() { #[serial] fn test_forbidden_network() { skip_if_unsupported!(); - + // Create test project let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create profile without network permissions - let operations = vec![ - Operation::FileReadAll(gaol::profile::PathPattern::Subpath(test_fs.project_path.clone())), - ]; - + let operations = vec![Operation::FileReadAll(gaol::profile::PathPattern::Subpath( + test_fs.project_path.clone(), + ))]; + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -97,13 +101,13 @@ fn test_forbidden_network() { return; } }; - + // Create test binary that tries to connect let test_code = test_code::network_connect("google.com:80"); let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_no_network", &test_code, binary_dir.path()) .expect("Failed to create test binary"); - + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -137,19 +141,19 @@ fn test_network_tcp_port_specific() { eprintln!("Skipping test: TCP port filtering not supported"); return; } - + // Create test project let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Get two ports - one allowed, one forbidden let allowed_port = get_available_port(); let forbidden_port = get_available_port(); - + // Create profile allowing only specific port - let operations = vec![ - Operation::NetworkOutbound(AddressPattern::Tcp(allowed_port)), - ]; - + let operations = vec![Operation::NetworkOutbound(AddressPattern::Tcp( + allowed_port, + ))]; + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -157,17 +161,17 @@ fn test_network_tcp_port_specific() { return; } }; - + // Test 1: Allowed port { let test_code = test_code::network_connect(&format!("127.0.0.1:{}", allowed_port)); let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_allowed_port", &test_code, binary_dir.path()) .expect("Failed to create test binary"); - + let listener = TcpListener::bind(format!("127.0.0.1:{}", allowed_port)) .expect("Failed to bind listener"); - + let executor = SandboxExecutor::new(profile.clone(), test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( &binary_path.to_string_lossy(), @@ -178,23 +182,26 @@ fn test_network_tcp_port_specific() { std::thread::spawn(move || { let _ = listener.accept(); }); - + let status = child.wait().expect("Failed to wait for child"); - assert!(status.success(), "Connection to allowed port should succeed"); + assert!( + status.success(), + "Connection to allowed port should succeed" + ); } Err(e) => { eprintln!("Sandbox execution failed: {} (may be expected in CI)", e); } } } - + // Test 2: Forbidden port { let test_code = test_code::network_connect(&format!("127.0.0.1:{}", forbidden_port)); let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_forbidden_port", &test_code, binary_dir.path()) .expect("Failed to create test binary"); - + let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( &binary_path.to_string_lossy(), @@ -203,7 +210,10 @@ fn test_network_tcp_port_specific() { ) { Ok(mut child) => { let status = child.wait().expect("Failed to wait for child"); - assert!(!status.success(), "Connection to forbidden port should fail"); + assert!( + !status.success(), + "Connection to forbidden port should fail" + ); } Err(e) => { eprintln!("Sandbox execution failed: {} (may be expected in CI)", e); @@ -218,28 +228,26 @@ fn test_network_tcp_port_specific() { #[cfg(unix)] fn test_local_socket_connections() { skip_if_unsupported!(); - + let platform = PlatformConfig::current(); - + // Create test project let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let socket_path = test_fs.project_path.join("test.sock"); - + // Create appropriate profile based on platform let operations = if platform.supports_network_local { - vec![ - Operation::NetworkOutbound(AddressPattern::LocalSocket(socket_path.clone())), - ] + vec![Operation::NetworkOutbound(AddressPattern::LocalSocket( + socket_path.clone(), + ))] } else if platform.supports_network_all { // Fallback to allowing all network - vec![ - Operation::NetworkOutbound(AddressPattern::All), - ] + vec![Operation::NetworkOutbound(AddressPattern::All)] } else { eprintln!("Skipping test: no network support on this platform"); return; }; - + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -247,7 +255,7 @@ fn test_local_socket_connections() { return; } }; - + // Create test binary that connects to local socket let test_code = format!( r#" @@ -267,15 +275,15 @@ fn main() {{ "#, socket_path.to_string_lossy() ); - + let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_local_socket", &test_code, binary_dir.path()) .expect("Failed to create test binary"); - + // Create Unix socket listener use std::os::unix::net::UnixListener; let listener = UnixListener::bind(&socket_path).expect("Failed to bind Unix socket"); - + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -287,15 +295,18 @@ fn main() {{ std::thread::spawn(move || { let _ = listener.accept(); }); - + let status = child.wait().expect("Failed to wait for child"); - assert!(status.success(), "Local socket connection should succeed when allowed"); + assert!( + status.success(), + "Local socket connection should succeed when allowed" + ); } Err(e) => { eprintln!("Sandbox execution failed: {} (may be expected in CI)", e); } } - + // Clean up socket file let _ = std::fs::remove_file(&socket_path); -} \ No newline at end of file +} diff --git a/src-tauri/tests/sandbox/integration/process_isolation.rs b/src-tauri/tests/sandbox/integration/process_isolation.rs index c579864..b3f7bee 100644 --- a/src-tauri/tests/sandbox/integration/process_isolation.rs +++ b/src-tauri/tests/sandbox/integration/process_isolation.rs @@ -2,7 +2,7 @@ use crate::sandbox::common::*; use crate::skip_if_unsupported; use claudia_lib::sandbox::executor::SandboxExecutor; -use gaol::profile::{Profile, Operation, PathPattern, AddressPattern}; +use gaol::profile::{AddressPattern, Operation, PathPattern, Profile}; use serial_test::serial; use tempfile::TempDir; @@ -11,16 +11,16 @@ use tempfile::TempDir; #[serial] fn test_process_spawn_forbidden() { skip_if_unsupported!(); - + // Create test project let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create profile with various permissions (process spawn should still be blocked) let operations = vec![ Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), Operation::NetworkOutbound(AddressPattern::All), ]; - + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -28,13 +28,13 @@ fn test_process_spawn_forbidden() { return; } }; - + // Create test binary that tries to spawn a process let test_code = test_code::spawn_process(); let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_spawn", test_code, binary_dir.path()) .expect("Failed to create test binary"); - + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -49,7 +49,10 @@ fn test_process_spawn_forbidden() { eprintln!("WARNING: Process spawning was not blocked"); // macOS sandbox might have limitations if std::env::consts::OS != "linux" { - eprintln!("Process spawning might not be fully blocked on {}", std::env::consts::OS); + eprintln!( + "Process spawning might not be fully blocked on {}", + std::env::consts::OS + ); } else { panic!("Process spawning should be blocked on Linux"); } @@ -67,15 +70,15 @@ fn test_process_spawn_forbidden() { #[cfg(unix)] fn test_fork_forbidden() { skip_if_unsupported!(); - + // Create test project let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create minimal profile - let operations = vec![ - Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), - ]; - + let operations = vec![Operation::FileReadAll(PathPattern::Subpath( + test_fs.project_path.clone(), + ))]; + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -83,14 +86,19 @@ fn test_fork_forbidden() { return; } }; - + // Create test binary that tries to fork let test_code = test_code::fork_process(); - + let binary_dir = TempDir::new().expect("Failed to create temp dir"); - let binary_path = create_test_binary_with_deps("test_fork", test_code, binary_dir.path(), &[("libc", "0.2")]) - .expect("Failed to create test binary"); - + let binary_path = create_test_binary_with_deps( + "test_fork", + test_code, + binary_dir.path(), + &[("libc", "0.2")], + ) + .expect("Failed to create test binary"); + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -120,15 +128,15 @@ fn test_fork_forbidden() { #[cfg(unix)] fn test_exec_forbidden() { skip_if_unsupported!(); - + // Create test project let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create minimal profile - let operations = vec![ - Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), - ]; - + let operations = vec![Operation::FileReadAll(PathPattern::Subpath( + test_fs.project_path.clone(), + ))]; + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -136,14 +144,19 @@ fn test_exec_forbidden() { return; } }; - + // Create test binary that tries to exec let test_code = test_code::exec_process(); - + let binary_dir = TempDir::new().expect("Failed to create temp dir"); - let binary_path = create_test_binary_with_deps("test_exec", test_code, binary_dir.path(), &[("libc", "0.2")]) - .expect("Failed to create test binary"); - + let binary_path = create_test_binary_with_deps( + "test_exec", + test_code, + binary_dir.path(), + &[("libc", "0.2")], + ) + .expect("Failed to create test binary"); + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -172,15 +185,15 @@ fn test_exec_forbidden() { #[serial] fn test_thread_creation_allowed() { skip_if_unsupported!(); - + // Create test project let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create minimal profile - let operations = vec![ - Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), - ]; - + let operations = vec![Operation::FileReadAll(PathPattern::Subpath( + test_fs.project_path.clone(), + ))]; + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -188,7 +201,7 @@ fn test_thread_creation_allowed() { return; } }; - + // Create test binary that creates threads let test_code = r#" use std::thread; @@ -211,11 +224,11 @@ fn main() { } } "#; - + let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_thread", test_code, binary_dir.path()) .expect("Failed to create test binary"); - + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -231,4 +244,4 @@ fn main() { eprintln!("Sandbox execution failed: {} (may be expected in CI)", e); } } -} \ No newline at end of file +} diff --git a/src-tauri/tests/sandbox/integration/system_info.rs b/src-tauri/tests/sandbox/integration/system_info.rs index a207270..9c41b4e 100644 --- a/src-tauri/tests/sandbox/integration/system_info.rs +++ b/src-tauri/tests/sandbox/integration/system_info.rs @@ -2,7 +2,7 @@ use crate::sandbox::common::*; use crate::skip_if_unsupported; use claudia_lib::sandbox::executor::SandboxExecutor; -use gaol::profile::{Profile, Operation}; +use gaol::profile::{Operation, Profile}; use serial_test::serial; use tempfile::TempDir; @@ -11,21 +11,19 @@ use tempfile::TempDir; #[serial] fn test_system_info_read() { skip_if_unsupported!(); - + let platform = PlatformConfig::current(); if !platform.supports_system_info { eprintln!("Skipping test: system info read not supported on this platform"); return; } - + // Create test project let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create profile allowing system info read - let operations = vec![ - Operation::SystemInfoRead, - ]; - + let operations = vec![Operation::SystemInfoRead]; + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -33,13 +31,13 @@ fn test_system_info_read() { return; } }; - + // Create test binary that reads system info let test_code = test_code::system_info(); let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_sysinfo", test_code, binary_dir.path()) .expect("Failed to create test binary"); - + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -49,7 +47,10 @@ fn test_system_info_read() { ) { Ok(mut child) => { let status = child.wait().expect("Failed to wait for child"); - assert!(status.success(), "System info read should succeed when allowed"); + assert!( + status.success(), + "System info read should succeed when allowed" + ); } Err(e) => { eprintln!("Sandbox execution failed: {} (may be expected in CI)", e); @@ -64,12 +65,12 @@ fn test_system_info_read() { fn test_forbidden_system_info() { // Create test project let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create profile without system info permission - let operations = vec![ - Operation::FileReadAll(gaol::profile::PathPattern::Subpath(test_fs.project_path.clone())), - ]; - + let operations = vec![Operation::FileReadAll(gaol::profile::PathPattern::Subpath( + test_fs.project_path.clone(), + ))]; + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -77,13 +78,13 @@ fn test_forbidden_system_info() { return; } }; - + // Create test binary that reads system info let test_code = test_code::system_info(); let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_no_sysinfo", test_code, binary_dir.path()) .expect("Failed to create test binary"); - + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -118,27 +119,33 @@ fn test_forbidden_system_info() { #[serial] fn test_platform_specific_system_info() { skip_if_unsupported!(); - + let platform = PlatformConfig::current(); - + match std::env::consts::OS { "linux" => { // On Linux, system info is never allowed - assert!(!platform.supports_system_info, - "Linux should not support system info read"); + assert!( + !platform.supports_system_info, + "Linux should not support system info read" + ); } "macos" => { // On macOS, system info can be allowed - assert!(platform.supports_system_info, - "macOS should support system info read"); + assert!( + platform.supports_system_info, + "macOS should support system info read" + ); } "freebsd" => { // On FreeBSD, system info is always allowed (can't be restricted) - assert!(platform.supports_system_info, - "FreeBSD always allows system info read"); + assert!( + platform.supports_system_info, + "FreeBSD always allows system info read" + ); } _ => { eprintln!("Unknown platform behavior for system info"); } } -} \ No newline at end of file +} diff --git a/src-tauri/tests/sandbox/integration/violations.rs b/src-tauri/tests/sandbox/integration/violations.rs index 0a7b9d3..dbc8001 100644 --- a/src-tauri/tests/sandbox/integration/violations.rs +++ b/src-tauri/tests/sandbox/integration/violations.rs @@ -2,7 +2,7 @@ use crate::sandbox::common::*; use crate::skip_if_unsupported; use claudia_lib::sandbox::executor::SandboxExecutor; -use gaol::profile::{Profile, Operation, PathPattern}; +use gaol::profile::{Operation, PathPattern, Profile}; use serial_test::serial; use std::sync::{Arc, Mutex}; use tempfile::TempDir; @@ -27,19 +27,19 @@ impl ViolationCollector { violations: Arc::new(Mutex::new(Vec::new())), } } - + fn record(&self, operation_type: &str, pattern_value: Option<&str>, process_name: &str) { let event = ViolationEvent { operation_type: operation_type.to_string(), pattern_value: pattern_value.map(|s| s.to_string()), process_name: process_name.to_string(), }; - + if let Ok(mut violations) = self.violations.lock() { violations.push(event); } } - + fn get_violations(&self) -> Vec { self.violations.lock().unwrap().clone() } @@ -50,22 +50,22 @@ impl ViolationCollector { #[serial] fn test_violation_detection() { skip_if_unsupported!(); - + let platform = PlatformConfig::current(); if !platform.supports_file_read { eprintln!("Skipping test: file read not supported on this platform"); return; } - + // Create test file system let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); let collector = ViolationCollector::new(); - + // Create profile allowing only project path - let operations = vec![ - Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), - ]; - + let operations = vec![Operation::FileReadAll(PathPattern::Subpath( + test_fs.project_path.clone(), + ))]; + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -73,19 +73,31 @@ fn test_violation_detection() { return; } }; - + // Test various forbidden operations let test_cases = vec![ - ("file_read", test_code::file_read(&test_fs.forbidden_path.join("secret.txt").to_string_lossy()), "file_read_forbidden"), - ("file_write", test_code::file_write(&test_fs.project_path.join("new.txt").to_string_lossy()), "file_write_forbidden"), - ("process_spawn", test_code::spawn_process().to_string(), "process_spawn_forbidden"), + ( + "file_read", + test_code::file_read(&test_fs.forbidden_path.join("secret.txt").to_string_lossy()), + "file_read_forbidden", + ), + ( + "file_write", + test_code::file_write(&test_fs.project_path.join("new.txt").to_string_lossy()), + "file_write_forbidden", + ), + ( + "process_spawn", + test_code::spawn_process().to_string(), + "process_spawn_forbidden", + ), ]; - + for (op_type, test_code, binary_name) in test_cases { let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary(binary_name, &test_code, binary_dir.path()) .expect("Failed to create test binary"); - + let executor = SandboxExecutor::new(profile.clone(), test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( &binary_path.to_string_lossy(), @@ -104,7 +116,7 @@ fn test_violation_detection() { } } } - + // Verify violations were detected let violations = collector.get_violations(); // On some platforms (like macOS), sandbox might not block all operations @@ -122,25 +134,25 @@ fn test_violation_detection() { #[serial] fn test_violation_patterns() { skip_if_unsupported!(); - + let platform = PlatformConfig::current(); if !platform.supports_file_read { eprintln!("Skipping test: file read not supported on this platform"); return; } - + // Create test file system let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create profile with specific allowed paths let allowed_dir = test_fs.root.path().join("allowed_specific"); std::fs::create_dir_all(&allowed_dir).expect("Failed to create allowed dir"); - + let operations = vec![ Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), Operation::FileReadAll(PathPattern::Literal(allowed_dir.join("file.txt"))), ]; - + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -148,21 +160,25 @@ fn test_violation_patterns() { return; } }; - + // Test accessing different forbidden paths - let forbidden_db_path = test_fs.forbidden_path.join("data.db").to_string_lossy().to_string(); + let forbidden_db_path = test_fs + .forbidden_path + .join("data.db") + .to_string_lossy() + .to_string(); let forbidden_paths = vec![ ("/etc/passwd", "system_file"), ("/tmp/test.txt", "temp_file"), (forbidden_db_path.as_str(), "forbidden_db"), ]; - + for (path, test_name) in forbidden_paths { let test_code = test_code::file_read(path); let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary(test_name, &test_code, binary_dir.path()) .expect("Failed to create test binary"); - + let executor = SandboxExecutor::new(profile.clone(), test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( &binary_path.to_string_lossy(), @@ -173,7 +189,10 @@ fn test_violation_patterns() { let status = child.wait().expect("Failed to wait for child"); // Some platforms might not block all file access if status.success() { - eprintln!("WARNING: Access to {} was allowed (possible platform limitation)", path); + eprintln!( + "WARNING: Access to {} was allowed (possible platform limitation)", + path + ); if std::env::consts::OS == "linux" && path.starts_with("/etc") { panic!("Access to {} should be denied on Linux", path); } @@ -191,15 +210,15 @@ fn test_violation_patterns() { #[serial] fn test_multiple_violations_sequence() { skip_if_unsupported!(); - + // Create test file system let test_fs = TestFileSystem::new().expect("Failed to create test filesystem"); - + // Create minimal profile - let operations = vec![ - Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())), - ]; - + let operations = vec![Operation::FileReadAll(PathPattern::Subpath( + test_fs.project_path.clone(), + ))]; + let profile = match Profile::new(operations) { Ok(p) => p, Err(_) => { @@ -207,7 +226,7 @@ fn test_multiple_violations_sequence() { return; } }; - + // Create test binary that attempts multiple forbidden operations let test_code = r#" use std::fs; @@ -249,11 +268,11 @@ fn main() {{ }} }} "#; - + let binary_dir = TempDir::new().expect("Failed to create temp dir"); let binary_path = create_test_binary("test_multi_violations", test_code, binary_dir.path()) .expect("Failed to create test binary"); - + // Execute in sandbox let executor = SandboxExecutor::new(profile, test_fs.project_path.clone()); match executor.execute_sandboxed_spawn( @@ -275,4 +294,4 @@ fn main() {{ eprintln!("Sandbox execution failed: {} (may be expected in CI)", e); } } -} \ No newline at end of file +} diff --git a/src-tauri/tests/sandbox/mod.rs b/src-tauri/tests/sandbox/mod.rs index bd6c63e..9841a12 100644 --- a/src-tauri/tests/sandbox/mod.rs +++ b/src-tauri/tests/sandbox/mod.rs @@ -1,5 +1,5 @@ //! Comprehensive test suite for sandbox functionality -//! +//! //! This test suite validates the sandboxing capabilities across different platforms, //! ensuring that security policies are correctly enforced. @@ -14,4 +14,4 @@ pub mod unit; pub mod integration; #[cfg(unix)] -pub mod e2e; \ No newline at end of file +pub mod e2e; diff --git a/src-tauri/tests/sandbox/unit/executor.rs b/src-tauri/tests/sandbox/unit/executor.rs index 3adce93..e65ae5e 100644 --- a/src-tauri/tests/sandbox/unit/executor.rs +++ b/src-tauri/tests/sandbox/unit/executor.rs @@ -1,6 +1,6 @@ //! Unit tests for SandboxExecutor -use claudia_lib::sandbox::executor::{SandboxExecutor, should_activate_sandbox}; -use gaol::profile::{Profile, Operation, PathPattern, AddressPattern}; +use claudia_lib::sandbox::executor::{should_activate_sandbox, SandboxExecutor}; +use gaol::profile::{AddressPattern, Operation, PathPattern, Profile}; use std::env; use std::path::PathBuf; @@ -10,7 +10,7 @@ fn create_test_profile(project_path: PathBuf) -> Profile { Operation::FileReadAll(PathPattern::Subpath(project_path)), Operation::NetworkOutbound(AddressPattern::All), ]; - + Profile::new(operations).expect("Failed to create test profile") } @@ -18,7 +18,7 @@ fn create_test_profile(project_path: PathBuf) -> Profile { fn test_executor_creation() { let project_path = PathBuf::from("/test/project"); let profile = create_test_profile(project_path.clone()); - + let _executor = SandboxExecutor::new(profile, project_path); // Executor should be created successfully } @@ -27,16 +27,25 @@ fn test_executor_creation() { fn test_should_activate_sandbox_env_var() { // Test when env var is not set env::remove_var("GAOL_SANDBOX_ACTIVE"); - assert!(!should_activate_sandbox(), "Should not activate when env var is not set"); - + assert!( + !should_activate_sandbox(), + "Should not activate when env var is not set" + ); + // Test when env var is set to "1" env::set_var("GAOL_SANDBOX_ACTIVE", "1"); - assert!(should_activate_sandbox(), "Should activate when env var is '1'"); - + assert!( + should_activate_sandbox(), + "Should activate when env var is '1'" + ); + // Test when env var is set to other value env::set_var("GAOL_SANDBOX_ACTIVE", "0"); - assert!(!should_activate_sandbox(), "Should not activate when env var is not '1'"); - + assert!( + !should_activate_sandbox(), + "Should not activate when env var is not '1'" + ); + // Clean up env::remove_var("GAOL_SANDBOX_ACTIVE"); } @@ -46,9 +55,9 @@ fn test_prepare_sandboxed_command() { let project_path = PathBuf::from("/test/project"); let profile = create_test_profile(project_path.clone()); let executor = SandboxExecutor::new(profile, project_path.clone()); - + let _cmd = executor.prepare_sandboxed_command("echo", &["hello"], &project_path); - + // The command should have sandbox environment variables set // Note: We can't easily test Command internals, but we can verify it doesn't panic } @@ -57,10 +66,10 @@ fn test_prepare_sandboxed_command() { fn test_executor_with_empty_profile() { let project_path = PathBuf::from("/test/project"); let profile = Profile::new(vec![]).expect("Failed to create empty profile"); - + let executor = SandboxExecutor::new(profile, project_path.clone()); let _cmd = executor.prepare_sandboxed_command("echo", &["test"], &project_path); - + // Should handle empty profile gracefully } @@ -76,15 +85,16 @@ fn test_executor_with_complex_profile() { Operation::NetworkOutbound(AddressPattern::Tcp(443)), Operation::SystemInfoRead, ]; - + // Only create profile with supported operations - let filtered_ops: Vec<_> = operations.into_iter() + let filtered_ops: Vec<_> = operations + .into_iter() .filter(|op| { use gaol::profile::{OperationSupport, OperationSupportLevel}; matches!(op.support(), OperationSupportLevel::CanBeAllowed) }) .collect(); - + if !filtered_ops.is_empty() { let profile = Profile::new(filtered_ops).expect("Failed to create complex profile"); let executor = SandboxExecutor::new(profile, project_path.clone()); @@ -97,12 +107,12 @@ fn test_command_environment_setup() { let project_path = PathBuf::from("/test/project"); let profile = create_test_profile(project_path.clone()); let executor = SandboxExecutor::new(profile, project_path.clone()); - + // Test with various arguments let _cmd1 = executor.prepare_sandboxed_command("ls", &[], &project_path); let _cmd2 = executor.prepare_sandboxed_command("cat", &["file.txt"], &project_path); let _cmd3 = executor.prepare_sandboxed_command("grep", &["-r", "pattern", "."], &project_path); - + // Commands should be prepared without panic } @@ -110,18 +120,18 @@ fn test_command_environment_setup() { #[cfg(unix)] fn test_spawn_sandboxed_process() { use crate::sandbox::common::is_sandboxing_supported; - + if !is_sandboxing_supported() { return; } - + let project_path = env::current_dir().unwrap_or_else(|_| PathBuf::from("/tmp")); let profile = create_test_profile(project_path.clone()); let executor = SandboxExecutor::new(profile, project_path.clone()); - + // Try to spawn a simple command let result = executor.execute_sandboxed_spawn("echo", &["sandbox test"], &project_path); - + // On supported platforms, this should either succeed or fail gracefully match result { Ok(mut child) => { @@ -133,4 +143,4 @@ fn test_spawn_sandboxed_process() { println!("Sandbox spawn failed (expected in some environments): {e}"); } } -} \ No newline at end of file +} diff --git a/src-tauri/tests/sandbox/unit/mod.rs b/src-tauri/tests/sandbox/unit/mod.rs index bcd8f0a..b3a4f3d 100644 --- a/src-tauri/tests/sandbox/unit/mod.rs +++ b/src-tauri/tests/sandbox/unit/mod.rs @@ -1,7 +1,7 @@ //! Unit tests for sandbox components #[cfg(test)] -mod profile_builder; +mod executor; #[cfg(test)] mod platform; #[cfg(test)] -mod executor; \ No newline at end of file +mod profile_builder; diff --git a/src-tauri/tests/sandbox/unit/platform.rs b/src-tauri/tests/sandbox/unit/platform.rs index 4ca9cbf..fefd393 100644 --- a/src-tauri/tests/sandbox/unit/platform.rs +++ b/src-tauri/tests/sandbox/unit/platform.rs @@ -1,13 +1,13 @@ //! Unit tests for platform capabilities use claudia_lib::sandbox::platform::{get_platform_capabilities, is_sandboxing_available}; -use std::env; use pretty_assertions::assert_eq; +use std::env; #[test] fn test_sandboxing_availability() { let is_available = is_sandboxing_available(); let expected = matches!(env::consts::OS, "linux" | "macos" | "freebsd"); - + assert_eq!( is_available, expected, "Sandboxing availability should match platform support" @@ -17,44 +17,59 @@ fn test_sandboxing_availability() { #[test] fn test_platform_capabilities_structure() { let caps = get_platform_capabilities(); - + // Verify basic structure assert_eq!(caps.os, env::consts::OS, "OS should match current platform"); - assert!(!caps.operations.is_empty() || !caps.sandboxing_supported, - "Should have operations if sandboxing is supported"); - assert!(!caps.notes.is_empty(), "Should have platform-specific notes"); + assert!( + !caps.operations.is_empty() || !caps.sandboxing_supported, + "Should have operations if sandboxing is supported" + ); + assert!( + !caps.notes.is_empty(), + "Should have platform-specific notes" + ); } #[test] #[cfg(target_os = "linux")] fn test_linux_capabilities() { let caps = get_platform_capabilities(); - + assert_eq!(caps.os, "linux"); assert!(caps.sandboxing_supported); - + // Verify Linux-specific capabilities - let file_read = caps.operations.iter() + let file_read = caps + .operations + .iter() .find(|op| op.operation == "file_read_all") .expect("file_read_all should be present"); assert_eq!(file_read.support_level, "can_be_allowed"); - - let metadata_read = caps.operations.iter() + + let metadata_read = caps + .operations + .iter() .find(|op| op.operation == "file_read_metadata") .expect("file_read_metadata should be present"); assert_eq!(metadata_read.support_level, "cannot_be_precisely"); - - let network_all = caps.operations.iter() + + let network_all = caps + .operations + .iter() .find(|op| op.operation == "network_outbound_all") .expect("network_outbound_all should be present"); assert_eq!(network_all.support_level, "can_be_allowed"); - - let network_tcp = caps.operations.iter() + + let network_tcp = caps + .operations + .iter() .find(|op| op.operation == "network_outbound_tcp") .expect("network_outbound_tcp should be present"); assert_eq!(network_tcp.support_level, "cannot_be_precisely"); - - let system_info = caps.operations.iter() + + let system_info = caps + .operations + .iter() .find(|op| op.operation == "system_info_read") .expect("system_info_read should be present"); assert_eq!(system_info.support_level, "never"); @@ -64,27 +79,35 @@ fn test_linux_capabilities() { #[cfg(target_os = "macos")] fn test_macos_capabilities() { let caps = get_platform_capabilities(); - + assert_eq!(caps.os, "macos"); assert!(caps.sandboxing_supported); - + // Verify macOS-specific capabilities - let file_read = caps.operations.iter() + let file_read = caps + .operations + .iter() .find(|op| op.operation == "file_read_all") .expect("file_read_all should be present"); assert_eq!(file_read.support_level, "can_be_allowed"); - - let metadata_read = caps.operations.iter() + + let metadata_read = caps + .operations + .iter() .find(|op| op.operation == "file_read_metadata") .expect("file_read_metadata should be present"); assert_eq!(metadata_read.support_level, "can_be_allowed"); - - let network_tcp = caps.operations.iter() + + let network_tcp = caps + .operations + .iter() .find(|op| op.operation == "network_outbound_tcp") .expect("network_outbound_tcp should be present"); assert_eq!(network_tcp.support_level, "can_be_allowed"); - - let system_info = caps.operations.iter() + + let system_info = caps + .operations + .iter() .find(|op| op.operation == "system_info_read") .expect("system_info_read should be present"); assert_eq!(system_info.support_level, "can_be_allowed"); @@ -94,17 +117,21 @@ fn test_macos_capabilities() { #[cfg(target_os = "freebsd")] fn test_freebsd_capabilities() { let caps = get_platform_capabilities(); - + assert_eq!(caps.os, "freebsd"); assert!(caps.sandboxing_supported); - + // Verify FreeBSD-specific capabilities - let file_read = caps.operations.iter() + let file_read = caps + .operations + .iter() .find(|op| op.operation == "file_read_all") .expect("file_read_all should be present"); assert_eq!(file_read.support_level, "never"); - - let system_info = caps.operations.iter() + + let system_info = caps + .operations + .iter() .find(|op| op.operation == "system_info_read") .expect("system_info_read should be present"); assert_eq!(system_info.support_level, "always"); @@ -114,7 +141,7 @@ fn test_freebsd_capabilities() { #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "freebsd")))] fn test_unsupported_platform_capabilities() { let caps = get_platform_capabilities(); - + assert!(!caps.sandboxing_supported); assert_eq!(caps.operations.len(), 0); assert!(caps.notes.iter().any(|note| note.contains("not supported"))); @@ -123,12 +150,18 @@ fn test_unsupported_platform_capabilities() { #[test] fn test_all_operations_have_descriptions() { let caps = get_platform_capabilities(); - + for op in &caps.operations { - assert!(!op.description.is_empty(), - "Operation {} should have a description", op.operation); - assert!(!op.support_level.is_empty(), - "Operation {} should have a support level", op.operation); + assert!( + !op.description.is_empty(), + "Operation {} should have a description", + op.operation + ); + assert!( + !op.support_level.is_empty(), + "Operation {} should have a support level", + op.operation + ); } } @@ -136,7 +169,7 @@ fn test_all_operations_have_descriptions() { fn test_support_level_values() { let caps = get_platform_capabilities(); let valid_levels = ["never", "can_be_allowed", "cannot_be_precisely", "always"]; - + for op in &caps.operations { assert!( valid_levels.contains(&op.support_level.as_str()), @@ -145,4 +178,4 @@ fn test_support_level_values() { op.support_level ); } -} \ No newline at end of file +} diff --git a/src-tauri/tests/sandbox/unit/profile_builder.rs b/src-tauri/tests/sandbox/unit/profile_builder.rs index 0a7d940..1703eca 100644 --- a/src-tauri/tests/sandbox/unit/profile_builder.rs +++ b/src-tauri/tests/sandbox/unit/profile_builder.rs @@ -18,8 +18,7 @@ fn make_rule( pattern_value: pattern_value.to_string(), enabled: true, platform_support: platforms.map(|p| { - serde_json::to_string(&p.iter().map(|s| s.to_string()).collect::>()) - .unwrap() + serde_json::to_string(&p.iter().map(|s| s.to_string()).collect::>()).unwrap() }), created_at: String::new(), } @@ -29,34 +28,53 @@ fn make_rule( fn test_profile_builder_creation() { let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path.clone()); - - assert!(builder.is_ok(), "ProfileBuilder should be created successfully"); + + assert!( + builder.is_ok(), + "ProfileBuilder should be created successfully" + ); } #[test] fn test_empty_rules_creates_empty_profile() { let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path).unwrap(); - + let profile = builder.build_profile(vec![]); - assert!(profile.is_ok(), "Empty rules should create valid empty profile"); + assert!( + profile.is_ok(), + "Empty rules should create valid empty profile" + ); } #[test] fn test_file_read_rule_parsing() { let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path.clone()).unwrap(); - + let rules = vec![ - make_rule("file_read_all", "literal", "/usr/lib/test.so", Some(&["linux", "macos"])), - make_rule("file_read_all", "subpath", "/usr/lib", Some(&["linux", "macos"])), + make_rule( + "file_read_all", + "literal", + "/usr/lib/test.so", + Some(&["linux", "macos"]), + ), + make_rule( + "file_read_all", + "subpath", + "/usr/lib", + Some(&["linux", "macos"]), + ), ]; - + let _profile = builder.build_profile(rules); - + // Profile creation might fail on unsupported platforms, but parsing should work if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" { - assert!(_profile.is_ok(), "File read rules should be parsed on supported platforms"); + assert!( + _profile.is_ok(), + "File read rules should be parsed on supported platforms" + ); } } @@ -64,17 +82,25 @@ fn test_file_read_rule_parsing() { fn test_network_rule_parsing() { let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path).unwrap(); - + let rules = vec![ make_rule("network_outbound", "all", "", Some(&["linux", "macos"])), make_rule("network_outbound", "tcp", "8080", Some(&["macos"])), - make_rule("network_outbound", "local_socket", "/tmp/socket", Some(&["macos"])), + make_rule( + "network_outbound", + "local_socket", + "/tmp/socket", + Some(&["macos"]), + ), ]; - + let _profile = builder.build_profile(rules); - + if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" { - assert!(_profile.is_ok(), "Network rules should be parsed on supported platforms"); + assert!( + _profile.is_ok(), + "Network rules should be parsed on supported platforms" + ); } } @@ -82,15 +108,16 @@ fn test_network_rule_parsing() { fn test_system_info_rule_parsing() { let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path).unwrap(); - - let rules = vec![ - make_rule("system_info_read", "all", "", Some(&["macos"])), - ]; - + + let rules = vec![make_rule("system_info_read", "all", "", Some(&["macos"]))]; + let _profile = builder.build_profile(rules); - + if std::env::consts::OS == "macos" { - assert!(_profile.is_ok(), "System info rule should be parsed on macOS"); + assert!( + _profile.is_ok(), + "System info rule should be parsed on macOS" + ); } } @@ -98,12 +125,22 @@ fn test_system_info_rule_parsing() { fn test_template_variable_replacement() { let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path.clone()).unwrap(); - + let rules = vec![ - make_rule("file_read_all", "subpath", "{{PROJECT_PATH}}/src", Some(&["linux", "macos"])), - make_rule("file_read_all", "subpath", "{{HOME}}/.config", Some(&["linux", "macos"])), + make_rule( + "file_read_all", + "subpath", + "{{PROJECT_PATH}}/src", + Some(&["linux", "macos"]), + ), + make_rule( + "file_read_all", + "subpath", + "{{HOME}}/.config", + Some(&["linux", "macos"]), + ), ]; - + let _profile = builder.build_profile(rules); // We can't easily verify the exact paths without inspecting the Profile internals, // but this test ensures template replacement doesn't panic @@ -113,10 +150,15 @@ fn test_template_variable_replacement() { fn test_disabled_rules_are_ignored() { let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path).unwrap(); - - let mut rule = make_rule("file_read_all", "subpath", "/usr/lib", Some(&["linux", "macos"])); + + let mut rule = make_rule( + "file_read_all", + "subpath", + "/usr/lib", + Some(&["linux", "macos"]), + ); rule.enabled = false; - + let profile = builder.build_profile(vec![rule]); assert!(profile.is_ok(), "Disabled rules should be ignored"); } @@ -125,21 +167,30 @@ fn test_disabled_rules_are_ignored() { fn test_platform_filtering() { let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path).unwrap(); - + let current_os = std::env::consts::OS; - let other_os = if current_os == "linux" { "macos" } else { "linux" }; - + let other_os = if current_os == "linux" { + "macos" + } else { + "linux" + }; + let rules = vec![ // Rule for current platform make_rule("file_read_all", "subpath", "/test1", Some(&[current_os])), // Rule for other platform make_rule("file_read_all", "subpath", "/test2", Some(&[other_os])), // Rule for both platforms - make_rule("file_read_all", "subpath", "/test3", Some(&["linux", "macos"])), + make_rule( + "file_read_all", + "subpath", + "/test3", + Some(&["linux", "macos"]), + ), // Rule with no platform specification (should be included) make_rule("file_read_all", "subpath", "/test4", None), ]; - + let _profile = builder.build_profile(rules); // Rules for other platforms should be filtered out } @@ -148,11 +199,14 @@ fn test_platform_filtering() { fn test_invalid_operation_type() { let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path).unwrap(); - - let rules = vec![ - make_rule("invalid_operation", "subpath", "/test", Some(&["linux", "macos"])), - ]; - + + let rules = vec![make_rule( + "invalid_operation", + "subpath", + "/test", + Some(&["linux", "macos"]), + )]; + let _profile = builder.build_profile(rules); assert!(_profile.is_ok(), "Invalid operations should be skipped"); } @@ -161,11 +215,14 @@ fn test_invalid_operation_type() { fn test_invalid_pattern_type() { let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path).unwrap(); - - let rules = vec![ - make_rule("file_read_all", "invalid_pattern", "/test", Some(&["linux", "macos"])), - ]; - + + let rules = vec![make_rule( + "file_read_all", + "invalid_pattern", + "/test", + Some(&["linux", "macos"]), + )]; + let _profile = builder.build_profile(rules); // Should either skip the rule or fail gracefully } @@ -174,11 +231,14 @@ fn test_invalid_pattern_type() { fn test_invalid_tcp_port() { let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path).unwrap(); - - let rules = vec![ - make_rule("network_outbound", "tcp", "not_a_number", Some(&["macos"])), - ]; - + + let rules = vec![make_rule( + "network_outbound", + "tcp", + "not_a_number", + Some(&["macos"]), + )]; + let _profile = builder.build_profile(rules); // Should handle invalid port gracefully } @@ -188,13 +248,12 @@ fn test_invalid_tcp_port() { #[test_case("network_outbound", "all", "" ; "network all operation")] #[test_case("system_info_read", "all", "" ; "system info operation")] fn test_operation_support_level(operation_type: &str, pattern_type: &str, pattern_value: &str) { - let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path).unwrap(); - + let rule = make_rule(operation_type, pattern_type, pattern_value, None); let rules = vec![rule]; - + match builder.build_profile(rules) { Ok(_) => { // Profile created successfully - operation is supported @@ -211,27 +270,43 @@ fn test_operation_support_level(operation_type: &str, pattern_type: &str, patter fn test_complex_profile_with_multiple_rules() { let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path.clone()).unwrap(); - + let rules = vec![ // File operations - make_rule("file_read_all", "subpath", "{{PROJECT_PATH}}", Some(&["linux", "macos"])), - make_rule("file_read_all", "subpath", "/usr/lib", Some(&["linux", "macos"])), - make_rule("file_read_all", "literal", "/etc/hosts", Some(&["linux", "macos"])), + make_rule( + "file_read_all", + "subpath", + "{{PROJECT_PATH}}", + Some(&["linux", "macos"]), + ), + make_rule( + "file_read_all", + "subpath", + "/usr/lib", + Some(&["linux", "macos"]), + ), + make_rule( + "file_read_all", + "literal", + "/etc/hosts", + Some(&["linux", "macos"]), + ), make_rule("file_read_metadata", "subpath", "/", Some(&["macos"])), - // Network operations make_rule("network_outbound", "all", "", Some(&["linux", "macos"])), make_rule("network_outbound", "tcp", "443", Some(&["macos"])), make_rule("network_outbound", "tcp", "80", Some(&["macos"])), - // System info make_rule("system_info_read", "all", "", Some(&["macos"])), ]; - + let _profile = builder.build_profile(rules); - + if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" { - assert!(_profile.is_ok(), "Complex profile should be created on supported platforms"); + assert!( + _profile.is_ok(), + "Complex profile should be created on supported platforms" + ); } } @@ -239,14 +314,24 @@ fn test_complex_profile_with_multiple_rules() { fn test_rule_order_preservation() { let project_path = PathBuf::from("/test/project"); let builder = ProfileBuilder::new(project_path).unwrap(); - + // Create rules with specific order let rules = vec![ - make_rule("file_read_all", "subpath", "/first", Some(&["linux", "macos"])), + make_rule( + "file_read_all", + "subpath", + "/first", + Some(&["linux", "macos"]), + ), make_rule("network_outbound", "all", "", Some(&["linux", "macos"])), - make_rule("file_read_all", "subpath", "/second", Some(&["linux", "macos"])), + make_rule( + "file_read_all", + "subpath", + "/second", + Some(&["linux", "macos"]), + ), ]; - + let _profile = builder.build_profile(rules); // Order should be preserved in the resulting profile -} \ No newline at end of file +} diff --git a/src-tauri/tests/sandbox_tests.rs b/src-tauri/tests/sandbox_tests.rs index 79991e2..4997805 100644 --- a/src-tauri/tests/sandbox_tests.rs +++ b/src-tauri/tests/sandbox_tests.rs @@ -1,5 +1,5 @@ //! Main entry point for sandbox tests -//! +//! //! This file integrates all the sandbox test modules and provides //! a central location for running the comprehensive test suite. #![allow(dead_code)] @@ -8,4 +8,4 @@ mod sandbox; // Re-export test modules to make them discoverable -pub use sandbox::*; \ No newline at end of file +pub use sandbox::*;