style: apply cargo fmt across entire Rust codebase

- Remove Rust formatting check from CI workflow since formatting is now applied
- Standardize import ordering and organization throughout codebase
- Fix indentation, spacing, and line breaks for consistency
- Clean up trailing whitespace and formatting inconsistencies
- Apply rustfmt to all Rust source files including checkpoint, sandbox, commands, and test modules

This establishes a consistent code style baseline for the project.
This commit is contained in:
Mufeed VH
2025-06-25 03:45:59 +05:30
parent bb48a32784
commit bcffce0a08
41 changed files with 3617 additions and 2662 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,12 @@
use tauri::AppHandle;
use anyhow::{Context, Result};
use dirs;
use log::{error, info};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::process::Command;
use log::{info, error};
use dirs;
use tauri::AppHandle;
/// Helper function to create a std::process::Command with proper environment variables
/// This ensures commands like Claude can find Node.js and other dependencies
@@ -17,8 +17,7 @@ fn create_command_with_env(program: &str) -> Command {
/// Finds the full path to the claude binary
/// This is necessary because macOS apps have a limited PATH environment
fn find_claude_binary(app_handle: &AppHandle) -> Result<String> {
crate::claude_binary::find_claude_binary(app_handle)
.map_err(|e| anyhow::anyhow!(e))
crate::claude_binary::find_claude_binary(app_handle).map_err(|e| anyhow::anyhow!(e))
}
/// Represents an MCP server configuration
@@ -99,17 +98,16 @@ pub struct ImportServerResult {
/// Executes a claude mcp command
fn execute_claude_mcp_command(app_handle: &AppHandle, args: Vec<&str>) -> Result<String> {
info!("Executing claude mcp command with args: {:?}", args);
let claude_path = find_claude_binary(app_handle)?;
let mut cmd = create_command_with_env(&claude_path);
cmd.arg("mcp");
for arg in args {
cmd.arg(arg);
}
let output = cmd.output()
.context("Failed to execute claude command")?;
let output = cmd.output().context("Failed to execute claude command")?;
if output.status.success() {
Ok(String::from_utf8_lossy(&output.stdout).to_string())
} else {
@@ -131,33 +129,34 @@ pub async fn mcp_add(
scope: String,
) -> Result<AddServerResult, String> {
info!("Adding MCP server: {} with transport: {}", name, transport);
// Prepare owned strings for environment variables
let env_args: Vec<String> = env.iter()
let env_args: Vec<String> = env
.iter()
.map(|(key, value)| format!("{}={}", key, value))
.collect();
let mut cmd_args = vec!["add"];
// Add scope flag
cmd_args.push("-s");
cmd_args.push(&scope);
// Add transport flag for SSE
if transport == "sse" {
cmd_args.push("--transport");
cmd_args.push("sse");
}
// Add environment variables
for (i, _) in env.iter().enumerate() {
cmd_args.push("-e");
cmd_args.push(&env_args[i]);
}
// Add name
cmd_args.push(&name);
// Add command/URL based on transport
if transport == "stdio" {
if let Some(cmd) = &command {
@@ -188,7 +187,7 @@ pub async fn mcp_add(
});
}
}
match execute_claude_mcp_command(&app, cmd_args) {
Ok(output) => {
info!("Successfully added MCP server: {}", name);
@@ -213,19 +212,19 @@ pub async fn mcp_add(
#[tauri::command]
pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
info!("Listing MCP servers");
match execute_claude_mcp_command(&app, vec!["list"]) {
Ok(output) => {
info!("Raw output from 'claude mcp list': {:?}", output);
let trimmed = output.trim();
info!("Trimmed output: {:?}", trimmed);
// Check if no servers are configured
if trimmed.contains("No MCP servers configured") || trimmed.is_empty() {
info!("No servers found - empty or 'No MCP servers' message");
return Ok(vec![]);
}
// Parse the text output, handling multi-line commands
let mut servers = Vec::new();
let lines: Vec<&str> = trimmed.lines().collect();
@@ -233,13 +232,13 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
for (idx, line) in lines.iter().enumerate() {
info!("Line {}: {:?}", idx, line);
}
let mut i = 0;
while i < lines.len() {
let line = lines[i];
info!("Processing line {}: {:?}", i, line);
// Check if this line starts a new server entry
if let Some(colon_pos) = line.find(':') {
info!("Found colon at position {} in line: {:?}", colon_pos, line);
@@ -247,26 +246,31 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
// Server names typically don't contain '/' or '\'
let potential_name = line[..colon_pos].trim();
info!("Potential server name: {:?}", potential_name);
if !potential_name.contains('/') && !potential_name.contains('\\') {
info!("Valid server name detected: {:?}", potential_name);
let name = potential_name.to_string();
let mut command_parts = vec![line[colon_pos + 1..].trim().to_string()];
info!("Initial command part: {:?}", command_parts[0]);
// Check if command continues on next lines
i += 1;
while i < lines.len() {
let next_line = lines[i];
info!("Checking next line {} for continuation: {:?}", i, next_line);
// If the next line starts with a server name pattern, break
if next_line.contains(':') {
let potential_next_name = next_line.split(':').next().unwrap_or("").trim();
info!("Found colon in next line, potential name: {:?}", potential_next_name);
if !potential_next_name.is_empty() &&
!potential_next_name.contains('/') &&
!potential_next_name.contains('\\') {
let potential_next_name =
next_line.split(':').next().unwrap_or("").trim();
info!(
"Found colon in next line, potential name: {:?}",
potential_next_name
);
if !potential_next_name.is_empty()
&& !potential_next_name.contains('/')
&& !potential_next_name.contains('\\')
{
info!("Next line is a new server, breaking");
break;
}
@@ -276,11 +280,11 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
command_parts.push(next_line.trim().to_string());
i += 1;
}
// Join all command parts
let full_command = command_parts.join(" ");
info!("Full command for server '{}': {:?}", name, full_command);
// For now, we'll create a basic server entry
servers.push(MCPServer {
name: name.clone(),
@@ -298,7 +302,7 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
},
});
info!("Added server: {:?}", name);
continue;
} else {
info!("Skipping line - name contains path separators");
@@ -306,13 +310,16 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
} else {
info!("No colon found in line {}", i);
}
i += 1;
}
info!("Found {} MCP servers total", servers.len());
for (idx, server) in servers.iter().enumerate() {
info!("Server {}: name='{}', command={:?}", idx, server.name, server.command);
info!(
"Server {}: name='{}', command={:?}",
idx, server.name, server.command
);
}
Ok(servers)
}
@@ -327,7 +334,7 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
#[tauri::command]
pub async fn mcp_get(app: AppHandle, name: String) -> Result<MCPServer, String> {
info!("Getting MCP server details for: {}", name);
match execute_claude_mcp_command(&app, vec!["get", &name]) {
Ok(output) => {
// Parse the structured text output
@@ -337,17 +344,19 @@ pub async fn mcp_get(app: AppHandle, name: String) -> Result<MCPServer, String>
let mut args = vec![];
let env = HashMap::new();
let mut url = None;
for line in output.lines() {
let line = line.trim();
if line.starts_with("Scope:") {
let scope_part = line.replace("Scope:", "").trim().to_string();
if scope_part.to_lowercase().contains("local") {
scope = "local".to_string();
} else if scope_part.to_lowercase().contains("project") {
scope = "project".to_string();
} else if scope_part.to_lowercase().contains("user") || scope_part.to_lowercase().contains("global") {
} else if scope_part.to_lowercase().contains("user")
|| scope_part.to_lowercase().contains("global")
{
scope = "user".to_string();
}
} else if line.starts_with("Type:") {
@@ -366,7 +375,7 @@ pub async fn mcp_get(app: AppHandle, name: String) -> Result<MCPServer, String>
// For now, we'll leave it empty
}
}
Ok(MCPServer {
name,
transport,
@@ -394,7 +403,7 @@ pub async fn mcp_get(app: AppHandle, name: String) -> Result<MCPServer, String>
#[tauri::command]
pub async fn mcp_remove(app: AppHandle, name: String) -> Result<String, String> {
info!("Removing MCP server: {}", name);
match execute_claude_mcp_command(&app, vec!["remove", &name]) {
Ok(output) => {
info!("Successfully removed MCP server: {}", name);
@@ -409,17 +418,25 @@ pub async fn mcp_remove(app: AppHandle, name: String) -> Result<String, String>
/// Adds an MCP server from JSON configuration
#[tauri::command]
pub async fn mcp_add_json(app: AppHandle, name: String, json_config: String, scope: String) -> Result<AddServerResult, String> {
info!("Adding MCP server from JSON: {} with scope: {}", name, scope);
pub async fn mcp_add_json(
app: AppHandle,
name: String,
json_config: String,
scope: String,
) -> Result<AddServerResult, String> {
info!(
"Adding MCP server from JSON: {} with scope: {}",
name, scope
);
// Build command args
let mut cmd_args = vec!["add-json", &name, &json_config];
// Add scope flag
let scope_flag = "-s";
cmd_args.push(scope_flag);
cmd_args.push(&scope);
match execute_claude_mcp_command(&app, cmd_args) {
Ok(output) => {
info!("Successfully added MCP server from JSON: {}", name);
@@ -442,9 +459,15 @@ pub async fn mcp_add_json(app: AppHandle, name: String, json_config: String, sco
/// Imports MCP servers from Claude Desktop
#[tauri::command]
pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Result<ImportResult, String> {
info!("Importing MCP servers from Claude Desktop with scope: {}", scope);
pub async fn mcp_add_from_claude_desktop(
app: AppHandle,
scope: String,
) -> Result<ImportResult, String> {
info!(
"Importing MCP servers from Claude Desktop with scope: {}",
scope
);
// Get Claude Desktop config path based on platform
let config_path = if cfg!(target_os = "macos") {
dirs::home_dir()
@@ -460,43 +483,55 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul
.join("Claude")
.join("claude_desktop_config.json")
} else {
return Err("Import from Claude Desktop is only supported on macOS and Linux/WSL".to_string());
return Err(
"Import from Claude Desktop is only supported on macOS and Linux/WSL".to_string(),
);
};
// Check if config file exists
if !config_path.exists() {
return Err("Claude Desktop configuration not found. Make sure Claude Desktop is installed.".to_string());
return Err(
"Claude Desktop configuration not found. Make sure Claude Desktop is installed."
.to_string(),
);
}
// Read and parse the config file
let config_content = fs::read_to_string(&config_path)
.map_err(|e| format!("Failed to read Claude Desktop config: {}", e))?;
let config: serde_json::Value = serde_json::from_str(&config_content)
.map_err(|e| format!("Failed to parse Claude Desktop config: {}", e))?;
// Extract MCP servers
let mcp_servers = config.get("mcpServers")
let mcp_servers = config
.get("mcpServers")
.and_then(|v| v.as_object())
.ok_or_else(|| "No MCP servers found in Claude Desktop config".to_string())?;
let mut imported_count = 0;
let mut failed_count = 0;
let mut server_results = Vec::new();
// Import each server using add-json
for (name, server_config) in mcp_servers {
info!("Importing server: {}", name);
// Convert Claude Desktop format to add-json format
let mut json_config = serde_json::Map::new();
// All Claude Desktop servers are stdio type
json_config.insert("type".to_string(), serde_json::Value::String("stdio".to_string()));
json_config.insert(
"type".to_string(),
serde_json::Value::String("stdio".to_string()),
);
// Add command
if let Some(command) = server_config.get("command").and_then(|v| v.as_str()) {
json_config.insert("command".to_string(), serde_json::Value::String(command.to_string()));
json_config.insert(
"command".to_string(),
serde_json::Value::String(command.to_string()),
);
} else {
failed_count += 1;
server_results.push(ImportServerResult {
@@ -506,25 +541,28 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul
});
continue;
}
// Add args if present
if let Some(args) = server_config.get("args").and_then(|v| v.as_array()) {
json_config.insert("args".to_string(), args.clone().into());
} else {
json_config.insert("args".to_string(), serde_json::Value::Array(vec![]));
}
// Add env if present
if let Some(env) = server_config.get("env").and_then(|v| v.as_object()) {
json_config.insert("env".to_string(), env.clone().into());
} else {
json_config.insert("env".to_string(), serde_json::Value::Object(serde_json::Map::new()));
json_config.insert(
"env".to_string(),
serde_json::Value::Object(serde_json::Map::new()),
);
}
// Convert to JSON string
let json_str = serde_json::to_string(&json_config)
.map_err(|e| format!("Failed to serialize config for {}: {}", name, e))?;
// Call add-json command
match mcp_add_json(app.clone(), name.clone(), json_str, scope.clone()).await {
Ok(result) => {
@@ -559,9 +597,12 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul
}
}
}
info!("Import complete: {} imported, {} failed", imported_count, failed_count);
info!(
"Import complete: {} imported, {} failed",
imported_count, failed_count
);
Ok(ImportResult {
imported_count,
failed_count,
@@ -573,7 +614,7 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul
#[tauri::command]
pub async fn mcp_serve(app: AppHandle) -> Result<String, String> {
info!("Starting Claude Code as MCP server");
// Start the server in a separate process
let claude_path = match find_claude_binary(&app) {
Ok(path) => path,
@@ -582,10 +623,10 @@ pub async fn mcp_serve(app: AppHandle) -> Result<String, String> {
return Err(e.to_string());
}
};
let mut cmd = create_command_with_env(&claude_path);
cmd.arg("mcp").arg("serve");
match cmd.spawn() {
Ok(_) => {
info!("Successfully started Claude Code MCP server");
@@ -602,7 +643,7 @@ pub async fn mcp_serve(app: AppHandle) -> Result<String, String> {
#[tauri::command]
pub async fn mcp_test_connection(app: AppHandle, name: String) -> Result<String, String> {
info!("Testing connection to MCP server: {}", name);
// For now, we'll use the get command to test if the server exists
match execute_claude_mcp_command(&app, vec!["get", &name]) {
Ok(_) => Ok(format!("Connection to {} successful", name)),
@@ -614,7 +655,7 @@ pub async fn mcp_test_connection(app: AppHandle, name: String) -> Result<String,
#[tauri::command]
pub async fn mcp_reset_project_choices(app: AppHandle) -> Result<String, String> {
info!("Resetting MCP project choices");
match execute_claude_mcp_command(&app, vec!["reset-project-choices"]) {
Ok(output) => {
info!("Successfully reset MCP project choices");
@@ -631,7 +672,7 @@ pub async fn mcp_reset_project_choices(app: AppHandle) -> Result<String, String>
#[tauri::command]
pub async fn mcp_get_server_status() -> Result<HashMap<String, ServerStatus>, String> {
info!("Getting MCP server status");
// TODO: Implement actual status checking
// For now, return empty status
Ok(HashMap::new())
@@ -641,25 +682,23 @@ pub async fn mcp_get_server_status() -> Result<HashMap<String, ServerStatus>, St
#[tauri::command]
pub async fn mcp_read_project_config(project_path: String) -> Result<MCPProjectConfig, String> {
info!("Reading .mcp.json from project: {}", project_path);
let mcp_json_path = PathBuf::from(&project_path).join(".mcp.json");
if !mcp_json_path.exists() {
return Ok(MCPProjectConfig {
mcp_servers: HashMap::new(),
});
}
match fs::read_to_string(&mcp_json_path) {
Ok(content) => {
match serde_json::from_str::<MCPProjectConfig>(&content) {
Ok(config) => Ok(config),
Err(e) => {
error!("Failed to parse .mcp.json: {}", e);
Err(format!("Failed to parse .mcp.json: {}", e))
}
Ok(content) => match serde_json::from_str::<MCPProjectConfig>(&content) {
Ok(config) => Ok(config),
Err(e) => {
error!("Failed to parse .mcp.json: {}", e);
Err(format!("Failed to parse .mcp.json: {}", e))
}
}
},
Err(e) => {
error!("Failed to read .mcp.json: {}", e);
Err(format!("Failed to read .mcp.json: {}", e))
@@ -674,14 +713,14 @@ pub async fn mcp_save_project_config(
config: MCPProjectConfig,
) -> Result<String, String> {
info!("Saving .mcp.json to project: {}", project_path);
let mcp_json_path = PathBuf::from(&project_path).join(".mcp.json");
let json_content = serde_json::to_string_pretty(&config)
.map_err(|e| format!("Failed to serialize config: {}", e))?;
fs::write(&mcp_json_path, json_content)
.map_err(|e| format!("Failed to write .mcp.json: {}", e))?;
Ok("Project MCP configuration saved".to_string())
}
}

View File

@@ -1,6 +1,6 @@
pub mod claude;
pub mod agents;
pub mod sandbox;
pub mod usage;
pub mod claude;
pub mod mcp;
pub mod screenshot;
pub mod sandbox;
pub mod screenshot;
pub mod usage;

View File

@@ -52,11 +52,11 @@ pub struct ImportResult {
#[tauri::command]
pub async fn list_sandbox_profiles(db: State<'_, AgentDb>) -> Result<Vec<SandboxProfile>, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
let mut stmt = conn
.prepare("SELECT id, name, description, is_active, is_default, created_at, updated_at FROM sandbox_profiles ORDER BY name")
.map_err(|e| e.to_string())?;
let profiles = stmt
.query_map([], |row| {
Ok(SandboxProfile {
@@ -72,7 +72,7 @@ pub async fn list_sandbox_profiles(db: State<'_, AgentDb>) -> Result<Vec<Sandbox
.map_err(|e| e.to_string())?
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?;
Ok(profiles)
}
@@ -84,15 +84,15 @@ pub async fn create_sandbox_profile(
description: Option<String>,
) -> Result<SandboxProfile, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
conn.execute(
"INSERT INTO sandbox_profiles (name, description) VALUES (?1, ?2)",
params![name, description],
)
.map_err(|e| e.to_string())?;
let id = conn.last_insert_rowid();
// Fetch the created profile
let profile = conn
.query_row(
@@ -111,7 +111,7 @@ pub async fn create_sandbox_profile(
},
)
.map_err(|e| e.to_string())?;
Ok(profile)
}
@@ -126,7 +126,7 @@ pub async fn update_sandbox_profile(
is_default: bool,
) -> Result<SandboxProfile, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
// If setting as default, unset other defaults
if is_default {
conn.execute(
@@ -135,13 +135,13 @@ pub async fn update_sandbox_profile(
)
.map_err(|e| e.to_string())?;
}
conn.execute(
"UPDATE sandbox_profiles SET name = ?1, description = ?2, is_active = ?3, is_default = ?4 WHERE id = ?5",
params![name, description, is_active, is_default, id],
)
.map_err(|e| e.to_string())?;
// Fetch the updated profile
let profile = conn
.query_row(
@@ -160,7 +160,7 @@ pub async fn update_sandbox_profile(
},
)
.map_err(|e| e.to_string())?;
Ok(profile)
}
@@ -168,7 +168,7 @@ pub async fn update_sandbox_profile(
#[tauri::command]
pub async fn delete_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result<(), String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
// Check if it's the default profile
let is_default: bool = conn
.query_row(
@@ -177,22 +177,25 @@ pub async fn delete_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result<(
|row| row.get(0),
)
.map_err(|e| e.to_string())?;
if is_default {
return Err("Cannot delete the default profile".to_string());
}
conn.execute("DELETE FROM sandbox_profiles WHERE id = ?1", params![id])
.map_err(|e| e.to_string())?;
Ok(())
}
/// Get a single sandbox profile by ID
#[tauri::command]
pub async fn get_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result<SandboxProfile, String> {
pub async fn get_sandbox_profile(
db: State<'_, AgentDb>,
id: i64,
) -> Result<SandboxProfile, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
let profile = conn
.query_row(
"SELECT id, name, description, is_active, is_default, created_at, updated_at FROM sandbox_profiles WHERE id = ?1",
@@ -210,7 +213,7 @@ pub async fn get_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result<Sand
},
)
.map_err(|e| e.to_string())?;
Ok(profile)
}
@@ -221,11 +224,11 @@ pub async fn list_sandbox_rules(
profile_id: i64,
) -> Result<Vec<SandboxRule>, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
let mut stmt = conn
.prepare("SELECT id, profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support, created_at FROM sandbox_rules WHERE profile_id = ?1 ORDER BY operation_type, pattern_value")
.map_err(|e| e.to_string())?;
let rules = stmt
.query_map(params![profile_id], |row| {
Ok(SandboxRule {
@@ -242,7 +245,7 @@ pub async fn list_sandbox_rules(
.map_err(|e| e.to_string())?
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?;
Ok(rules)
}
@@ -258,18 +261,18 @@ pub async fn create_sandbox_rule(
platform_support: Option<String>,
) -> Result<SandboxRule, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
// Validate rule doesn't conflict
// TODO: Add more validation logic here
conn.execute(
"INSERT INTO sandbox_rules (profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
params![profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support],
)
.map_err(|e| e.to_string())?;
let id = conn.last_insert_rowid();
// Fetch the created rule
let rule = conn
.query_row(
@@ -289,7 +292,7 @@ pub async fn create_sandbox_rule(
},
)
.map_err(|e| e.to_string())?;
Ok(rule)
}
@@ -305,13 +308,13 @@ pub async fn update_sandbox_rule(
platform_support: Option<String>,
) -> Result<SandboxRule, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
conn.execute(
"UPDATE sandbox_rules SET operation_type = ?1, pattern_type = ?2, pattern_value = ?3, enabled = ?4, platform_support = ?5 WHERE id = ?6",
params![operation_type, pattern_type, pattern_value, enabled, platform_support, id],
)
.map_err(|e| e.to_string())?;
// Fetch the updated rule
let rule = conn
.query_row(
@@ -331,7 +334,7 @@ pub async fn update_sandbox_rule(
},
)
.map_err(|e| e.to_string())?;
Ok(rule)
}
@@ -339,10 +342,10 @@ pub async fn update_sandbox_rule(
#[tauri::command]
pub async fn delete_sandbox_rule(db: State<'_, AgentDb>, id: i64) -> Result<(), String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
conn.execute("DELETE FROM sandbox_rules WHERE id = ?1", params![id])
.map_err(|e| e.to_string())?;
Ok(())
}
@@ -359,38 +362,38 @@ pub async fn test_sandbox_profile(
profile_id: i64,
) -> Result<String, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
// Load the profile and rules
let profile = crate::sandbox::profile::load_profile(&conn, profile_id)
.map_err(|e| format!("Failed to load profile: {}", e))?;
if !profile.is_active {
return Ok(format!(
"Profile '{}' is currently inactive. Activate it to use with agents.",
profile.name
));
}
let rules = crate::sandbox::profile::load_profile_rules(&conn, profile_id)
.map_err(|e| format!("Failed to load profile rules: {}", e))?;
if rules.is_empty() {
return Ok(format!(
"Profile '{}' has no rules configured. Add rules to define sandbox permissions.",
profile.name
));
}
// Try to build the gaol profile
let test_path = std::env::current_dir()
.unwrap_or_else(|_| std::path::PathBuf::from("/tmp"));
let test_path = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("/tmp"));
let builder = crate::sandbox::profile::ProfileBuilder::new(test_path.clone())
.map_err(|e| format!("Failed to create profile builder: {}", e))?;
let build_result = builder.build_profile_with_serialization(rules.clone())
let build_result = builder
.build_profile_with_serialization(rules.clone())
.map_err(|e| format!("Failed to build sandbox profile: {}", e))?;
// Check platform support
let platform_caps = crate::sandbox::platform::get_platform_capabilities();
if !platform_caps.sandboxing_supported {
@@ -401,27 +404,23 @@ pub async fn test_sandbox_profile(
platform_caps.os
));
}
// Try to execute a simple command in the sandbox
let executor = crate::sandbox::executor::SandboxExecutor::new_with_serialization(
build_result.profile,
build_result.profile,
test_path.clone(),
build_result.serialized
build_result.serialized,
);
// Use a simple echo command for testing
let test_command = if cfg!(windows) {
"cmd"
} else {
"echo"
};
let test_command = if cfg!(windows) { "cmd" } else { "echo" };
let test_args = if cfg!(windows) {
vec!["/C", "echo", "sandbox test successful"]
} else {
vec!["sandbox test successful"]
};
match executor.execute_sandboxed_spawn(test_command, &test_args, &test_path) {
Ok(mut child) => {
// Wait for the process to complete with a timeout
@@ -452,19 +451,17 @@ pub async fn test_sandbox_profile(
))
}
}
Err(e) => {
Ok(format!(
"⚠️ Profile '{}' validated with warnings.\n\n\
Err(e) => Ok(format!(
"⚠️ Profile '{}' validated with warnings.\n\n\
{} rules loaded and validated\n\
• Sandbox activation: Partial\n\
• Test process: Could not get exit status ({})\n\
• Platform: {}",
profile.name,
rules.len(),
e,
platform_caps.os
))
}
profile.name,
rules.len(),
e,
platform_caps.os
)),
}
}
Err(e) => {
@@ -509,176 +506,200 @@ pub async fn list_sandbox_violations(
limit: Option<i64>,
) -> Result<Vec<SandboxViolation>, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
// Build dynamic query
let mut query = String::from(
"SELECT id, profile_id, agent_id, agent_run_id, operation_type, pattern_value, process_name, pid, denied_at
FROM sandbox_violations WHERE 1=1"
);
let mut param_idx = 1;
if profile_id.is_some() {
query.push_str(&format!(" AND profile_id = ?{}", param_idx));
param_idx += 1;
}
if agent_id.is_some() {
query.push_str(&format!(" AND agent_id = ?{}", param_idx));
param_idx += 1;
}
query.push_str(" ORDER BY denied_at DESC");
if limit.is_some() {
query.push_str(&format!(" LIMIT ?{}", param_idx));
}
// Execute query based on parameters
let violations: Vec<SandboxViolation> = if let Some(pid) = profile_id {
if let Some(aid) = agent_id {
if let Some(lim) = limit {
// All three parameters
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![pid, aid, lim], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
let rows = stmt
.query_map(params![pid, aid, lim], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
})
})
}).map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
.map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
} else {
// profile_id and agent_id only
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![pid, aid], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
let rows = stmt
.query_map(params![pid, aid], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
})
})
}).map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
.map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
}
} else if let Some(lim) = limit {
// profile_id and limit only
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![pid, lim], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
let rows = stmt
.query_map(params![pid, lim], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
})
})
}).map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
.map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
} else {
// profile_id only
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![pid], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
let rows = stmt
.query_map(params![pid], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
})
})
}).map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
.map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
}
} else if let Some(aid) = agent_id {
if let Some(lim) = limit {
// agent_id and limit only
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![aid, lim], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
let rows = stmt
.query_map(params![aid, lim], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
})
})
}).map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
.map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
} else {
// agent_id only
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![aid], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
let rows = stmt
.query_map(params![aid], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
})
})
}).map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
.map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
}
} else if let Some(lim) = limit {
// limit only
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map(params![lim], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
let rows = stmt
.query_map(params![lim], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
})
})
}).map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
.map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
} else {
// No parameters
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
let rows = stmt.query_map([], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
let rows = stmt
.query_map([], |row| {
Ok(SandboxViolation {
id: Some(row.get(0)?),
profile_id: row.get(1)?,
agent_id: row.get(2)?,
agent_run_id: row.get(3)?,
operation_type: row.get(4)?,
pattern_value: row.get(5)?,
process_name: row.get(6)?,
pid: row.get(7)?,
denied_at: row.get(8)?,
})
})
}).map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
.map_err(|e| e.to_string())?;
rows.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?
};
Ok(violations)
}
@@ -695,14 +716,14 @@ pub async fn log_sandbox_violation(
pid: Option<i32>,
) -> Result<(), String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
conn.execute(
"INSERT INTO sandbox_violations (profile_id, agent_id, agent_run_id, operation_type, pattern_value, process_name, pid)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
params![profile_id, agent_id, agent_run_id, operation_type, pattern_value, process_name, pid],
)
.map_err(|e| e.to_string())?;
Ok(())
}
@@ -713,7 +734,7 @@ pub async fn clear_sandbox_violations(
older_than_days: Option<i64>,
) -> Result<i64, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
let query = if let Some(days) = older_than_days {
format!(
"DELETE FROM sandbox_violations WHERE denied_at < datetime('now', '-{} days')",
@@ -722,10 +743,9 @@ pub async fn clear_sandbox_violations(
} else {
"DELETE FROM sandbox_violations".to_string()
};
let deleted = conn.execute(&query, [])
.map_err(|e| e.to_string())?;
let deleted = conn.execute(&query, []).map_err(|e| e.to_string())?;
Ok(deleted as i64)
}
@@ -735,28 +755,30 @@ pub async fn get_sandbox_violation_stats(
db: State<'_, AgentDb>,
) -> Result<serde_json::Value, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
// Get total violations
let total: i64 = conn
.query_row("SELECT COUNT(*) FROM sandbox_violations", [], |row| row.get(0))
.query_row("SELECT COUNT(*) FROM sandbox_violations", [], |row| {
row.get(0)
})
.map_err(|e| e.to_string())?;
// Get violations by operation type
let mut stmt = conn
.prepare(
"SELECT operation_type, COUNT(*) as count
FROM sandbox_violations
GROUP BY operation_type
ORDER BY count DESC"
ORDER BY count DESC",
)
.map_err(|e| e.to_string())?;
let by_operation: Vec<(String, i64)> = stmt
.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))
.map_err(|e| e.to_string())?
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?;
// Get recent violations count (last 24 hours)
let recent: i64 = conn
.query_row(
@@ -765,7 +787,7 @@ pub async fn get_sandbox_violation_stats(
|row| row.get(0),
)
.map_err(|e| e.to_string())?;
Ok(serde_json::json!({
"total": total,
"recent_24h": recent,
@@ -789,10 +811,10 @@ pub async fn export_sandbox_profile(
let conn = db.0.lock().map_err(|e| e.to_string())?;
crate::sandbox::profile::load_profile(&conn, profile_id).map_err(|e| e.to_string())?
};
// Get the rules
let rules = list_sandbox_rules(db.clone(), profile_id).await?;
Ok(SandboxProfileExport {
version: 1,
exported_at: chrono::Utc::now().to_rfc3339(),
@@ -808,17 +830,14 @@ pub async fn export_all_sandbox_profiles(
) -> Result<SandboxProfileExport, String> {
let profiles = list_sandbox_profiles(db.clone()).await?;
let mut profile_exports = Vec::new();
for profile in profiles {
if let Some(id) = profile.id {
let rules = list_sandbox_rules(db.clone(), id).await?;
profile_exports.push(SandboxProfileWithRules {
profile,
rules,
});
profile_exports.push(SandboxProfileWithRules { profile, rules });
}
}
Ok(SandboxProfileExport {
version: 1,
exported_at: chrono::Utc::now().to_rfc3339(),
@@ -834,16 +853,19 @@ pub async fn import_sandbox_profiles(
export_data: SandboxProfileExport,
) -> Result<Vec<ImportResult>, String> {
let mut results = Vec::new();
// Validate version
if export_data.version != 1 {
return Err(format!("Unsupported export version: {}", export_data.version));
return Err(format!(
"Unsupported export version: {}",
export_data.version
));
}
for profile_export in export_data.profiles {
let mut profile = profile_export.profile;
let original_name = profile.name.clone();
// Check for name conflicts
let existing: Result<i64, _> = {
let conn = db.0.lock().map_err(|e| e.to_string())?;
@@ -853,29 +875,31 @@ pub async fn import_sandbox_profiles(
|row| row.get(0),
)
};
let (imported, new_name) = match existing {
Ok(_) => {
// Name conflict - append timestamp
let new_name = format!("{} (imported {})", profile.name, chrono::Utc::now().format("%Y-%m-%d %H:%M"));
let new_name = format!(
"{} (imported {})",
profile.name,
chrono::Utc::now().format("%Y-%m-%d %H:%M")
);
profile.name = new_name.clone();
(true, Some(new_name))
}
Err(_) => (true, None),
};
if imported {
// Reset profile fields for new insert
profile.id = None;
profile.is_default = false; // Never import as default
// Create the profile
let created_profile = create_sandbox_profile(
db.clone(),
profile.name.clone(),
profile.description,
).await?;
let created_profile =
create_sandbox_profile(db.clone(), profile.name.clone(), profile.description)
.await?;
if let Some(new_id) = created_profile.id {
// Import rules
for rule in profile_export.rules {
@@ -889,10 +913,11 @@ pub async fn import_sandbox_profiles(
rule.pattern_value,
rule.enabled,
rule.platform_support,
).await;
)
.await;
}
}
// Update profile status if needed
if profile.is_active {
let _ = update_sandbox_profile(
@@ -902,18 +927,21 @@ pub async fn import_sandbox_profiles(
created_profile.description,
profile.is_active,
false, // Never set as default on import
).await;
)
.await;
}
}
results.push(ImportResult {
profile_name: original_name,
imported: true,
reason: new_name.as_ref().map(|_| "Name conflict resolved".to_string()),
reason: new_name
.as_ref()
.map(|_| "Name conflict resolved".to_string()),
new_name,
});
}
}
Ok(results)
}
}

View File

@@ -1,20 +1,20 @@
use headless_chrome::{Browser, LaunchOptions};
use headless_chrome::protocol::cdp::Page;
use headless_chrome::{Browser, LaunchOptions};
use std::fs;
use std::time::Duration;
use tauri::AppHandle;
/// Captures a screenshot of a URL using headless Chrome
///
///
/// This function launches a headless Chrome browser, navigates to the specified URL,
/// and captures a screenshot of either the entire page or a specific element.
///
///
/// # Arguments
/// * `app` - The Tauri application handle
/// * `url` - The URL to capture
/// * `selector` - Optional CSS selector for a specific element to capture
/// * `full_page` - Whether to capture the entire page or just the viewport
///
///
/// # Returns
/// * `Result<String, String>` - The path to the saved screenshot file, or an error message
#[tauri::command]
@@ -32,11 +32,10 @@ pub async fn capture_url_screenshot(
);
// Run the browser operations in a blocking task since headless_chrome is not async
let result = tokio::task::spawn_blocking(move || {
capture_screenshot_sync(url, selector, full_page)
})
.await
.map_err(|e| format!("Failed to spawn blocking task: {}", e))?;
let result =
tokio::task::spawn_blocking(move || capture_screenshot_sync(url, selector, full_page))
.await
.map_err(|e| format!("Failed to spawn blocking task: {}", e))?;
// Log the result of the headless Chrome capture before returning
match &result {
@@ -61,8 +60,8 @@ fn capture_screenshot_sync(
};
// Launch the browser
let browser = Browser::new(launch_options)
.map_err(|e| format!("Failed to launch browser: {}", e))?;
let browser =
Browser::new(launch_options).map_err(|e| format!("Failed to launch browser: {}", e))?;
// Create a new tab
let tab = browser
@@ -86,14 +85,17 @@ fn capture_screenshot_sync(
// Wait explicitly for the <body> element to exist this often prevents
// "Unable to capture screenshot" CDP errors on some pages
if let Err(e) = tab.wait_for_element("body") {
log::warn!("Timed out waiting for <body> element: {} continuing anyway", e);
log::warn!(
"Timed out waiting for <body> element: {} continuing anyway",
e
);
}
// Capture the screenshot
let screenshot_data = if let Some(selector) = selector {
// Wait for the element and capture it
log::info!("Waiting for element with selector: {}", selector);
let element = tab
.wait_for_element(&selector)
.map_err(|e| format!("Failed to find element '{}': {}", selector, e))?;
@@ -103,8 +105,11 @@ fn capture_screenshot_sync(
.map_err(|e| format!("Failed to capture element screenshot: {}", e))?
} else {
// Capture the entire page or viewport
log::info!("Capturing {} screenshot", if full_page { "full page" } else { "viewport" });
log::info!(
"Capturing {} screenshot",
if full_page { "full page" } else { "viewport" }
);
// Get the page dimensions for full page screenshot
let clip = if full_page {
// Execute JavaScript to get the full page dimensions
@@ -132,30 +137,30 @@ fn capture_screenshot_sync(
)
.map_err(|e| format!("Failed to get page dimensions: {}", e))?;
// Extract dimensions from the result
let width = dimensions
.value
.as_ref()
.and_then(|v| v.as_object())
.and_then(|obj| obj.get("width"))
.and_then(|v| v.as_f64())
.unwrap_or(1920.0);
// Extract dimensions from the result
let width = dimensions
.value
.as_ref()
.and_then(|v| v.as_object())
.and_then(|obj| obj.get("width"))
.and_then(|v| v.as_f64())
.unwrap_or(1920.0);
let height = dimensions
.value
.as_ref()
.and_then(|v| v.as_object())
.and_then(|obj| obj.get("height"))
.and_then(|v| v.as_f64())
.unwrap_or(1080.0);
let height = dimensions
.value
.as_ref()
.and_then(|v| v.as_object())
.and_then(|obj| obj.get("height"))
.and_then(|v| v.as_f64())
.unwrap_or(1080.0);
Some(Page::Viewport {
x: 0.0,
y: 0.0,
width,
height,
scale: 1.0,
})
Some(Page::Viewport {
x: 0.0,
y: 0.0,
width,
height,
scale: 1.0,
})
} else {
None
};
@@ -176,13 +181,8 @@ fn capture_screenshot_sync(
err
);
tab.capture_screenshot(
Page::CaptureScreenshotFormatOption::Png,
None,
clip,
true,
)
.map_err(|e| format!("Failed to capture screenshot after retry: {}", e))?
tab.capture_screenshot(Page::CaptureScreenshotFormatOption::Png, None, clip, true)
.map_err(|e| format!("Failed to capture screenshot after retry: {}", e))?
}
}
};
@@ -208,13 +208,13 @@ fn capture_screenshot_sync(
}
/// Cleans up old screenshot files from the temporary directory
///
///
/// This function removes screenshot files older than the specified number of minutes
/// to prevent accumulation of temporary files.
///
///
/// # Arguments
/// * `older_than_minutes` - Remove files older than this many minutes (default: 60)
///
///
/// # Returns
/// * `Result<usize, String>` - The number of files deleted, or an error message
#[tauri::command]
@@ -222,24 +222,29 @@ pub async fn cleanup_screenshot_temp_files(
older_than_minutes: Option<u64>,
) -> Result<usize, String> {
let minutes = older_than_minutes.unwrap_or(60);
log::info!("Cleaning up screenshot files older than {} minutes", minutes);
log::info!(
"Cleaning up screenshot files older than {} minutes",
minutes
);
let temp_dir = std::env::temp_dir();
let cutoff_time = chrono::Utc::now() - chrono::Duration::minutes(minutes as i64);
let mut deleted_count = 0;
// Read directory entries
let entries = fs::read_dir(&temp_dir)
.map_err(|e| format!("Failed to read temp directory: {}", e))?;
let entries =
fs::read_dir(&temp_dir).map_err(|e| format!("Failed to read temp directory: {}", e))?;
for entry in entries {
if let Ok(entry) = entry {
let path = entry.path();
// Check if it's a claudia screenshot file
if let Some(filename) = path.file_name() {
if let Some(filename_str) = filename.to_str() {
if filename_str.starts_with("claudia_screenshot_") && filename_str.ends_with(".png") {
if filename_str.starts_with("claudia_screenshot_")
&& filename_str.ends_with(".png")
{
// Check file age
if let Ok(metadata) = fs::metadata(&path) {
if let Ok(modified) = metadata.modified() {
@@ -258,7 +263,7 @@ pub async fn cleanup_screenshot_temp_files(
}
}
}
log::info!("Cleaned up {} old screenshot files", deleted_count);
Ok(deleted_count)
}
}

View File

@@ -1,9 +1,9 @@
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::PathBuf;
use chrono::{DateTime, Local, NaiveDate};
use serde::{Deserialize, Serialize};
use serde_json;
use std::collections::{HashMap, HashSet};
use std::fs;
use std::path::PathBuf;
use tauri::command;
#[derive(Debug, Serialize, Deserialize, Clone)]
@@ -108,11 +108,21 @@ fn calculate_cost(model: &str, usage: &UsageData) -> f64 {
let cache_read_tokens = usage.cache_read_input_tokens.unwrap_or(0) as f64;
// Calculate cost based on model
let (input_price, output_price, cache_write_price, cache_read_price) =
let (input_price, output_price, cache_write_price, cache_read_price) =
if model.contains("opus-4") || model.contains("claude-opus-4") {
(OPUS_4_INPUT_PRICE, OPUS_4_OUTPUT_PRICE, OPUS_4_CACHE_WRITE_PRICE, OPUS_4_CACHE_READ_PRICE)
(
OPUS_4_INPUT_PRICE,
OPUS_4_OUTPUT_PRICE,
OPUS_4_CACHE_WRITE_PRICE,
OPUS_4_CACHE_READ_PRICE,
)
} else if model.contains("sonnet-4") || model.contains("claude-sonnet-4") {
(SONNET_4_INPUT_PRICE, SONNET_4_OUTPUT_PRICE, SONNET_4_CACHE_WRITE_PRICE, SONNET_4_CACHE_READ_PRICE)
(
SONNET_4_INPUT_PRICE,
SONNET_4_OUTPUT_PRICE,
SONNET_4_CACHE_WRITE_PRICE,
SONNET_4_CACHE_READ_PRICE,
)
} else {
// Return 0 for unknown models to avoid incorrect cost estimations.
(0.0, 0.0, 0.0, 0.0)
@@ -134,10 +144,11 @@ fn parse_jsonl_file(
) -> Vec<UsageEntry> {
let mut entries = Vec::new();
let mut actual_project_path: Option<String> = None;
if let Ok(content) = fs::read_to_string(path) {
// Extract session ID from the file path
let session_id = path.parent()
let session_id = path
.parent()
.and_then(|p| p.file_name())
.and_then(|n| n.to_str())
.unwrap_or("unknown")
@@ -155,7 +166,7 @@ fn parse_jsonl_file(
actual_project_path = Some(cwd.to_string());
}
}
// Try to parse as JsonlEntry for usage data
if let Ok(entry) = serde_json::from_value::<JsonlEntry>(json_value) {
if let Some(message) = &entry.message {
@@ -170,10 +181,11 @@ fn parse_jsonl_file(
if let Some(usage) = &message.usage {
// Skip entries without meaningful token usage
if usage.input_tokens.unwrap_or(0) == 0 &&
usage.output_tokens.unwrap_or(0) == 0 &&
usage.cache_creation_input_tokens.unwrap_or(0) == 0 &&
usage.cache_read_input_tokens.unwrap_or(0) == 0 {
if usage.input_tokens.unwrap_or(0) == 0
&& usage.output_tokens.unwrap_or(0) == 0
&& usage.cache_creation_input_tokens.unwrap_or(0) == 0
&& usage.cache_read_input_tokens.unwrap_or(0) == 0
{
continue;
}
@@ -184,17 +196,23 @@ fn parse_jsonl_file(
0.0
}
});
// Use actual project path if found, otherwise use encoded name
let project_path = actual_project_path.clone()
let project_path = actual_project_path
.clone()
.unwrap_or_else(|| encoded_project_name.to_string());
entries.push(UsageEntry {
timestamp: entry.timestamp,
model: message.model.clone().unwrap_or_else(|| "unknown".to_string()),
model: message
.model
.clone()
.unwrap_or_else(|| "unknown".to_string()),
input_tokens: usage.input_tokens.unwrap_or(0),
output_tokens: usage.output_tokens.unwrap_or(0),
cache_creation_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
cache_creation_tokens: usage
.cache_creation_input_tokens
.unwrap_or(0),
cache_read_tokens: usage.cache_read_input_tokens.unwrap_or(0),
cost,
session_id: entry.session_id.unwrap_or_else(|| session_id.clone()),
@@ -263,10 +281,10 @@ fn get_all_usage_entries(claude_path: &PathBuf) -> Vec<UsageEntry> {
let entries = parse_jsonl_file(&path, &project_name, &mut processed_hashes);
all_entries.extend(entries);
}
// Sort by timestamp
all_entries.sort_by(|a, b| a.timestamp.cmp(&b.timestamp));
all_entries
}
@@ -275,9 +293,9 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
let claude_path = dirs::home_dir()
.ok_or("Failed to get home directory")?
.join(".claude");
let all_entries = get_all_usage_entries(&claude_path);
if all_entries.is_empty() {
return Ok(UsageStats {
total_cost: 0.0,
@@ -292,11 +310,12 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
by_project: vec![],
});
}
// Filter by days if specified
let filtered_entries = if let Some(days) = days {
let cutoff = Local::now().naive_local().date() - chrono::Duration::days(days as i64);
all_entries.into_iter()
all_entries
.into_iter()
.filter(|e| {
if let Ok(dt) = DateTime::parse_from_rfc3339(&e.timestamp) {
dt.naive_local().date() >= cutoff
@@ -308,18 +327,18 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
} else {
all_entries
};
// Calculate aggregated stats
let mut total_cost = 0.0;
let mut total_input_tokens = 0u64;
let mut total_output_tokens = 0u64;
let mut total_cache_creation_tokens = 0u64;
let mut total_cache_read_tokens = 0u64;
let mut model_stats: HashMap<String, ModelUsage> = HashMap::new();
let mut daily_stats: HashMap<String, DailyUsage> = HashMap::new();
let mut project_stats: HashMap<String, ProjectUsage> = HashMap::new();
for entry in &filtered_entries {
// Update totals
total_cost += entry.cost;
@@ -327,18 +346,20 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
total_output_tokens += entry.output_tokens;
total_cache_creation_tokens += entry.cache_creation_tokens;
total_cache_read_tokens += entry.cache_read_tokens;
// Update model stats
let model_stat = model_stats.entry(entry.model.clone()).or_insert(ModelUsage {
model: entry.model.clone(),
total_cost: 0.0,
total_tokens: 0,
input_tokens: 0,
output_tokens: 0,
cache_creation_tokens: 0,
cache_read_tokens: 0,
session_count: 0,
});
let model_stat = model_stats
.entry(entry.model.clone())
.or_insert(ModelUsage {
model: entry.model.clone(),
total_cost: 0.0,
total_tokens: 0,
input_tokens: 0,
output_tokens: 0,
cache_creation_tokens: 0,
cache_read_tokens: 0,
session_count: 0,
});
model_stat.total_cost += entry.cost;
model_stat.input_tokens += entry.input_tokens;
model_stat.output_tokens += entry.output_tokens;
@@ -346,9 +367,14 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
model_stat.cache_read_tokens += entry.cache_read_tokens;
model_stat.total_tokens = model_stat.input_tokens + model_stat.output_tokens;
model_stat.session_count += 1;
// Update daily stats
let date = entry.timestamp.split('T').next().unwrap_or(&entry.timestamp).to_string();
let date = entry
.timestamp
.split('T')
.next()
.unwrap_or(&entry.timestamp)
.to_string();
let daily_stat = daily_stats.entry(date.clone()).or_insert(DailyUsage {
date,
total_cost: 0.0,
@@ -356,43 +382,58 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
models_used: vec![],
});
daily_stat.total_cost += entry.cost;
daily_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens;
daily_stat.total_tokens += entry.input_tokens
+ entry.output_tokens
+ entry.cache_creation_tokens
+ entry.cache_read_tokens;
if !daily_stat.models_used.contains(&entry.model) {
daily_stat.models_used.push(entry.model.clone());
}
// Update project stats
let project_stat = project_stats.entry(entry.project_path.clone()).or_insert(ProjectUsage {
project_path: entry.project_path.clone(),
project_name: entry.project_path.split('/').last()
.unwrap_or(&entry.project_path)
.to_string(),
total_cost: 0.0,
total_tokens: 0,
session_count: 0,
last_used: entry.timestamp.clone(),
});
let project_stat =
project_stats
.entry(entry.project_path.clone())
.or_insert(ProjectUsage {
project_path: entry.project_path.clone(),
project_name: entry
.project_path
.split('/')
.last()
.unwrap_or(&entry.project_path)
.to_string(),
total_cost: 0.0,
total_tokens: 0,
session_count: 0,
last_used: entry.timestamp.clone(),
});
project_stat.total_cost += entry.cost;
project_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens;
project_stat.total_tokens += entry.input_tokens
+ entry.output_tokens
+ entry.cache_creation_tokens
+ entry.cache_read_tokens;
project_stat.session_count += 1;
if entry.timestamp > project_stat.last_used {
project_stat.last_used = entry.timestamp.clone();
}
}
let total_tokens = total_input_tokens + total_output_tokens + total_cache_creation_tokens + total_cache_read_tokens;
let total_tokens = total_input_tokens
+ total_output_tokens
+ total_cache_creation_tokens
+ total_cache_read_tokens;
let total_sessions = filtered_entries.len() as u64;
// Convert hashmaps to sorted vectors
let mut by_model: Vec<ModelUsage> = model_stats.into_values().collect();
by_model.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap());
let mut by_date: Vec<DailyUsage> = daily_stats.into_values().collect();
by_date.sort_by(|a, b| b.date.cmp(&a.date));
let mut by_project: Vec<ProjectUsage> = project_stats.into_values().collect();
by_project.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap());
Ok(UsageStats {
total_cost,
total_tokens,
@@ -412,27 +453,26 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
let claude_path = dirs::home_dir()
.ok_or("Failed to get home directory")?
.join(".claude");
let all_entries = get_all_usage_entries(&claude_path);
// Parse dates
let start = NaiveDate::parse_from_str(&start_date, "%Y-%m-%d")
.or_else(|_| {
// Try parsing ISO datetime format
DateTime::parse_from_rfc3339(&start_date)
.map(|dt| dt.naive_local().date())
.map_err(|e| format!("Invalid start date: {}", e))
})?;
let end = NaiveDate::parse_from_str(&end_date, "%Y-%m-%d")
.or_else(|_| {
// Try parsing ISO datetime format
DateTime::parse_from_rfc3339(&end_date)
.map(|dt| dt.naive_local().date())
.map_err(|e| format!("Invalid end date: {}", e))
})?;
let start = NaiveDate::parse_from_str(&start_date, "%Y-%m-%d").or_else(|_| {
// Try parsing ISO datetime format
DateTime::parse_from_rfc3339(&start_date)
.map(|dt| dt.naive_local().date())
.map_err(|e| format!("Invalid start date: {}", e))
})?;
let end = NaiveDate::parse_from_str(&end_date, "%Y-%m-%d").or_else(|_| {
// Try parsing ISO datetime format
DateTime::parse_from_rfc3339(&end_date)
.map(|dt| dt.naive_local().date())
.map_err(|e| format!("Invalid end date: {}", e))
})?;
// Filter entries by date range
let filtered_entries: Vec<_> = all_entries.into_iter()
let filtered_entries: Vec<_> = all_entries
.into_iter()
.filter(|e| {
if let Ok(dt) = DateTime::parse_from_rfc3339(&e.timestamp) {
let date = dt.naive_local().date();
@@ -442,7 +482,7 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
}
})
.collect();
if filtered_entries.is_empty() {
return Ok(UsageStats {
total_cost: 0.0,
@@ -457,18 +497,18 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
by_project: vec![],
});
}
// Calculate aggregated stats (same logic as get_usage_stats)
let mut total_cost = 0.0;
let mut total_input_tokens = 0u64;
let mut total_output_tokens = 0u64;
let mut total_cache_creation_tokens = 0u64;
let mut total_cache_read_tokens = 0u64;
let mut model_stats: HashMap<String, ModelUsage> = HashMap::new();
let mut daily_stats: HashMap<String, DailyUsage> = HashMap::new();
let mut project_stats: HashMap<String, ProjectUsage> = HashMap::new();
for entry in &filtered_entries {
// Update totals
total_cost += entry.cost;
@@ -476,18 +516,20 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
total_output_tokens += entry.output_tokens;
total_cache_creation_tokens += entry.cache_creation_tokens;
total_cache_read_tokens += entry.cache_read_tokens;
// Update model stats
let model_stat = model_stats.entry(entry.model.clone()).or_insert(ModelUsage {
model: entry.model.clone(),
total_cost: 0.0,
total_tokens: 0,
input_tokens: 0,
output_tokens: 0,
cache_creation_tokens: 0,
cache_read_tokens: 0,
session_count: 0,
});
let model_stat = model_stats
.entry(entry.model.clone())
.or_insert(ModelUsage {
model: entry.model.clone(),
total_cost: 0.0,
total_tokens: 0,
input_tokens: 0,
output_tokens: 0,
cache_creation_tokens: 0,
cache_read_tokens: 0,
session_count: 0,
});
model_stat.total_cost += entry.cost;
model_stat.input_tokens += entry.input_tokens;
model_stat.output_tokens += entry.output_tokens;
@@ -495,9 +537,14 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
model_stat.cache_read_tokens += entry.cache_read_tokens;
model_stat.total_tokens = model_stat.input_tokens + model_stat.output_tokens;
model_stat.session_count += 1;
// Update daily stats
let date = entry.timestamp.split('T').next().unwrap_or(&entry.timestamp).to_string();
let date = entry
.timestamp
.split('T')
.next()
.unwrap_or(&entry.timestamp)
.to_string();
let daily_stat = daily_stats.entry(date.clone()).or_insert(DailyUsage {
date,
total_cost: 0.0,
@@ -505,43 +552,58 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
models_used: vec![],
});
daily_stat.total_cost += entry.cost;
daily_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens;
daily_stat.total_tokens += entry.input_tokens
+ entry.output_tokens
+ entry.cache_creation_tokens
+ entry.cache_read_tokens;
if !daily_stat.models_used.contains(&entry.model) {
daily_stat.models_used.push(entry.model.clone());
}
// Update project stats
let project_stat = project_stats.entry(entry.project_path.clone()).or_insert(ProjectUsage {
project_path: entry.project_path.clone(),
project_name: entry.project_path.split('/').last()
.unwrap_or(&entry.project_path)
.to_string(),
total_cost: 0.0,
total_tokens: 0,
session_count: 0,
last_used: entry.timestamp.clone(),
});
let project_stat =
project_stats
.entry(entry.project_path.clone())
.or_insert(ProjectUsage {
project_path: entry.project_path.clone(),
project_name: entry
.project_path
.split('/')
.last()
.unwrap_or(&entry.project_path)
.to_string(),
total_cost: 0.0,
total_tokens: 0,
session_count: 0,
last_used: entry.timestamp.clone(),
});
project_stat.total_cost += entry.cost;
project_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens;
project_stat.total_tokens += entry.input_tokens
+ entry.output_tokens
+ entry.cache_creation_tokens
+ entry.cache_read_tokens;
project_stat.session_count += 1;
if entry.timestamp > project_stat.last_used {
project_stat.last_used = entry.timestamp.clone();
}
}
let total_tokens = total_input_tokens + total_output_tokens + total_cache_creation_tokens + total_cache_read_tokens;
let total_tokens = total_input_tokens
+ total_output_tokens
+ total_cache_creation_tokens
+ total_cache_read_tokens;
let total_sessions = filtered_entries.len() as u64;
// Convert hashmaps to sorted vectors
let mut by_model: Vec<ModelUsage> = model_stats.into_values().collect();
by_model.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap());
let mut by_date: Vec<DailyUsage> = daily_stats.into_values().collect();
by_date.sort_by(|a, b| b.date.cmp(&a.date));
let mut by_project: Vec<ProjectUsage> = project_stats.into_values().collect();
by_project.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap());
Ok(UsageStats {
total_cost,
total_tokens,
@@ -557,23 +619,26 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
}
#[command]
pub fn get_usage_details(project_path: Option<String>, date: Option<String>) -> Result<Vec<UsageEntry>, String> {
pub fn get_usage_details(
project_path: Option<String>,
date: Option<String>,
) -> Result<Vec<UsageEntry>, String> {
let claude_path = dirs::home_dir()
.ok_or("Failed to get home directory")?
.join(".claude");
let mut all_entries = get_all_usage_entries(&claude_path);
// Filter by project if specified
if let Some(project) = project_path {
all_entries.retain(|e| e.project_path == project);
}
// Filter by date if specified
if let Some(date) = date {
all_entries.retain(|e| e.timestamp.starts_with(&date));
}
Ok(all_entries)
}
@@ -586,7 +651,7 @@ pub fn get_session_stats(
let claude_path = dirs::home_dir()
.ok_or("Failed to get home directory")?
.join(".claude");
let all_entries = get_all_usage_entries(&claude_path);
let since_date = since.and_then(|s| NaiveDate::parse_from_str(&s, "%Y%m%d").ok());
@@ -609,14 +674,16 @@ pub fn get_session_stats(
let mut session_stats: HashMap<String, ProjectUsage> = HashMap::new();
for entry in &filtered_entries {
let session_key = format!("{}/{}", entry.project_path, entry.session_id);
let project_stat = session_stats.entry(session_key).or_insert_with(|| ProjectUsage {
project_path: entry.project_path.clone(),
project_name: entry.session_id.clone(), // Using session_id as project_name for session view
total_cost: 0.0,
total_tokens: 0,
session_count: 0, // In this context, this will count entries per session
last_used: " ".to_string(),
});
let project_stat = session_stats
.entry(session_key)
.or_insert_with(|| ProjectUsage {
project_path: entry.project_path.clone(),
project_name: entry.session_id.clone(), // Using session_id as project_name for session view
total_cost: 0.0,
total_tokens: 0,
session_count: 0, // In this context, this will count entries per session
last_used: " ".to_string(),
});
project_stat.total_cost += entry.cost;
project_stat.total_tokens += entry.input_tokens
@@ -643,6 +710,5 @@ pub fn get_session_stats(
by_session.sort_by(|a, b| b.last_used.cmp(&a.last_used));
}
Ok(by_session)
}
}