style: apply cargo fmt across entire Rust codebase
- Remove Rust formatting check from CI workflow since formatting is now applied - Standardize import ordering and organization throughout codebase - Fix indentation, spacing, and line breaks for consistency - Clean up trailing whitespace and formatting inconsistencies - Apply rustfmt to all Rust source files including checkpoint, sandbox, commands, and test modules This establishes a consistent code style baseline for the project.
This commit is contained in:
8
.github/workflows/build-test.yml
vendored
8
.github/workflows/build-test.yml
vendored
@@ -100,14 +100,6 @@ jobs:
|
|||||||
- name: Build frontend
|
- name: Build frontend
|
||||||
run: bun run build
|
run: bun run build
|
||||||
|
|
||||||
# Check Rust formatting
|
|
||||||
- name: Check Rust formatting
|
|
||||||
if: matrix.platform.os == 'ubuntu-latest'
|
|
||||||
working-directory: ./src-tauri
|
|
||||||
run: |
|
|
||||||
rustup component add rustfmt
|
|
||||||
cargo fmt -- --check
|
|
||||||
|
|
||||||
# Run Rust linter
|
# Run Rust linter
|
||||||
- name: Run Rust linter
|
- name: Run Rust linter
|
||||||
if: matrix.platform.os == 'ubuntu-latest'
|
if: matrix.platform.os == 'ubuntu-latest'
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
|
use chrono::{DateTime, TimeZone, Utc};
|
||||||
|
use log;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use chrono::{Utc, TimeZone, DateTime};
|
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use log;
|
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
Checkpoint, CheckpointMetadata, FileSnapshot, FileTracker, FileState,
|
storage::{self, CheckpointStorage},
|
||||||
CheckpointResult, SessionTimeline, CheckpointStrategy, CheckpointPaths,
|
Checkpoint, CheckpointMetadata, CheckpointPaths, CheckpointResult, CheckpointStrategy,
|
||||||
storage::{CheckpointStorage, self},
|
FileSnapshot, FileState, FileTracker, SessionTimeline,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Manages checkpoint operations for a session
|
/// Manages checkpoint operations for a session
|
||||||
@@ -33,10 +33,10 @@ impl CheckpointManager {
|
|||||||
claude_dir: PathBuf,
|
claude_dir: PathBuf,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let storage = Arc::new(CheckpointStorage::new(claude_dir.clone()));
|
let storage = Arc::new(CheckpointStorage::new(claude_dir.clone()));
|
||||||
|
|
||||||
// Initialize storage
|
// Initialize storage
|
||||||
storage.init_storage(&project_id, &session_id)?;
|
storage.init_storage(&project_id, &session_id)?;
|
||||||
|
|
||||||
// Load or create timeline
|
// Load or create timeline
|
||||||
let paths = CheckpointPaths::new(&claude_dir, &project_id, &session_id);
|
let paths = CheckpointPaths::new(&claude_dir, &project_id, &session_id);
|
||||||
let timeline = if paths.timeline_file.exists() {
|
let timeline = if paths.timeline_file.exists() {
|
||||||
@@ -44,11 +44,11 @@ impl CheckpointManager {
|
|||||||
} else {
|
} else {
|
||||||
SessionTimeline::new(session_id.clone())
|
SessionTimeline::new(session_id.clone())
|
||||||
};
|
};
|
||||||
|
|
||||||
let file_tracker = FileTracker {
|
let file_tracker = FileTracker {
|
||||||
tracked_files: HashMap::new(),
|
tracked_files: HashMap::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
project_id,
|
project_id,
|
||||||
session_id,
|
session_id,
|
||||||
@@ -59,12 +59,12 @@ impl CheckpointManager {
|
|||||||
current_messages: Arc::new(RwLock::new(Vec::new())),
|
current_messages: Arc::new(RwLock::new(Vec::new())),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Track a new message in the session
|
/// Track a new message in the session
|
||||||
pub async fn track_message(&self, jsonl_message: String) -> Result<()> {
|
pub async fn track_message(&self, jsonl_message: String) -> Result<()> {
|
||||||
let mut messages = self.current_messages.write().await;
|
let mut messages = self.current_messages.write().await;
|
||||||
messages.push(jsonl_message.clone());
|
messages.push(jsonl_message.clone());
|
||||||
|
|
||||||
// Parse message to check for tool usage
|
// Parse message to check for tool usage
|
||||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&jsonl_message) {
|
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&jsonl_message) {
|
||||||
if let Some(content) = msg.get("message").and_then(|m| m.get("content")) {
|
if let Some(content) = msg.get("message").and_then(|m| m.get("content")) {
|
||||||
@@ -81,10 +81,10 @@ impl CheckpointManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Track file operations from tool usage
|
/// Track file operations from tool usage
|
||||||
async fn track_tool_operation(&self, tool: &str, input: &serde_json::Value) -> Result<()> {
|
async fn track_tool_operation(&self, tool: &str, input: &serde_json::Value) -> Result<()> {
|
||||||
match tool.to_lowercase().as_str() {
|
match tool.to_lowercase().as_str() {
|
||||||
@@ -103,47 +103,51 @@ impl CheckpointManager {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Track a file modification
|
/// Track a file modification
|
||||||
pub async fn track_file_modification(&self, file_path: &str) -> Result<()> {
|
pub async fn track_file_modification(&self, file_path: &str) -> Result<()> {
|
||||||
let mut tracker = self.file_tracker.write().await;
|
let mut tracker = self.file_tracker.write().await;
|
||||||
let full_path = self.project_path.join(file_path);
|
let full_path = self.project_path.join(file_path);
|
||||||
|
|
||||||
// Read current file state
|
// Read current file state
|
||||||
let (hash, exists, _size, modified) = if full_path.exists() {
|
let (hash, exists, _size, modified) = if full_path.exists() {
|
||||||
let content = fs::read_to_string(&full_path)
|
let content = fs::read_to_string(&full_path).unwrap_or_default();
|
||||||
.unwrap_or_default();
|
|
||||||
let metadata = fs::metadata(&full_path)?;
|
let metadata = fs::metadata(&full_path)?;
|
||||||
let modified = metadata.modified()
|
let modified = metadata
|
||||||
|
.modified()
|
||||||
.ok()
|
.ok()
|
||||||
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
|
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
|
||||||
.map(|d| Utc.timestamp_opt(d.as_secs() as i64, d.subsec_nanos()).unwrap())
|
.map(|d| {
|
||||||
|
Utc.timestamp_opt(d.as_secs() as i64, d.subsec_nanos())
|
||||||
|
.unwrap()
|
||||||
|
})
|
||||||
.unwrap_or_else(Utc::now);
|
.unwrap_or_else(Utc::now);
|
||||||
|
|
||||||
(
|
(
|
||||||
storage::CheckpointStorage::calculate_file_hash(&content),
|
storage::CheckpointStorage::calculate_file_hash(&content),
|
||||||
true,
|
true,
|
||||||
metadata.len(),
|
metadata.len(),
|
||||||
modified
|
modified,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
(String::new(), false, 0, Utc::now())
|
(String::new(), false, 0, Utc::now())
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if file has actually changed
|
// Check if file has actually changed
|
||||||
let is_modified = if let Some(existing_state) = tracker.tracked_files.get(&PathBuf::from(file_path)) {
|
let is_modified =
|
||||||
// File is modified if:
|
if let Some(existing_state) = tracker.tracked_files.get(&PathBuf::from(file_path)) {
|
||||||
// 1. Hash has changed
|
// File is modified if:
|
||||||
// 2. Existence state has changed
|
// 1. Hash has changed
|
||||||
// 3. It was already marked as modified
|
// 2. Existence state has changed
|
||||||
existing_state.last_hash != hash ||
|
// 3. It was already marked as modified
|
||||||
existing_state.exists != exists ||
|
existing_state.last_hash != hash
|
||||||
existing_state.is_modified
|
|| existing_state.exists != exists
|
||||||
} else {
|
|| existing_state.is_modified
|
||||||
// New file is always considered modified
|
} else {
|
||||||
true
|
// New file is always considered modified
|
||||||
};
|
true
|
||||||
|
};
|
||||||
|
|
||||||
tracker.tracked_files.insert(
|
tracker.tracked_files.insert(
|
||||||
PathBuf::from(file_path),
|
PathBuf::from(file_path),
|
||||||
FileState {
|
FileState {
|
||||||
@@ -153,18 +157,18 @@ impl CheckpointManager {
|
|||||||
exists,
|
exists,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Track potential file changes from bash commands
|
/// Track potential file changes from bash commands
|
||||||
async fn track_bash_side_effects(&self, command: &str) -> Result<()> {
|
async fn track_bash_side_effects(&self, command: &str) -> Result<()> {
|
||||||
// Common file-modifying commands
|
// Common file-modifying commands
|
||||||
let file_commands = [
|
let file_commands = [
|
||||||
"echo", "cat", "cp", "mv", "rm", "touch", "sed", "awk",
|
"echo", "cat", "cp", "mv", "rm", "touch", "sed", "awk", "npm", "yarn", "pnpm", "bun",
|
||||||
"npm", "yarn", "pnpm", "bun", "cargo", "make", "gcc", "g++",
|
"cargo", "make", "gcc", "g++",
|
||||||
];
|
];
|
||||||
|
|
||||||
// Simple heuristic: if command contains file-modifying operations
|
// Simple heuristic: if command contains file-modifying operations
|
||||||
for cmd in &file_commands {
|
for cmd in &file_commands {
|
||||||
if command.contains(cmd) {
|
if command.contains(cmd) {
|
||||||
@@ -176,10 +180,10 @@ impl CheckpointManager {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a checkpoint
|
/// Create a checkpoint
|
||||||
pub async fn create_checkpoint(
|
pub async fn create_checkpoint(
|
||||||
&self,
|
&self,
|
||||||
@@ -188,13 +192,18 @@ impl CheckpointManager {
|
|||||||
) -> Result<CheckpointResult> {
|
) -> Result<CheckpointResult> {
|
||||||
let messages = self.current_messages.read().await;
|
let messages = self.current_messages.read().await;
|
||||||
let message_index = messages.len().saturating_sub(1);
|
let message_index = messages.len().saturating_sub(1);
|
||||||
|
|
||||||
// Extract metadata from the last user message
|
// Extract metadata from the last user message
|
||||||
let (user_prompt, model_used, total_tokens) = self.extract_checkpoint_metadata(&messages).await?;
|
let (user_prompt, model_used, total_tokens) =
|
||||||
|
self.extract_checkpoint_metadata(&messages).await?;
|
||||||
|
|
||||||
// Ensure every file in the project is tracked so new checkpoints include all files
|
// Ensure every file in the project is tracked so new checkpoints include all files
|
||||||
// Recursively walk the project directory and track each file
|
// Recursively walk the project directory and track each file
|
||||||
fn collect_files(dir: &std::path::Path, base: &std::path::Path, files: &mut Vec<std::path::PathBuf>) -> Result<(), std::io::Error> {
|
fn collect_files(
|
||||||
|
dir: &std::path::Path,
|
||||||
|
base: &std::path::Path,
|
||||||
|
files: &mut Vec<std::path::PathBuf>,
|
||||||
|
) -> Result<(), std::io::Error> {
|
||||||
for entry in std::fs::read_dir(dir)? {
|
for entry in std::fs::read_dir(dir)? {
|
||||||
let entry = entry?;
|
let entry = entry?;
|
||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
@@ -224,13 +233,13 @@ impl CheckpointManager {
|
|||||||
let _ = self.track_file_modification(p).await;
|
let _ = self.track_file_modification(p).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate checkpoint ID early so snapshots reference it
|
// Generate checkpoint ID early so snapshots reference it
|
||||||
let checkpoint_id = storage::CheckpointStorage::generate_checkpoint_id();
|
let checkpoint_id = storage::CheckpointStorage::generate_checkpoint_id();
|
||||||
|
|
||||||
// Create file snapshots
|
// Create file snapshots
|
||||||
let file_snapshots = self.create_file_snapshots(&checkpoint_id).await?;
|
let file_snapshots = self.create_file_snapshots(&checkpoint_id).await?;
|
||||||
|
|
||||||
// Generate checkpoint struct
|
// Generate checkpoint struct
|
||||||
let checkpoint = Checkpoint {
|
let checkpoint = Checkpoint {
|
||||||
id: checkpoint_id.clone(),
|
id: checkpoint_id.clone(),
|
||||||
@@ -259,7 +268,7 @@ impl CheckpointManager {
|
|||||||
),
|
),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// Save checkpoint
|
// Save checkpoint
|
||||||
let messages_content = messages.join("\n");
|
let messages_content = messages.join("\n");
|
||||||
let result = self.storage.save_checkpoint(
|
let result = self.storage.save_checkpoint(
|
||||||
@@ -269,7 +278,7 @@ impl CheckpointManager {
|
|||||||
file_snapshots,
|
file_snapshots,
|
||||||
&messages_content,
|
&messages_content,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// Reload timeline from disk so in-memory timeline has updated nodes and total_checkpoints
|
// Reload timeline from disk so in-memory timeline has updated nodes and total_checkpoints
|
||||||
let claude_dir = self.storage.claude_dir.clone();
|
let claude_dir = self.storage.claude_dir.clone();
|
||||||
let paths = CheckpointPaths::new(&claude_dir, &self.project_id, &self.session_id);
|
let paths = CheckpointPaths::new(&claude_dir, &self.project_id, &self.session_id);
|
||||||
@@ -278,20 +287,20 @@ impl CheckpointManager {
|
|||||||
let mut timeline_lock = self.timeline.write().await;
|
let mut timeline_lock = self.timeline.write().await;
|
||||||
*timeline_lock = updated_timeline;
|
*timeline_lock = updated_timeline;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update timeline (current checkpoint only)
|
// Update timeline (current checkpoint only)
|
||||||
let mut timeline = self.timeline.write().await;
|
let mut timeline = self.timeline.write().await;
|
||||||
timeline.current_checkpoint_id = Some(checkpoint_id);
|
timeline.current_checkpoint_id = Some(checkpoint_id);
|
||||||
|
|
||||||
// Reset file tracker
|
// Reset file tracker
|
||||||
let mut tracker = self.file_tracker.write().await;
|
let mut tracker = self.file_tracker.write().await;
|
||||||
for (_, state) in tracker.tracked_files.iter_mut() {
|
for (_, state) in tracker.tracked_files.iter_mut() {
|
||||||
state.is_modified = false;
|
state.is_modified = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract metadata from messages for checkpoint
|
/// Extract metadata from messages for checkpoint
|
||||||
async fn extract_checkpoint_metadata(
|
async fn extract_checkpoint_metadata(
|
||||||
&self,
|
&self,
|
||||||
@@ -300,13 +309,14 @@ impl CheckpointManager {
|
|||||||
let mut user_prompt = String::new();
|
let mut user_prompt = String::new();
|
||||||
let mut model_used = String::from("unknown");
|
let mut model_used = String::from("unknown");
|
||||||
let mut total_tokens = 0u64;
|
let mut total_tokens = 0u64;
|
||||||
|
|
||||||
// Iterate through messages in reverse to find the last user prompt
|
// Iterate through messages in reverse to find the last user prompt
|
||||||
for msg_str in messages.iter().rev() {
|
for msg_str in messages.iter().rev() {
|
||||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(msg_str) {
|
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(msg_str) {
|
||||||
// Check for user message
|
// Check for user message
|
||||||
if msg.get("type").and_then(|t| t.as_str()) == Some("user") {
|
if msg.get("type").and_then(|t| t.as_str()) == Some("user") {
|
||||||
if let Some(content) = msg.get("message")
|
if let Some(content) = msg
|
||||||
|
.get("message")
|
||||||
.and_then(|m| m.get("content"))
|
.and_then(|m| m.get("content"))
|
||||||
.and_then(|c| c.as_array())
|
.and_then(|c| c.as_array())
|
||||||
{
|
{
|
||||||
@@ -320,19 +330,19 @@ impl CheckpointManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract model info
|
// Extract model info
|
||||||
if let Some(model) = msg.get("model").and_then(|m| m.as_str()) {
|
if let Some(model) = msg.get("model").and_then(|m| m.as_str()) {
|
||||||
model_used = model.to_string();
|
model_used = model.to_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also check for model in message.model (assistant messages)
|
// Also check for model in message.model (assistant messages)
|
||||||
if let Some(message) = msg.get("message") {
|
if let Some(message) = msg.get("message") {
|
||||||
if let Some(model) = message.get("model").and_then(|m| m.as_str()) {
|
if let Some(model) = message.get("model").and_then(|m| m.as_str()) {
|
||||||
model_used = model.to_string();
|
model_used = model.to_string();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Count tokens - check both top-level and nested usage
|
// Count tokens - check both top-level and nested usage
|
||||||
// First check for usage in message.usage (assistant messages)
|
// First check for usage in message.usage (assistant messages)
|
||||||
if let Some(message) = msg.get("message") {
|
if let Some(message) = msg.get("message") {
|
||||||
@@ -344,15 +354,21 @@ impl CheckpointManager {
|
|||||||
total_tokens += output;
|
total_tokens += output;
|
||||||
}
|
}
|
||||||
// Also count cache tokens
|
// Also count cache tokens
|
||||||
if let Some(cache_creation) = usage.get("cache_creation_input_tokens").and_then(|t| t.as_u64()) {
|
if let Some(cache_creation) = usage
|
||||||
|
.get("cache_creation_input_tokens")
|
||||||
|
.and_then(|t| t.as_u64())
|
||||||
|
{
|
||||||
total_tokens += cache_creation;
|
total_tokens += cache_creation;
|
||||||
}
|
}
|
||||||
if let Some(cache_read) = usage.get("cache_read_input_tokens").and_then(|t| t.as_u64()) {
|
if let Some(cache_read) = usage
|
||||||
|
.get("cache_read_input_tokens")
|
||||||
|
.and_then(|t| t.as_u64())
|
||||||
|
{
|
||||||
total_tokens += cache_read;
|
total_tokens += cache_read;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then check for top-level usage (result messages)
|
// Then check for top-level usage (result messages)
|
||||||
if let Some(usage) = msg.get("usage") {
|
if let Some(usage) = msg.get("usage") {
|
||||||
if let Some(input) = usage.get("input_tokens").and_then(|t| t.as_u64()) {
|
if let Some(input) = usage.get("input_tokens").and_then(|t| t.as_u64()) {
|
||||||
@@ -362,40 +378,45 @@ impl CheckpointManager {
|
|||||||
total_tokens += output;
|
total_tokens += output;
|
||||||
}
|
}
|
||||||
// Also count cache tokens
|
// Also count cache tokens
|
||||||
if let Some(cache_creation) = usage.get("cache_creation_input_tokens").and_then(|t| t.as_u64()) {
|
if let Some(cache_creation) = usage
|
||||||
|
.get("cache_creation_input_tokens")
|
||||||
|
.and_then(|t| t.as_u64())
|
||||||
|
{
|
||||||
total_tokens += cache_creation;
|
total_tokens += cache_creation;
|
||||||
}
|
}
|
||||||
if let Some(cache_read) = usage.get("cache_read_input_tokens").and_then(|t| t.as_u64()) {
|
if let Some(cache_read) = usage
|
||||||
|
.get("cache_read_input_tokens")
|
||||||
|
.and_then(|t| t.as_u64())
|
||||||
|
{
|
||||||
total_tokens += cache_read;
|
total_tokens += cache_read;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok((user_prompt, model_used, total_tokens))
|
Ok((user_prompt, model_used, total_tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create file snapshots for all tracked modified files
|
/// Create file snapshots for all tracked modified files
|
||||||
async fn create_file_snapshots(&self, checkpoint_id: &str) -> Result<Vec<FileSnapshot>> {
|
async fn create_file_snapshots(&self, checkpoint_id: &str) -> Result<Vec<FileSnapshot>> {
|
||||||
let tracker = self.file_tracker.read().await;
|
let tracker = self.file_tracker.read().await;
|
||||||
let mut snapshots = Vec::new();
|
let mut snapshots = Vec::new();
|
||||||
|
|
||||||
for (rel_path, state) in &tracker.tracked_files {
|
for (rel_path, state) in &tracker.tracked_files {
|
||||||
// Skip files that haven't been modified
|
// Skip files that haven't been modified
|
||||||
if !state.is_modified {
|
if !state.is_modified {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let full_path = self.project_path.join(rel_path);
|
let full_path = self.project_path.join(rel_path);
|
||||||
|
|
||||||
let (content, exists, permissions, size, current_hash) = if full_path.exists() {
|
let (content, exists, permissions, size, current_hash) = if full_path.exists() {
|
||||||
let content = fs::read_to_string(&full_path)
|
let content = fs::read_to_string(&full_path).unwrap_or_default();
|
||||||
.unwrap_or_default();
|
|
||||||
let current_hash = storage::CheckpointStorage::calculate_file_hash(&content);
|
let current_hash = storage::CheckpointStorage::calculate_file_hash(&content);
|
||||||
|
|
||||||
// Don't skip based on hash - if is_modified is true, we should snapshot it
|
// Don't skip based on hash - if is_modified is true, we should snapshot it
|
||||||
// The hash check in track_file_modification already determined if it changed
|
// The hash check in track_file_modification already determined if it changed
|
||||||
|
|
||||||
let metadata = fs::metadata(&full_path)?;
|
let metadata = fs::metadata(&full_path)?;
|
||||||
let permissions = {
|
let permissions = {
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
@@ -412,7 +433,7 @@ impl CheckpointManager {
|
|||||||
} else {
|
} else {
|
||||||
(String::new(), false, None, 0, String::new())
|
(String::new(), false, None, 0, String::new())
|
||||||
};
|
};
|
||||||
|
|
||||||
snapshots.push(FileSnapshot {
|
snapshots.push(FileSnapshot {
|
||||||
checkpoint_id: checkpoint_id.to_string(),
|
checkpoint_id: checkpoint_id.to_string(),
|
||||||
file_path: rel_path.clone(),
|
file_path: rel_path.clone(),
|
||||||
@@ -423,21 +444,23 @@ impl CheckpointManager {
|
|||||||
size,
|
size,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(snapshots)
|
Ok(snapshots)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Restore a checkpoint
|
/// Restore a checkpoint
|
||||||
pub async fn restore_checkpoint(&self, checkpoint_id: &str) -> Result<CheckpointResult> {
|
pub async fn restore_checkpoint(&self, checkpoint_id: &str) -> Result<CheckpointResult> {
|
||||||
// Load checkpoint data
|
// Load checkpoint data
|
||||||
let (checkpoint, file_snapshots, messages) = self.storage.load_checkpoint(
|
let (checkpoint, file_snapshots, messages) =
|
||||||
&self.project_id,
|
self.storage
|
||||||
&self.session_id,
|
.load_checkpoint(&self.project_id, &self.session_id, checkpoint_id)?;
|
||||||
checkpoint_id,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// First, collect all files currently in the project to handle deletions
|
// First, collect all files currently in the project to handle deletions
|
||||||
fn collect_all_project_files(dir: &std::path::Path, base: &std::path::Path, files: &mut Vec<std::path::PathBuf>) -> Result<(), std::io::Error> {
|
fn collect_all_project_files(
|
||||||
|
dir: &std::path::Path,
|
||||||
|
base: &std::path::Path,
|
||||||
|
files: &mut Vec<std::path::PathBuf>,
|
||||||
|
) -> Result<(), std::io::Error> {
|
||||||
for entry in std::fs::read_dir(dir)? {
|
for entry in std::fs::read_dir(dir)? {
|
||||||
let entry = entry?;
|
let entry = entry?;
|
||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
@@ -458,10 +481,11 @@ impl CheckpointManager {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut current_files = Vec::new();
|
let mut current_files = Vec::new();
|
||||||
let _ = collect_all_project_files(&self.project_path, &self.project_path, &mut current_files);
|
let _ =
|
||||||
|
collect_all_project_files(&self.project_path, &self.project_path, &mut current_files);
|
||||||
|
|
||||||
// Create a set of files that should exist after restore
|
// Create a set of files that should exist after restore
|
||||||
let mut checkpoint_files = std::collections::HashSet::new();
|
let mut checkpoint_files = std::collections::HashSet::new();
|
||||||
for snapshot in &file_snapshots {
|
for snapshot in &file_snapshots {
|
||||||
@@ -469,11 +493,11 @@ impl CheckpointManager {
|
|||||||
checkpoint_files.insert(snapshot.file_path.clone());
|
checkpoint_files.insert(snapshot.file_path.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete files that exist now but shouldn't exist in the checkpoint
|
// Delete files that exist now but shouldn't exist in the checkpoint
|
||||||
let mut warnings = Vec::new();
|
let mut warnings = Vec::new();
|
||||||
let mut files_processed = 0;
|
let mut files_processed = 0;
|
||||||
|
|
||||||
for current_file in current_files {
|
for current_file in current_files {
|
||||||
if !checkpoint_files.contains(¤t_file) {
|
if !checkpoint_files.contains(¤t_file) {
|
||||||
// This file exists now but not in the checkpoint, so delete it
|
// This file exists now but not in the checkpoint, so delete it
|
||||||
@@ -484,18 +508,25 @@ impl CheckpointManager {
|
|||||||
log::info!("Deleted file not in checkpoint: {:?}", current_file);
|
log::info!("Deleted file not in checkpoint: {:?}", current_file);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warnings.push(format!("Failed to delete {}: {}", current_file.display(), e));
|
warnings.push(format!(
|
||||||
|
"Failed to delete {}: {}",
|
||||||
|
current_file.display(),
|
||||||
|
e
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up empty directories
|
// Clean up empty directories
|
||||||
fn remove_empty_dirs(dir: &std::path::Path, base: &std::path::Path) -> Result<bool, std::io::Error> {
|
fn remove_empty_dirs(
|
||||||
|
dir: &std::path::Path,
|
||||||
|
base: &std::path::Path,
|
||||||
|
) -> Result<bool, std::io::Error> {
|
||||||
if dir == base {
|
if dir == base {
|
||||||
return Ok(false); // Don't remove the base directory
|
return Ok(false); // Don't remove the base directory
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut is_empty = true;
|
let mut is_empty = true;
|
||||||
for entry in fs::read_dir(dir)? {
|
for entry in fs::read_dir(dir)? {
|
||||||
let entry = entry?;
|
let entry = entry?;
|
||||||
@@ -508,7 +539,7 @@ impl CheckpointManager {
|
|||||||
is_empty = false;
|
is_empty = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_empty {
|
if is_empty {
|
||||||
fs::remove_dir(dir)?;
|
fs::remove_dir(dir)?;
|
||||||
Ok(true)
|
Ok(true)
|
||||||
@@ -516,30 +547,33 @@ impl CheckpointManager {
|
|||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up any empty directories left after file deletion
|
// Clean up any empty directories left after file deletion
|
||||||
let _ = remove_empty_dirs(&self.project_path, &self.project_path);
|
let _ = remove_empty_dirs(&self.project_path, &self.project_path);
|
||||||
|
|
||||||
// Restore files from checkpoint
|
// Restore files from checkpoint
|
||||||
for snapshot in &file_snapshots {
|
for snapshot in &file_snapshots {
|
||||||
match self.restore_file_snapshot(snapshot).await {
|
match self.restore_file_snapshot(snapshot).await {
|
||||||
Ok(_) => files_processed += 1,
|
Ok(_) => files_processed += 1,
|
||||||
Err(e) => warnings.push(format!("Failed to restore {}: {}",
|
Err(e) => warnings.push(format!(
|
||||||
snapshot.file_path.display(), e)),
|
"Failed to restore {}: {}",
|
||||||
|
snapshot.file_path.display(),
|
||||||
|
e
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update current messages
|
// Update current messages
|
||||||
let mut current_messages = self.current_messages.write().await;
|
let mut current_messages = self.current_messages.write().await;
|
||||||
current_messages.clear();
|
current_messages.clear();
|
||||||
for line in messages.lines() {
|
for line in messages.lines() {
|
||||||
current_messages.push(line.to_string());
|
current_messages.push(line.to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update timeline
|
// Update timeline
|
||||||
let mut timeline = self.timeline.write().await;
|
let mut timeline = self.timeline.write().await;
|
||||||
timeline.current_checkpoint_id = Some(checkpoint_id.to_string());
|
timeline.current_checkpoint_id = Some(checkpoint_id.to_string());
|
||||||
|
|
||||||
// Update file tracker
|
// Update file tracker
|
||||||
let mut tracker = self.file_tracker.write().await;
|
let mut tracker = self.file_tracker.write().await;
|
||||||
tracker.tracked_files.clear();
|
tracker.tracked_files.clear();
|
||||||
@@ -556,35 +590,32 @@ impl CheckpointManager {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(CheckpointResult {
|
Ok(CheckpointResult {
|
||||||
checkpoint: checkpoint.clone(),
|
checkpoint: checkpoint.clone(),
|
||||||
files_processed,
|
files_processed,
|
||||||
warnings,
|
warnings,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Restore a single file from snapshot
|
/// Restore a single file from snapshot
|
||||||
async fn restore_file_snapshot(&self, snapshot: &FileSnapshot) -> Result<()> {
|
async fn restore_file_snapshot(&self, snapshot: &FileSnapshot) -> Result<()> {
|
||||||
let full_path = self.project_path.join(&snapshot.file_path);
|
let full_path = self.project_path.join(&snapshot.file_path);
|
||||||
|
|
||||||
if snapshot.is_deleted {
|
if snapshot.is_deleted {
|
||||||
// Delete the file if it exists
|
// Delete the file if it exists
|
||||||
if full_path.exists() {
|
if full_path.exists() {
|
||||||
fs::remove_file(&full_path)
|
fs::remove_file(&full_path).context("Failed to delete file")?;
|
||||||
.context("Failed to delete file")?;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Create parent directories if needed
|
// Create parent directories if needed
|
||||||
if let Some(parent) = full_path.parent() {
|
if let Some(parent) = full_path.parent() {
|
||||||
fs::create_dir_all(parent)
|
fs::create_dir_all(parent).context("Failed to create parent directories")?;
|
||||||
.context("Failed to create parent directories")?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write file content
|
// Write file content
|
||||||
fs::write(&full_path, &snapshot.content)
|
fs::write(&full_path, &snapshot.content).context("Failed to write file")?;
|
||||||
.context("Failed to write file")?;
|
|
||||||
|
|
||||||
// Restore permissions if available
|
// Restore permissions if available
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
if let Some(mode) = snapshot.permissions {
|
if let Some(mode) = snapshot.permissions {
|
||||||
@@ -594,35 +625,38 @@ impl CheckpointManager {
|
|||||||
.context("Failed to set file permissions")?;
|
.context("Failed to set file permissions")?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the current timeline
|
/// Get the current timeline
|
||||||
pub async fn get_timeline(&self) -> SessionTimeline {
|
pub async fn get_timeline(&self) -> SessionTimeline {
|
||||||
self.timeline.read().await.clone()
|
self.timeline.read().await.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// List all checkpoints
|
/// List all checkpoints
|
||||||
pub async fn list_checkpoints(&self) -> Vec<Checkpoint> {
|
pub async fn list_checkpoints(&self) -> Vec<Checkpoint> {
|
||||||
let timeline = self.timeline.read().await;
|
let timeline = self.timeline.read().await;
|
||||||
let mut checkpoints = Vec::new();
|
let mut checkpoints = Vec::new();
|
||||||
|
|
||||||
if let Some(root) = &timeline.root_node {
|
if let Some(root) = &timeline.root_node {
|
||||||
Self::collect_checkpoints_from_node(root, &mut checkpoints);
|
Self::collect_checkpoints_from_node(root, &mut checkpoints);
|
||||||
}
|
}
|
||||||
|
|
||||||
checkpoints
|
checkpoints
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Recursively collect checkpoints from timeline tree
|
/// Recursively collect checkpoints from timeline tree
|
||||||
fn collect_checkpoints_from_node(node: &super::TimelineNode, checkpoints: &mut Vec<Checkpoint>) {
|
fn collect_checkpoints_from_node(
|
||||||
|
node: &super::TimelineNode,
|
||||||
|
checkpoints: &mut Vec<Checkpoint>,
|
||||||
|
) {
|
||||||
checkpoints.push(node.checkpoint.clone());
|
checkpoints.push(node.checkpoint.clone());
|
||||||
for child in &node.children {
|
for child in &node.children {
|
||||||
Self::collect_checkpoints_from_node(child, checkpoints);
|
Self::collect_checkpoints_from_node(child, checkpoints);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Fork from a checkpoint
|
/// Fork from a checkpoint
|
||||||
pub async fn fork_from_checkpoint(
|
pub async fn fork_from_checkpoint(
|
||||||
&self,
|
&self,
|
||||||
@@ -630,31 +664,29 @@ impl CheckpointManager {
|
|||||||
description: Option<String>,
|
description: Option<String>,
|
||||||
) -> Result<CheckpointResult> {
|
) -> Result<CheckpointResult> {
|
||||||
// Load the checkpoint to fork from
|
// Load the checkpoint to fork from
|
||||||
let (_base_checkpoint, _, _) = self.storage.load_checkpoint(
|
let (_base_checkpoint, _, _) =
|
||||||
&self.project_id,
|
self.storage
|
||||||
&self.session_id,
|
.load_checkpoint(&self.project_id, &self.session_id, checkpoint_id)?;
|
||||||
checkpoint_id,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// Restore to that checkpoint first
|
// Restore to that checkpoint first
|
||||||
self.restore_checkpoint(checkpoint_id).await?;
|
self.restore_checkpoint(checkpoint_id).await?;
|
||||||
|
|
||||||
// Create a new checkpoint with the fork
|
// Create a new checkpoint with the fork
|
||||||
let fork_description = description.unwrap_or_else(|| {
|
let fork_description =
|
||||||
format!("Fork from checkpoint {}", &checkpoint_id[..8])
|
description.unwrap_or_else(|| format!("Fork from checkpoint {}", &checkpoint_id[..8]));
|
||||||
});
|
|
||||||
|
self.create_checkpoint(Some(fork_description), Some(checkpoint_id.to_string()))
|
||||||
self.create_checkpoint(Some(fork_description), Some(checkpoint_id.to_string())).await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if auto-checkpoint should be triggered
|
/// Check if auto-checkpoint should be triggered
|
||||||
pub async fn should_auto_checkpoint(&self, message: &str) -> bool {
|
pub async fn should_auto_checkpoint(&self, message: &str) -> bool {
|
||||||
let timeline = self.timeline.read().await;
|
let timeline = self.timeline.read().await;
|
||||||
|
|
||||||
if !timeline.auto_checkpoint_enabled {
|
if !timeline.auto_checkpoint_enabled {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
match timeline.checkpoint_strategy {
|
match timeline.checkpoint_strategy {
|
||||||
CheckpointStrategy::Manual => false,
|
CheckpointStrategy::Manual => false,
|
||||||
CheckpointStrategy::PerPrompt => {
|
CheckpointStrategy::PerPrompt => {
|
||||||
@@ -668,7 +700,11 @@ impl CheckpointManager {
|
|||||||
CheckpointStrategy::PerToolUse => {
|
CheckpointStrategy::PerToolUse => {
|
||||||
// Check if message contains tool use
|
// Check if message contains tool use
|
||||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(message) {
|
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(message) {
|
||||||
if let Some(content) = msg.get("message").and_then(|m| m.get("content")).and_then(|c| c.as_array()) {
|
if let Some(content) = msg
|
||||||
|
.get("message")
|
||||||
|
.and_then(|m| m.get("content"))
|
||||||
|
.and_then(|c| c.as_array())
|
||||||
|
{
|
||||||
content.iter().any(|item| {
|
content.iter().any(|item| {
|
||||||
item.get("type").and_then(|t| t.as_str()) == Some("tool_use")
|
item.get("type").and_then(|t| t.as_str()) == Some("tool_use")
|
||||||
})
|
})
|
||||||
@@ -682,12 +718,19 @@ impl CheckpointManager {
|
|||||||
CheckpointStrategy::Smart => {
|
CheckpointStrategy::Smart => {
|
||||||
// Smart strategy: checkpoint after destructive operations
|
// Smart strategy: checkpoint after destructive operations
|
||||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(message) {
|
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(message) {
|
||||||
if let Some(content) = msg.get("message").and_then(|m| m.get("content")).and_then(|c| c.as_array()) {
|
if let Some(content) = msg
|
||||||
|
.get("message")
|
||||||
|
.and_then(|m| m.get("content"))
|
||||||
|
.and_then(|c| c.as_array())
|
||||||
|
{
|
||||||
content.iter().any(|item| {
|
content.iter().any(|item| {
|
||||||
if item.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
|
if item.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
|
||||||
let tool_name = item.get("name").and_then(|n| n.as_str()).unwrap_or("");
|
let tool_name =
|
||||||
matches!(tool_name.to_lowercase().as_str(),
|
item.get("name").and_then(|n| n.as_str()).unwrap_or("");
|
||||||
"write" | "edit" | "multiedit" | "bash" | "rm" | "delete")
|
matches!(
|
||||||
|
tool_name.to_lowercase().as_str(),
|
||||||
|
"write" | "edit" | "multiedit" | "bash" | "rm" | "delete"
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
@@ -701,7 +744,7 @@ impl CheckpointManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Update checkpoint settings
|
/// Update checkpoint settings
|
||||||
pub async fn update_settings(
|
pub async fn update_settings(
|
||||||
&self,
|
&self,
|
||||||
@@ -711,31 +754,34 @@ impl CheckpointManager {
|
|||||||
let mut timeline = self.timeline.write().await;
|
let mut timeline = self.timeline.write().await;
|
||||||
timeline.auto_checkpoint_enabled = auto_checkpoint_enabled;
|
timeline.auto_checkpoint_enabled = auto_checkpoint_enabled;
|
||||||
timeline.checkpoint_strategy = checkpoint_strategy;
|
timeline.checkpoint_strategy = checkpoint_strategy;
|
||||||
|
|
||||||
// Save updated timeline
|
// Save updated timeline
|
||||||
let claude_dir = self.storage.claude_dir.clone();
|
let claude_dir = self.storage.claude_dir.clone();
|
||||||
let paths = CheckpointPaths::new(&claude_dir, &self.project_id, &self.session_id);
|
let paths = CheckpointPaths::new(&claude_dir, &self.project_id, &self.session_id);
|
||||||
self.storage.save_timeline(&paths.timeline_file, &timeline)?;
|
self.storage
|
||||||
|
.save_timeline(&paths.timeline_file, &timeline)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get files modified since a given timestamp
|
/// Get files modified since a given timestamp
|
||||||
pub async fn get_files_modified_since(&self, since: DateTime<Utc>) -> Vec<PathBuf> {
|
pub async fn get_files_modified_since(&self, since: DateTime<Utc>) -> Vec<PathBuf> {
|
||||||
let tracker = self.file_tracker.read().await;
|
let tracker = self.file_tracker.read().await;
|
||||||
tracker.tracked_files
|
tracker
|
||||||
|
.tracked_files
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|(_, state)| state.last_modified > since && state.is_modified)
|
.filter(|(_, state)| state.last_modified > since && state.is_modified)
|
||||||
.map(|(path, _)| path.clone())
|
.map(|(path, _)| path.clone())
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the last modification time of any tracked file
|
/// Get the last modification time of any tracked file
|
||||||
pub async fn get_last_modification_time(&self) -> Option<DateTime<Utc>> {
|
pub async fn get_last_modification_time(&self) -> Option<DateTime<Utc>> {
|
||||||
let tracker = self.file_tracker.read().await;
|
let tracker = self.file_tracker.read().await;
|
||||||
tracker.tracked_files
|
tracker
|
||||||
|
.tracked_files
|
||||||
.values()
|
.values()
|
||||||
.map(|state| state.last_modified)
|
.map(|state| state.last_modified)
|
||||||
.max()
|
.max()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
|
use chrono::{DateTime, Utc};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use chrono::{DateTime, Utc};
|
|
||||||
|
|
||||||
pub mod manager;
|
pub mod manager;
|
||||||
pub mod storage;
|
|
||||||
pub mod state;
|
pub mod state;
|
||||||
|
pub mod storage;
|
||||||
|
|
||||||
/// Represents a checkpoint in the session timeline
|
/// Represents a checkpoint in the session timeline
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -188,24 +188,25 @@ impl SessionTimeline {
|
|||||||
total_checkpoints: 0,
|
total_checkpoints: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Find a checkpoint by ID in the timeline tree
|
/// Find a checkpoint by ID in the timeline tree
|
||||||
pub fn find_checkpoint(&self, checkpoint_id: &str) -> Option<&TimelineNode> {
|
pub fn find_checkpoint(&self, checkpoint_id: &str) -> Option<&TimelineNode> {
|
||||||
self.root_node.as_ref()
|
self.root_node
|
||||||
|
.as_ref()
|
||||||
.and_then(|root| Self::find_in_tree(root, checkpoint_id))
|
.and_then(|root| Self::find_in_tree(root, checkpoint_id))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn find_in_tree<'a>(node: &'a TimelineNode, checkpoint_id: &str) -> Option<&'a TimelineNode> {
|
fn find_in_tree<'a>(node: &'a TimelineNode, checkpoint_id: &str) -> Option<&'a TimelineNode> {
|
||||||
if node.checkpoint.id == checkpoint_id {
|
if node.checkpoint.id == checkpoint_id {
|
||||||
return Some(node);
|
return Some(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
for child in &node.children {
|
for child in &node.children {
|
||||||
if let Some(found) = Self::find_in_tree(child, checkpoint_id) {
|
if let Some(found) = Self::find_in_tree(child, checkpoint_id) {
|
||||||
return Some(found);
|
return Some(found);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -224,35 +225,38 @@ impl CheckpointPaths {
|
|||||||
.join(project_id)
|
.join(project_id)
|
||||||
.join(".timelines")
|
.join(".timelines")
|
||||||
.join(session_id);
|
.join(session_id);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
timeline_file: base_dir.join("timeline.json"),
|
timeline_file: base_dir.join("timeline.json"),
|
||||||
checkpoints_dir: base_dir.join("checkpoints"),
|
checkpoints_dir: base_dir.join("checkpoints"),
|
||||||
files_dir: base_dir.join("files"),
|
files_dir: base_dir.join("files"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn checkpoint_dir(&self, checkpoint_id: &str) -> PathBuf {
|
pub fn checkpoint_dir(&self, checkpoint_id: &str) -> PathBuf {
|
||||||
self.checkpoints_dir.join(checkpoint_id)
|
self.checkpoints_dir.join(checkpoint_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn checkpoint_metadata_file(&self, checkpoint_id: &str) -> PathBuf {
|
pub fn checkpoint_metadata_file(&self, checkpoint_id: &str) -> PathBuf {
|
||||||
self.checkpoint_dir(checkpoint_id).join("metadata.json")
|
self.checkpoint_dir(checkpoint_id).join("metadata.json")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn checkpoint_messages_file(&self, checkpoint_id: &str) -> PathBuf {
|
pub fn checkpoint_messages_file(&self, checkpoint_id: &str) -> PathBuf {
|
||||||
self.checkpoint_dir(checkpoint_id).join("messages.jsonl")
|
self.checkpoint_dir(checkpoint_id).join("messages.jsonl")
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub fn file_snapshot_path(&self, _checkpoint_id: &str, file_hash: &str) -> PathBuf {
|
pub fn file_snapshot_path(&self, _checkpoint_id: &str, file_hash: &str) -> PathBuf {
|
||||||
// In content-addressable storage, files are stored by hash in the content pool
|
// In content-addressable storage, files are stored by hash in the content pool
|
||||||
self.files_dir.join("content_pool").join(file_hash)
|
self.files_dir.join("content_pool").join(file_hash)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub fn file_reference_path(&self, checkpoint_id: &str, safe_filename: &str) -> PathBuf {
|
pub fn file_reference_path(&self, checkpoint_id: &str, safe_filename: &str) -> PathBuf {
|
||||||
// References are stored per checkpoint
|
// References are stored per checkpoint
|
||||||
self.files_dir.join("refs").join(checkpoint_id).join(format!("{}.json", safe_filename))
|
self.files_dir
|
||||||
|
.join("refs")
|
||||||
|
.join(checkpoint_id)
|
||||||
|
.join(format!("{}.json", safe_filename))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
|
use anyhow::Result;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use anyhow::Result;
|
|
||||||
|
|
||||||
use super::manager::CheckpointManager;
|
use super::manager::CheckpointManager;
|
||||||
|
|
||||||
/// Manages checkpoint managers for active sessions
|
/// Manages checkpoint managers for active sessions
|
||||||
///
|
///
|
||||||
/// This struct maintains a stateful collection of CheckpointManager instances,
|
/// This struct maintains a stateful collection of CheckpointManager instances,
|
||||||
/// one per active session, to avoid recreating them on every command invocation.
|
/// one per active session, to avoid recreating them on every command invocation.
|
||||||
/// It provides thread-safe access to managers and handles their lifecycle.
|
/// It provides thread-safe access to managers and handles their lifecycle.
|
||||||
@@ -28,25 +28,25 @@ impl CheckpointState {
|
|||||||
claude_dir: Arc::new(RwLock::new(None)),
|
claude_dir: Arc::new(RwLock::new(None)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sets the Claude directory path
|
/// Sets the Claude directory path
|
||||||
///
|
///
|
||||||
/// This should be called once during application initialization
|
/// This should be called once during application initialization
|
||||||
pub async fn set_claude_dir(&self, claude_dir: PathBuf) {
|
pub async fn set_claude_dir(&self, claude_dir: PathBuf) {
|
||||||
let mut dir = self.claude_dir.write().await;
|
let mut dir = self.claude_dir.write().await;
|
||||||
*dir = Some(claude_dir);
|
*dir = Some(claude_dir);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets or creates a CheckpointManager for a session
|
/// Gets or creates a CheckpointManager for a session
|
||||||
///
|
///
|
||||||
/// If a manager already exists for the session, it returns the existing one.
|
/// If a manager already exists for the session, it returns the existing one.
|
||||||
/// Otherwise, it creates a new manager and stores it for future use.
|
/// Otherwise, it creates a new manager and stores it for future use.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `session_id` - The session identifier
|
/// * `session_id` - The session identifier
|
||||||
/// * `project_id` - The project identifier
|
/// * `project_id` - The project identifier
|
||||||
/// * `project_path` - The path to the project directory
|
/// * `project_path` - The path to the project directory
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// An Arc reference to the CheckpointManager for thread-safe sharing
|
/// An Arc reference to the CheckpointManager for thread-safe sharing
|
||||||
pub async fn get_or_create_manager(
|
pub async fn get_or_create_manager(
|
||||||
@@ -56,12 +56,12 @@ impl CheckpointState {
|
|||||||
project_path: PathBuf,
|
project_path: PathBuf,
|
||||||
) -> Result<Arc<CheckpointManager>> {
|
) -> Result<Arc<CheckpointManager>> {
|
||||||
let mut managers = self.managers.write().await;
|
let mut managers = self.managers.write().await;
|
||||||
|
|
||||||
// Check if manager already exists
|
// Check if manager already exists
|
||||||
if let Some(manager) = managers.get(&session_id) {
|
if let Some(manager) = managers.get(&session_id) {
|
||||||
return Ok(Arc::clone(manager));
|
return Ok(Arc::clone(manager));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get Claude directory
|
// Get Claude directory
|
||||||
let claude_dir = {
|
let claude_dir = {
|
||||||
let dir = self.claude_dir.read().await;
|
let dir = self.claude_dir.read().await;
|
||||||
@@ -69,65 +69,62 @@ impl CheckpointState {
|
|||||||
.ok_or_else(|| anyhow::anyhow!("Claude directory not set"))?
|
.ok_or_else(|| anyhow::anyhow!("Claude directory not set"))?
|
||||||
.clone()
|
.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create new manager
|
// Create new manager
|
||||||
let manager = CheckpointManager::new(
|
let manager =
|
||||||
project_id,
|
CheckpointManager::new(project_id, session_id.clone(), project_path, claude_dir)
|
||||||
session_id.clone(),
|
.await?;
|
||||||
project_path,
|
|
||||||
claude_dir,
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
let manager_arc = Arc::new(manager);
|
let manager_arc = Arc::new(manager);
|
||||||
managers.insert(session_id, Arc::clone(&manager_arc));
|
managers.insert(session_id, Arc::clone(&manager_arc));
|
||||||
|
|
||||||
Ok(manager_arc)
|
Ok(manager_arc)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets an existing CheckpointManager for a session
|
/// Gets an existing CheckpointManager for a session
|
||||||
///
|
///
|
||||||
/// Returns None if no manager exists for the session
|
/// Returns None if no manager exists for the session
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub async fn get_manager(&self, session_id: &str) -> Option<Arc<CheckpointManager>> {
|
pub async fn get_manager(&self, session_id: &str) -> Option<Arc<CheckpointManager>> {
|
||||||
let managers = self.managers.read().await;
|
let managers = self.managers.read().await;
|
||||||
managers.get(session_id).map(Arc::clone)
|
managers.get(session_id).map(Arc::clone)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Removes a CheckpointManager for a session
|
/// Removes a CheckpointManager for a session
|
||||||
///
|
///
|
||||||
/// This should be called when a session ends to free resources
|
/// This should be called when a session ends to free resources
|
||||||
pub async fn remove_manager(&self, session_id: &str) -> Option<Arc<CheckpointManager>> {
|
pub async fn remove_manager(&self, session_id: &str) -> Option<Arc<CheckpointManager>> {
|
||||||
let mut managers = self.managers.write().await;
|
let mut managers = self.managers.write().await;
|
||||||
managers.remove(session_id)
|
managers.remove(session_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Clears all managers
|
/// Clears all managers
|
||||||
///
|
///
|
||||||
/// This is useful for cleanup during application shutdown
|
/// This is useful for cleanup during application shutdown
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub async fn clear_all(&self) {
|
pub async fn clear_all(&self) {
|
||||||
let mut managers = self.managers.write().await;
|
let mut managers = self.managers.write().await;
|
||||||
managers.clear();
|
managers.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets the number of active managers
|
/// Gets the number of active managers
|
||||||
pub async fn active_count(&self) -> usize {
|
pub async fn active_count(&self) -> usize {
|
||||||
let managers = self.managers.read().await;
|
let managers = self.managers.read().await;
|
||||||
managers.len()
|
managers.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lists all active session IDs
|
/// Lists all active session IDs
|
||||||
pub async fn list_active_sessions(&self) -> Vec<String> {
|
pub async fn list_active_sessions(&self) -> Vec<String> {
|
||||||
let managers = self.managers.read().await;
|
let managers = self.managers.read().await;
|
||||||
managers.keys().cloned().collect()
|
managers.keys().cloned().collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Checks if a session has an active manager
|
/// Checks if a session has an active manager
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub async fn has_active_manager(&self, session_id: &str) -> bool {
|
pub async fn has_active_manager(&self, session_id: &str) -> bool {
|
||||||
self.get_manager(session_id).await.is_some()
|
self.get_manager(session_id).await.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Clears all managers and returns the count that were cleared
|
/// Clears all managers and returns the count that were cleared
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub async fn clear_all_and_count(&self) -> usize {
|
pub async fn clear_all_and_count(&self) -> usize {
|
||||||
@@ -141,50 +138,47 @@ impl CheckpointState {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_checkpoint_state_lifecycle() {
|
async fn test_checkpoint_state_lifecycle() {
|
||||||
let state = CheckpointState::new();
|
let state = CheckpointState::new();
|
||||||
let temp_dir = TempDir::new().unwrap();
|
let temp_dir = TempDir::new().unwrap();
|
||||||
let claude_dir = temp_dir.path().to_path_buf();
|
let claude_dir = temp_dir.path().to_path_buf();
|
||||||
|
|
||||||
// Set Claude directory
|
// Set Claude directory
|
||||||
state.set_claude_dir(claude_dir.clone()).await;
|
state.set_claude_dir(claude_dir.clone()).await;
|
||||||
|
|
||||||
// Create a manager
|
// Create a manager
|
||||||
let session_id = "test-session-123".to_string();
|
let session_id = "test-session-123".to_string();
|
||||||
let project_id = "test-project".to_string();
|
let project_id = "test-project".to_string();
|
||||||
let project_path = temp_dir.path().join("project");
|
let project_path = temp_dir.path().join("project");
|
||||||
std::fs::create_dir_all(&project_path).unwrap();
|
std::fs::create_dir_all(&project_path).unwrap();
|
||||||
|
|
||||||
let manager1 = state.get_or_create_manager(
|
let manager1 = state
|
||||||
session_id.clone(),
|
.get_or_create_manager(session_id.clone(), project_id.clone(), project_path.clone())
|
||||||
project_id.clone(),
|
.await
|
||||||
project_path.clone(),
|
.unwrap();
|
||||||
).await.unwrap();
|
|
||||||
|
|
||||||
// Getting the same session should return the same manager
|
// Getting the same session should return the same manager
|
||||||
let manager2 = state.get_or_create_manager(
|
let manager2 = state
|
||||||
session_id.clone(),
|
.get_or_create_manager(session_id.clone(), project_id.clone(), project_path.clone())
|
||||||
project_id.clone(),
|
.await
|
||||||
project_path.clone(),
|
.unwrap();
|
||||||
).await.unwrap();
|
|
||||||
|
|
||||||
assert!(Arc::ptr_eq(&manager1, &manager2));
|
assert!(Arc::ptr_eq(&manager1, &manager2));
|
||||||
assert_eq!(state.active_count().await, 1);
|
assert_eq!(state.active_count().await, 1);
|
||||||
|
|
||||||
// Remove the manager
|
// Remove the manager
|
||||||
let removed = state.remove_manager(&session_id).await;
|
let removed = state.remove_manager(&session_id).await;
|
||||||
assert!(removed.is_some());
|
assert!(removed.is_some());
|
||||||
assert_eq!(state.active_count().await, 0);
|
assert_eq!(state.active_count().await, 0);
|
||||||
|
|
||||||
// Getting after removal should create a new one
|
// Getting after removal should create a new one
|
||||||
let manager3 = state.get_or_create_manager(
|
let manager3 = state
|
||||||
session_id.clone(),
|
.get_or_create_manager(session_id.clone(), project_id, project_path)
|
||||||
project_id,
|
.await
|
||||||
project_path,
|
.unwrap();
|
||||||
).await.unwrap();
|
|
||||||
|
|
||||||
assert!(!Arc::ptr_eq(&manager1, &manager3));
|
assert!(!Arc::ptr_eq(&manager1, &manager3));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use sha2::{Sha256, Digest};
|
|
||||||
use zstd::stream::{encode_all, decode_all};
|
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
use zstd::stream::{decode_all, encode_all};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
Checkpoint, FileSnapshot, SessionTimeline,
|
Checkpoint, CheckpointPaths, CheckpointResult, FileSnapshot, SessionTimeline, TimelineNode,
|
||||||
TimelineNode, CheckpointPaths, CheckpointResult
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Manages checkpoint storage operations
|
/// Manages checkpoint storage operations
|
||||||
@@ -24,26 +23,25 @@ impl CheckpointStorage {
|
|||||||
compression_level: 3, // Default zstd compression level
|
compression_level: 3, // Default zstd compression level
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initialize checkpoint storage for a session
|
/// Initialize checkpoint storage for a session
|
||||||
pub fn init_storage(&self, project_id: &str, session_id: &str) -> Result<()> {
|
pub fn init_storage(&self, project_id: &str, session_id: &str) -> Result<()> {
|
||||||
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
||||||
|
|
||||||
// Create directory structure
|
// Create directory structure
|
||||||
fs::create_dir_all(&paths.checkpoints_dir)
|
fs::create_dir_all(&paths.checkpoints_dir)
|
||||||
.context("Failed to create checkpoints directory")?;
|
.context("Failed to create checkpoints directory")?;
|
||||||
fs::create_dir_all(&paths.files_dir)
|
fs::create_dir_all(&paths.files_dir).context("Failed to create files directory")?;
|
||||||
.context("Failed to create files directory")?;
|
|
||||||
|
|
||||||
// Initialize empty timeline if it doesn't exist
|
// Initialize empty timeline if it doesn't exist
|
||||||
if !paths.timeline_file.exists() {
|
if !paths.timeline_file.exists() {
|
||||||
let timeline = SessionTimeline::new(session_id.to_string());
|
let timeline = SessionTimeline::new(session_id.to_string());
|
||||||
self.save_timeline(&paths.timeline_file, &timeline)?;
|
self.save_timeline(&paths.timeline_file, &timeline)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Save a checkpoint to disk
|
/// Save a checkpoint to disk
|
||||||
pub fn save_checkpoint(
|
pub fn save_checkpoint(
|
||||||
&self,
|
&self,
|
||||||
@@ -55,76 +53,73 @@ impl CheckpointStorage {
|
|||||||
) -> Result<CheckpointResult> {
|
) -> Result<CheckpointResult> {
|
||||||
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
||||||
let checkpoint_dir = paths.checkpoint_dir(&checkpoint.id);
|
let checkpoint_dir = paths.checkpoint_dir(&checkpoint.id);
|
||||||
|
|
||||||
// Create checkpoint directory
|
// Create checkpoint directory
|
||||||
fs::create_dir_all(&checkpoint_dir)
|
fs::create_dir_all(&checkpoint_dir).context("Failed to create checkpoint directory")?;
|
||||||
.context("Failed to create checkpoint directory")?;
|
|
||||||
|
|
||||||
// Save checkpoint metadata
|
// Save checkpoint metadata
|
||||||
let metadata_path = paths.checkpoint_metadata_file(&checkpoint.id);
|
let metadata_path = paths.checkpoint_metadata_file(&checkpoint.id);
|
||||||
let metadata_json = serde_json::to_string_pretty(checkpoint)
|
let metadata_json = serde_json::to_string_pretty(checkpoint)
|
||||||
.context("Failed to serialize checkpoint metadata")?;
|
.context("Failed to serialize checkpoint metadata")?;
|
||||||
fs::write(&metadata_path, metadata_json)
|
fs::write(&metadata_path, metadata_json).context("Failed to write checkpoint metadata")?;
|
||||||
.context("Failed to write checkpoint metadata")?;
|
|
||||||
|
|
||||||
// Save messages (compressed)
|
// Save messages (compressed)
|
||||||
let messages_path = paths.checkpoint_messages_file(&checkpoint.id);
|
let messages_path = paths.checkpoint_messages_file(&checkpoint.id);
|
||||||
let compressed_messages = encode_all(messages.as_bytes(), self.compression_level)
|
let compressed_messages = encode_all(messages.as_bytes(), self.compression_level)
|
||||||
.context("Failed to compress messages")?;
|
.context("Failed to compress messages")?;
|
||||||
fs::write(&messages_path, compressed_messages)
|
fs::write(&messages_path, compressed_messages)
|
||||||
.context("Failed to write compressed messages")?;
|
.context("Failed to write compressed messages")?;
|
||||||
|
|
||||||
// Save file snapshots
|
// Save file snapshots
|
||||||
let mut warnings = Vec::new();
|
let mut warnings = Vec::new();
|
||||||
let mut files_processed = 0;
|
let mut files_processed = 0;
|
||||||
|
|
||||||
for snapshot in &file_snapshots {
|
for snapshot in &file_snapshots {
|
||||||
match self.save_file_snapshot(&paths, snapshot) {
|
match self.save_file_snapshot(&paths, snapshot) {
|
||||||
Ok(_) => files_processed += 1,
|
Ok(_) => files_processed += 1,
|
||||||
Err(e) => warnings.push(format!("Failed to save {}: {}",
|
Err(e) => warnings.push(format!(
|
||||||
snapshot.file_path.display(), e)),
|
"Failed to save {}: {}",
|
||||||
|
snapshot.file_path.display(),
|
||||||
|
e
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update timeline
|
// Update timeline
|
||||||
self.update_timeline_with_checkpoint(
|
self.update_timeline_with_checkpoint(&paths.timeline_file, checkpoint, &file_snapshots)?;
|
||||||
&paths.timeline_file,
|
|
||||||
checkpoint,
|
|
||||||
&file_snapshots
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(CheckpointResult {
|
Ok(CheckpointResult {
|
||||||
checkpoint: checkpoint.clone(),
|
checkpoint: checkpoint.clone(),
|
||||||
files_processed,
|
files_processed,
|
||||||
warnings,
|
warnings,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Save a single file snapshot
|
/// Save a single file snapshot
|
||||||
fn save_file_snapshot(&self, paths: &CheckpointPaths, snapshot: &FileSnapshot) -> Result<()> {
|
fn save_file_snapshot(&self, paths: &CheckpointPaths, snapshot: &FileSnapshot) -> Result<()> {
|
||||||
// Use content-addressable storage: store files by their hash
|
// Use content-addressable storage: store files by their hash
|
||||||
// This prevents duplication of identical file content across checkpoints
|
// This prevents duplication of identical file content across checkpoints
|
||||||
let content_pool_dir = paths.files_dir.join("content_pool");
|
let content_pool_dir = paths.files_dir.join("content_pool");
|
||||||
fs::create_dir_all(&content_pool_dir)
|
fs::create_dir_all(&content_pool_dir).context("Failed to create content pool directory")?;
|
||||||
.context("Failed to create content pool directory")?;
|
|
||||||
|
|
||||||
// Store the actual content in the content pool
|
// Store the actual content in the content pool
|
||||||
let content_file = content_pool_dir.join(&snapshot.hash);
|
let content_file = content_pool_dir.join(&snapshot.hash);
|
||||||
|
|
||||||
// Only write the content if it doesn't already exist
|
// Only write the content if it doesn't already exist
|
||||||
if !content_file.exists() {
|
if !content_file.exists() {
|
||||||
// Compress and save file content
|
// Compress and save file content
|
||||||
let compressed_content = encode_all(snapshot.content.as_bytes(), self.compression_level)
|
let compressed_content =
|
||||||
.context("Failed to compress file content")?;
|
encode_all(snapshot.content.as_bytes(), self.compression_level)
|
||||||
|
.context("Failed to compress file content")?;
|
||||||
fs::write(&content_file, compressed_content)
|
fs::write(&content_file, compressed_content)
|
||||||
.context("Failed to write file content to pool")?;
|
.context("Failed to write file content to pool")?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a reference in the checkpoint-specific directory
|
// Create a reference in the checkpoint-specific directory
|
||||||
let checkpoint_refs_dir = paths.files_dir.join("refs").join(&snapshot.checkpoint_id);
|
let checkpoint_refs_dir = paths.files_dir.join("refs").join(&snapshot.checkpoint_id);
|
||||||
fs::create_dir_all(&checkpoint_refs_dir)
|
fs::create_dir_all(&checkpoint_refs_dir)
|
||||||
.context("Failed to create checkpoint refs directory")?;
|
.context("Failed to create checkpoint refs directory")?;
|
||||||
|
|
||||||
// Save file metadata with reference to content
|
// Save file metadata with reference to content
|
||||||
let ref_metadata = serde_json::json!({
|
let ref_metadata = serde_json::json!({
|
||||||
"path": snapshot.file_path,
|
"path": snapshot.file_path,
|
||||||
@@ -133,20 +128,21 @@ impl CheckpointStorage {
|
|||||||
"permissions": snapshot.permissions,
|
"permissions": snapshot.permissions,
|
||||||
"size": snapshot.size,
|
"size": snapshot.size,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Use a sanitized filename for the reference
|
// Use a sanitized filename for the reference
|
||||||
let safe_filename = snapshot.file_path
|
let safe_filename = snapshot
|
||||||
|
.file_path
|
||||||
.to_string_lossy()
|
.to_string_lossy()
|
||||||
.replace('/', "_")
|
.replace('/', "_")
|
||||||
.replace('\\', "_");
|
.replace('\\', "_");
|
||||||
let ref_path = checkpoint_refs_dir.join(format!("{}.json", safe_filename));
|
let ref_path = checkpoint_refs_dir.join(format!("{}.json", safe_filename));
|
||||||
|
|
||||||
fs::write(&ref_path, serde_json::to_string_pretty(&ref_metadata)?)
|
fs::write(&ref_path, serde_json::to_string_pretty(&ref_metadata)?)
|
||||||
.context("Failed to write file reference")?;
|
.context("Failed to write file reference")?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Load a checkpoint from disk
|
/// Load a checkpoint from disk
|
||||||
pub fn load_checkpoint(
|
pub fn load_checkpoint(
|
||||||
&self,
|
&self,
|
||||||
@@ -155,75 +151,78 @@ impl CheckpointStorage {
|
|||||||
checkpoint_id: &str,
|
checkpoint_id: &str,
|
||||||
) -> Result<(Checkpoint, Vec<FileSnapshot>, String)> {
|
) -> Result<(Checkpoint, Vec<FileSnapshot>, String)> {
|
||||||
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
||||||
|
|
||||||
// Load checkpoint metadata
|
// Load checkpoint metadata
|
||||||
let metadata_path = paths.checkpoint_metadata_file(checkpoint_id);
|
let metadata_path = paths.checkpoint_metadata_file(checkpoint_id);
|
||||||
let metadata_json = fs::read_to_string(&metadata_path)
|
let metadata_json =
|
||||||
.context("Failed to read checkpoint metadata")?;
|
fs::read_to_string(&metadata_path).context("Failed to read checkpoint metadata")?;
|
||||||
let checkpoint: Checkpoint = serde_json::from_str(&metadata_json)
|
let checkpoint: Checkpoint =
|
||||||
.context("Failed to parse checkpoint metadata")?;
|
serde_json::from_str(&metadata_json).context("Failed to parse checkpoint metadata")?;
|
||||||
|
|
||||||
// Load messages
|
// Load messages
|
||||||
let messages_path = paths.checkpoint_messages_file(checkpoint_id);
|
let messages_path = paths.checkpoint_messages_file(checkpoint_id);
|
||||||
let compressed_messages = fs::read(&messages_path)
|
let compressed_messages =
|
||||||
.context("Failed to read compressed messages")?;
|
fs::read(&messages_path).context("Failed to read compressed messages")?;
|
||||||
let messages = String::from_utf8(decode_all(&compressed_messages[..])
|
let messages = String::from_utf8(
|
||||||
.context("Failed to decompress messages")?)
|
decode_all(&compressed_messages[..]).context("Failed to decompress messages")?,
|
||||||
.context("Invalid UTF-8 in messages")?;
|
)
|
||||||
|
.context("Invalid UTF-8 in messages")?;
|
||||||
|
|
||||||
// Load file snapshots
|
// Load file snapshots
|
||||||
let file_snapshots = self.load_file_snapshots(&paths, checkpoint_id)?;
|
let file_snapshots = self.load_file_snapshots(&paths, checkpoint_id)?;
|
||||||
|
|
||||||
Ok((checkpoint, file_snapshots, messages))
|
Ok((checkpoint, file_snapshots, messages))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Load all file snapshots for a checkpoint
|
/// Load all file snapshots for a checkpoint
|
||||||
fn load_file_snapshots(
|
fn load_file_snapshots(
|
||||||
&self,
|
&self,
|
||||||
paths: &CheckpointPaths,
|
paths: &CheckpointPaths,
|
||||||
checkpoint_id: &str
|
checkpoint_id: &str,
|
||||||
) -> Result<Vec<FileSnapshot>> {
|
) -> Result<Vec<FileSnapshot>> {
|
||||||
let refs_dir = paths.files_dir.join("refs").join(checkpoint_id);
|
let refs_dir = paths.files_dir.join("refs").join(checkpoint_id);
|
||||||
if !refs_dir.exists() {
|
if !refs_dir.exists() {
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
let content_pool_dir = paths.files_dir.join("content_pool");
|
let content_pool_dir = paths.files_dir.join("content_pool");
|
||||||
let mut snapshots = Vec::new();
|
let mut snapshots = Vec::new();
|
||||||
|
|
||||||
// Read all reference files
|
// Read all reference files
|
||||||
for entry in fs::read_dir(&refs_dir)? {
|
for entry in fs::read_dir(&refs_dir)? {
|
||||||
let entry = entry?;
|
let entry = entry?;
|
||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
|
|
||||||
// Skip non-JSON files
|
// Skip non-JSON files
|
||||||
if path.extension().and_then(|e| e.to_str()) != Some("json") {
|
if path.extension().and_then(|e| e.to_str()) != Some("json") {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load reference metadata
|
// Load reference metadata
|
||||||
let ref_json = fs::read_to_string(&path)
|
let ref_json = fs::read_to_string(&path).context("Failed to read file reference")?;
|
||||||
.context("Failed to read file reference")?;
|
let ref_metadata: serde_json::Value =
|
||||||
let ref_metadata: serde_json::Value = serde_json::from_str(&ref_json)
|
serde_json::from_str(&ref_json).context("Failed to parse file reference")?;
|
||||||
.context("Failed to parse file reference")?;
|
|
||||||
|
let hash = ref_metadata["hash"]
|
||||||
let hash = ref_metadata["hash"].as_str()
|
.as_str()
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing hash in reference"))?;
|
.ok_or_else(|| anyhow::anyhow!("Missing hash in reference"))?;
|
||||||
|
|
||||||
// Load content from pool
|
// Load content from pool
|
||||||
let content_file = content_pool_dir.join(hash);
|
let content_file = content_pool_dir.join(hash);
|
||||||
let content = if content_file.exists() {
|
let content = if content_file.exists() {
|
||||||
let compressed_content = fs::read(&content_file)
|
let compressed_content =
|
||||||
.context("Failed to read file content from pool")?;
|
fs::read(&content_file).context("Failed to read file content from pool")?;
|
||||||
String::from_utf8(decode_all(&compressed_content[..])
|
String::from_utf8(
|
||||||
.context("Failed to decompress file content")?)
|
decode_all(&compressed_content[..])
|
||||||
.context("Invalid UTF-8 in file content")?
|
.context("Failed to decompress file content")?,
|
||||||
|
)
|
||||||
|
.context("Invalid UTF-8 in file content")?
|
||||||
} else {
|
} else {
|
||||||
// Handle missing content gracefully
|
// Handle missing content gracefully
|
||||||
log::warn!("Content file missing for hash: {}", hash);
|
log::warn!("Content file missing for hash: {}", hash);
|
||||||
String::new()
|
String::new()
|
||||||
};
|
};
|
||||||
|
|
||||||
snapshots.push(FileSnapshot {
|
snapshots.push(FileSnapshot {
|
||||||
checkpoint_id: checkpoint_id.to_string(),
|
checkpoint_id: checkpoint_id.to_string(),
|
||||||
file_path: PathBuf::from(ref_metadata["path"].as_str().unwrap_or("")),
|
file_path: PathBuf::from(ref_metadata["path"].as_str().unwrap_or("")),
|
||||||
@@ -234,28 +233,26 @@ impl CheckpointStorage {
|
|||||||
size: ref_metadata["size"].as_u64().unwrap_or(0),
|
size: ref_metadata["size"].as_u64().unwrap_or(0),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(snapshots)
|
Ok(snapshots)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Save timeline to disk
|
/// Save timeline to disk
|
||||||
pub fn save_timeline(&self, timeline_path: &Path, timeline: &SessionTimeline) -> Result<()> {
|
pub fn save_timeline(&self, timeline_path: &Path, timeline: &SessionTimeline) -> Result<()> {
|
||||||
let timeline_json = serde_json::to_string_pretty(timeline)
|
let timeline_json =
|
||||||
.context("Failed to serialize timeline")?;
|
serde_json::to_string_pretty(timeline).context("Failed to serialize timeline")?;
|
||||||
fs::write(timeline_path, timeline_json)
|
fs::write(timeline_path, timeline_json).context("Failed to write timeline")?;
|
||||||
.context("Failed to write timeline")?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Load timeline from disk
|
/// Load timeline from disk
|
||||||
pub fn load_timeline(&self, timeline_path: &Path) -> Result<SessionTimeline> {
|
pub fn load_timeline(&self, timeline_path: &Path) -> Result<SessionTimeline> {
|
||||||
let timeline_json = fs::read_to_string(timeline_path)
|
let timeline_json = fs::read_to_string(timeline_path).context("Failed to read timeline")?;
|
||||||
.context("Failed to read timeline")?;
|
let timeline: SessionTimeline =
|
||||||
let timeline: SessionTimeline = serde_json::from_str(&timeline_json)
|
serde_json::from_str(&timeline_json).context("Failed to parse timeline")?;
|
||||||
.context("Failed to parse timeline")?;
|
|
||||||
Ok(timeline)
|
Ok(timeline)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Update timeline with a new checkpoint
|
/// Update timeline with a new checkpoint
|
||||||
fn update_timeline_with_checkpoint(
|
fn update_timeline_with_checkpoint(
|
||||||
&self,
|
&self,
|
||||||
@@ -264,15 +261,13 @@ impl CheckpointStorage {
|
|||||||
file_snapshots: &[FileSnapshot],
|
file_snapshots: &[FileSnapshot],
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut timeline = self.load_timeline(timeline_path)?;
|
let mut timeline = self.load_timeline(timeline_path)?;
|
||||||
|
|
||||||
let new_node = TimelineNode {
|
let new_node = TimelineNode {
|
||||||
checkpoint: checkpoint.clone(),
|
checkpoint: checkpoint.clone(),
|
||||||
children: Vec::new(),
|
children: Vec::new(),
|
||||||
file_snapshot_ids: file_snapshots.iter()
|
file_snapshot_ids: file_snapshots.iter().map(|s| s.hash.clone()).collect(),
|
||||||
.map(|s| s.hash.clone())
|
|
||||||
.collect(),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// If this is the first checkpoint
|
// If this is the first checkpoint
|
||||||
if timeline.root_node.is_none() {
|
if timeline.root_node.is_none() {
|
||||||
timeline.root_node = Some(new_node);
|
timeline.root_node = Some(new_node);
|
||||||
@@ -280,7 +275,7 @@ impl CheckpointStorage {
|
|||||||
} else if let Some(parent_id) = &checkpoint.parent_checkpoint_id {
|
} else if let Some(parent_id) = &checkpoint.parent_checkpoint_id {
|
||||||
// Check if parent exists before modifying
|
// Check if parent exists before modifying
|
||||||
let parent_exists = timeline.find_checkpoint(parent_id).is_some();
|
let parent_exists = timeline.find_checkpoint(parent_id).is_some();
|
||||||
|
|
||||||
if parent_exists {
|
if parent_exists {
|
||||||
if let Some(root) = &mut timeline.root_node {
|
if let Some(root) = &mut timeline.root_node {
|
||||||
Self::add_child_to_node(root, parent_id, new_node)?;
|
Self::add_child_to_node(root, parent_id, new_node)?;
|
||||||
@@ -290,59 +285,54 @@ impl CheckpointStorage {
|
|||||||
anyhow::bail!("Parent checkpoint not found: {}", parent_id);
|
anyhow::bail!("Parent checkpoint not found: {}", parent_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
timeline.total_checkpoints += 1;
|
timeline.total_checkpoints += 1;
|
||||||
self.save_timeline(timeline_path, &timeline)?;
|
self.save_timeline(timeline_path, &timeline)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Recursively add a child node to the timeline tree
|
/// Recursively add a child node to the timeline tree
|
||||||
fn add_child_to_node(
|
fn add_child_to_node(
|
||||||
node: &mut TimelineNode,
|
node: &mut TimelineNode,
|
||||||
parent_id: &str,
|
parent_id: &str,
|
||||||
child: TimelineNode
|
child: TimelineNode,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
if node.checkpoint.id == parent_id {
|
if node.checkpoint.id == parent_id {
|
||||||
node.children.push(child);
|
node.children.push(child);
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
for child_node in &mut node.children {
|
for child_node in &mut node.children {
|
||||||
if Self::add_child_to_node(child_node, parent_id, child.clone()).is_ok() {
|
if Self::add_child_to_node(child_node, parent_id, child.clone()).is_ok() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
anyhow::bail!("Parent checkpoint not found: {}", parent_id)
|
anyhow::bail!("Parent checkpoint not found: {}", parent_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Calculate hash of file content
|
/// Calculate hash of file content
|
||||||
pub fn calculate_file_hash(content: &str) -> String {
|
pub fn calculate_file_hash(content: &str) -> String {
|
||||||
let mut hasher = Sha256::new();
|
let mut hasher = Sha256::new();
|
||||||
hasher.update(content.as_bytes());
|
hasher.update(content.as_bytes());
|
||||||
format!("{:x}", hasher.finalize())
|
format!("{:x}", hasher.finalize())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generate a new checkpoint ID
|
/// Generate a new checkpoint ID
|
||||||
pub fn generate_checkpoint_id() -> String {
|
pub fn generate_checkpoint_id() -> String {
|
||||||
Uuid::new_v4().to_string()
|
Uuid::new_v4().to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Estimate storage size for a checkpoint
|
/// Estimate storage size for a checkpoint
|
||||||
pub fn estimate_checkpoint_size(
|
pub fn estimate_checkpoint_size(messages: &str, file_snapshots: &[FileSnapshot]) -> u64 {
|
||||||
messages: &str,
|
|
||||||
file_snapshots: &[FileSnapshot],
|
|
||||||
) -> u64 {
|
|
||||||
let messages_size = messages.len() as u64;
|
let messages_size = messages.len() as u64;
|
||||||
let files_size: u64 = file_snapshots.iter()
|
let files_size: u64 = file_snapshots.iter().map(|s| s.content.len() as u64).sum();
|
||||||
.map(|s| s.content.len() as u64)
|
|
||||||
.sum();
|
|
||||||
|
|
||||||
// Estimate compressed size (typically 20-30% of original for text)
|
// Estimate compressed size (typically 20-30% of original for text)
|
||||||
(messages_size + files_size) / 4
|
(messages_size + files_size) / 4
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Clean up old checkpoints based on retention policy
|
/// Clean up old checkpoints based on retention policy
|
||||||
pub fn cleanup_old_checkpoints(
|
pub fn cleanup_old_checkpoints(
|
||||||
&self,
|
&self,
|
||||||
@@ -352,26 +342,26 @@ impl CheckpointStorage {
|
|||||||
) -> Result<usize> {
|
) -> Result<usize> {
|
||||||
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
||||||
let timeline = self.load_timeline(&paths.timeline_file)?;
|
let timeline = self.load_timeline(&paths.timeline_file)?;
|
||||||
|
|
||||||
// Collect all checkpoint IDs in chronological order
|
// Collect all checkpoint IDs in chronological order
|
||||||
let mut all_checkpoints = Vec::new();
|
let mut all_checkpoints = Vec::new();
|
||||||
if let Some(root) = &timeline.root_node {
|
if let Some(root) = &timeline.root_node {
|
||||||
Self::collect_checkpoints(root, &mut all_checkpoints);
|
Self::collect_checkpoints(root, &mut all_checkpoints);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort by timestamp (oldest first)
|
// Sort by timestamp (oldest first)
|
||||||
all_checkpoints.sort_by(|a, b| a.timestamp.cmp(&b.timestamp));
|
all_checkpoints.sort_by(|a, b| a.timestamp.cmp(&b.timestamp));
|
||||||
|
|
||||||
// Keep only the most recent checkpoints
|
// Keep only the most recent checkpoints
|
||||||
let to_remove = all_checkpoints.len().saturating_sub(keep_count);
|
let to_remove = all_checkpoints.len().saturating_sub(keep_count);
|
||||||
let mut removed_count = 0;
|
let mut removed_count = 0;
|
||||||
|
|
||||||
for checkpoint in all_checkpoints.into_iter().take(to_remove) {
|
for checkpoint in all_checkpoints.into_iter().take(to_remove) {
|
||||||
if self.remove_checkpoint(&paths, &checkpoint.id).is_ok() {
|
if self.remove_checkpoint(&paths, &checkpoint.id).is_ok() {
|
||||||
removed_count += 1;
|
removed_count += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run garbage collection to clean up orphaned content
|
// Run garbage collection to clean up orphaned content
|
||||||
if removed_count > 0 {
|
if removed_count > 0 {
|
||||||
match self.garbage_collect_content(project_id, session_id) {
|
match self.garbage_collect_content(project_id, session_id) {
|
||||||
@@ -383,10 +373,10 @@ impl CheckpointStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(removed_count)
|
Ok(removed_count)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Collect all checkpoints from the tree in order
|
/// Collect all checkpoints from the tree in order
|
||||||
fn collect_checkpoints(node: &TimelineNode, checkpoints: &mut Vec<Checkpoint>) {
|
fn collect_checkpoints(node: &TimelineNode, checkpoints: &mut Vec<Checkpoint>) {
|
||||||
checkpoints.push(node.checkpoint.clone());
|
checkpoints.push(node.checkpoint.clone());
|
||||||
@@ -394,46 +384,40 @@ impl CheckpointStorage {
|
|||||||
Self::collect_checkpoints(child, checkpoints);
|
Self::collect_checkpoints(child, checkpoints);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Remove a checkpoint and its associated files
|
/// Remove a checkpoint and its associated files
|
||||||
fn remove_checkpoint(&self, paths: &CheckpointPaths, checkpoint_id: &str) -> Result<()> {
|
fn remove_checkpoint(&self, paths: &CheckpointPaths, checkpoint_id: &str) -> Result<()> {
|
||||||
// Remove checkpoint metadata directory
|
// Remove checkpoint metadata directory
|
||||||
let checkpoint_dir = paths.checkpoint_dir(checkpoint_id);
|
let checkpoint_dir = paths.checkpoint_dir(checkpoint_id);
|
||||||
if checkpoint_dir.exists() {
|
if checkpoint_dir.exists() {
|
||||||
fs::remove_dir_all(&checkpoint_dir)
|
fs::remove_dir_all(&checkpoint_dir).context("Failed to remove checkpoint directory")?;
|
||||||
.context("Failed to remove checkpoint directory")?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove file references for this checkpoint
|
// Remove file references for this checkpoint
|
||||||
let refs_dir = paths.files_dir.join("refs").join(checkpoint_id);
|
let refs_dir = paths.files_dir.join("refs").join(checkpoint_id);
|
||||||
if refs_dir.exists() {
|
if refs_dir.exists() {
|
||||||
fs::remove_dir_all(&refs_dir)
|
fs::remove_dir_all(&refs_dir).context("Failed to remove file references")?;
|
||||||
.context("Failed to remove file references")?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: We don't remove content from the pool here as it might be
|
// Note: We don't remove content from the pool here as it might be
|
||||||
// referenced by other checkpoints. Use garbage_collect_content() for that.
|
// referenced by other checkpoints. Use garbage_collect_content() for that.
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Garbage collect unreferenced content from the content pool
|
/// Garbage collect unreferenced content from the content pool
|
||||||
pub fn garbage_collect_content(
|
pub fn garbage_collect_content(&self, project_id: &str, session_id: &str) -> Result<usize> {
|
||||||
&self,
|
|
||||||
project_id: &str,
|
|
||||||
session_id: &str,
|
|
||||||
) -> Result<usize> {
|
|
||||||
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
||||||
let content_pool_dir = paths.files_dir.join("content_pool");
|
let content_pool_dir = paths.files_dir.join("content_pool");
|
||||||
let refs_dir = paths.files_dir.join("refs");
|
let refs_dir = paths.files_dir.join("refs");
|
||||||
|
|
||||||
if !content_pool_dir.exists() {
|
if !content_pool_dir.exists() {
|
||||||
return Ok(0);
|
return Ok(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect all referenced hashes
|
// Collect all referenced hashes
|
||||||
let mut referenced_hashes = std::collections::HashSet::new();
|
let mut referenced_hashes = std::collections::HashSet::new();
|
||||||
|
|
||||||
if refs_dir.exists() {
|
if refs_dir.exists() {
|
||||||
for checkpoint_entry in fs::read_dir(&refs_dir)? {
|
for checkpoint_entry in fs::read_dir(&refs_dir)? {
|
||||||
let checkpoint_dir = checkpoint_entry?.path();
|
let checkpoint_dir = checkpoint_entry?.path();
|
||||||
@@ -442,7 +426,9 @@ impl CheckpointStorage {
|
|||||||
let ref_path = ref_entry?.path();
|
let ref_path = ref_entry?.path();
|
||||||
if ref_path.extension().and_then(|e| e.to_str()) == Some("json") {
|
if ref_path.extension().and_then(|e| e.to_str()) == Some("json") {
|
||||||
if let Ok(ref_json) = fs::read_to_string(&ref_path) {
|
if let Ok(ref_json) = fs::read_to_string(&ref_path) {
|
||||||
if let Ok(ref_metadata) = serde_json::from_str::<serde_json::Value>(&ref_json) {
|
if let Ok(ref_metadata) =
|
||||||
|
serde_json::from_str::<serde_json::Value>(&ref_json)
|
||||||
|
{
|
||||||
if let Some(hash) = ref_metadata["hash"].as_str() {
|
if let Some(hash) = ref_metadata["hash"].as_str() {
|
||||||
referenced_hashes.insert(hash.to_string());
|
referenced_hashes.insert(hash.to_string());
|
||||||
}
|
}
|
||||||
@@ -453,7 +439,7 @@ impl CheckpointStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove unreferenced content
|
// Remove unreferenced content
|
||||||
let mut removed_count = 0;
|
let mut removed_count = 0;
|
||||||
for entry in fs::read_dir(&content_pool_dir)? {
|
for entry in fs::read_dir(&content_pool_dir)? {
|
||||||
@@ -468,7 +454,7 @@ impl CheckpointStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(removed_count)
|
Ok(removed_count)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
|
use anyhow::Result;
|
||||||
|
use log::{debug, error, info, warn};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::cmp::Ordering;
|
||||||
/// Shared module for detecting Claude Code binary installations
|
/// Shared module for detecting Claude Code binary installations
|
||||||
/// Supports NVM installations, aliased paths, and version-based selection
|
/// Supports NVM installations, aliased paths, and version-based selection
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
use log::{info, warn, debug, error};
|
|
||||||
use anyhow::Result;
|
|
||||||
use std::cmp::Ordering;
|
|
||||||
use tauri::Manager;
|
use tauri::Manager;
|
||||||
use serde::{Serialize, Deserialize};
|
|
||||||
|
|
||||||
/// Represents a Claude installation with metadata
|
/// Represents a Claude installation with metadata
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -23,7 +23,7 @@ pub struct ClaudeInstallation {
|
|||||||
/// Checks database first, then discovers all installations and selects the best one
|
/// Checks database first, then discovers all installations and selects the best one
|
||||||
pub fn find_claude_binary(app_handle: &tauri::AppHandle) -> Result<String, String> {
|
pub fn find_claude_binary(app_handle: &tauri::AppHandle) -> Result<String, String> {
|
||||||
info!("Searching for claude binary...");
|
info!("Searching for claude binary...");
|
||||||
|
|
||||||
// First check if we have a stored path in the database
|
// First check if we have a stored path in the database
|
||||||
if let Ok(app_data_dir) = app_handle.path().app_data_dir() {
|
if let Ok(app_data_dir) = app_handle.path().app_data_dir() {
|
||||||
let db_path = app_data_dir.join("agents.db");
|
let db_path = app_data_dir.join("agents.db");
|
||||||
@@ -45,24 +45,26 @@ pub fn find_claude_binary(app_handle: &tauri::AppHandle) -> Result<String, Strin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Discover all available installations
|
// Discover all available installations
|
||||||
let installations = discover_all_installations();
|
let installations = discover_all_installations();
|
||||||
|
|
||||||
if installations.is_empty() {
|
if installations.is_empty() {
|
||||||
error!("Could not find claude binary in any location");
|
error!("Could not find claude binary in any location");
|
||||||
return Err("Claude Code not found. Please ensure it's installed in one of these locations: PATH, /usr/local/bin, /opt/homebrew/bin, ~/.nvm/versions/node/*/bin, ~/.claude/local, ~/.local/bin".to_string());
|
return Err("Claude Code not found. Please ensure it's installed in one of these locations: PATH, /usr/local/bin, /opt/homebrew/bin, ~/.nvm/versions/node/*/bin, ~/.claude/local, ~/.local/bin".to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log all found installations
|
// Log all found installations
|
||||||
for installation in &installations {
|
for installation in &installations {
|
||||||
info!("Found Claude installation: {:?}", installation);
|
info!("Found Claude installation: {:?}", installation);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Select the best installation (highest version)
|
// Select the best installation (highest version)
|
||||||
if let Some(best) = select_best_installation(installations) {
|
if let Some(best) = select_best_installation(installations) {
|
||||||
info!("Selected Claude installation: path={}, version={:?}, source={}",
|
info!(
|
||||||
best.path, best.version, best.source);
|
"Selected Claude installation: path={}, version={:?}, source={}",
|
||||||
|
best.path, best.version, best.source
|
||||||
|
);
|
||||||
Ok(best.path)
|
Ok(best.path)
|
||||||
} else {
|
} else {
|
||||||
Err("No valid Claude installation found".to_string())
|
Err("No valid Claude installation found".to_string())
|
||||||
@@ -73,9 +75,9 @@ pub fn find_claude_binary(app_handle: &tauri::AppHandle) -> Result<String, Strin
|
|||||||
/// This allows UI to show a version selector
|
/// This allows UI to show a version selector
|
||||||
pub fn discover_claude_installations() -> Vec<ClaudeInstallation> {
|
pub fn discover_claude_installations() -> Vec<ClaudeInstallation> {
|
||||||
info!("Discovering all Claude installations...");
|
info!("Discovering all Claude installations...");
|
||||||
|
|
||||||
let installations = discover_all_installations();
|
let installations = discover_all_installations();
|
||||||
|
|
||||||
// Sort by version (highest first), then by source preference
|
// Sort by version (highest first), then by source preference
|
||||||
let mut sorted = installations;
|
let mut sorted = installations;
|
||||||
sorted.sort_by(|a, b| {
|
sorted.sort_by(|a, b| {
|
||||||
@@ -87,15 +89,15 @@ pub fn discover_claude_installations() -> Vec<ClaudeInstallation> {
|
|||||||
// If versions are equal, prefer by source
|
// If versions are equal, prefer by source
|
||||||
source_preference(a).cmp(&source_preference(b))
|
source_preference(a).cmp(&source_preference(b))
|
||||||
}
|
}
|
||||||
other => other
|
other => other,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(Some(_), None) => Ordering::Less, // Version comes before no version
|
(Some(_), None) => Ordering::Less, // Version comes before no version
|
||||||
(None, Some(_)) => Ordering::Greater,
|
(None, Some(_)) => Ordering::Greater,
|
||||||
(None, None) => source_preference(a).cmp(&source_preference(b))
|
(None, None) => source_preference(a).cmp(&source_preference(b)),
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
sorted
|
sorted
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,57 +123,58 @@ fn source_preference(installation: &ClaudeInstallation) -> u8 {
|
|||||||
/// Discovers all Claude installations on the system
|
/// Discovers all Claude installations on the system
|
||||||
fn discover_all_installations() -> Vec<ClaudeInstallation> {
|
fn discover_all_installations() -> Vec<ClaudeInstallation> {
|
||||||
let mut installations = Vec::new();
|
let mut installations = Vec::new();
|
||||||
|
|
||||||
// 1. Try 'which' command first (now works in production)
|
// 1. Try 'which' command first (now works in production)
|
||||||
if let Some(installation) = try_which_command() {
|
if let Some(installation) = try_which_command() {
|
||||||
installations.push(installation);
|
installations.push(installation);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Check NVM paths
|
// 2. Check NVM paths
|
||||||
installations.extend(find_nvm_installations());
|
installations.extend(find_nvm_installations());
|
||||||
|
|
||||||
// 3. Check standard paths
|
// 3. Check standard paths
|
||||||
installations.extend(find_standard_installations());
|
installations.extend(find_standard_installations());
|
||||||
|
|
||||||
// Remove duplicates by path
|
// Remove duplicates by path
|
||||||
let mut unique_paths = std::collections::HashSet::new();
|
let mut unique_paths = std::collections::HashSet::new();
|
||||||
installations.retain(|install| unique_paths.insert(install.path.clone()));
|
installations.retain(|install| unique_paths.insert(install.path.clone()));
|
||||||
|
|
||||||
installations
|
installations
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Try using the 'which' command to find Claude
|
/// Try using the 'which' command to find Claude
|
||||||
fn try_which_command() -> Option<ClaudeInstallation> {
|
fn try_which_command() -> Option<ClaudeInstallation> {
|
||||||
debug!("Trying 'which claude' to find binary...");
|
debug!("Trying 'which claude' to find binary...");
|
||||||
|
|
||||||
match Command::new("which").arg("claude").output() {
|
match Command::new("which").arg("claude").output() {
|
||||||
Ok(output) if output.status.success() => {
|
Ok(output) if output.status.success() => {
|
||||||
let output_str = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
let output_str = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||||
|
|
||||||
if output_str.is_empty() {
|
if output_str.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse aliased output: "claude: aliased to /path/to/claude"
|
// Parse aliased output: "claude: aliased to /path/to/claude"
|
||||||
let path = if output_str.starts_with("claude:") && output_str.contains("aliased to") {
|
let path = if output_str.starts_with("claude:") && output_str.contains("aliased to") {
|
||||||
output_str.split("aliased to")
|
output_str
|
||||||
|
.split("aliased to")
|
||||||
.nth(1)
|
.nth(1)
|
||||||
.map(|s| s.trim().to_string())
|
.map(|s| s.trim().to_string())
|
||||||
} else {
|
} else {
|
||||||
Some(output_str)
|
Some(output_str)
|
||||||
}?;
|
}?;
|
||||||
|
|
||||||
debug!("'which' found claude at: {}", path);
|
debug!("'which' found claude at: {}", path);
|
||||||
|
|
||||||
// Verify the path exists
|
// Verify the path exists
|
||||||
if !PathBuf::from(&path).exists() {
|
if !PathBuf::from(&path).exists() {
|
||||||
warn!("Path from 'which' does not exist: {}", path);
|
warn!("Path from 'which' does not exist: {}", path);
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get version
|
// Get version
|
||||||
let version = get_claude_version(&path).ok().flatten();
|
let version = get_claude_version(&path).ok().flatten();
|
||||||
|
|
||||||
Some(ClaudeInstallation {
|
Some(ClaudeInstallation {
|
||||||
path,
|
path,
|
||||||
version,
|
version,
|
||||||
@@ -185,26 +188,29 @@ fn try_which_command() -> Option<ClaudeInstallation> {
|
|||||||
/// Find Claude installations in NVM directories
|
/// Find Claude installations in NVM directories
|
||||||
fn find_nvm_installations() -> Vec<ClaudeInstallation> {
|
fn find_nvm_installations() -> Vec<ClaudeInstallation> {
|
||||||
let mut installations = Vec::new();
|
let mut installations = Vec::new();
|
||||||
|
|
||||||
if let Ok(home) = std::env::var("HOME") {
|
if let Ok(home) = std::env::var("HOME") {
|
||||||
let nvm_dir = PathBuf::from(&home).join(".nvm").join("versions").join("node");
|
let nvm_dir = PathBuf::from(&home)
|
||||||
|
.join(".nvm")
|
||||||
|
.join("versions")
|
||||||
|
.join("node");
|
||||||
|
|
||||||
debug!("Checking NVM directory: {:?}", nvm_dir);
|
debug!("Checking NVM directory: {:?}", nvm_dir);
|
||||||
|
|
||||||
if let Ok(entries) = std::fs::read_dir(&nvm_dir) {
|
if let Ok(entries) = std::fs::read_dir(&nvm_dir) {
|
||||||
for entry in entries.flatten() {
|
for entry in entries.flatten() {
|
||||||
if entry.file_type().map(|t| t.is_dir()).unwrap_or(false) {
|
if entry.file_type().map(|t| t.is_dir()).unwrap_or(false) {
|
||||||
let claude_path = entry.path().join("bin").join("claude");
|
let claude_path = entry.path().join("bin").join("claude");
|
||||||
|
|
||||||
if claude_path.exists() && claude_path.is_file() {
|
if claude_path.exists() && claude_path.is_file() {
|
||||||
let path_str = claude_path.to_string_lossy().to_string();
|
let path_str = claude_path.to_string_lossy().to_string();
|
||||||
let node_version = entry.file_name().to_string_lossy().to_string();
|
let node_version = entry.file_name().to_string_lossy().to_string();
|
||||||
|
|
||||||
debug!("Found Claude in NVM node {}: {}", node_version, path_str);
|
debug!("Found Claude in NVM node {}: {}", node_version, path_str);
|
||||||
|
|
||||||
// Get Claude version
|
// Get Claude version
|
||||||
let version = get_claude_version(&path_str).ok().flatten();
|
let version = get_claude_version(&path_str).ok().flatten();
|
||||||
|
|
||||||
installations.push(ClaudeInstallation {
|
installations.push(ClaudeInstallation {
|
||||||
path: path_str,
|
path: path_str,
|
||||||
version,
|
version,
|
||||||
@@ -215,46 +221,64 @@ fn find_nvm_installations() -> Vec<ClaudeInstallation> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
installations
|
installations
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check standard installation paths
|
/// Check standard installation paths
|
||||||
fn find_standard_installations() -> Vec<ClaudeInstallation> {
|
fn find_standard_installations() -> Vec<ClaudeInstallation> {
|
||||||
let mut installations = Vec::new();
|
let mut installations = Vec::new();
|
||||||
|
|
||||||
// Common installation paths for claude
|
// Common installation paths for claude
|
||||||
let mut paths_to_check: Vec<(String, String)> = vec![
|
let mut paths_to_check: Vec<(String, String)> = vec![
|
||||||
("/usr/local/bin/claude".to_string(), "system".to_string()),
|
("/usr/local/bin/claude".to_string(), "system".to_string()),
|
||||||
("/opt/homebrew/bin/claude".to_string(), "homebrew".to_string()),
|
(
|
||||||
|
"/opt/homebrew/bin/claude".to_string(),
|
||||||
|
"homebrew".to_string(),
|
||||||
|
),
|
||||||
("/usr/bin/claude".to_string(), "system".to_string()),
|
("/usr/bin/claude".to_string(), "system".to_string()),
|
||||||
("/bin/claude".to_string(), "system".to_string()),
|
("/bin/claude".to_string(), "system".to_string()),
|
||||||
];
|
];
|
||||||
|
|
||||||
// Also check user-specific paths
|
// Also check user-specific paths
|
||||||
if let Ok(home) = std::env::var("HOME") {
|
if let Ok(home) = std::env::var("HOME") {
|
||||||
paths_to_check.extend(vec![
|
paths_to_check.extend(vec![
|
||||||
(format!("{}/.claude/local/claude", home), "claude-local".to_string()),
|
(
|
||||||
(format!("{}/.local/bin/claude", home), "local-bin".to_string()),
|
format!("{}/.claude/local/claude", home),
|
||||||
(format!("{}/.npm-global/bin/claude", home), "npm-global".to_string()),
|
"claude-local".to_string(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
format!("{}/.local/bin/claude", home),
|
||||||
|
"local-bin".to_string(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
format!("{}/.npm-global/bin/claude", home),
|
||||||
|
"npm-global".to_string(),
|
||||||
|
),
|
||||||
(format!("{}/.yarn/bin/claude", home), "yarn".to_string()),
|
(format!("{}/.yarn/bin/claude", home), "yarn".to_string()),
|
||||||
(format!("{}/.bun/bin/claude", home), "bun".to_string()),
|
(format!("{}/.bun/bin/claude", home), "bun".to_string()),
|
||||||
(format!("{}/bin/claude", home), "home-bin".to_string()),
|
(format!("{}/bin/claude", home), "home-bin".to_string()),
|
||||||
// Check common node_modules locations
|
// Check common node_modules locations
|
||||||
(format!("{}/node_modules/.bin/claude", home), "node-modules".to_string()),
|
(
|
||||||
(format!("{}/.config/yarn/global/node_modules/.bin/claude", home), "yarn-global".to_string()),
|
format!("{}/node_modules/.bin/claude", home),
|
||||||
|
"node-modules".to_string(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
format!("{}/.config/yarn/global/node_modules/.bin/claude", home),
|
||||||
|
"yarn-global".to_string(),
|
||||||
|
),
|
||||||
]);
|
]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check each path
|
// Check each path
|
||||||
for (path, source) in paths_to_check {
|
for (path, source) in paths_to_check {
|
||||||
let path_buf = PathBuf::from(&path);
|
let path_buf = PathBuf::from(&path);
|
||||||
if path_buf.exists() && path_buf.is_file() {
|
if path_buf.exists() && path_buf.is_file() {
|
||||||
debug!("Found claude at standard path: {} ({})", path, source);
|
debug!("Found claude at standard path: {} ({})", path, source);
|
||||||
|
|
||||||
// Get version
|
// Get version
|
||||||
let version = get_claude_version(&path).ok().flatten();
|
let version = get_claude_version(&path).ok().flatten();
|
||||||
|
|
||||||
installations.push(ClaudeInstallation {
|
installations.push(ClaudeInstallation {
|
||||||
path,
|
path,
|
||||||
version,
|
version,
|
||||||
@@ -262,13 +286,13 @@ fn find_standard_installations() -> Vec<ClaudeInstallation> {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also check if claude is available in PATH (without full path)
|
// Also check if claude is available in PATH (without full path)
|
||||||
if let Ok(output) = Command::new("claude").arg("--version").output() {
|
if let Ok(output) = Command::new("claude").arg("--version").output() {
|
||||||
if output.status.success() {
|
if output.status.success() {
|
||||||
debug!("claude is available in PATH");
|
debug!("claude is available in PATH");
|
||||||
let version = extract_version_from_output(&output.stdout);
|
let version = extract_version_from_output(&output.stdout);
|
||||||
|
|
||||||
installations.push(ClaudeInstallation {
|
installations.push(ClaudeInstallation {
|
||||||
path: "claude".to_string(),
|
path: "claude".to_string(),
|
||||||
version,
|
version,
|
||||||
@@ -276,7 +300,7 @@ fn find_standard_installations() -> Vec<ClaudeInstallation> {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
installations
|
installations
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -300,13 +324,13 @@ fn get_claude_version(path: &str) -> Result<Option<String>, String> {
|
|||||||
/// Extract version string from command output
|
/// Extract version string from command output
|
||||||
fn extract_version_from_output(stdout: &[u8]) -> Option<String> {
|
fn extract_version_from_output(stdout: &[u8]) -> Option<String> {
|
||||||
let output_str = String::from_utf8_lossy(stdout);
|
let output_str = String::from_utf8_lossy(stdout);
|
||||||
|
|
||||||
// Extract version: first token before whitespace that looks like a version
|
// Extract version: first token before whitespace that looks like a version
|
||||||
output_str.split_whitespace()
|
output_str
|
||||||
|
.split_whitespace()
|
||||||
.find(|token| {
|
.find(|token| {
|
||||||
// Version usually contains dots and numbers
|
// Version usually contains dots and numbers
|
||||||
token.chars().any(|c| c == '.') &&
|
token.chars().any(|c| c == '.') && token.chars().any(|c| c.is_numeric())
|
||||||
token.chars().any(|c| c.is_numeric())
|
|
||||||
})
|
})
|
||||||
.map(|s| s.to_string())
|
.map(|s| s.to_string())
|
||||||
}
|
}
|
||||||
@@ -320,34 +344,34 @@ fn select_best_installation(installations: Vec<ClaudeInstallation>) -> Option<Cl
|
|||||||
// prefer binaries with version information when it is available so that
|
// prefer binaries with version information when it is available so that
|
||||||
// in development builds we keep the previous behaviour of picking the
|
// in development builds we keep the previous behaviour of picking the
|
||||||
// most recent version.
|
// most recent version.
|
||||||
installations.into_iter()
|
installations.into_iter().max_by(|a, b| {
|
||||||
.max_by(|a, b| {
|
match (&a.version, &b.version) {
|
||||||
match (&a.version, &b.version) {
|
// If both have versions, compare them semantically.
|
||||||
// If both have versions, compare them semantically.
|
(Some(v1), Some(v2)) => compare_versions(v1, v2),
|
||||||
(Some(v1), Some(v2)) => compare_versions(v1, v2),
|
// Prefer the entry that actually has version information.
|
||||||
// Prefer the entry that actually has version information.
|
(Some(_), None) => Ordering::Greater,
|
||||||
(Some(_), None) => Ordering::Greater,
|
(None, Some(_)) => Ordering::Less,
|
||||||
(None, Some(_)) => Ordering::Less,
|
// Neither have version info: prefer the one that is not just
|
||||||
// Neither have version info: prefer the one that is not just
|
// the bare "claude" lookup from PATH, because that may fail
|
||||||
// the bare "claude" lookup from PATH, because that may fail
|
// at runtime if PATH is sandbox-stripped.
|
||||||
// at runtime if PATH is sandbox-stripped.
|
(None, None) => {
|
||||||
(None, None) => {
|
if a.path == "claude" && b.path != "claude" {
|
||||||
if a.path == "claude" && b.path != "claude" {
|
Ordering::Less
|
||||||
Ordering::Less
|
} else if a.path != "claude" && b.path == "claude" {
|
||||||
} else if a.path != "claude" && b.path == "claude" {
|
Ordering::Greater
|
||||||
Ordering::Greater
|
} else {
|
||||||
} else {
|
Ordering::Equal
|
||||||
Ordering::Equal
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compare two version strings
|
/// Compare two version strings
|
||||||
fn compare_versions(a: &str, b: &str) -> Ordering {
|
fn compare_versions(a: &str, b: &str) -> Ordering {
|
||||||
// Simple semantic version comparison
|
// Simple semantic version comparison
|
||||||
let a_parts: Vec<u32> = a.split('.')
|
let a_parts: Vec<u32> = a
|
||||||
|
.split('.')
|
||||||
.filter_map(|s| {
|
.filter_map(|s| {
|
||||||
// Handle versions like "1.0.17-beta" by taking only numeric part
|
// Handle versions like "1.0.17-beta" by taking only numeric part
|
||||||
s.chars()
|
s.chars()
|
||||||
@@ -357,8 +381,9 @@ fn compare_versions(a: &str, b: &str) -> Ordering {
|
|||||||
.ok()
|
.ok()
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let b_parts: Vec<u32> = b.split('.')
|
let b_parts: Vec<u32> = b
|
||||||
|
.split('.')
|
||||||
.filter_map(|s| {
|
.filter_map(|s| {
|
||||||
s.chars()
|
s.chars()
|
||||||
.take_while(|c| c.is_numeric())
|
.take_while(|c| c.is_numeric())
|
||||||
@@ -367,7 +392,7 @@ fn compare_versions(a: &str, b: &str) -> Ordering {
|
|||||||
.ok()
|
.ok()
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Compare each part
|
// Compare each part
|
||||||
for i in 0..std::cmp::max(a_parts.len(), b_parts.len()) {
|
for i in 0..std::cmp::max(a_parts.len(), b_parts.len()) {
|
||||||
let a_val = a_parts.get(i).unwrap_or(&0);
|
let a_val = a_parts.get(i).unwrap_or(&0);
|
||||||
@@ -377,7 +402,7 @@ fn compare_versions(a: &str, b: &str) -> Ordering {
|
|||||||
other => return other,
|
other => return other,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ordering::Equal
|
Ordering::Equal
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -385,19 +410,28 @@ fn compare_versions(a: &str, b: &str) -> Ordering {
|
|||||||
/// This ensures commands like Claude can find Node.js and other dependencies
|
/// This ensures commands like Claude can find Node.js and other dependencies
|
||||||
pub fn create_command_with_env(program: &str) -> Command {
|
pub fn create_command_with_env(program: &str) -> Command {
|
||||||
let mut cmd = Command::new(program);
|
let mut cmd = Command::new(program);
|
||||||
|
|
||||||
// Inherit essential environment variables from parent process
|
// Inherit essential environment variables from parent process
|
||||||
for (key, value) in std::env::vars() {
|
for (key, value) in std::env::vars() {
|
||||||
// Pass through PATH and other essential environment variables
|
// Pass through PATH and other essential environment variables
|
||||||
if key == "PATH" || key == "HOME" || key == "USER"
|
if key == "PATH"
|
||||||
|| key == "SHELL" || key == "LANG" || key == "LC_ALL" || key.starts_with("LC_")
|
|| key == "HOME"
|
||||||
|| key == "NODE_PATH" || key == "NVM_DIR" || key == "NVM_BIN"
|
|| key == "USER"
|
||||||
|| key == "HOMEBREW_PREFIX" || key == "HOMEBREW_CELLAR" {
|
|| key == "SHELL"
|
||||||
|
|| key == "LANG"
|
||||||
|
|| key == "LC_ALL"
|
||||||
|
|| key.starts_with("LC_")
|
||||||
|
|| key == "NODE_PATH"
|
||||||
|
|| key == "NVM_DIR"
|
||||||
|
|| key == "NVM_BIN"
|
||||||
|
|| key == "HOMEBREW_PREFIX"
|
||||||
|
|| key == "HOMEBREW_CELLAR"
|
||||||
|
{
|
||||||
debug!("Inheriting env var: {}={}", key, value);
|
debug!("Inheriting env var: {}={}", key, value);
|
||||||
cmd.env(&key, &value);
|
cmd.env(&key, &value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add NVM support if the program is in an NVM directory
|
// Add NVM support if the program is in an NVM directory
|
||||||
if program.contains("/.nvm/versions/node/") {
|
if program.contains("/.nvm/versions/node/") {
|
||||||
if let Some(node_bin_dir) = std::path::Path::new(program).parent() {
|
if let Some(node_bin_dir) = std::path::Path::new(program).parent() {
|
||||||
@@ -411,6 +445,6 @@ pub fn create_command_with_env(program: &str) -> Command {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd
|
cmd
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,12 @@
|
|||||||
use tauri::AppHandle;
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
|
use dirs;
|
||||||
|
use log::{error, info};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
use log::{info, error};
|
use tauri::AppHandle;
|
||||||
use dirs;
|
|
||||||
|
|
||||||
/// Helper function to create a std::process::Command with proper environment variables
|
/// Helper function to create a std::process::Command with proper environment variables
|
||||||
/// This ensures commands like Claude can find Node.js and other dependencies
|
/// This ensures commands like Claude can find Node.js and other dependencies
|
||||||
@@ -17,8 +17,7 @@ fn create_command_with_env(program: &str) -> Command {
|
|||||||
/// Finds the full path to the claude binary
|
/// Finds the full path to the claude binary
|
||||||
/// This is necessary because macOS apps have a limited PATH environment
|
/// This is necessary because macOS apps have a limited PATH environment
|
||||||
fn find_claude_binary(app_handle: &AppHandle) -> Result<String> {
|
fn find_claude_binary(app_handle: &AppHandle) -> Result<String> {
|
||||||
crate::claude_binary::find_claude_binary(app_handle)
|
crate::claude_binary::find_claude_binary(app_handle).map_err(|e| anyhow::anyhow!(e))
|
||||||
.map_err(|e| anyhow::anyhow!(e))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Represents an MCP server configuration
|
/// Represents an MCP server configuration
|
||||||
@@ -99,17 +98,16 @@ pub struct ImportServerResult {
|
|||||||
/// Executes a claude mcp command
|
/// Executes a claude mcp command
|
||||||
fn execute_claude_mcp_command(app_handle: &AppHandle, args: Vec<&str>) -> Result<String> {
|
fn execute_claude_mcp_command(app_handle: &AppHandle, args: Vec<&str>) -> Result<String> {
|
||||||
info!("Executing claude mcp command with args: {:?}", args);
|
info!("Executing claude mcp command with args: {:?}", args);
|
||||||
|
|
||||||
let claude_path = find_claude_binary(app_handle)?;
|
let claude_path = find_claude_binary(app_handle)?;
|
||||||
let mut cmd = create_command_with_env(&claude_path);
|
let mut cmd = create_command_with_env(&claude_path);
|
||||||
cmd.arg("mcp");
|
cmd.arg("mcp");
|
||||||
for arg in args {
|
for arg in args {
|
||||||
cmd.arg(arg);
|
cmd.arg(arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
let output = cmd.output()
|
let output = cmd.output().context("Failed to execute claude command")?;
|
||||||
.context("Failed to execute claude command")?;
|
|
||||||
|
|
||||||
if output.status.success() {
|
if output.status.success() {
|
||||||
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||||
} else {
|
} else {
|
||||||
@@ -131,33 +129,34 @@ pub async fn mcp_add(
|
|||||||
scope: String,
|
scope: String,
|
||||||
) -> Result<AddServerResult, String> {
|
) -> Result<AddServerResult, String> {
|
||||||
info!("Adding MCP server: {} with transport: {}", name, transport);
|
info!("Adding MCP server: {} with transport: {}", name, transport);
|
||||||
|
|
||||||
// Prepare owned strings for environment variables
|
// Prepare owned strings for environment variables
|
||||||
let env_args: Vec<String> = env.iter()
|
let env_args: Vec<String> = env
|
||||||
|
.iter()
|
||||||
.map(|(key, value)| format!("{}={}", key, value))
|
.map(|(key, value)| format!("{}={}", key, value))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let mut cmd_args = vec!["add"];
|
let mut cmd_args = vec!["add"];
|
||||||
|
|
||||||
// Add scope flag
|
// Add scope flag
|
||||||
cmd_args.push("-s");
|
cmd_args.push("-s");
|
||||||
cmd_args.push(&scope);
|
cmd_args.push(&scope);
|
||||||
|
|
||||||
// Add transport flag for SSE
|
// Add transport flag for SSE
|
||||||
if transport == "sse" {
|
if transport == "sse" {
|
||||||
cmd_args.push("--transport");
|
cmd_args.push("--transport");
|
||||||
cmd_args.push("sse");
|
cmd_args.push("sse");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add environment variables
|
// Add environment variables
|
||||||
for (i, _) in env.iter().enumerate() {
|
for (i, _) in env.iter().enumerate() {
|
||||||
cmd_args.push("-e");
|
cmd_args.push("-e");
|
||||||
cmd_args.push(&env_args[i]);
|
cmd_args.push(&env_args[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add name
|
// Add name
|
||||||
cmd_args.push(&name);
|
cmd_args.push(&name);
|
||||||
|
|
||||||
// Add command/URL based on transport
|
// Add command/URL based on transport
|
||||||
if transport == "stdio" {
|
if transport == "stdio" {
|
||||||
if let Some(cmd) = &command {
|
if let Some(cmd) = &command {
|
||||||
@@ -188,7 +187,7 @@ pub async fn mcp_add(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
match execute_claude_mcp_command(&app, cmd_args) {
|
match execute_claude_mcp_command(&app, cmd_args) {
|
||||||
Ok(output) => {
|
Ok(output) => {
|
||||||
info!("Successfully added MCP server: {}", name);
|
info!("Successfully added MCP server: {}", name);
|
||||||
@@ -213,19 +212,19 @@ pub async fn mcp_add(
|
|||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
|
pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
|
||||||
info!("Listing MCP servers");
|
info!("Listing MCP servers");
|
||||||
|
|
||||||
match execute_claude_mcp_command(&app, vec!["list"]) {
|
match execute_claude_mcp_command(&app, vec!["list"]) {
|
||||||
Ok(output) => {
|
Ok(output) => {
|
||||||
info!("Raw output from 'claude mcp list': {:?}", output);
|
info!("Raw output from 'claude mcp list': {:?}", output);
|
||||||
let trimmed = output.trim();
|
let trimmed = output.trim();
|
||||||
info!("Trimmed output: {:?}", trimmed);
|
info!("Trimmed output: {:?}", trimmed);
|
||||||
|
|
||||||
// Check if no servers are configured
|
// Check if no servers are configured
|
||||||
if trimmed.contains("No MCP servers configured") || trimmed.is_empty() {
|
if trimmed.contains("No MCP servers configured") || trimmed.is_empty() {
|
||||||
info!("No servers found - empty or 'No MCP servers' message");
|
info!("No servers found - empty or 'No MCP servers' message");
|
||||||
return Ok(vec![]);
|
return Ok(vec![]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the text output, handling multi-line commands
|
// Parse the text output, handling multi-line commands
|
||||||
let mut servers = Vec::new();
|
let mut servers = Vec::new();
|
||||||
let lines: Vec<&str> = trimmed.lines().collect();
|
let lines: Vec<&str> = trimmed.lines().collect();
|
||||||
@@ -233,13 +232,13 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
|
|||||||
for (idx, line) in lines.iter().enumerate() {
|
for (idx, line) in lines.iter().enumerate() {
|
||||||
info!("Line {}: {:?}", idx, line);
|
info!("Line {}: {:?}", idx, line);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
|
|
||||||
while i < lines.len() {
|
while i < lines.len() {
|
||||||
let line = lines[i];
|
let line = lines[i];
|
||||||
info!("Processing line {}: {:?}", i, line);
|
info!("Processing line {}: {:?}", i, line);
|
||||||
|
|
||||||
// Check if this line starts a new server entry
|
// Check if this line starts a new server entry
|
||||||
if let Some(colon_pos) = line.find(':') {
|
if let Some(colon_pos) = line.find(':') {
|
||||||
info!("Found colon at position {} in line: {:?}", colon_pos, line);
|
info!("Found colon at position {} in line: {:?}", colon_pos, line);
|
||||||
@@ -247,26 +246,31 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
|
|||||||
// Server names typically don't contain '/' or '\'
|
// Server names typically don't contain '/' or '\'
|
||||||
let potential_name = line[..colon_pos].trim();
|
let potential_name = line[..colon_pos].trim();
|
||||||
info!("Potential server name: {:?}", potential_name);
|
info!("Potential server name: {:?}", potential_name);
|
||||||
|
|
||||||
if !potential_name.contains('/') && !potential_name.contains('\\') {
|
if !potential_name.contains('/') && !potential_name.contains('\\') {
|
||||||
info!("Valid server name detected: {:?}", potential_name);
|
info!("Valid server name detected: {:?}", potential_name);
|
||||||
let name = potential_name.to_string();
|
let name = potential_name.to_string();
|
||||||
let mut command_parts = vec![line[colon_pos + 1..].trim().to_string()];
|
let mut command_parts = vec![line[colon_pos + 1..].trim().to_string()];
|
||||||
info!("Initial command part: {:?}", command_parts[0]);
|
info!("Initial command part: {:?}", command_parts[0]);
|
||||||
|
|
||||||
// Check if command continues on next lines
|
// Check if command continues on next lines
|
||||||
i += 1;
|
i += 1;
|
||||||
while i < lines.len() {
|
while i < lines.len() {
|
||||||
let next_line = lines[i];
|
let next_line = lines[i];
|
||||||
info!("Checking next line {} for continuation: {:?}", i, next_line);
|
info!("Checking next line {} for continuation: {:?}", i, next_line);
|
||||||
|
|
||||||
// If the next line starts with a server name pattern, break
|
// If the next line starts with a server name pattern, break
|
||||||
if next_line.contains(':') {
|
if next_line.contains(':') {
|
||||||
let potential_next_name = next_line.split(':').next().unwrap_or("").trim();
|
let potential_next_name =
|
||||||
info!("Found colon in next line, potential name: {:?}", potential_next_name);
|
next_line.split(':').next().unwrap_or("").trim();
|
||||||
if !potential_next_name.is_empty() &&
|
info!(
|
||||||
!potential_next_name.contains('/') &&
|
"Found colon in next line, potential name: {:?}",
|
||||||
!potential_next_name.contains('\\') {
|
potential_next_name
|
||||||
|
);
|
||||||
|
if !potential_next_name.is_empty()
|
||||||
|
&& !potential_next_name.contains('/')
|
||||||
|
&& !potential_next_name.contains('\\')
|
||||||
|
{
|
||||||
info!("Next line is a new server, breaking");
|
info!("Next line is a new server, breaking");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -276,11 +280,11 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
|
|||||||
command_parts.push(next_line.trim().to_string());
|
command_parts.push(next_line.trim().to_string());
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Join all command parts
|
// Join all command parts
|
||||||
let full_command = command_parts.join(" ");
|
let full_command = command_parts.join(" ");
|
||||||
info!("Full command for server '{}': {:?}", name, full_command);
|
info!("Full command for server '{}': {:?}", name, full_command);
|
||||||
|
|
||||||
// For now, we'll create a basic server entry
|
// For now, we'll create a basic server entry
|
||||||
servers.push(MCPServer {
|
servers.push(MCPServer {
|
||||||
name: name.clone(),
|
name: name.clone(),
|
||||||
@@ -298,7 +302,7 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
info!("Added server: {:?}", name);
|
info!("Added server: {:?}", name);
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
info!("Skipping line - name contains path separators");
|
info!("Skipping line - name contains path separators");
|
||||||
@@ -306,13 +310,16 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
|
|||||||
} else {
|
} else {
|
||||||
info!("No colon found in line {}", i);
|
info!("No colon found in line {}", i);
|
||||||
}
|
}
|
||||||
|
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Found {} MCP servers total", servers.len());
|
info!("Found {} MCP servers total", servers.len());
|
||||||
for (idx, server) in servers.iter().enumerate() {
|
for (idx, server) in servers.iter().enumerate() {
|
||||||
info!("Server {}: name='{}', command={:?}", idx, server.name, server.command);
|
info!(
|
||||||
|
"Server {}: name='{}', command={:?}",
|
||||||
|
idx, server.name, server.command
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Ok(servers)
|
Ok(servers)
|
||||||
}
|
}
|
||||||
@@ -327,7 +334,7 @@ pub async fn mcp_list(app: AppHandle) -> Result<Vec<MCPServer>, String> {
|
|||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn mcp_get(app: AppHandle, name: String) -> Result<MCPServer, String> {
|
pub async fn mcp_get(app: AppHandle, name: String) -> Result<MCPServer, String> {
|
||||||
info!("Getting MCP server details for: {}", name);
|
info!("Getting MCP server details for: {}", name);
|
||||||
|
|
||||||
match execute_claude_mcp_command(&app, vec!["get", &name]) {
|
match execute_claude_mcp_command(&app, vec!["get", &name]) {
|
||||||
Ok(output) => {
|
Ok(output) => {
|
||||||
// Parse the structured text output
|
// Parse the structured text output
|
||||||
@@ -337,17 +344,19 @@ pub async fn mcp_get(app: AppHandle, name: String) -> Result<MCPServer, String>
|
|||||||
let mut args = vec![];
|
let mut args = vec![];
|
||||||
let env = HashMap::new();
|
let env = HashMap::new();
|
||||||
let mut url = None;
|
let mut url = None;
|
||||||
|
|
||||||
for line in output.lines() {
|
for line in output.lines() {
|
||||||
let line = line.trim();
|
let line = line.trim();
|
||||||
|
|
||||||
if line.starts_with("Scope:") {
|
if line.starts_with("Scope:") {
|
||||||
let scope_part = line.replace("Scope:", "").trim().to_string();
|
let scope_part = line.replace("Scope:", "").trim().to_string();
|
||||||
if scope_part.to_lowercase().contains("local") {
|
if scope_part.to_lowercase().contains("local") {
|
||||||
scope = "local".to_string();
|
scope = "local".to_string();
|
||||||
} else if scope_part.to_lowercase().contains("project") {
|
} else if scope_part.to_lowercase().contains("project") {
|
||||||
scope = "project".to_string();
|
scope = "project".to_string();
|
||||||
} else if scope_part.to_lowercase().contains("user") || scope_part.to_lowercase().contains("global") {
|
} else if scope_part.to_lowercase().contains("user")
|
||||||
|
|| scope_part.to_lowercase().contains("global")
|
||||||
|
{
|
||||||
scope = "user".to_string();
|
scope = "user".to_string();
|
||||||
}
|
}
|
||||||
} else if line.starts_with("Type:") {
|
} else if line.starts_with("Type:") {
|
||||||
@@ -366,7 +375,7 @@ pub async fn mcp_get(app: AppHandle, name: String) -> Result<MCPServer, String>
|
|||||||
// For now, we'll leave it empty
|
// For now, we'll leave it empty
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(MCPServer {
|
Ok(MCPServer {
|
||||||
name,
|
name,
|
||||||
transport,
|
transport,
|
||||||
@@ -394,7 +403,7 @@ pub async fn mcp_get(app: AppHandle, name: String) -> Result<MCPServer, String>
|
|||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn mcp_remove(app: AppHandle, name: String) -> Result<String, String> {
|
pub async fn mcp_remove(app: AppHandle, name: String) -> Result<String, String> {
|
||||||
info!("Removing MCP server: {}", name);
|
info!("Removing MCP server: {}", name);
|
||||||
|
|
||||||
match execute_claude_mcp_command(&app, vec!["remove", &name]) {
|
match execute_claude_mcp_command(&app, vec!["remove", &name]) {
|
||||||
Ok(output) => {
|
Ok(output) => {
|
||||||
info!("Successfully removed MCP server: {}", name);
|
info!("Successfully removed MCP server: {}", name);
|
||||||
@@ -409,17 +418,25 @@ pub async fn mcp_remove(app: AppHandle, name: String) -> Result<String, String>
|
|||||||
|
|
||||||
/// Adds an MCP server from JSON configuration
|
/// Adds an MCP server from JSON configuration
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn mcp_add_json(app: AppHandle, name: String, json_config: String, scope: String) -> Result<AddServerResult, String> {
|
pub async fn mcp_add_json(
|
||||||
info!("Adding MCP server from JSON: {} with scope: {}", name, scope);
|
app: AppHandle,
|
||||||
|
name: String,
|
||||||
|
json_config: String,
|
||||||
|
scope: String,
|
||||||
|
) -> Result<AddServerResult, String> {
|
||||||
|
info!(
|
||||||
|
"Adding MCP server from JSON: {} with scope: {}",
|
||||||
|
name, scope
|
||||||
|
);
|
||||||
|
|
||||||
// Build command args
|
// Build command args
|
||||||
let mut cmd_args = vec!["add-json", &name, &json_config];
|
let mut cmd_args = vec!["add-json", &name, &json_config];
|
||||||
|
|
||||||
// Add scope flag
|
// Add scope flag
|
||||||
let scope_flag = "-s";
|
let scope_flag = "-s";
|
||||||
cmd_args.push(scope_flag);
|
cmd_args.push(scope_flag);
|
||||||
cmd_args.push(&scope);
|
cmd_args.push(&scope);
|
||||||
|
|
||||||
match execute_claude_mcp_command(&app, cmd_args) {
|
match execute_claude_mcp_command(&app, cmd_args) {
|
||||||
Ok(output) => {
|
Ok(output) => {
|
||||||
info!("Successfully added MCP server from JSON: {}", name);
|
info!("Successfully added MCP server from JSON: {}", name);
|
||||||
@@ -442,9 +459,15 @@ pub async fn mcp_add_json(app: AppHandle, name: String, json_config: String, sco
|
|||||||
|
|
||||||
/// Imports MCP servers from Claude Desktop
|
/// Imports MCP servers from Claude Desktop
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Result<ImportResult, String> {
|
pub async fn mcp_add_from_claude_desktop(
|
||||||
info!("Importing MCP servers from Claude Desktop with scope: {}", scope);
|
app: AppHandle,
|
||||||
|
scope: String,
|
||||||
|
) -> Result<ImportResult, String> {
|
||||||
|
info!(
|
||||||
|
"Importing MCP servers from Claude Desktop with scope: {}",
|
||||||
|
scope
|
||||||
|
);
|
||||||
|
|
||||||
// Get Claude Desktop config path based on platform
|
// Get Claude Desktop config path based on platform
|
||||||
let config_path = if cfg!(target_os = "macos") {
|
let config_path = if cfg!(target_os = "macos") {
|
||||||
dirs::home_dir()
|
dirs::home_dir()
|
||||||
@@ -460,43 +483,55 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul
|
|||||||
.join("Claude")
|
.join("Claude")
|
||||||
.join("claude_desktop_config.json")
|
.join("claude_desktop_config.json")
|
||||||
} else {
|
} else {
|
||||||
return Err("Import from Claude Desktop is only supported on macOS and Linux/WSL".to_string());
|
return Err(
|
||||||
|
"Import from Claude Desktop is only supported on macOS and Linux/WSL".to_string(),
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if config file exists
|
// Check if config file exists
|
||||||
if !config_path.exists() {
|
if !config_path.exists() {
|
||||||
return Err("Claude Desktop configuration not found. Make sure Claude Desktop is installed.".to_string());
|
return Err(
|
||||||
|
"Claude Desktop configuration not found. Make sure Claude Desktop is installed."
|
||||||
|
.to_string(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read and parse the config file
|
// Read and parse the config file
|
||||||
let config_content = fs::read_to_string(&config_path)
|
let config_content = fs::read_to_string(&config_path)
|
||||||
.map_err(|e| format!("Failed to read Claude Desktop config: {}", e))?;
|
.map_err(|e| format!("Failed to read Claude Desktop config: {}", e))?;
|
||||||
|
|
||||||
let config: serde_json::Value = serde_json::from_str(&config_content)
|
let config: serde_json::Value = serde_json::from_str(&config_content)
|
||||||
.map_err(|e| format!("Failed to parse Claude Desktop config: {}", e))?;
|
.map_err(|e| format!("Failed to parse Claude Desktop config: {}", e))?;
|
||||||
|
|
||||||
// Extract MCP servers
|
// Extract MCP servers
|
||||||
let mcp_servers = config.get("mcpServers")
|
let mcp_servers = config
|
||||||
|
.get("mcpServers")
|
||||||
.and_then(|v| v.as_object())
|
.and_then(|v| v.as_object())
|
||||||
.ok_or_else(|| "No MCP servers found in Claude Desktop config".to_string())?;
|
.ok_or_else(|| "No MCP servers found in Claude Desktop config".to_string())?;
|
||||||
|
|
||||||
let mut imported_count = 0;
|
let mut imported_count = 0;
|
||||||
let mut failed_count = 0;
|
let mut failed_count = 0;
|
||||||
let mut server_results = Vec::new();
|
let mut server_results = Vec::new();
|
||||||
|
|
||||||
// Import each server using add-json
|
// Import each server using add-json
|
||||||
for (name, server_config) in mcp_servers {
|
for (name, server_config) in mcp_servers {
|
||||||
info!("Importing server: {}", name);
|
info!("Importing server: {}", name);
|
||||||
|
|
||||||
// Convert Claude Desktop format to add-json format
|
// Convert Claude Desktop format to add-json format
|
||||||
let mut json_config = serde_json::Map::new();
|
let mut json_config = serde_json::Map::new();
|
||||||
|
|
||||||
// All Claude Desktop servers are stdio type
|
// All Claude Desktop servers are stdio type
|
||||||
json_config.insert("type".to_string(), serde_json::Value::String("stdio".to_string()));
|
json_config.insert(
|
||||||
|
"type".to_string(),
|
||||||
|
serde_json::Value::String("stdio".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
// Add command
|
// Add command
|
||||||
if let Some(command) = server_config.get("command").and_then(|v| v.as_str()) {
|
if let Some(command) = server_config.get("command").and_then(|v| v.as_str()) {
|
||||||
json_config.insert("command".to_string(), serde_json::Value::String(command.to_string()));
|
json_config.insert(
|
||||||
|
"command".to_string(),
|
||||||
|
serde_json::Value::String(command.to_string()),
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
failed_count += 1;
|
failed_count += 1;
|
||||||
server_results.push(ImportServerResult {
|
server_results.push(ImportServerResult {
|
||||||
@@ -506,25 +541,28 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul
|
|||||||
});
|
});
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add args if present
|
// Add args if present
|
||||||
if let Some(args) = server_config.get("args").and_then(|v| v.as_array()) {
|
if let Some(args) = server_config.get("args").and_then(|v| v.as_array()) {
|
||||||
json_config.insert("args".to_string(), args.clone().into());
|
json_config.insert("args".to_string(), args.clone().into());
|
||||||
} else {
|
} else {
|
||||||
json_config.insert("args".to_string(), serde_json::Value::Array(vec![]));
|
json_config.insert("args".to_string(), serde_json::Value::Array(vec![]));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add env if present
|
// Add env if present
|
||||||
if let Some(env) = server_config.get("env").and_then(|v| v.as_object()) {
|
if let Some(env) = server_config.get("env").and_then(|v| v.as_object()) {
|
||||||
json_config.insert("env".to_string(), env.clone().into());
|
json_config.insert("env".to_string(), env.clone().into());
|
||||||
} else {
|
} else {
|
||||||
json_config.insert("env".to_string(), serde_json::Value::Object(serde_json::Map::new()));
|
json_config.insert(
|
||||||
|
"env".to_string(),
|
||||||
|
serde_json::Value::Object(serde_json::Map::new()),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to JSON string
|
// Convert to JSON string
|
||||||
let json_str = serde_json::to_string(&json_config)
|
let json_str = serde_json::to_string(&json_config)
|
||||||
.map_err(|e| format!("Failed to serialize config for {}: {}", name, e))?;
|
.map_err(|e| format!("Failed to serialize config for {}: {}", name, e))?;
|
||||||
|
|
||||||
// Call add-json command
|
// Call add-json command
|
||||||
match mcp_add_json(app.clone(), name.clone(), json_str, scope.clone()).await {
|
match mcp_add_json(app.clone(), name.clone(), json_str, scope.clone()).await {
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
@@ -559,9 +597,12 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Import complete: {} imported, {} failed", imported_count, failed_count);
|
info!(
|
||||||
|
"Import complete: {} imported, {} failed",
|
||||||
|
imported_count, failed_count
|
||||||
|
);
|
||||||
|
|
||||||
Ok(ImportResult {
|
Ok(ImportResult {
|
||||||
imported_count,
|
imported_count,
|
||||||
failed_count,
|
failed_count,
|
||||||
@@ -573,7 +614,7 @@ pub async fn mcp_add_from_claude_desktop(app: AppHandle, scope: String) -> Resul
|
|||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn mcp_serve(app: AppHandle) -> Result<String, String> {
|
pub async fn mcp_serve(app: AppHandle) -> Result<String, String> {
|
||||||
info!("Starting Claude Code as MCP server");
|
info!("Starting Claude Code as MCP server");
|
||||||
|
|
||||||
// Start the server in a separate process
|
// Start the server in a separate process
|
||||||
let claude_path = match find_claude_binary(&app) {
|
let claude_path = match find_claude_binary(&app) {
|
||||||
Ok(path) => path,
|
Ok(path) => path,
|
||||||
@@ -582,10 +623,10 @@ pub async fn mcp_serve(app: AppHandle) -> Result<String, String> {
|
|||||||
return Err(e.to_string());
|
return Err(e.to_string());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut cmd = create_command_with_env(&claude_path);
|
let mut cmd = create_command_with_env(&claude_path);
|
||||||
cmd.arg("mcp").arg("serve");
|
cmd.arg("mcp").arg("serve");
|
||||||
|
|
||||||
match cmd.spawn() {
|
match cmd.spawn() {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
info!("Successfully started Claude Code MCP server");
|
info!("Successfully started Claude Code MCP server");
|
||||||
@@ -602,7 +643,7 @@ pub async fn mcp_serve(app: AppHandle) -> Result<String, String> {
|
|||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn mcp_test_connection(app: AppHandle, name: String) -> Result<String, String> {
|
pub async fn mcp_test_connection(app: AppHandle, name: String) -> Result<String, String> {
|
||||||
info!("Testing connection to MCP server: {}", name);
|
info!("Testing connection to MCP server: {}", name);
|
||||||
|
|
||||||
// For now, we'll use the get command to test if the server exists
|
// For now, we'll use the get command to test if the server exists
|
||||||
match execute_claude_mcp_command(&app, vec!["get", &name]) {
|
match execute_claude_mcp_command(&app, vec!["get", &name]) {
|
||||||
Ok(_) => Ok(format!("Connection to {} successful", name)),
|
Ok(_) => Ok(format!("Connection to {} successful", name)),
|
||||||
@@ -614,7 +655,7 @@ pub async fn mcp_test_connection(app: AppHandle, name: String) -> Result<String,
|
|||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn mcp_reset_project_choices(app: AppHandle) -> Result<String, String> {
|
pub async fn mcp_reset_project_choices(app: AppHandle) -> Result<String, String> {
|
||||||
info!("Resetting MCP project choices");
|
info!("Resetting MCP project choices");
|
||||||
|
|
||||||
match execute_claude_mcp_command(&app, vec!["reset-project-choices"]) {
|
match execute_claude_mcp_command(&app, vec!["reset-project-choices"]) {
|
||||||
Ok(output) => {
|
Ok(output) => {
|
||||||
info!("Successfully reset MCP project choices");
|
info!("Successfully reset MCP project choices");
|
||||||
@@ -631,7 +672,7 @@ pub async fn mcp_reset_project_choices(app: AppHandle) -> Result<String, String>
|
|||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn mcp_get_server_status() -> Result<HashMap<String, ServerStatus>, String> {
|
pub async fn mcp_get_server_status() -> Result<HashMap<String, ServerStatus>, String> {
|
||||||
info!("Getting MCP server status");
|
info!("Getting MCP server status");
|
||||||
|
|
||||||
// TODO: Implement actual status checking
|
// TODO: Implement actual status checking
|
||||||
// For now, return empty status
|
// For now, return empty status
|
||||||
Ok(HashMap::new())
|
Ok(HashMap::new())
|
||||||
@@ -641,25 +682,23 @@ pub async fn mcp_get_server_status() -> Result<HashMap<String, ServerStatus>, St
|
|||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn mcp_read_project_config(project_path: String) -> Result<MCPProjectConfig, String> {
|
pub async fn mcp_read_project_config(project_path: String) -> Result<MCPProjectConfig, String> {
|
||||||
info!("Reading .mcp.json from project: {}", project_path);
|
info!("Reading .mcp.json from project: {}", project_path);
|
||||||
|
|
||||||
let mcp_json_path = PathBuf::from(&project_path).join(".mcp.json");
|
let mcp_json_path = PathBuf::from(&project_path).join(".mcp.json");
|
||||||
|
|
||||||
if !mcp_json_path.exists() {
|
if !mcp_json_path.exists() {
|
||||||
return Ok(MCPProjectConfig {
|
return Ok(MCPProjectConfig {
|
||||||
mcp_servers: HashMap::new(),
|
mcp_servers: HashMap::new(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
match fs::read_to_string(&mcp_json_path) {
|
match fs::read_to_string(&mcp_json_path) {
|
||||||
Ok(content) => {
|
Ok(content) => match serde_json::from_str::<MCPProjectConfig>(&content) {
|
||||||
match serde_json::from_str::<MCPProjectConfig>(&content) {
|
Ok(config) => Ok(config),
|
||||||
Ok(config) => Ok(config),
|
Err(e) => {
|
||||||
Err(e) => {
|
error!("Failed to parse .mcp.json: {}", e);
|
||||||
error!("Failed to parse .mcp.json: {}", e);
|
Err(format!("Failed to parse .mcp.json: {}", e))
|
||||||
Err(format!("Failed to parse .mcp.json: {}", e))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to read .mcp.json: {}", e);
|
error!("Failed to read .mcp.json: {}", e);
|
||||||
Err(format!("Failed to read .mcp.json: {}", e))
|
Err(format!("Failed to read .mcp.json: {}", e))
|
||||||
@@ -674,14 +713,14 @@ pub async fn mcp_save_project_config(
|
|||||||
config: MCPProjectConfig,
|
config: MCPProjectConfig,
|
||||||
) -> Result<String, String> {
|
) -> Result<String, String> {
|
||||||
info!("Saving .mcp.json to project: {}", project_path);
|
info!("Saving .mcp.json to project: {}", project_path);
|
||||||
|
|
||||||
let mcp_json_path = PathBuf::from(&project_path).join(".mcp.json");
|
let mcp_json_path = PathBuf::from(&project_path).join(".mcp.json");
|
||||||
|
|
||||||
let json_content = serde_json::to_string_pretty(&config)
|
let json_content = serde_json::to_string_pretty(&config)
|
||||||
.map_err(|e| format!("Failed to serialize config: {}", e))?;
|
.map_err(|e| format!("Failed to serialize config: {}", e))?;
|
||||||
|
|
||||||
fs::write(&mcp_json_path, json_content)
|
fs::write(&mcp_json_path, json_content)
|
||||||
.map_err(|e| format!("Failed to write .mcp.json: {}", e))?;
|
.map_err(|e| format!("Failed to write .mcp.json: {}", e))?;
|
||||||
|
|
||||||
Ok("Project MCP configuration saved".to_string())
|
Ok("Project MCP configuration saved".to_string())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
pub mod claude;
|
|
||||||
pub mod agents;
|
pub mod agents;
|
||||||
pub mod sandbox;
|
pub mod claude;
|
||||||
pub mod usage;
|
|
||||||
pub mod mcp;
|
pub mod mcp;
|
||||||
pub mod screenshot;
|
pub mod sandbox;
|
||||||
|
pub mod screenshot;
|
||||||
|
pub mod usage;
|
||||||
|
|||||||
@@ -52,11 +52,11 @@ pub struct ImportResult {
|
|||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn list_sandbox_profiles(db: State<'_, AgentDb>) -> Result<Vec<SandboxProfile>, String> {
|
pub async fn list_sandbox_profiles(db: State<'_, AgentDb>) -> Result<Vec<SandboxProfile>, String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
let mut stmt = conn
|
let mut stmt = conn
|
||||||
.prepare("SELECT id, name, description, is_active, is_default, created_at, updated_at FROM sandbox_profiles ORDER BY name")
|
.prepare("SELECT id, name, description, is_active, is_default, created_at, updated_at FROM sandbox_profiles ORDER BY name")
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
let profiles = stmt
|
let profiles = stmt
|
||||||
.query_map([], |row| {
|
.query_map([], |row| {
|
||||||
Ok(SandboxProfile {
|
Ok(SandboxProfile {
|
||||||
@@ -72,7 +72,7 @@ pub async fn list_sandbox_profiles(db: State<'_, AgentDb>) -> Result<Vec<Sandbox
|
|||||||
.map_err(|e| e.to_string())?
|
.map_err(|e| e.to_string())?
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
Ok(profiles)
|
Ok(profiles)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,15 +84,15 @@ pub async fn create_sandbox_profile(
|
|||||||
description: Option<String>,
|
description: Option<String>,
|
||||||
) -> Result<SandboxProfile, String> {
|
) -> Result<SandboxProfile, String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO sandbox_profiles (name, description) VALUES (?1, ?2)",
|
"INSERT INTO sandbox_profiles (name, description) VALUES (?1, ?2)",
|
||||||
params![name, description],
|
params![name, description],
|
||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
let id = conn.last_insert_rowid();
|
let id = conn.last_insert_rowid();
|
||||||
|
|
||||||
// Fetch the created profile
|
// Fetch the created profile
|
||||||
let profile = conn
|
let profile = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
@@ -111,7 +111,7 @@ pub async fn create_sandbox_profile(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
Ok(profile)
|
Ok(profile)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,7 +126,7 @@ pub async fn update_sandbox_profile(
|
|||||||
is_default: bool,
|
is_default: bool,
|
||||||
) -> Result<SandboxProfile, String> {
|
) -> Result<SandboxProfile, String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
// If setting as default, unset other defaults
|
// If setting as default, unset other defaults
|
||||||
if is_default {
|
if is_default {
|
||||||
conn.execute(
|
conn.execute(
|
||||||
@@ -135,13 +135,13 @@ pub async fn update_sandbox_profile(
|
|||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"UPDATE sandbox_profiles SET name = ?1, description = ?2, is_active = ?3, is_default = ?4 WHERE id = ?5",
|
"UPDATE sandbox_profiles SET name = ?1, description = ?2, is_active = ?3, is_default = ?4 WHERE id = ?5",
|
||||||
params![name, description, is_active, is_default, id],
|
params![name, description, is_active, is_default, id],
|
||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
// Fetch the updated profile
|
// Fetch the updated profile
|
||||||
let profile = conn
|
let profile = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
@@ -160,7 +160,7 @@ pub async fn update_sandbox_profile(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
Ok(profile)
|
Ok(profile)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,7 +168,7 @@ pub async fn update_sandbox_profile(
|
|||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn delete_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result<(), String> {
|
pub async fn delete_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result<(), String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
// Check if it's the default profile
|
// Check if it's the default profile
|
||||||
let is_default: bool = conn
|
let is_default: bool = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
@@ -177,22 +177,25 @@ pub async fn delete_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result<(
|
|||||||
|row| row.get(0),
|
|row| row.get(0),
|
||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
if is_default {
|
if is_default {
|
||||||
return Err("Cannot delete the default profile".to_string());
|
return Err("Cannot delete the default profile".to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.execute("DELETE FROM sandbox_profiles WHERE id = ?1", params![id])
|
conn.execute("DELETE FROM sandbox_profiles WHERE id = ?1", params![id])
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a single sandbox profile by ID
|
/// Get a single sandbox profile by ID
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn get_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result<SandboxProfile, String> {
|
pub async fn get_sandbox_profile(
|
||||||
|
db: State<'_, AgentDb>,
|
||||||
|
id: i64,
|
||||||
|
) -> Result<SandboxProfile, String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
let profile = conn
|
let profile = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
"SELECT id, name, description, is_active, is_default, created_at, updated_at FROM sandbox_profiles WHERE id = ?1",
|
"SELECT id, name, description, is_active, is_default, created_at, updated_at FROM sandbox_profiles WHERE id = ?1",
|
||||||
@@ -210,7 +213,7 @@ pub async fn get_sandbox_profile(db: State<'_, AgentDb>, id: i64) -> Result<Sand
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
Ok(profile)
|
Ok(profile)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -221,11 +224,11 @@ pub async fn list_sandbox_rules(
|
|||||||
profile_id: i64,
|
profile_id: i64,
|
||||||
) -> Result<Vec<SandboxRule>, String> {
|
) -> Result<Vec<SandboxRule>, String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
let mut stmt = conn
|
let mut stmt = conn
|
||||||
.prepare("SELECT id, profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support, created_at FROM sandbox_rules WHERE profile_id = ?1 ORDER BY operation_type, pattern_value")
|
.prepare("SELECT id, profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support, created_at FROM sandbox_rules WHERE profile_id = ?1 ORDER BY operation_type, pattern_value")
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
let rules = stmt
|
let rules = stmt
|
||||||
.query_map(params![profile_id], |row| {
|
.query_map(params![profile_id], |row| {
|
||||||
Ok(SandboxRule {
|
Ok(SandboxRule {
|
||||||
@@ -242,7 +245,7 @@ pub async fn list_sandbox_rules(
|
|||||||
.map_err(|e| e.to_string())?
|
.map_err(|e| e.to_string())?
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
Ok(rules)
|
Ok(rules)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -258,18 +261,18 @@ pub async fn create_sandbox_rule(
|
|||||||
platform_support: Option<String>,
|
platform_support: Option<String>,
|
||||||
) -> Result<SandboxRule, String> {
|
) -> Result<SandboxRule, String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
// Validate rule doesn't conflict
|
// Validate rule doesn't conflict
|
||||||
// TODO: Add more validation logic here
|
// TODO: Add more validation logic here
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO sandbox_rules (profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
"INSERT INTO sandbox_rules (profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||||
params![profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support],
|
params![profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support],
|
||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
let id = conn.last_insert_rowid();
|
let id = conn.last_insert_rowid();
|
||||||
|
|
||||||
// Fetch the created rule
|
// Fetch the created rule
|
||||||
let rule = conn
|
let rule = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
@@ -289,7 +292,7 @@ pub async fn create_sandbox_rule(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
Ok(rule)
|
Ok(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -305,13 +308,13 @@ pub async fn update_sandbox_rule(
|
|||||||
platform_support: Option<String>,
|
platform_support: Option<String>,
|
||||||
) -> Result<SandboxRule, String> {
|
) -> Result<SandboxRule, String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"UPDATE sandbox_rules SET operation_type = ?1, pattern_type = ?2, pattern_value = ?3, enabled = ?4, platform_support = ?5 WHERE id = ?6",
|
"UPDATE sandbox_rules SET operation_type = ?1, pattern_type = ?2, pattern_value = ?3, enabled = ?4, platform_support = ?5 WHERE id = ?6",
|
||||||
params![operation_type, pattern_type, pattern_value, enabled, platform_support, id],
|
params![operation_type, pattern_type, pattern_value, enabled, platform_support, id],
|
||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
// Fetch the updated rule
|
// Fetch the updated rule
|
||||||
let rule = conn
|
let rule = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
@@ -331,7 +334,7 @@ pub async fn update_sandbox_rule(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
Ok(rule)
|
Ok(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -339,10 +342,10 @@ pub async fn update_sandbox_rule(
|
|||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
pub async fn delete_sandbox_rule(db: State<'_, AgentDb>, id: i64) -> Result<(), String> {
|
pub async fn delete_sandbox_rule(db: State<'_, AgentDb>, id: i64) -> Result<(), String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
conn.execute("DELETE FROM sandbox_rules WHERE id = ?1", params![id])
|
conn.execute("DELETE FROM sandbox_rules WHERE id = ?1", params![id])
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -359,38 +362,38 @@ pub async fn test_sandbox_profile(
|
|||||||
profile_id: i64,
|
profile_id: i64,
|
||||||
) -> Result<String, String> {
|
) -> Result<String, String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
// Load the profile and rules
|
// Load the profile and rules
|
||||||
let profile = crate::sandbox::profile::load_profile(&conn, profile_id)
|
let profile = crate::sandbox::profile::load_profile(&conn, profile_id)
|
||||||
.map_err(|e| format!("Failed to load profile: {}", e))?;
|
.map_err(|e| format!("Failed to load profile: {}", e))?;
|
||||||
|
|
||||||
if !profile.is_active {
|
if !profile.is_active {
|
||||||
return Ok(format!(
|
return Ok(format!(
|
||||||
"Profile '{}' is currently inactive. Activate it to use with agents.",
|
"Profile '{}' is currently inactive. Activate it to use with agents.",
|
||||||
profile.name
|
profile.name
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let rules = crate::sandbox::profile::load_profile_rules(&conn, profile_id)
|
let rules = crate::sandbox::profile::load_profile_rules(&conn, profile_id)
|
||||||
.map_err(|e| format!("Failed to load profile rules: {}", e))?;
|
.map_err(|e| format!("Failed to load profile rules: {}", e))?;
|
||||||
|
|
||||||
if rules.is_empty() {
|
if rules.is_empty() {
|
||||||
return Ok(format!(
|
return Ok(format!(
|
||||||
"Profile '{}' has no rules configured. Add rules to define sandbox permissions.",
|
"Profile '{}' has no rules configured. Add rules to define sandbox permissions.",
|
||||||
profile.name
|
profile.name
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to build the gaol profile
|
// Try to build the gaol profile
|
||||||
let test_path = std::env::current_dir()
|
let test_path = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("/tmp"));
|
||||||
.unwrap_or_else(|_| std::path::PathBuf::from("/tmp"));
|
|
||||||
|
|
||||||
let builder = crate::sandbox::profile::ProfileBuilder::new(test_path.clone())
|
let builder = crate::sandbox::profile::ProfileBuilder::new(test_path.clone())
|
||||||
.map_err(|e| format!("Failed to create profile builder: {}", e))?;
|
.map_err(|e| format!("Failed to create profile builder: {}", e))?;
|
||||||
|
|
||||||
let build_result = builder.build_profile_with_serialization(rules.clone())
|
let build_result = builder
|
||||||
|
.build_profile_with_serialization(rules.clone())
|
||||||
.map_err(|e| format!("Failed to build sandbox profile: {}", e))?;
|
.map_err(|e| format!("Failed to build sandbox profile: {}", e))?;
|
||||||
|
|
||||||
// Check platform support
|
// Check platform support
|
||||||
let platform_caps = crate::sandbox::platform::get_platform_capabilities();
|
let platform_caps = crate::sandbox::platform::get_platform_capabilities();
|
||||||
if !platform_caps.sandboxing_supported {
|
if !platform_caps.sandboxing_supported {
|
||||||
@@ -401,27 +404,23 @@ pub async fn test_sandbox_profile(
|
|||||||
platform_caps.os
|
platform_caps.os
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to execute a simple command in the sandbox
|
// Try to execute a simple command in the sandbox
|
||||||
let executor = crate::sandbox::executor::SandboxExecutor::new_with_serialization(
|
let executor = crate::sandbox::executor::SandboxExecutor::new_with_serialization(
|
||||||
build_result.profile,
|
build_result.profile,
|
||||||
test_path.clone(),
|
test_path.clone(),
|
||||||
build_result.serialized
|
build_result.serialized,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Use a simple echo command for testing
|
// Use a simple echo command for testing
|
||||||
let test_command = if cfg!(windows) {
|
let test_command = if cfg!(windows) { "cmd" } else { "echo" };
|
||||||
"cmd"
|
|
||||||
} else {
|
|
||||||
"echo"
|
|
||||||
};
|
|
||||||
|
|
||||||
let test_args = if cfg!(windows) {
|
let test_args = if cfg!(windows) {
|
||||||
vec!["/C", "echo", "sandbox test successful"]
|
vec!["/C", "echo", "sandbox test successful"]
|
||||||
} else {
|
} else {
|
||||||
vec!["sandbox test successful"]
|
vec!["sandbox test successful"]
|
||||||
};
|
};
|
||||||
|
|
||||||
match executor.execute_sandboxed_spawn(test_command, &test_args, &test_path) {
|
match executor.execute_sandboxed_spawn(test_command, &test_args, &test_path) {
|
||||||
Ok(mut child) => {
|
Ok(mut child) => {
|
||||||
// Wait for the process to complete with a timeout
|
// Wait for the process to complete with a timeout
|
||||||
@@ -452,19 +451,17 @@ pub async fn test_sandbox_profile(
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => Ok(format!(
|
||||||
Ok(format!(
|
"⚠️ Profile '{}' validated with warnings.\n\n\
|
||||||
"⚠️ Profile '{}' validated with warnings.\n\n\
|
|
||||||
• {} rules loaded and validated\n\
|
• {} rules loaded and validated\n\
|
||||||
• Sandbox activation: Partial\n\
|
• Sandbox activation: Partial\n\
|
||||||
• Test process: Could not get exit status ({})\n\
|
• Test process: Could not get exit status ({})\n\
|
||||||
• Platform: {}",
|
• Platform: {}",
|
||||||
profile.name,
|
profile.name,
|
||||||
rules.len(),
|
rules.len(),
|
||||||
e,
|
e,
|
||||||
platform_caps.os
|
platform_caps.os
|
||||||
))
|
)),
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -509,176 +506,200 @@ pub async fn list_sandbox_violations(
|
|||||||
limit: Option<i64>,
|
limit: Option<i64>,
|
||||||
) -> Result<Vec<SandboxViolation>, String> {
|
) -> Result<Vec<SandboxViolation>, String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
// Build dynamic query
|
// Build dynamic query
|
||||||
let mut query = String::from(
|
let mut query = String::from(
|
||||||
"SELECT id, profile_id, agent_id, agent_run_id, operation_type, pattern_value, process_name, pid, denied_at
|
"SELECT id, profile_id, agent_id, agent_run_id, operation_type, pattern_value, process_name, pid, denied_at
|
||||||
FROM sandbox_violations WHERE 1=1"
|
FROM sandbox_violations WHERE 1=1"
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut param_idx = 1;
|
let mut param_idx = 1;
|
||||||
|
|
||||||
if profile_id.is_some() {
|
if profile_id.is_some() {
|
||||||
query.push_str(&format!(" AND profile_id = ?{}", param_idx));
|
query.push_str(&format!(" AND profile_id = ?{}", param_idx));
|
||||||
param_idx += 1;
|
param_idx += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if agent_id.is_some() {
|
if agent_id.is_some() {
|
||||||
query.push_str(&format!(" AND agent_id = ?{}", param_idx));
|
query.push_str(&format!(" AND agent_id = ?{}", param_idx));
|
||||||
param_idx += 1;
|
param_idx += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
query.push_str(" ORDER BY denied_at DESC");
|
query.push_str(" ORDER BY denied_at DESC");
|
||||||
|
|
||||||
if limit.is_some() {
|
if limit.is_some() {
|
||||||
query.push_str(&format!(" LIMIT ?{}", param_idx));
|
query.push_str(&format!(" LIMIT ?{}", param_idx));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute query based on parameters
|
// Execute query based on parameters
|
||||||
let violations: Vec<SandboxViolation> = if let Some(pid) = profile_id {
|
let violations: Vec<SandboxViolation> = if let Some(pid) = profile_id {
|
||||||
if let Some(aid) = agent_id {
|
if let Some(aid) = agent_id {
|
||||||
if let Some(lim) = limit {
|
if let Some(lim) = limit {
|
||||||
// All three parameters
|
// All three parameters
|
||||||
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
||||||
let rows = stmt.query_map(params![pid, aid, lim], |row| {
|
let rows = stmt
|
||||||
Ok(SandboxViolation {
|
.query_map(params![pid, aid, lim], |row| {
|
||||||
id: Some(row.get(0)?),
|
Ok(SandboxViolation {
|
||||||
profile_id: row.get(1)?,
|
id: Some(row.get(0)?),
|
||||||
agent_id: row.get(2)?,
|
profile_id: row.get(1)?,
|
||||||
agent_run_id: row.get(3)?,
|
agent_id: row.get(2)?,
|
||||||
operation_type: row.get(4)?,
|
agent_run_id: row.get(3)?,
|
||||||
pattern_value: row.get(5)?,
|
operation_type: row.get(4)?,
|
||||||
process_name: row.get(6)?,
|
pattern_value: row.get(5)?,
|
||||||
pid: row.get(7)?,
|
process_name: row.get(6)?,
|
||||||
denied_at: row.get(8)?,
|
pid: row.get(7)?,
|
||||||
|
denied_at: row.get(8)?,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}).map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
|
rows.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|e| e.to_string())?
|
||||||
} else {
|
} else {
|
||||||
// profile_id and agent_id only
|
// profile_id and agent_id only
|
||||||
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
||||||
let rows = stmt.query_map(params![pid, aid], |row| {
|
let rows = stmt
|
||||||
Ok(SandboxViolation {
|
.query_map(params![pid, aid], |row| {
|
||||||
id: Some(row.get(0)?),
|
Ok(SandboxViolation {
|
||||||
profile_id: row.get(1)?,
|
id: Some(row.get(0)?),
|
||||||
agent_id: row.get(2)?,
|
profile_id: row.get(1)?,
|
||||||
agent_run_id: row.get(3)?,
|
agent_id: row.get(2)?,
|
||||||
operation_type: row.get(4)?,
|
agent_run_id: row.get(3)?,
|
||||||
pattern_value: row.get(5)?,
|
operation_type: row.get(4)?,
|
||||||
process_name: row.get(6)?,
|
pattern_value: row.get(5)?,
|
||||||
pid: row.get(7)?,
|
process_name: row.get(6)?,
|
||||||
denied_at: row.get(8)?,
|
pid: row.get(7)?,
|
||||||
|
denied_at: row.get(8)?,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}).map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
|
rows.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|e| e.to_string())?
|
||||||
}
|
}
|
||||||
} else if let Some(lim) = limit {
|
} else if let Some(lim) = limit {
|
||||||
// profile_id and limit only
|
// profile_id and limit only
|
||||||
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
||||||
let rows = stmt.query_map(params![pid, lim], |row| {
|
let rows = stmt
|
||||||
Ok(SandboxViolation {
|
.query_map(params![pid, lim], |row| {
|
||||||
id: Some(row.get(0)?),
|
Ok(SandboxViolation {
|
||||||
profile_id: row.get(1)?,
|
id: Some(row.get(0)?),
|
||||||
agent_id: row.get(2)?,
|
profile_id: row.get(1)?,
|
||||||
agent_run_id: row.get(3)?,
|
agent_id: row.get(2)?,
|
||||||
operation_type: row.get(4)?,
|
agent_run_id: row.get(3)?,
|
||||||
pattern_value: row.get(5)?,
|
operation_type: row.get(4)?,
|
||||||
process_name: row.get(6)?,
|
pattern_value: row.get(5)?,
|
||||||
pid: row.get(7)?,
|
process_name: row.get(6)?,
|
||||||
denied_at: row.get(8)?,
|
pid: row.get(7)?,
|
||||||
|
denied_at: row.get(8)?,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}).map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
|
rows.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|e| e.to_string())?
|
||||||
} else {
|
} else {
|
||||||
// profile_id only
|
// profile_id only
|
||||||
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
||||||
let rows = stmt.query_map(params![pid], |row| {
|
let rows = stmt
|
||||||
Ok(SandboxViolation {
|
.query_map(params![pid], |row| {
|
||||||
id: Some(row.get(0)?),
|
Ok(SandboxViolation {
|
||||||
profile_id: row.get(1)?,
|
id: Some(row.get(0)?),
|
||||||
agent_id: row.get(2)?,
|
profile_id: row.get(1)?,
|
||||||
agent_run_id: row.get(3)?,
|
agent_id: row.get(2)?,
|
||||||
operation_type: row.get(4)?,
|
agent_run_id: row.get(3)?,
|
||||||
pattern_value: row.get(5)?,
|
operation_type: row.get(4)?,
|
||||||
process_name: row.get(6)?,
|
pattern_value: row.get(5)?,
|
||||||
pid: row.get(7)?,
|
process_name: row.get(6)?,
|
||||||
denied_at: row.get(8)?,
|
pid: row.get(7)?,
|
||||||
|
denied_at: row.get(8)?,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}).map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
|
rows.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|e| e.to_string())?
|
||||||
}
|
}
|
||||||
} else if let Some(aid) = agent_id {
|
} else if let Some(aid) = agent_id {
|
||||||
if let Some(lim) = limit {
|
if let Some(lim) = limit {
|
||||||
// agent_id and limit only
|
// agent_id and limit only
|
||||||
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
||||||
let rows = stmt.query_map(params![aid, lim], |row| {
|
let rows = stmt
|
||||||
Ok(SandboxViolation {
|
.query_map(params![aid, lim], |row| {
|
||||||
id: Some(row.get(0)?),
|
Ok(SandboxViolation {
|
||||||
profile_id: row.get(1)?,
|
id: Some(row.get(0)?),
|
||||||
agent_id: row.get(2)?,
|
profile_id: row.get(1)?,
|
||||||
agent_run_id: row.get(3)?,
|
agent_id: row.get(2)?,
|
||||||
operation_type: row.get(4)?,
|
agent_run_id: row.get(3)?,
|
||||||
pattern_value: row.get(5)?,
|
operation_type: row.get(4)?,
|
||||||
process_name: row.get(6)?,
|
pattern_value: row.get(5)?,
|
||||||
pid: row.get(7)?,
|
process_name: row.get(6)?,
|
||||||
denied_at: row.get(8)?,
|
pid: row.get(7)?,
|
||||||
|
denied_at: row.get(8)?,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}).map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
|
rows.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|e| e.to_string())?
|
||||||
} else {
|
} else {
|
||||||
// agent_id only
|
// agent_id only
|
||||||
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
||||||
let rows = stmt.query_map(params![aid], |row| {
|
let rows = stmt
|
||||||
Ok(SandboxViolation {
|
.query_map(params![aid], |row| {
|
||||||
id: Some(row.get(0)?),
|
Ok(SandboxViolation {
|
||||||
profile_id: row.get(1)?,
|
id: Some(row.get(0)?),
|
||||||
agent_id: row.get(2)?,
|
profile_id: row.get(1)?,
|
||||||
agent_run_id: row.get(3)?,
|
agent_id: row.get(2)?,
|
||||||
operation_type: row.get(4)?,
|
agent_run_id: row.get(3)?,
|
||||||
pattern_value: row.get(5)?,
|
operation_type: row.get(4)?,
|
||||||
process_name: row.get(6)?,
|
pattern_value: row.get(5)?,
|
||||||
pid: row.get(7)?,
|
process_name: row.get(6)?,
|
||||||
denied_at: row.get(8)?,
|
pid: row.get(7)?,
|
||||||
|
denied_at: row.get(8)?,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}).map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
|
rows.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|e| e.to_string())?
|
||||||
}
|
}
|
||||||
} else if let Some(lim) = limit {
|
} else if let Some(lim) = limit {
|
||||||
// limit only
|
// limit only
|
||||||
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
||||||
let rows = stmt.query_map(params![lim], |row| {
|
let rows = stmt
|
||||||
Ok(SandboxViolation {
|
.query_map(params![lim], |row| {
|
||||||
id: Some(row.get(0)?),
|
Ok(SandboxViolation {
|
||||||
profile_id: row.get(1)?,
|
id: Some(row.get(0)?),
|
||||||
agent_id: row.get(2)?,
|
profile_id: row.get(1)?,
|
||||||
agent_run_id: row.get(3)?,
|
agent_id: row.get(2)?,
|
||||||
operation_type: row.get(4)?,
|
agent_run_id: row.get(3)?,
|
||||||
pattern_value: row.get(5)?,
|
operation_type: row.get(4)?,
|
||||||
process_name: row.get(6)?,
|
pattern_value: row.get(5)?,
|
||||||
pid: row.get(7)?,
|
process_name: row.get(6)?,
|
||||||
denied_at: row.get(8)?,
|
pid: row.get(7)?,
|
||||||
|
denied_at: row.get(8)?,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}).map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
|
rows.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|e| e.to_string())?
|
||||||
} else {
|
} else {
|
||||||
// No parameters
|
// No parameters
|
||||||
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
let mut stmt = conn.prepare(&query).map_err(|e| e.to_string())?;
|
||||||
let rows = stmt.query_map([], |row| {
|
let rows = stmt
|
||||||
Ok(SandboxViolation {
|
.query_map([], |row| {
|
||||||
id: Some(row.get(0)?),
|
Ok(SandboxViolation {
|
||||||
profile_id: row.get(1)?,
|
id: Some(row.get(0)?),
|
||||||
agent_id: row.get(2)?,
|
profile_id: row.get(1)?,
|
||||||
agent_run_id: row.get(3)?,
|
agent_id: row.get(2)?,
|
||||||
operation_type: row.get(4)?,
|
agent_run_id: row.get(3)?,
|
||||||
pattern_value: row.get(5)?,
|
operation_type: row.get(4)?,
|
||||||
process_name: row.get(6)?,
|
pattern_value: row.get(5)?,
|
||||||
pid: row.get(7)?,
|
process_name: row.get(6)?,
|
||||||
denied_at: row.get(8)?,
|
pid: row.get(7)?,
|
||||||
|
denied_at: row.get(8)?,
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}).map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
rows.collect::<Result<Vec<_>, _>>().map_err(|e| e.to_string())?
|
rows.collect::<Result<Vec<_>, _>>()
|
||||||
|
.map_err(|e| e.to_string())?
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(violations)
|
Ok(violations)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -695,14 +716,14 @@ pub async fn log_sandbox_violation(
|
|||||||
pid: Option<i32>,
|
pid: Option<i32>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO sandbox_violations (profile_id, agent_id, agent_run_id, operation_type, pattern_value, process_name, pid)
|
"INSERT INTO sandbox_violations (profile_id, agent_id, agent_run_id, operation_type, pattern_value, process_name, pid)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
|
||||||
params![profile_id, agent_id, agent_run_id, operation_type, pattern_value, process_name, pid],
|
params![profile_id, agent_id, agent_run_id, operation_type, pattern_value, process_name, pid],
|
||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -713,7 +734,7 @@ pub async fn clear_sandbox_violations(
|
|||||||
older_than_days: Option<i64>,
|
older_than_days: Option<i64>,
|
||||||
) -> Result<i64, String> {
|
) -> Result<i64, String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
let query = if let Some(days) = older_than_days {
|
let query = if let Some(days) = older_than_days {
|
||||||
format!(
|
format!(
|
||||||
"DELETE FROM sandbox_violations WHERE denied_at < datetime('now', '-{} days')",
|
"DELETE FROM sandbox_violations WHERE denied_at < datetime('now', '-{} days')",
|
||||||
@@ -722,10 +743,9 @@ pub async fn clear_sandbox_violations(
|
|||||||
} else {
|
} else {
|
||||||
"DELETE FROM sandbox_violations".to_string()
|
"DELETE FROM sandbox_violations".to_string()
|
||||||
};
|
};
|
||||||
|
|
||||||
let deleted = conn.execute(&query, [])
|
let deleted = conn.execute(&query, []).map_err(|e| e.to_string())?;
|
||||||
.map_err(|e| e.to_string())?;
|
|
||||||
|
|
||||||
Ok(deleted as i64)
|
Ok(deleted as i64)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -735,28 +755,30 @@ pub async fn get_sandbox_violation_stats(
|
|||||||
db: State<'_, AgentDb>,
|
db: State<'_, AgentDb>,
|
||||||
) -> Result<serde_json::Value, String> {
|
) -> Result<serde_json::Value, String> {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
// Get total violations
|
// Get total violations
|
||||||
let total: i64 = conn
|
let total: i64 = conn
|
||||||
.query_row("SELECT COUNT(*) FROM sandbox_violations", [], |row| row.get(0))
|
.query_row("SELECT COUNT(*) FROM sandbox_violations", [], |row| {
|
||||||
|
row.get(0)
|
||||||
|
})
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
// Get violations by operation type
|
// Get violations by operation type
|
||||||
let mut stmt = conn
|
let mut stmt = conn
|
||||||
.prepare(
|
.prepare(
|
||||||
"SELECT operation_type, COUNT(*) as count
|
"SELECT operation_type, COUNT(*) as count
|
||||||
FROM sandbox_violations
|
FROM sandbox_violations
|
||||||
GROUP BY operation_type
|
GROUP BY operation_type
|
||||||
ORDER BY count DESC"
|
ORDER BY count DESC",
|
||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
let by_operation: Vec<(String, i64)> = stmt
|
let by_operation: Vec<(String, i64)> = stmt
|
||||||
.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))
|
.query_map([], |row| Ok((row.get(0)?, row.get(1)?)))
|
||||||
.map_err(|e| e.to_string())?
|
.map_err(|e| e.to_string())?
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
// Get recent violations count (last 24 hours)
|
// Get recent violations count (last 24 hours)
|
||||||
let recent: i64 = conn
|
let recent: i64 = conn
|
||||||
.query_row(
|
.query_row(
|
||||||
@@ -765,7 +787,7 @@ pub async fn get_sandbox_violation_stats(
|
|||||||
|row| row.get(0),
|
|row| row.get(0),
|
||||||
)
|
)
|
||||||
.map_err(|e| e.to_string())?;
|
.map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
Ok(serde_json::json!({
|
Ok(serde_json::json!({
|
||||||
"total": total,
|
"total": total,
|
||||||
"recent_24h": recent,
|
"recent_24h": recent,
|
||||||
@@ -789,10 +811,10 @@ pub async fn export_sandbox_profile(
|
|||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
crate::sandbox::profile::load_profile(&conn, profile_id).map_err(|e| e.to_string())?
|
crate::sandbox::profile::load_profile(&conn, profile_id).map_err(|e| e.to_string())?
|
||||||
};
|
};
|
||||||
|
|
||||||
// Get the rules
|
// Get the rules
|
||||||
let rules = list_sandbox_rules(db.clone(), profile_id).await?;
|
let rules = list_sandbox_rules(db.clone(), profile_id).await?;
|
||||||
|
|
||||||
Ok(SandboxProfileExport {
|
Ok(SandboxProfileExport {
|
||||||
version: 1,
|
version: 1,
|
||||||
exported_at: chrono::Utc::now().to_rfc3339(),
|
exported_at: chrono::Utc::now().to_rfc3339(),
|
||||||
@@ -808,17 +830,14 @@ pub async fn export_all_sandbox_profiles(
|
|||||||
) -> Result<SandboxProfileExport, String> {
|
) -> Result<SandboxProfileExport, String> {
|
||||||
let profiles = list_sandbox_profiles(db.clone()).await?;
|
let profiles = list_sandbox_profiles(db.clone()).await?;
|
||||||
let mut profile_exports = Vec::new();
|
let mut profile_exports = Vec::new();
|
||||||
|
|
||||||
for profile in profiles {
|
for profile in profiles {
|
||||||
if let Some(id) = profile.id {
|
if let Some(id) = profile.id {
|
||||||
let rules = list_sandbox_rules(db.clone(), id).await?;
|
let rules = list_sandbox_rules(db.clone(), id).await?;
|
||||||
profile_exports.push(SandboxProfileWithRules {
|
profile_exports.push(SandboxProfileWithRules { profile, rules });
|
||||||
profile,
|
|
||||||
rules,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(SandboxProfileExport {
|
Ok(SandboxProfileExport {
|
||||||
version: 1,
|
version: 1,
|
||||||
exported_at: chrono::Utc::now().to_rfc3339(),
|
exported_at: chrono::Utc::now().to_rfc3339(),
|
||||||
@@ -834,16 +853,19 @@ pub async fn import_sandbox_profiles(
|
|||||||
export_data: SandboxProfileExport,
|
export_data: SandboxProfileExport,
|
||||||
) -> Result<Vec<ImportResult>, String> {
|
) -> Result<Vec<ImportResult>, String> {
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
|
|
||||||
// Validate version
|
// Validate version
|
||||||
if export_data.version != 1 {
|
if export_data.version != 1 {
|
||||||
return Err(format!("Unsupported export version: {}", export_data.version));
|
return Err(format!(
|
||||||
|
"Unsupported export version: {}",
|
||||||
|
export_data.version
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
for profile_export in export_data.profiles {
|
for profile_export in export_data.profiles {
|
||||||
let mut profile = profile_export.profile;
|
let mut profile = profile_export.profile;
|
||||||
let original_name = profile.name.clone();
|
let original_name = profile.name.clone();
|
||||||
|
|
||||||
// Check for name conflicts
|
// Check for name conflicts
|
||||||
let existing: Result<i64, _> = {
|
let existing: Result<i64, _> = {
|
||||||
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
let conn = db.0.lock().map_err(|e| e.to_string())?;
|
||||||
@@ -853,29 +875,31 @@ pub async fn import_sandbox_profiles(
|
|||||||
|row| row.get(0),
|
|row| row.get(0),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let (imported, new_name) = match existing {
|
let (imported, new_name) = match existing {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
// Name conflict - append timestamp
|
// Name conflict - append timestamp
|
||||||
let new_name = format!("{} (imported {})", profile.name, chrono::Utc::now().format("%Y-%m-%d %H:%M"));
|
let new_name = format!(
|
||||||
|
"{} (imported {})",
|
||||||
|
profile.name,
|
||||||
|
chrono::Utc::now().format("%Y-%m-%d %H:%M")
|
||||||
|
);
|
||||||
profile.name = new_name.clone();
|
profile.name = new_name.clone();
|
||||||
(true, Some(new_name))
|
(true, Some(new_name))
|
||||||
}
|
}
|
||||||
Err(_) => (true, None),
|
Err(_) => (true, None),
|
||||||
};
|
};
|
||||||
|
|
||||||
if imported {
|
if imported {
|
||||||
// Reset profile fields for new insert
|
// Reset profile fields for new insert
|
||||||
profile.id = None;
|
profile.id = None;
|
||||||
profile.is_default = false; // Never import as default
|
profile.is_default = false; // Never import as default
|
||||||
|
|
||||||
// Create the profile
|
// Create the profile
|
||||||
let created_profile = create_sandbox_profile(
|
let created_profile =
|
||||||
db.clone(),
|
create_sandbox_profile(db.clone(), profile.name.clone(), profile.description)
|
||||||
profile.name.clone(),
|
.await?;
|
||||||
profile.description,
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
if let Some(new_id) = created_profile.id {
|
if let Some(new_id) = created_profile.id {
|
||||||
// Import rules
|
// Import rules
|
||||||
for rule in profile_export.rules {
|
for rule in profile_export.rules {
|
||||||
@@ -889,10 +913,11 @@ pub async fn import_sandbox_profiles(
|
|||||||
rule.pattern_value,
|
rule.pattern_value,
|
||||||
rule.enabled,
|
rule.enabled,
|
||||||
rule.platform_support,
|
rule.platform_support,
|
||||||
).await;
|
)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update profile status if needed
|
// Update profile status if needed
|
||||||
if profile.is_active {
|
if profile.is_active {
|
||||||
let _ = update_sandbox_profile(
|
let _ = update_sandbox_profile(
|
||||||
@@ -902,18 +927,21 @@ pub async fn import_sandbox_profiles(
|
|||||||
created_profile.description,
|
created_profile.description,
|
||||||
profile.is_active,
|
profile.is_active,
|
||||||
false, // Never set as default on import
|
false, // Never set as default on import
|
||||||
).await;
|
)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
results.push(ImportResult {
|
results.push(ImportResult {
|
||||||
profile_name: original_name,
|
profile_name: original_name,
|
||||||
imported: true,
|
imported: true,
|
||||||
reason: new_name.as_ref().map(|_| "Name conflict resolved".to_string()),
|
reason: new_name
|
||||||
|
.as_ref()
|
||||||
|
.map(|_| "Name conflict resolved".to_string()),
|
||||||
new_name,
|
new_name,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(results)
|
Ok(results)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,20 +1,20 @@
|
|||||||
use headless_chrome::{Browser, LaunchOptions};
|
|
||||||
use headless_chrome::protocol::cdp::Page;
|
use headless_chrome::protocol::cdp::Page;
|
||||||
|
use headless_chrome::{Browser, LaunchOptions};
|
||||||
use std::fs;
|
use std::fs;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tauri::AppHandle;
|
use tauri::AppHandle;
|
||||||
|
|
||||||
/// Captures a screenshot of a URL using headless Chrome
|
/// Captures a screenshot of a URL using headless Chrome
|
||||||
///
|
///
|
||||||
/// This function launches a headless Chrome browser, navigates to the specified URL,
|
/// This function launches a headless Chrome browser, navigates to the specified URL,
|
||||||
/// and captures a screenshot of either the entire page or a specific element.
|
/// and captures a screenshot of either the entire page or a specific element.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `app` - The Tauri application handle
|
/// * `app` - The Tauri application handle
|
||||||
/// * `url` - The URL to capture
|
/// * `url` - The URL to capture
|
||||||
/// * `selector` - Optional CSS selector for a specific element to capture
|
/// * `selector` - Optional CSS selector for a specific element to capture
|
||||||
/// * `full_page` - Whether to capture the entire page or just the viewport
|
/// * `full_page` - Whether to capture the entire page or just the viewport
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<String, String>` - The path to the saved screenshot file, or an error message
|
/// * `Result<String, String>` - The path to the saved screenshot file, or an error message
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
@@ -32,11 +32,10 @@ pub async fn capture_url_screenshot(
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Run the browser operations in a blocking task since headless_chrome is not async
|
// Run the browser operations in a blocking task since headless_chrome is not async
|
||||||
let result = tokio::task::spawn_blocking(move || {
|
let result =
|
||||||
capture_screenshot_sync(url, selector, full_page)
|
tokio::task::spawn_blocking(move || capture_screenshot_sync(url, selector, full_page))
|
||||||
})
|
.await
|
||||||
.await
|
.map_err(|e| format!("Failed to spawn blocking task: {}", e))?;
|
||||||
.map_err(|e| format!("Failed to spawn blocking task: {}", e))?;
|
|
||||||
|
|
||||||
// Log the result of the headless Chrome capture before returning
|
// Log the result of the headless Chrome capture before returning
|
||||||
match &result {
|
match &result {
|
||||||
@@ -61,8 +60,8 @@ fn capture_screenshot_sync(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Launch the browser
|
// Launch the browser
|
||||||
let browser = Browser::new(launch_options)
|
let browser =
|
||||||
.map_err(|e| format!("Failed to launch browser: {}", e))?;
|
Browser::new(launch_options).map_err(|e| format!("Failed to launch browser: {}", e))?;
|
||||||
|
|
||||||
// Create a new tab
|
// Create a new tab
|
||||||
let tab = browser
|
let tab = browser
|
||||||
@@ -86,14 +85,17 @@ fn capture_screenshot_sync(
|
|||||||
// Wait explicitly for the <body> element to exist – this often prevents
|
// Wait explicitly for the <body> element to exist – this often prevents
|
||||||
// "Unable to capture screenshot" CDP errors on some pages
|
// "Unable to capture screenshot" CDP errors on some pages
|
||||||
if let Err(e) = tab.wait_for_element("body") {
|
if let Err(e) = tab.wait_for_element("body") {
|
||||||
log::warn!("Timed out waiting for <body> element: {} – continuing anyway", e);
|
log::warn!(
|
||||||
|
"Timed out waiting for <body> element: {} – continuing anyway",
|
||||||
|
e
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture the screenshot
|
// Capture the screenshot
|
||||||
let screenshot_data = if let Some(selector) = selector {
|
let screenshot_data = if let Some(selector) = selector {
|
||||||
// Wait for the element and capture it
|
// Wait for the element and capture it
|
||||||
log::info!("Waiting for element with selector: {}", selector);
|
log::info!("Waiting for element with selector: {}", selector);
|
||||||
|
|
||||||
let element = tab
|
let element = tab
|
||||||
.wait_for_element(&selector)
|
.wait_for_element(&selector)
|
||||||
.map_err(|e| format!("Failed to find element '{}': {}", selector, e))?;
|
.map_err(|e| format!("Failed to find element '{}': {}", selector, e))?;
|
||||||
@@ -103,8 +105,11 @@ fn capture_screenshot_sync(
|
|||||||
.map_err(|e| format!("Failed to capture element screenshot: {}", e))?
|
.map_err(|e| format!("Failed to capture element screenshot: {}", e))?
|
||||||
} else {
|
} else {
|
||||||
// Capture the entire page or viewport
|
// Capture the entire page or viewport
|
||||||
log::info!("Capturing {} screenshot", if full_page { "full page" } else { "viewport" });
|
log::info!(
|
||||||
|
"Capturing {} screenshot",
|
||||||
|
if full_page { "full page" } else { "viewport" }
|
||||||
|
);
|
||||||
|
|
||||||
// Get the page dimensions for full page screenshot
|
// Get the page dimensions for full page screenshot
|
||||||
let clip = if full_page {
|
let clip = if full_page {
|
||||||
// Execute JavaScript to get the full page dimensions
|
// Execute JavaScript to get the full page dimensions
|
||||||
@@ -132,30 +137,30 @@ fn capture_screenshot_sync(
|
|||||||
)
|
)
|
||||||
.map_err(|e| format!("Failed to get page dimensions: {}", e))?;
|
.map_err(|e| format!("Failed to get page dimensions: {}", e))?;
|
||||||
|
|
||||||
// Extract dimensions from the result
|
// Extract dimensions from the result
|
||||||
let width = dimensions
|
let width = dimensions
|
||||||
.value
|
.value
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.and_then(|v| v.as_object())
|
.and_then(|v| v.as_object())
|
||||||
.and_then(|obj| obj.get("width"))
|
.and_then(|obj| obj.get("width"))
|
||||||
.and_then(|v| v.as_f64())
|
.and_then(|v| v.as_f64())
|
||||||
.unwrap_or(1920.0);
|
.unwrap_or(1920.0);
|
||||||
|
|
||||||
let height = dimensions
|
let height = dimensions
|
||||||
.value
|
.value
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.and_then(|v| v.as_object())
|
.and_then(|v| v.as_object())
|
||||||
.and_then(|obj| obj.get("height"))
|
.and_then(|obj| obj.get("height"))
|
||||||
.and_then(|v| v.as_f64())
|
.and_then(|v| v.as_f64())
|
||||||
.unwrap_or(1080.0);
|
.unwrap_or(1080.0);
|
||||||
|
|
||||||
Some(Page::Viewport {
|
Some(Page::Viewport {
|
||||||
x: 0.0,
|
x: 0.0,
|
||||||
y: 0.0,
|
y: 0.0,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
scale: 1.0,
|
scale: 1.0,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
@@ -176,13 +181,8 @@ fn capture_screenshot_sync(
|
|||||||
err
|
err
|
||||||
);
|
);
|
||||||
|
|
||||||
tab.capture_screenshot(
|
tab.capture_screenshot(Page::CaptureScreenshotFormatOption::Png, None, clip, true)
|
||||||
Page::CaptureScreenshotFormatOption::Png,
|
.map_err(|e| format!("Failed to capture screenshot after retry: {}", e))?
|
||||||
None,
|
|
||||||
clip,
|
|
||||||
true,
|
|
||||||
)
|
|
||||||
.map_err(|e| format!("Failed to capture screenshot after retry: {}", e))?
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -208,13 +208,13 @@ fn capture_screenshot_sync(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Cleans up old screenshot files from the temporary directory
|
/// Cleans up old screenshot files from the temporary directory
|
||||||
///
|
///
|
||||||
/// This function removes screenshot files older than the specified number of minutes
|
/// This function removes screenshot files older than the specified number of minutes
|
||||||
/// to prevent accumulation of temporary files.
|
/// to prevent accumulation of temporary files.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `older_than_minutes` - Remove files older than this many minutes (default: 60)
|
/// * `older_than_minutes` - Remove files older than this many minutes (default: 60)
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// * `Result<usize, String>` - The number of files deleted, or an error message
|
/// * `Result<usize, String>` - The number of files deleted, or an error message
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
@@ -222,24 +222,29 @@ pub async fn cleanup_screenshot_temp_files(
|
|||||||
older_than_minutes: Option<u64>,
|
older_than_minutes: Option<u64>,
|
||||||
) -> Result<usize, String> {
|
) -> Result<usize, String> {
|
||||||
let minutes = older_than_minutes.unwrap_or(60);
|
let minutes = older_than_minutes.unwrap_or(60);
|
||||||
log::info!("Cleaning up screenshot files older than {} minutes", minutes);
|
log::info!(
|
||||||
|
"Cleaning up screenshot files older than {} minutes",
|
||||||
|
minutes
|
||||||
|
);
|
||||||
|
|
||||||
let temp_dir = std::env::temp_dir();
|
let temp_dir = std::env::temp_dir();
|
||||||
let cutoff_time = chrono::Utc::now() - chrono::Duration::minutes(minutes as i64);
|
let cutoff_time = chrono::Utc::now() - chrono::Duration::minutes(minutes as i64);
|
||||||
let mut deleted_count = 0;
|
let mut deleted_count = 0;
|
||||||
|
|
||||||
// Read directory entries
|
// Read directory entries
|
||||||
let entries = fs::read_dir(&temp_dir)
|
let entries =
|
||||||
.map_err(|e| format!("Failed to read temp directory: {}", e))?;
|
fs::read_dir(&temp_dir).map_err(|e| format!("Failed to read temp directory: {}", e))?;
|
||||||
|
|
||||||
for entry in entries {
|
for entry in entries {
|
||||||
if let Ok(entry) = entry {
|
if let Ok(entry) = entry {
|
||||||
let path = entry.path();
|
let path = entry.path();
|
||||||
|
|
||||||
// Check if it's a claudia screenshot file
|
// Check if it's a claudia screenshot file
|
||||||
if let Some(filename) = path.file_name() {
|
if let Some(filename) = path.file_name() {
|
||||||
if let Some(filename_str) = filename.to_str() {
|
if let Some(filename_str) = filename.to_str() {
|
||||||
if filename_str.starts_with("claudia_screenshot_") && filename_str.ends_with(".png") {
|
if filename_str.starts_with("claudia_screenshot_")
|
||||||
|
&& filename_str.ends_with(".png")
|
||||||
|
{
|
||||||
// Check file age
|
// Check file age
|
||||||
if let Ok(metadata) = fs::metadata(&path) {
|
if let Ok(metadata) = fs::metadata(&path) {
|
||||||
if let Ok(modified) = metadata.modified() {
|
if let Ok(modified) = metadata.modified() {
|
||||||
@@ -258,7 +263,7 @@ pub async fn cleanup_screenshot_temp_files(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log::info!("Cleaned up {} old screenshot files", deleted_count);
|
log::info!("Cleaned up {} old screenshot files", deleted_count);
|
||||||
Ok(deleted_count)
|
Ok(deleted_count)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
use std::collections::{HashMap, HashSet};
|
|
||||||
use std::fs;
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use chrono::{DateTime, Local, NaiveDate};
|
use chrono::{DateTime, Local, NaiveDate};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json;
|
use serde_json;
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::fs;
|
||||||
|
use std::path::PathBuf;
|
||||||
use tauri::command;
|
use tauri::command;
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
@@ -108,11 +108,21 @@ fn calculate_cost(model: &str, usage: &UsageData) -> f64 {
|
|||||||
let cache_read_tokens = usage.cache_read_input_tokens.unwrap_or(0) as f64;
|
let cache_read_tokens = usage.cache_read_input_tokens.unwrap_or(0) as f64;
|
||||||
|
|
||||||
// Calculate cost based on model
|
// Calculate cost based on model
|
||||||
let (input_price, output_price, cache_write_price, cache_read_price) =
|
let (input_price, output_price, cache_write_price, cache_read_price) =
|
||||||
if model.contains("opus-4") || model.contains("claude-opus-4") {
|
if model.contains("opus-4") || model.contains("claude-opus-4") {
|
||||||
(OPUS_4_INPUT_PRICE, OPUS_4_OUTPUT_PRICE, OPUS_4_CACHE_WRITE_PRICE, OPUS_4_CACHE_READ_PRICE)
|
(
|
||||||
|
OPUS_4_INPUT_PRICE,
|
||||||
|
OPUS_4_OUTPUT_PRICE,
|
||||||
|
OPUS_4_CACHE_WRITE_PRICE,
|
||||||
|
OPUS_4_CACHE_READ_PRICE,
|
||||||
|
)
|
||||||
} else if model.contains("sonnet-4") || model.contains("claude-sonnet-4") {
|
} else if model.contains("sonnet-4") || model.contains("claude-sonnet-4") {
|
||||||
(SONNET_4_INPUT_PRICE, SONNET_4_OUTPUT_PRICE, SONNET_4_CACHE_WRITE_PRICE, SONNET_4_CACHE_READ_PRICE)
|
(
|
||||||
|
SONNET_4_INPUT_PRICE,
|
||||||
|
SONNET_4_OUTPUT_PRICE,
|
||||||
|
SONNET_4_CACHE_WRITE_PRICE,
|
||||||
|
SONNET_4_CACHE_READ_PRICE,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
// Return 0 for unknown models to avoid incorrect cost estimations.
|
// Return 0 for unknown models to avoid incorrect cost estimations.
|
||||||
(0.0, 0.0, 0.0, 0.0)
|
(0.0, 0.0, 0.0, 0.0)
|
||||||
@@ -134,10 +144,11 @@ fn parse_jsonl_file(
|
|||||||
) -> Vec<UsageEntry> {
|
) -> Vec<UsageEntry> {
|
||||||
let mut entries = Vec::new();
|
let mut entries = Vec::new();
|
||||||
let mut actual_project_path: Option<String> = None;
|
let mut actual_project_path: Option<String> = None;
|
||||||
|
|
||||||
if let Ok(content) = fs::read_to_string(path) {
|
if let Ok(content) = fs::read_to_string(path) {
|
||||||
// Extract session ID from the file path
|
// Extract session ID from the file path
|
||||||
let session_id = path.parent()
|
let session_id = path
|
||||||
|
.parent()
|
||||||
.and_then(|p| p.file_name())
|
.and_then(|p| p.file_name())
|
||||||
.and_then(|n| n.to_str())
|
.and_then(|n| n.to_str())
|
||||||
.unwrap_or("unknown")
|
.unwrap_or("unknown")
|
||||||
@@ -155,7 +166,7 @@ fn parse_jsonl_file(
|
|||||||
actual_project_path = Some(cwd.to_string());
|
actual_project_path = Some(cwd.to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to parse as JsonlEntry for usage data
|
// Try to parse as JsonlEntry for usage data
|
||||||
if let Ok(entry) = serde_json::from_value::<JsonlEntry>(json_value) {
|
if let Ok(entry) = serde_json::from_value::<JsonlEntry>(json_value) {
|
||||||
if let Some(message) = &entry.message {
|
if let Some(message) = &entry.message {
|
||||||
@@ -170,10 +181,11 @@ fn parse_jsonl_file(
|
|||||||
|
|
||||||
if let Some(usage) = &message.usage {
|
if let Some(usage) = &message.usage {
|
||||||
// Skip entries without meaningful token usage
|
// Skip entries without meaningful token usage
|
||||||
if usage.input_tokens.unwrap_or(0) == 0 &&
|
if usage.input_tokens.unwrap_or(0) == 0
|
||||||
usage.output_tokens.unwrap_or(0) == 0 &&
|
&& usage.output_tokens.unwrap_or(0) == 0
|
||||||
usage.cache_creation_input_tokens.unwrap_or(0) == 0 &&
|
&& usage.cache_creation_input_tokens.unwrap_or(0) == 0
|
||||||
usage.cache_read_input_tokens.unwrap_or(0) == 0 {
|
&& usage.cache_read_input_tokens.unwrap_or(0) == 0
|
||||||
|
{
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -184,17 +196,23 @@ fn parse_jsonl_file(
|
|||||||
0.0
|
0.0
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Use actual project path if found, otherwise use encoded name
|
// Use actual project path if found, otherwise use encoded name
|
||||||
let project_path = actual_project_path.clone()
|
let project_path = actual_project_path
|
||||||
|
.clone()
|
||||||
.unwrap_or_else(|| encoded_project_name.to_string());
|
.unwrap_or_else(|| encoded_project_name.to_string());
|
||||||
|
|
||||||
entries.push(UsageEntry {
|
entries.push(UsageEntry {
|
||||||
timestamp: entry.timestamp,
|
timestamp: entry.timestamp,
|
||||||
model: message.model.clone().unwrap_or_else(|| "unknown".to_string()),
|
model: message
|
||||||
|
.model
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "unknown".to_string()),
|
||||||
input_tokens: usage.input_tokens.unwrap_or(0),
|
input_tokens: usage.input_tokens.unwrap_or(0),
|
||||||
output_tokens: usage.output_tokens.unwrap_or(0),
|
output_tokens: usage.output_tokens.unwrap_or(0),
|
||||||
cache_creation_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
|
cache_creation_tokens: usage
|
||||||
|
.cache_creation_input_tokens
|
||||||
|
.unwrap_or(0),
|
||||||
cache_read_tokens: usage.cache_read_input_tokens.unwrap_or(0),
|
cache_read_tokens: usage.cache_read_input_tokens.unwrap_or(0),
|
||||||
cost,
|
cost,
|
||||||
session_id: entry.session_id.unwrap_or_else(|| session_id.clone()),
|
session_id: entry.session_id.unwrap_or_else(|| session_id.clone()),
|
||||||
@@ -263,10 +281,10 @@ fn get_all_usage_entries(claude_path: &PathBuf) -> Vec<UsageEntry> {
|
|||||||
let entries = parse_jsonl_file(&path, &project_name, &mut processed_hashes);
|
let entries = parse_jsonl_file(&path, &project_name, &mut processed_hashes);
|
||||||
all_entries.extend(entries);
|
all_entries.extend(entries);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort by timestamp
|
// Sort by timestamp
|
||||||
all_entries.sort_by(|a, b| a.timestamp.cmp(&b.timestamp));
|
all_entries.sort_by(|a, b| a.timestamp.cmp(&b.timestamp));
|
||||||
|
|
||||||
all_entries
|
all_entries
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -275,9 +293,9 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
|
|||||||
let claude_path = dirs::home_dir()
|
let claude_path = dirs::home_dir()
|
||||||
.ok_or("Failed to get home directory")?
|
.ok_or("Failed to get home directory")?
|
||||||
.join(".claude");
|
.join(".claude");
|
||||||
|
|
||||||
let all_entries = get_all_usage_entries(&claude_path);
|
let all_entries = get_all_usage_entries(&claude_path);
|
||||||
|
|
||||||
if all_entries.is_empty() {
|
if all_entries.is_empty() {
|
||||||
return Ok(UsageStats {
|
return Ok(UsageStats {
|
||||||
total_cost: 0.0,
|
total_cost: 0.0,
|
||||||
@@ -292,11 +310,12 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
|
|||||||
by_project: vec![],
|
by_project: vec![],
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter by days if specified
|
// Filter by days if specified
|
||||||
let filtered_entries = if let Some(days) = days {
|
let filtered_entries = if let Some(days) = days {
|
||||||
let cutoff = Local::now().naive_local().date() - chrono::Duration::days(days as i64);
|
let cutoff = Local::now().naive_local().date() - chrono::Duration::days(days as i64);
|
||||||
all_entries.into_iter()
|
all_entries
|
||||||
|
.into_iter()
|
||||||
.filter(|e| {
|
.filter(|e| {
|
||||||
if let Ok(dt) = DateTime::parse_from_rfc3339(&e.timestamp) {
|
if let Ok(dt) = DateTime::parse_from_rfc3339(&e.timestamp) {
|
||||||
dt.naive_local().date() >= cutoff
|
dt.naive_local().date() >= cutoff
|
||||||
@@ -308,18 +327,18 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
|
|||||||
} else {
|
} else {
|
||||||
all_entries
|
all_entries
|
||||||
};
|
};
|
||||||
|
|
||||||
// Calculate aggregated stats
|
// Calculate aggregated stats
|
||||||
let mut total_cost = 0.0;
|
let mut total_cost = 0.0;
|
||||||
let mut total_input_tokens = 0u64;
|
let mut total_input_tokens = 0u64;
|
||||||
let mut total_output_tokens = 0u64;
|
let mut total_output_tokens = 0u64;
|
||||||
let mut total_cache_creation_tokens = 0u64;
|
let mut total_cache_creation_tokens = 0u64;
|
||||||
let mut total_cache_read_tokens = 0u64;
|
let mut total_cache_read_tokens = 0u64;
|
||||||
|
|
||||||
let mut model_stats: HashMap<String, ModelUsage> = HashMap::new();
|
let mut model_stats: HashMap<String, ModelUsage> = HashMap::new();
|
||||||
let mut daily_stats: HashMap<String, DailyUsage> = HashMap::new();
|
let mut daily_stats: HashMap<String, DailyUsage> = HashMap::new();
|
||||||
let mut project_stats: HashMap<String, ProjectUsage> = HashMap::new();
|
let mut project_stats: HashMap<String, ProjectUsage> = HashMap::new();
|
||||||
|
|
||||||
for entry in &filtered_entries {
|
for entry in &filtered_entries {
|
||||||
// Update totals
|
// Update totals
|
||||||
total_cost += entry.cost;
|
total_cost += entry.cost;
|
||||||
@@ -327,18 +346,20 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
|
|||||||
total_output_tokens += entry.output_tokens;
|
total_output_tokens += entry.output_tokens;
|
||||||
total_cache_creation_tokens += entry.cache_creation_tokens;
|
total_cache_creation_tokens += entry.cache_creation_tokens;
|
||||||
total_cache_read_tokens += entry.cache_read_tokens;
|
total_cache_read_tokens += entry.cache_read_tokens;
|
||||||
|
|
||||||
// Update model stats
|
// Update model stats
|
||||||
let model_stat = model_stats.entry(entry.model.clone()).or_insert(ModelUsage {
|
let model_stat = model_stats
|
||||||
model: entry.model.clone(),
|
.entry(entry.model.clone())
|
||||||
total_cost: 0.0,
|
.or_insert(ModelUsage {
|
||||||
total_tokens: 0,
|
model: entry.model.clone(),
|
||||||
input_tokens: 0,
|
total_cost: 0.0,
|
||||||
output_tokens: 0,
|
total_tokens: 0,
|
||||||
cache_creation_tokens: 0,
|
input_tokens: 0,
|
||||||
cache_read_tokens: 0,
|
output_tokens: 0,
|
||||||
session_count: 0,
|
cache_creation_tokens: 0,
|
||||||
});
|
cache_read_tokens: 0,
|
||||||
|
session_count: 0,
|
||||||
|
});
|
||||||
model_stat.total_cost += entry.cost;
|
model_stat.total_cost += entry.cost;
|
||||||
model_stat.input_tokens += entry.input_tokens;
|
model_stat.input_tokens += entry.input_tokens;
|
||||||
model_stat.output_tokens += entry.output_tokens;
|
model_stat.output_tokens += entry.output_tokens;
|
||||||
@@ -346,9 +367,14 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
|
|||||||
model_stat.cache_read_tokens += entry.cache_read_tokens;
|
model_stat.cache_read_tokens += entry.cache_read_tokens;
|
||||||
model_stat.total_tokens = model_stat.input_tokens + model_stat.output_tokens;
|
model_stat.total_tokens = model_stat.input_tokens + model_stat.output_tokens;
|
||||||
model_stat.session_count += 1;
|
model_stat.session_count += 1;
|
||||||
|
|
||||||
// Update daily stats
|
// Update daily stats
|
||||||
let date = entry.timestamp.split('T').next().unwrap_or(&entry.timestamp).to_string();
|
let date = entry
|
||||||
|
.timestamp
|
||||||
|
.split('T')
|
||||||
|
.next()
|
||||||
|
.unwrap_or(&entry.timestamp)
|
||||||
|
.to_string();
|
||||||
let daily_stat = daily_stats.entry(date.clone()).or_insert(DailyUsage {
|
let daily_stat = daily_stats.entry(date.clone()).or_insert(DailyUsage {
|
||||||
date,
|
date,
|
||||||
total_cost: 0.0,
|
total_cost: 0.0,
|
||||||
@@ -356,43 +382,58 @@ pub fn get_usage_stats(days: Option<u32>) -> Result<UsageStats, String> {
|
|||||||
models_used: vec![],
|
models_used: vec![],
|
||||||
});
|
});
|
||||||
daily_stat.total_cost += entry.cost;
|
daily_stat.total_cost += entry.cost;
|
||||||
daily_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens;
|
daily_stat.total_tokens += entry.input_tokens
|
||||||
|
+ entry.output_tokens
|
||||||
|
+ entry.cache_creation_tokens
|
||||||
|
+ entry.cache_read_tokens;
|
||||||
if !daily_stat.models_used.contains(&entry.model) {
|
if !daily_stat.models_used.contains(&entry.model) {
|
||||||
daily_stat.models_used.push(entry.model.clone());
|
daily_stat.models_used.push(entry.model.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update project stats
|
// Update project stats
|
||||||
let project_stat = project_stats.entry(entry.project_path.clone()).or_insert(ProjectUsage {
|
let project_stat =
|
||||||
project_path: entry.project_path.clone(),
|
project_stats
|
||||||
project_name: entry.project_path.split('/').last()
|
.entry(entry.project_path.clone())
|
||||||
.unwrap_or(&entry.project_path)
|
.or_insert(ProjectUsage {
|
||||||
.to_string(),
|
project_path: entry.project_path.clone(),
|
||||||
total_cost: 0.0,
|
project_name: entry
|
||||||
total_tokens: 0,
|
.project_path
|
||||||
session_count: 0,
|
.split('/')
|
||||||
last_used: entry.timestamp.clone(),
|
.last()
|
||||||
});
|
.unwrap_or(&entry.project_path)
|
||||||
|
.to_string(),
|
||||||
|
total_cost: 0.0,
|
||||||
|
total_tokens: 0,
|
||||||
|
session_count: 0,
|
||||||
|
last_used: entry.timestamp.clone(),
|
||||||
|
});
|
||||||
project_stat.total_cost += entry.cost;
|
project_stat.total_cost += entry.cost;
|
||||||
project_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens;
|
project_stat.total_tokens += entry.input_tokens
|
||||||
|
+ entry.output_tokens
|
||||||
|
+ entry.cache_creation_tokens
|
||||||
|
+ entry.cache_read_tokens;
|
||||||
project_stat.session_count += 1;
|
project_stat.session_count += 1;
|
||||||
if entry.timestamp > project_stat.last_used {
|
if entry.timestamp > project_stat.last_used {
|
||||||
project_stat.last_used = entry.timestamp.clone();
|
project_stat.last_used = entry.timestamp.clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let total_tokens = total_input_tokens + total_output_tokens + total_cache_creation_tokens + total_cache_read_tokens;
|
let total_tokens = total_input_tokens
|
||||||
|
+ total_output_tokens
|
||||||
|
+ total_cache_creation_tokens
|
||||||
|
+ total_cache_read_tokens;
|
||||||
let total_sessions = filtered_entries.len() as u64;
|
let total_sessions = filtered_entries.len() as u64;
|
||||||
|
|
||||||
// Convert hashmaps to sorted vectors
|
// Convert hashmaps to sorted vectors
|
||||||
let mut by_model: Vec<ModelUsage> = model_stats.into_values().collect();
|
let mut by_model: Vec<ModelUsage> = model_stats.into_values().collect();
|
||||||
by_model.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap());
|
by_model.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap());
|
||||||
|
|
||||||
let mut by_date: Vec<DailyUsage> = daily_stats.into_values().collect();
|
let mut by_date: Vec<DailyUsage> = daily_stats.into_values().collect();
|
||||||
by_date.sort_by(|a, b| b.date.cmp(&a.date));
|
by_date.sort_by(|a, b| b.date.cmp(&a.date));
|
||||||
|
|
||||||
let mut by_project: Vec<ProjectUsage> = project_stats.into_values().collect();
|
let mut by_project: Vec<ProjectUsage> = project_stats.into_values().collect();
|
||||||
by_project.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap());
|
by_project.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap());
|
||||||
|
|
||||||
Ok(UsageStats {
|
Ok(UsageStats {
|
||||||
total_cost,
|
total_cost,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
@@ -412,27 +453,26 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
|
|||||||
let claude_path = dirs::home_dir()
|
let claude_path = dirs::home_dir()
|
||||||
.ok_or("Failed to get home directory")?
|
.ok_or("Failed to get home directory")?
|
||||||
.join(".claude");
|
.join(".claude");
|
||||||
|
|
||||||
let all_entries = get_all_usage_entries(&claude_path);
|
let all_entries = get_all_usage_entries(&claude_path);
|
||||||
|
|
||||||
// Parse dates
|
// Parse dates
|
||||||
let start = NaiveDate::parse_from_str(&start_date, "%Y-%m-%d")
|
let start = NaiveDate::parse_from_str(&start_date, "%Y-%m-%d").or_else(|_| {
|
||||||
.or_else(|_| {
|
// Try parsing ISO datetime format
|
||||||
// Try parsing ISO datetime format
|
DateTime::parse_from_rfc3339(&start_date)
|
||||||
DateTime::parse_from_rfc3339(&start_date)
|
.map(|dt| dt.naive_local().date())
|
||||||
.map(|dt| dt.naive_local().date())
|
.map_err(|e| format!("Invalid start date: {}", e))
|
||||||
.map_err(|e| format!("Invalid start date: {}", e))
|
})?;
|
||||||
})?;
|
let end = NaiveDate::parse_from_str(&end_date, "%Y-%m-%d").or_else(|_| {
|
||||||
let end = NaiveDate::parse_from_str(&end_date, "%Y-%m-%d")
|
// Try parsing ISO datetime format
|
||||||
.or_else(|_| {
|
DateTime::parse_from_rfc3339(&end_date)
|
||||||
// Try parsing ISO datetime format
|
.map(|dt| dt.naive_local().date())
|
||||||
DateTime::parse_from_rfc3339(&end_date)
|
.map_err(|e| format!("Invalid end date: {}", e))
|
||||||
.map(|dt| dt.naive_local().date())
|
})?;
|
||||||
.map_err(|e| format!("Invalid end date: {}", e))
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Filter entries by date range
|
// Filter entries by date range
|
||||||
let filtered_entries: Vec<_> = all_entries.into_iter()
|
let filtered_entries: Vec<_> = all_entries
|
||||||
|
.into_iter()
|
||||||
.filter(|e| {
|
.filter(|e| {
|
||||||
if let Ok(dt) = DateTime::parse_from_rfc3339(&e.timestamp) {
|
if let Ok(dt) = DateTime::parse_from_rfc3339(&e.timestamp) {
|
||||||
let date = dt.naive_local().date();
|
let date = dt.naive_local().date();
|
||||||
@@ -442,7 +482,7 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if filtered_entries.is_empty() {
|
if filtered_entries.is_empty() {
|
||||||
return Ok(UsageStats {
|
return Ok(UsageStats {
|
||||||
total_cost: 0.0,
|
total_cost: 0.0,
|
||||||
@@ -457,18 +497,18 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
|
|||||||
by_project: vec![],
|
by_project: vec![],
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate aggregated stats (same logic as get_usage_stats)
|
// Calculate aggregated stats (same logic as get_usage_stats)
|
||||||
let mut total_cost = 0.0;
|
let mut total_cost = 0.0;
|
||||||
let mut total_input_tokens = 0u64;
|
let mut total_input_tokens = 0u64;
|
||||||
let mut total_output_tokens = 0u64;
|
let mut total_output_tokens = 0u64;
|
||||||
let mut total_cache_creation_tokens = 0u64;
|
let mut total_cache_creation_tokens = 0u64;
|
||||||
let mut total_cache_read_tokens = 0u64;
|
let mut total_cache_read_tokens = 0u64;
|
||||||
|
|
||||||
let mut model_stats: HashMap<String, ModelUsage> = HashMap::new();
|
let mut model_stats: HashMap<String, ModelUsage> = HashMap::new();
|
||||||
let mut daily_stats: HashMap<String, DailyUsage> = HashMap::new();
|
let mut daily_stats: HashMap<String, DailyUsage> = HashMap::new();
|
||||||
let mut project_stats: HashMap<String, ProjectUsage> = HashMap::new();
|
let mut project_stats: HashMap<String, ProjectUsage> = HashMap::new();
|
||||||
|
|
||||||
for entry in &filtered_entries {
|
for entry in &filtered_entries {
|
||||||
// Update totals
|
// Update totals
|
||||||
total_cost += entry.cost;
|
total_cost += entry.cost;
|
||||||
@@ -476,18 +516,20 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
|
|||||||
total_output_tokens += entry.output_tokens;
|
total_output_tokens += entry.output_tokens;
|
||||||
total_cache_creation_tokens += entry.cache_creation_tokens;
|
total_cache_creation_tokens += entry.cache_creation_tokens;
|
||||||
total_cache_read_tokens += entry.cache_read_tokens;
|
total_cache_read_tokens += entry.cache_read_tokens;
|
||||||
|
|
||||||
// Update model stats
|
// Update model stats
|
||||||
let model_stat = model_stats.entry(entry.model.clone()).or_insert(ModelUsage {
|
let model_stat = model_stats
|
||||||
model: entry.model.clone(),
|
.entry(entry.model.clone())
|
||||||
total_cost: 0.0,
|
.or_insert(ModelUsage {
|
||||||
total_tokens: 0,
|
model: entry.model.clone(),
|
||||||
input_tokens: 0,
|
total_cost: 0.0,
|
||||||
output_tokens: 0,
|
total_tokens: 0,
|
||||||
cache_creation_tokens: 0,
|
input_tokens: 0,
|
||||||
cache_read_tokens: 0,
|
output_tokens: 0,
|
||||||
session_count: 0,
|
cache_creation_tokens: 0,
|
||||||
});
|
cache_read_tokens: 0,
|
||||||
|
session_count: 0,
|
||||||
|
});
|
||||||
model_stat.total_cost += entry.cost;
|
model_stat.total_cost += entry.cost;
|
||||||
model_stat.input_tokens += entry.input_tokens;
|
model_stat.input_tokens += entry.input_tokens;
|
||||||
model_stat.output_tokens += entry.output_tokens;
|
model_stat.output_tokens += entry.output_tokens;
|
||||||
@@ -495,9 +537,14 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
|
|||||||
model_stat.cache_read_tokens += entry.cache_read_tokens;
|
model_stat.cache_read_tokens += entry.cache_read_tokens;
|
||||||
model_stat.total_tokens = model_stat.input_tokens + model_stat.output_tokens;
|
model_stat.total_tokens = model_stat.input_tokens + model_stat.output_tokens;
|
||||||
model_stat.session_count += 1;
|
model_stat.session_count += 1;
|
||||||
|
|
||||||
// Update daily stats
|
// Update daily stats
|
||||||
let date = entry.timestamp.split('T').next().unwrap_or(&entry.timestamp).to_string();
|
let date = entry
|
||||||
|
.timestamp
|
||||||
|
.split('T')
|
||||||
|
.next()
|
||||||
|
.unwrap_or(&entry.timestamp)
|
||||||
|
.to_string();
|
||||||
let daily_stat = daily_stats.entry(date.clone()).or_insert(DailyUsage {
|
let daily_stat = daily_stats.entry(date.clone()).or_insert(DailyUsage {
|
||||||
date,
|
date,
|
||||||
total_cost: 0.0,
|
total_cost: 0.0,
|
||||||
@@ -505,43 +552,58 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
|
|||||||
models_used: vec![],
|
models_used: vec![],
|
||||||
});
|
});
|
||||||
daily_stat.total_cost += entry.cost;
|
daily_stat.total_cost += entry.cost;
|
||||||
daily_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens;
|
daily_stat.total_tokens += entry.input_tokens
|
||||||
|
+ entry.output_tokens
|
||||||
|
+ entry.cache_creation_tokens
|
||||||
|
+ entry.cache_read_tokens;
|
||||||
if !daily_stat.models_used.contains(&entry.model) {
|
if !daily_stat.models_used.contains(&entry.model) {
|
||||||
daily_stat.models_used.push(entry.model.clone());
|
daily_stat.models_used.push(entry.model.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update project stats
|
// Update project stats
|
||||||
let project_stat = project_stats.entry(entry.project_path.clone()).or_insert(ProjectUsage {
|
let project_stat =
|
||||||
project_path: entry.project_path.clone(),
|
project_stats
|
||||||
project_name: entry.project_path.split('/').last()
|
.entry(entry.project_path.clone())
|
||||||
.unwrap_or(&entry.project_path)
|
.or_insert(ProjectUsage {
|
||||||
.to_string(),
|
project_path: entry.project_path.clone(),
|
||||||
total_cost: 0.0,
|
project_name: entry
|
||||||
total_tokens: 0,
|
.project_path
|
||||||
session_count: 0,
|
.split('/')
|
||||||
last_used: entry.timestamp.clone(),
|
.last()
|
||||||
});
|
.unwrap_or(&entry.project_path)
|
||||||
|
.to_string(),
|
||||||
|
total_cost: 0.0,
|
||||||
|
total_tokens: 0,
|
||||||
|
session_count: 0,
|
||||||
|
last_used: entry.timestamp.clone(),
|
||||||
|
});
|
||||||
project_stat.total_cost += entry.cost;
|
project_stat.total_cost += entry.cost;
|
||||||
project_stat.total_tokens += entry.input_tokens + entry.output_tokens + entry.cache_creation_tokens + entry.cache_read_tokens;
|
project_stat.total_tokens += entry.input_tokens
|
||||||
|
+ entry.output_tokens
|
||||||
|
+ entry.cache_creation_tokens
|
||||||
|
+ entry.cache_read_tokens;
|
||||||
project_stat.session_count += 1;
|
project_stat.session_count += 1;
|
||||||
if entry.timestamp > project_stat.last_used {
|
if entry.timestamp > project_stat.last_used {
|
||||||
project_stat.last_used = entry.timestamp.clone();
|
project_stat.last_used = entry.timestamp.clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let total_tokens = total_input_tokens + total_output_tokens + total_cache_creation_tokens + total_cache_read_tokens;
|
let total_tokens = total_input_tokens
|
||||||
|
+ total_output_tokens
|
||||||
|
+ total_cache_creation_tokens
|
||||||
|
+ total_cache_read_tokens;
|
||||||
let total_sessions = filtered_entries.len() as u64;
|
let total_sessions = filtered_entries.len() as u64;
|
||||||
|
|
||||||
// Convert hashmaps to sorted vectors
|
// Convert hashmaps to sorted vectors
|
||||||
let mut by_model: Vec<ModelUsage> = model_stats.into_values().collect();
|
let mut by_model: Vec<ModelUsage> = model_stats.into_values().collect();
|
||||||
by_model.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap());
|
by_model.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap());
|
||||||
|
|
||||||
let mut by_date: Vec<DailyUsage> = daily_stats.into_values().collect();
|
let mut by_date: Vec<DailyUsage> = daily_stats.into_values().collect();
|
||||||
by_date.sort_by(|a, b| b.date.cmp(&a.date));
|
by_date.sort_by(|a, b| b.date.cmp(&a.date));
|
||||||
|
|
||||||
let mut by_project: Vec<ProjectUsage> = project_stats.into_values().collect();
|
let mut by_project: Vec<ProjectUsage> = project_stats.into_values().collect();
|
||||||
by_project.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap());
|
by_project.sort_by(|a, b| b.total_cost.partial_cmp(&a.total_cost).unwrap());
|
||||||
|
|
||||||
Ok(UsageStats {
|
Ok(UsageStats {
|
||||||
total_cost,
|
total_cost,
|
||||||
total_tokens,
|
total_tokens,
|
||||||
@@ -557,23 +619,26 @@ pub fn get_usage_by_date_range(start_date: String, end_date: String) -> Result<U
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[command]
|
#[command]
|
||||||
pub fn get_usage_details(project_path: Option<String>, date: Option<String>) -> Result<Vec<UsageEntry>, String> {
|
pub fn get_usage_details(
|
||||||
|
project_path: Option<String>,
|
||||||
|
date: Option<String>,
|
||||||
|
) -> Result<Vec<UsageEntry>, String> {
|
||||||
let claude_path = dirs::home_dir()
|
let claude_path = dirs::home_dir()
|
||||||
.ok_or("Failed to get home directory")?
|
.ok_or("Failed to get home directory")?
|
||||||
.join(".claude");
|
.join(".claude");
|
||||||
|
|
||||||
let mut all_entries = get_all_usage_entries(&claude_path);
|
let mut all_entries = get_all_usage_entries(&claude_path);
|
||||||
|
|
||||||
// Filter by project if specified
|
// Filter by project if specified
|
||||||
if let Some(project) = project_path {
|
if let Some(project) = project_path {
|
||||||
all_entries.retain(|e| e.project_path == project);
|
all_entries.retain(|e| e.project_path == project);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter by date if specified
|
// Filter by date if specified
|
||||||
if let Some(date) = date {
|
if let Some(date) = date {
|
||||||
all_entries.retain(|e| e.timestamp.starts_with(&date));
|
all_entries.retain(|e| e.timestamp.starts_with(&date));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(all_entries)
|
Ok(all_entries)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -586,7 +651,7 @@ pub fn get_session_stats(
|
|||||||
let claude_path = dirs::home_dir()
|
let claude_path = dirs::home_dir()
|
||||||
.ok_or("Failed to get home directory")?
|
.ok_or("Failed to get home directory")?
|
||||||
.join(".claude");
|
.join(".claude");
|
||||||
|
|
||||||
let all_entries = get_all_usage_entries(&claude_path);
|
let all_entries = get_all_usage_entries(&claude_path);
|
||||||
|
|
||||||
let since_date = since.and_then(|s| NaiveDate::parse_from_str(&s, "%Y%m%d").ok());
|
let since_date = since.and_then(|s| NaiveDate::parse_from_str(&s, "%Y%m%d").ok());
|
||||||
@@ -609,14 +674,16 @@ pub fn get_session_stats(
|
|||||||
let mut session_stats: HashMap<String, ProjectUsage> = HashMap::new();
|
let mut session_stats: HashMap<String, ProjectUsage> = HashMap::new();
|
||||||
for entry in &filtered_entries {
|
for entry in &filtered_entries {
|
||||||
let session_key = format!("{}/{}", entry.project_path, entry.session_id);
|
let session_key = format!("{}/{}", entry.project_path, entry.session_id);
|
||||||
let project_stat = session_stats.entry(session_key).or_insert_with(|| ProjectUsage {
|
let project_stat = session_stats
|
||||||
project_path: entry.project_path.clone(),
|
.entry(session_key)
|
||||||
project_name: entry.session_id.clone(), // Using session_id as project_name for session view
|
.or_insert_with(|| ProjectUsage {
|
||||||
total_cost: 0.0,
|
project_path: entry.project_path.clone(),
|
||||||
total_tokens: 0,
|
project_name: entry.session_id.clone(), // Using session_id as project_name for session view
|
||||||
session_count: 0, // In this context, this will count entries per session
|
total_cost: 0.0,
|
||||||
last_used: " ".to_string(),
|
total_tokens: 0,
|
||||||
});
|
session_count: 0, // In this context, this will count entries per session
|
||||||
|
last_used: " ".to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
project_stat.total_cost += entry.cost;
|
project_stat.total_cost += entry.cost;
|
||||||
project_stat.total_tokens += entry.input_tokens
|
project_stat.total_tokens += entry.input_tokens
|
||||||
@@ -643,6 +710,5 @@ pub fn get_session_stats(
|
|||||||
by_session.sort_by(|a, b| b.last_used.cmp(&a.last_used));
|
by_session.sort_by(|a, b| b.last_used.cmp(&a.last_used));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Ok(by_session)
|
Ok(by_session)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
// Learn more about Tauri commands at https://tauri.app/develop/calling-rust/
|
// Learn more about Tauri commands at https://tauri.app/develop/calling-rust/
|
||||||
|
|
||||||
// Declare modules
|
// Declare modules
|
||||||
pub mod commands;
|
|
||||||
pub mod sandbox;
|
|
||||||
pub mod checkpoint;
|
pub mod checkpoint;
|
||||||
pub mod process;
|
|
||||||
pub mod claude_binary;
|
pub mod claude_binary;
|
||||||
|
pub mod commands;
|
||||||
|
pub mod process;
|
||||||
|
pub mod sandbox;
|
||||||
|
|
||||||
#[cfg_attr(mobile, tauri::mobile_entry_point)]
|
#[cfg_attr(mobile, tauri::mobile_entry_point)]
|
||||||
pub fn run() {
|
pub fn run() {
|
||||||
|
|||||||
@@ -1,57 +1,52 @@
|
|||||||
// Prevents additional console window on Windows in release, DO NOT REMOVE!!
|
// Prevents additional console window on Windows in release, DO NOT REMOVE!!
|
||||||
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
|
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
|
||||||
|
|
||||||
mod commands;
|
|
||||||
mod sandbox;
|
|
||||||
mod checkpoint;
|
mod checkpoint;
|
||||||
mod process;
|
|
||||||
mod claude_binary;
|
mod claude_binary;
|
||||||
|
mod commands;
|
||||||
|
mod process;
|
||||||
|
mod sandbox;
|
||||||
|
|
||||||
use tauri::Manager;
|
use checkpoint::state::CheckpointState;
|
||||||
use commands::claude::{
|
|
||||||
get_claude_settings, get_project_sessions, get_system_prompt, list_projects, open_new_session,
|
|
||||||
check_claude_version, save_system_prompt, save_claude_settings,
|
|
||||||
find_claude_md_files, read_claude_md_file, save_claude_md_file,
|
|
||||||
load_session_history, execute_claude_code, continue_claude_code, resume_claude_code,
|
|
||||||
list_directory_contents, search_files,
|
|
||||||
create_checkpoint, restore_checkpoint, list_checkpoints, fork_from_checkpoint,
|
|
||||||
get_session_timeline, update_checkpoint_settings, get_checkpoint_diff,
|
|
||||||
track_checkpoint_message, track_session_messages, check_auto_checkpoint, cleanup_old_checkpoints,
|
|
||||||
get_checkpoint_settings, clear_checkpoint_manager, get_checkpoint_state_stats,
|
|
||||||
get_recently_modified_files, cancel_claude_execution, ClaudeProcessState,
|
|
||||||
};
|
|
||||||
use commands::agents::{
|
use commands::agents::{
|
||||||
init_database, list_agents, create_agent, update_agent, delete_agent,
|
cleanup_finished_processes, create_agent, delete_agent, execute_agent, export_agent,
|
||||||
get_agent, execute_agent, list_agent_runs, get_agent_run,
|
export_agent_to_file, fetch_github_agent_content, fetch_github_agents, get_agent,
|
||||||
get_agent_run_with_real_time_metrics, list_agent_runs_with_metrics,
|
get_agent_run, get_agent_run_with_real_time_metrics, get_claude_binary_path,
|
||||||
list_running_sessions, kill_agent_session,
|
get_live_session_output, get_session_output, get_session_status, import_agent,
|
||||||
get_session_status, cleanup_finished_processes, get_session_output,
|
import_agent_from_file, import_agent_from_github, init_database, kill_agent_session,
|
||||||
get_live_session_output, stream_session_output, get_claude_binary_path,
|
list_agent_runs, list_agent_runs_with_metrics, list_agents, list_claude_installations,
|
||||||
set_claude_binary_path, export_agent, export_agent_to_file, import_agent,
|
list_running_sessions, set_claude_binary_path, stream_session_output, update_agent, AgentDb,
|
||||||
import_agent_from_file, fetch_github_agents, fetch_github_agent_content,
|
|
||||||
import_agent_from_github, list_claude_installations, AgentDb
|
|
||||||
};
|
};
|
||||||
use commands::sandbox::{
|
use commands::claude::{
|
||||||
list_sandbox_profiles, create_sandbox_profile, update_sandbox_profile, delete_sandbox_profile,
|
cancel_claude_execution, check_auto_checkpoint, check_claude_version, cleanup_old_checkpoints,
|
||||||
get_sandbox_profile, list_sandbox_rules, create_sandbox_rule, update_sandbox_rule,
|
clear_checkpoint_manager, continue_claude_code, create_checkpoint, execute_claude_code,
|
||||||
delete_sandbox_rule, get_platform_capabilities, test_sandbox_profile,
|
find_claude_md_files, fork_from_checkpoint, get_checkpoint_diff, get_checkpoint_settings,
|
||||||
list_sandbox_violations, log_sandbox_violation, clear_sandbox_violations, get_sandbox_violation_stats,
|
get_checkpoint_state_stats, get_claude_settings, get_project_sessions,
|
||||||
export_sandbox_profile, export_all_sandbox_profiles, import_sandbox_profiles,
|
get_recently_modified_files, get_session_timeline, get_system_prompt, list_checkpoints,
|
||||||
};
|
list_directory_contents, list_projects, load_session_history, open_new_session,
|
||||||
use commands::screenshot::{
|
read_claude_md_file, restore_checkpoint, resume_claude_code, save_claude_md_file,
|
||||||
capture_url_screenshot, cleanup_screenshot_temp_files,
|
save_claude_settings, save_system_prompt, search_files, track_checkpoint_message,
|
||||||
};
|
track_session_messages, update_checkpoint_settings, ClaudeProcessState,
|
||||||
use commands::usage::{
|
|
||||||
get_usage_stats, get_usage_by_date_range, get_usage_details, get_session_stats,
|
|
||||||
};
|
};
|
||||||
use commands::mcp::{
|
use commands::mcp::{
|
||||||
mcp_add, mcp_list, mcp_get, mcp_remove, mcp_add_json, mcp_add_from_claude_desktop,
|
mcp_add, mcp_add_from_claude_desktop, mcp_add_json, mcp_get, mcp_get_server_status, mcp_list,
|
||||||
mcp_serve, mcp_test_connection, mcp_reset_project_choices, mcp_get_server_status,
|
mcp_read_project_config, mcp_remove, mcp_reset_project_choices, mcp_save_project_config,
|
||||||
mcp_read_project_config, mcp_save_project_config,
|
mcp_serve, mcp_test_connection,
|
||||||
|
};
|
||||||
|
use commands::sandbox::{
|
||||||
|
clear_sandbox_violations, create_sandbox_profile, create_sandbox_rule, delete_sandbox_profile,
|
||||||
|
delete_sandbox_rule, export_all_sandbox_profiles, export_sandbox_profile,
|
||||||
|
get_platform_capabilities, get_sandbox_profile, get_sandbox_violation_stats,
|
||||||
|
import_sandbox_profiles, list_sandbox_profiles, list_sandbox_rules, list_sandbox_violations,
|
||||||
|
log_sandbox_violation, test_sandbox_profile, update_sandbox_profile, update_sandbox_rule,
|
||||||
|
};
|
||||||
|
use commands::screenshot::{capture_url_screenshot, cleanup_screenshot_temp_files};
|
||||||
|
use commands::usage::{
|
||||||
|
get_session_stats, get_usage_by_date_range, get_usage_details, get_usage_stats,
|
||||||
};
|
};
|
||||||
use std::sync::Mutex;
|
|
||||||
use checkpoint::state::CheckpointState;
|
|
||||||
use process::ProcessRegistryState;
|
use process::ProcessRegistryState;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use tauri::Manager;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// Initialize logger
|
// Initialize logger
|
||||||
@@ -72,32 +67,34 @@ fn main() {
|
|||||||
// Initialize agents database
|
// Initialize agents database
|
||||||
let conn = init_database(&app.handle()).expect("Failed to initialize agents database");
|
let conn = init_database(&app.handle()).expect("Failed to initialize agents database");
|
||||||
app.manage(AgentDb(Mutex::new(conn)));
|
app.manage(AgentDb(Mutex::new(conn)));
|
||||||
|
|
||||||
// Initialize checkpoint state
|
// Initialize checkpoint state
|
||||||
let checkpoint_state = CheckpointState::new();
|
let checkpoint_state = CheckpointState::new();
|
||||||
|
|
||||||
// Set the Claude directory path
|
// Set the Claude directory path
|
||||||
if let Ok(claude_dir) = dirs::home_dir()
|
if let Ok(claude_dir) = dirs::home_dir()
|
||||||
.ok_or_else(|| "Could not find home directory")
|
.ok_or_else(|| "Could not find home directory")
|
||||||
.and_then(|home| {
|
.and_then(|home| {
|
||||||
let claude_path = home.join(".claude");
|
let claude_path = home.join(".claude");
|
||||||
claude_path.canonicalize()
|
claude_path
|
||||||
|
.canonicalize()
|
||||||
.map_err(|_| "Could not find ~/.claude directory")
|
.map_err(|_| "Could not find ~/.claude directory")
|
||||||
}) {
|
})
|
||||||
|
{
|
||||||
let state_clone = checkpoint_state.clone();
|
let state_clone = checkpoint_state.clone();
|
||||||
tauri::async_runtime::spawn(async move {
|
tauri::async_runtime::spawn(async move {
|
||||||
state_clone.set_claude_dir(claude_dir).await;
|
state_clone.set_claude_dir(claude_dir).await;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
app.manage(checkpoint_state);
|
app.manage(checkpoint_state);
|
||||||
|
|
||||||
// Initialize process registry
|
// Initialize process registry
|
||||||
app.manage(ProcessRegistryState::default());
|
app.manage(ProcessRegistryState::default());
|
||||||
|
|
||||||
// Initialize Claude process state
|
// Initialize Claude process state
|
||||||
app.manage(ClaudeProcessState::default());
|
app.manage(ClaudeProcessState::default());
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
})
|
})
|
||||||
.invoke_handler(tauri::generate_handler![
|
.invoke_handler(tauri::generate_handler![
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
pub mod registry;
|
pub mod registry;
|
||||||
|
|
||||||
pub use registry::*;
|
pub use registry::*;
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use tokio::process::Child;
|
use tokio::process::Child;
|
||||||
use chrono::{DateTime, Utc};
|
|
||||||
|
|
||||||
/// Information about a running agent process
|
/// Information about a running agent process
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -50,7 +50,7 @@ impl ProcessRegistry {
|
|||||||
child: Child,
|
child: Child,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
let mut processes = self.processes.lock().map_err(|e| e.to_string())?;
|
let mut processes = self.processes.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
let process_info = ProcessInfo {
|
let process_info = ProcessInfo {
|
||||||
run_id,
|
run_id,
|
||||||
agent_id,
|
agent_id,
|
||||||
@@ -84,7 +84,10 @@ impl ProcessRegistry {
|
|||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub fn get_running_processes(&self) -> Result<Vec<ProcessInfo>, String> {
|
pub fn get_running_processes(&self) -> Result<Vec<ProcessInfo>, String> {
|
||||||
let processes = self.processes.lock().map_err(|e| e.to_string())?;
|
let processes = self.processes.lock().map_err(|e| e.to_string())?;
|
||||||
Ok(processes.values().map(|handle| handle.info.clone()).collect())
|
Ok(processes
|
||||||
|
.values()
|
||||||
|
.map(|handle| handle.info.clone())
|
||||||
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a specific running process
|
/// Get a specific running process
|
||||||
@@ -96,8 +99,8 @@ impl ProcessRegistry {
|
|||||||
|
|
||||||
/// Kill a running process with proper cleanup
|
/// Kill a running process with proper cleanup
|
||||||
pub async fn kill_process(&self, run_id: i64) -> Result<bool, String> {
|
pub async fn kill_process(&self, run_id: i64) -> Result<bool, String> {
|
||||||
use log::{info, warn, error};
|
use log::{error, info, warn};
|
||||||
|
|
||||||
// First check if the process exists and get its PID
|
// First check if the process exists and get its PID
|
||||||
let (pid, child_arc) = {
|
let (pid, child_arc) = {
|
||||||
let processes = self.processes.lock().map_err(|e| e.to_string())?;
|
let processes = self.processes.lock().map_err(|e| e.to_string())?;
|
||||||
@@ -107,9 +110,12 @@ impl ProcessRegistry {
|
|||||||
return Ok(false); // Process not found
|
return Ok(false); // Process not found
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
info!("Attempting graceful shutdown of process {} (PID: {})", run_id, pid);
|
info!(
|
||||||
|
"Attempting graceful shutdown of process {} (PID: {})",
|
||||||
|
run_id, pid
|
||||||
|
);
|
||||||
|
|
||||||
// Send kill signal to the process
|
// Send kill signal to the process
|
||||||
let kill_sent = {
|
let kill_sent = {
|
||||||
let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?;
|
let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?;
|
||||||
@@ -128,52 +134,50 @@ impl ProcessRegistry {
|
|||||||
false // Process already killed
|
false // Process already killed
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if !kill_sent {
|
if !kill_sent {
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for the process to exit (with timeout)
|
// Wait for the process to exit (with timeout)
|
||||||
let wait_result = tokio::time::timeout(
|
let wait_result = tokio::time::timeout(tokio::time::Duration::from_secs(5), async {
|
||||||
tokio::time::Duration::from_secs(5),
|
loop {
|
||||||
async {
|
// Check if process has exited
|
||||||
loop {
|
let status = {
|
||||||
// Check if process has exited
|
let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?;
|
||||||
let status = {
|
if let Some(child) = child_guard.as_mut() {
|
||||||
let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?;
|
match child.try_wait() {
|
||||||
if let Some(child) = child_guard.as_mut() {
|
Ok(Some(status)) => {
|
||||||
match child.try_wait() {
|
info!("Process {} exited with status: {:?}", run_id, status);
|
||||||
Ok(Some(status)) => {
|
*child_guard = None; // Clear the child handle
|
||||||
info!("Process {} exited with status: {:?}", run_id, status);
|
Some(Ok::<(), String>(()))
|
||||||
*child_guard = None; // Clear the child handle
|
}
|
||||||
Some(Ok::<(), String>(()))
|
Ok(None) => {
|
||||||
}
|
// Still running
|
||||||
Ok(None) => {
|
None
|
||||||
// Still running
|
}
|
||||||
None
|
Err(e) => {
|
||||||
}
|
error!("Error checking process status: {}", e);
|
||||||
Err(e) => {
|
Some(Err(e.to_string()))
|
||||||
error!("Error checking process status: {}", e);
|
|
||||||
Some(Err(e.to_string()))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Process already gone
|
|
||||||
Some(Ok(()))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
match status {
|
|
||||||
Some(result) => return result,
|
|
||||||
None => {
|
|
||||||
// Still running, wait a bit
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// Process already gone
|
||||||
|
Some(Ok(()))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match status {
|
||||||
|
Some(result) => return result,
|
||||||
|
None => {
|
||||||
|
// Still running, wait a bit
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
).await;
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
match wait_result {
|
match wait_result {
|
||||||
Ok(Ok(_)) => {
|
Ok(Ok(_)) => {
|
||||||
info!("Process {} exited gracefully", run_id);
|
info!("Process {} exited gracefully", run_id);
|
||||||
@@ -189,19 +193,19 @@ impl ProcessRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove from registry after killing
|
// Remove from registry after killing
|
||||||
self.unregister_process(run_id)?;
|
self.unregister_process(run_id)?;
|
||||||
|
|
||||||
Ok(true)
|
Ok(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Kill a process by PID using system commands (fallback method)
|
/// Kill a process by PID using system commands (fallback method)
|
||||||
pub fn kill_process_by_pid(&self, run_id: i64, pid: u32) -> Result<bool, String> {
|
pub fn kill_process_by_pid(&self, run_id: i64, pid: u32) -> Result<bool, String> {
|
||||||
use log::{info, warn, error};
|
use log::{error, info, warn};
|
||||||
|
|
||||||
info!("Attempting to kill process {} by PID {}", run_id, pid);
|
info!("Attempting to kill process {} by PID {}", run_id, pid);
|
||||||
|
|
||||||
let kill_result = if cfg!(target_os = "windows") {
|
let kill_result = if cfg!(target_os = "windows") {
|
||||||
std::process::Command::new("taskkill")
|
std::process::Command::new("taskkill")
|
||||||
.args(["/F", "/PID", &pid.to_string()])
|
.args(["/F", "/PID", &pid.to_string()])
|
||||||
@@ -211,22 +215,25 @@ impl ProcessRegistry {
|
|||||||
let term_result = std::process::Command::new("kill")
|
let term_result = std::process::Command::new("kill")
|
||||||
.args(["-TERM", &pid.to_string()])
|
.args(["-TERM", &pid.to_string()])
|
||||||
.output();
|
.output();
|
||||||
|
|
||||||
match &term_result {
|
match &term_result {
|
||||||
Ok(output) if output.status.success() => {
|
Ok(output) if output.status.success() => {
|
||||||
info!("Sent SIGTERM to PID {}", pid);
|
info!("Sent SIGTERM to PID {}", pid);
|
||||||
// Give it 2 seconds to exit gracefully
|
// Give it 2 seconds to exit gracefully
|
||||||
std::thread::sleep(std::time::Duration::from_secs(2));
|
std::thread::sleep(std::time::Duration::from_secs(2));
|
||||||
|
|
||||||
// Check if still running
|
// Check if still running
|
||||||
let check_result = std::process::Command::new("kill")
|
let check_result = std::process::Command::new("kill")
|
||||||
.args(["-0", &pid.to_string()])
|
.args(["-0", &pid.to_string()])
|
||||||
.output();
|
.output();
|
||||||
|
|
||||||
if let Ok(output) = check_result {
|
if let Ok(output) = check_result {
|
||||||
if output.status.success() {
|
if output.status.success() {
|
||||||
// Still running, send SIGKILL
|
// Still running, send SIGKILL
|
||||||
warn!("Process {} still running after SIGTERM, sending SIGKILL", pid);
|
warn!(
|
||||||
|
"Process {} still running after SIGTERM, sending SIGKILL",
|
||||||
|
pid
|
||||||
|
);
|
||||||
std::process::Command::new("kill")
|
std::process::Command::new("kill")
|
||||||
.args(["-KILL", &pid.to_string()])
|
.args(["-KILL", &pid.to_string()])
|
||||||
.output()
|
.output()
|
||||||
@@ -246,7 +253,7 @@ impl ProcessRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match kill_result {
|
match kill_result {
|
||||||
Ok(output) => {
|
Ok(output) => {
|
||||||
if output.status.success() {
|
if output.status.success() {
|
||||||
@@ -271,11 +278,11 @@ impl ProcessRegistry {
|
|||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub async fn is_process_running(&self, run_id: i64) -> Result<bool, String> {
|
pub async fn is_process_running(&self, run_id: i64) -> Result<bool, String> {
|
||||||
let processes = self.processes.lock().map_err(|e| e.to_string())?;
|
let processes = self.processes.lock().map_err(|e| e.to_string())?;
|
||||||
|
|
||||||
if let Some(handle) = processes.get(&run_id) {
|
if let Some(handle) = processes.get(&run_id) {
|
||||||
let child_arc = handle.child.clone();
|
let child_arc = handle.child.clone();
|
||||||
drop(processes); // Release the lock before async operation
|
drop(processes); // Release the lock before async operation
|
||||||
|
|
||||||
let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?;
|
let mut child_guard = child_arc.lock().map_err(|e| e.to_string())?;
|
||||||
if let Some(ref mut child) = child_guard.as_mut() {
|
if let Some(ref mut child) = child_guard.as_mut() {
|
||||||
match child.try_wait() {
|
match child.try_wait() {
|
||||||
@@ -329,20 +336,20 @@ impl ProcessRegistry {
|
|||||||
pub async fn cleanup_finished_processes(&self) -> Result<Vec<i64>, String> {
|
pub async fn cleanup_finished_processes(&self) -> Result<Vec<i64>, String> {
|
||||||
let mut finished_runs = Vec::new();
|
let mut finished_runs = Vec::new();
|
||||||
let processes_lock = self.processes.clone();
|
let processes_lock = self.processes.clone();
|
||||||
|
|
||||||
// First, identify finished processes
|
// First, identify finished processes
|
||||||
{
|
{
|
||||||
let processes = processes_lock.lock().map_err(|e| e.to_string())?;
|
let processes = processes_lock.lock().map_err(|e| e.to_string())?;
|
||||||
let run_ids: Vec<i64> = processes.keys().cloned().collect();
|
let run_ids: Vec<i64> = processes.keys().cloned().collect();
|
||||||
drop(processes);
|
drop(processes);
|
||||||
|
|
||||||
for run_id in run_ids {
|
for run_id in run_ids {
|
||||||
if !self.is_process_running(run_id).await? {
|
if !self.is_process_running(run_id).await? {
|
||||||
finished_runs.push(run_id);
|
finished_runs.push(run_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then remove them from the registry
|
// Then remove them from the registry
|
||||||
{
|
{
|
||||||
let mut processes = processes_lock.lock().map_err(|e| e.to_string())?;
|
let mut processes = processes_lock.lock().map_err(|e| e.to_string())?;
|
||||||
@@ -350,7 +357,7 @@ impl ProcessRegistry {
|
|||||||
processes.remove(run_id);
|
processes.remove(run_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(finished_runs)
|
Ok(finished_runs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -368,4 +375,4 @@ impl Default for ProcessRegistryState {
|
|||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self(Arc::new(ProcessRegistry::new()))
|
Self(Arc::new(ProcessRegistry::new()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,26 +4,24 @@ use rusqlite::{params, Connection, Result};
|
|||||||
/// Create default sandbox profiles for initial setup
|
/// Create default sandbox profiles for initial setup
|
||||||
pub fn create_default_profiles(conn: &Connection) -> Result<()> {
|
pub fn create_default_profiles(conn: &Connection) -> Result<()> {
|
||||||
// Check if we already have profiles
|
// Check if we already have profiles
|
||||||
let count: i64 = conn.query_row(
|
let count: i64 = conn.query_row("SELECT COUNT(*) FROM sandbox_profiles", [], |row| {
|
||||||
"SELECT COUNT(*) FROM sandbox_profiles",
|
row.get(0)
|
||||||
[],
|
})?;
|
||||||
|row| row.get(0),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
if count > 0 {
|
if count > 0 {
|
||||||
// Already have profiles, don't create defaults
|
// Already have profiles, don't create defaults
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create Standard Profile
|
// Create Standard Profile
|
||||||
create_standard_profile(conn)?;
|
create_standard_profile(conn)?;
|
||||||
|
|
||||||
// Create Minimal Profile
|
// Create Minimal Profile
|
||||||
create_minimal_profile(conn)?;
|
create_minimal_profile(conn)?;
|
||||||
|
|
||||||
// Create Development Profile
|
// Create Development Profile
|
||||||
create_development_profile(conn)?;
|
create_development_profile(conn)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,22 +36,57 @@ fn create_standard_profile(conn: &Connection) -> Result<()> {
|
|||||||
true // Set as default
|
true // Set as default
|
||||||
],
|
],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let profile_id = conn.last_insert_rowid();
|
let profile_id = conn.last_insert_rowid();
|
||||||
|
|
||||||
// Add rules
|
// Add rules
|
||||||
let rules = vec![
|
let rules = vec![
|
||||||
// File access
|
// File access
|
||||||
("file_read_all", "subpath", "{{PROJECT_PATH}}", true, Some(r#"["linux", "macos"]"#)),
|
(
|
||||||
("file_read_all", "subpath", "/usr/lib", true, Some(r#"["linux", "macos"]"#)),
|
"file_read_all",
|
||||||
("file_read_all", "subpath", "/usr/local/lib", true, Some(r#"["linux", "macos"]"#)),
|
"subpath",
|
||||||
("file_read_all", "subpath", "/System/Library", true, Some(r#"["macos"]"#)),
|
"{{PROJECT_PATH}}",
|
||||||
("file_read_metadata", "subpath", "/", true, Some(r#"["macos"]"#)),
|
true,
|
||||||
|
Some(r#"["linux", "macos"]"#),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"/usr/lib",
|
||||||
|
true,
|
||||||
|
Some(r#"["linux", "macos"]"#),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"/usr/local/lib",
|
||||||
|
true,
|
||||||
|
Some(r#"["linux", "macos"]"#),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"/System/Library",
|
||||||
|
true,
|
||||||
|
Some(r#"["macos"]"#),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"file_read_metadata",
|
||||||
|
"subpath",
|
||||||
|
"/",
|
||||||
|
true,
|
||||||
|
Some(r#"["macos"]"#),
|
||||||
|
),
|
||||||
// Network access
|
// Network access
|
||||||
("network_outbound", "all", "", true, Some(r#"["linux", "macos"]"#)),
|
(
|
||||||
|
"network_outbound",
|
||||||
|
"all",
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
Some(r#"["linux", "macos"]"#),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
for (op_type, pattern_type, pattern_value, enabled, platforms) in rules {
|
for (op_type, pattern_type, pattern_value, enabled, platforms) in rules {
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO sandbox_rules (profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support)
|
"INSERT INTO sandbox_rules (profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support)
|
||||||
@@ -61,7 +94,7 @@ fn create_standard_profile(conn: &Connection) -> Result<()> {
|
|||||||
params![profile_id, op_type, pattern_type, pattern_value, enabled, platforms],
|
params![profile_id, op_type, pattern_type, pattern_value, enabled, platforms],
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,9 +109,9 @@ fn create_minimal_profile(conn: &Connection) -> Result<()> {
|
|||||||
false
|
false
|
||||||
],
|
],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let profile_id = conn.last_insert_rowid();
|
let profile_id = conn.last_insert_rowid();
|
||||||
|
|
||||||
// Add minimal rules - only project access
|
// Add minimal rules - only project access
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO sandbox_rules (profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support)
|
"INSERT INTO sandbox_rules (profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support)
|
||||||
@@ -92,7 +125,7 @@ fn create_minimal_profile(conn: &Connection) -> Result<()> {
|
|||||||
Some(r#"["linux", "macos", "windows"]"#)
|
Some(r#"["linux", "macos", "windows"]"#)
|
||||||
],
|
],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,26 +140,66 @@ fn create_development_profile(conn: &Connection) -> Result<()> {
|
|||||||
false
|
false
|
||||||
],
|
],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let profile_id = conn.last_insert_rowid();
|
let profile_id = conn.last_insert_rowid();
|
||||||
|
|
||||||
// Add development rules
|
// Add development rules
|
||||||
let rules = vec![
|
let rules = vec![
|
||||||
// Broad file access
|
// Broad file access
|
||||||
("file_read_all", "subpath", "{{PROJECT_PATH}}", true, Some(r#"["linux", "macos"]"#)),
|
(
|
||||||
("file_read_all", "subpath", "{{HOME}}", true, Some(r#"["linux", "macos"]"#)),
|
"file_read_all",
|
||||||
("file_read_all", "subpath", "/usr", true, Some(r#"["linux", "macos"]"#)),
|
"subpath",
|
||||||
("file_read_all", "subpath", "/opt", true, Some(r#"["linux", "macos"]"#)),
|
"{{PROJECT_PATH}}",
|
||||||
("file_read_all", "subpath", "/Applications", true, Some(r#"["macos"]"#)),
|
true,
|
||||||
("file_read_metadata", "subpath", "/", true, Some(r#"["macos"]"#)),
|
Some(r#"["linux", "macos"]"#),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"{{HOME}}",
|
||||||
|
true,
|
||||||
|
Some(r#"["linux", "macos"]"#),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"/usr",
|
||||||
|
true,
|
||||||
|
Some(r#"["linux", "macos"]"#),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"/opt",
|
||||||
|
true,
|
||||||
|
Some(r#"["linux", "macos"]"#),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"/Applications",
|
||||||
|
true,
|
||||||
|
Some(r#"["macos"]"#),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"file_read_metadata",
|
||||||
|
"subpath",
|
||||||
|
"/",
|
||||||
|
true,
|
||||||
|
Some(r#"["macos"]"#),
|
||||||
|
),
|
||||||
// Network access
|
// Network access
|
||||||
("network_outbound", "all", "", true, Some(r#"["linux", "macos"]"#)),
|
(
|
||||||
|
"network_outbound",
|
||||||
|
"all",
|
||||||
|
"",
|
||||||
|
true,
|
||||||
|
Some(r#"["linux", "macos"]"#),
|
||||||
|
),
|
||||||
// System info (macOS only)
|
// System info (macOS only)
|
||||||
("system_info_read", "all", "", true, Some(r#"["macos"]"#)),
|
("system_info_read", "all", "", true, Some(r#"["macos"]"#)),
|
||||||
];
|
];
|
||||||
|
|
||||||
for (op_type, pattern_type, pattern_value, enabled, platforms) in rules {
|
for (op_type, pattern_type, pattern_value, enabled, platforms) in rules {
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO sandbox_rules (profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support)
|
"INSERT INTO sandbox_rules (profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support)
|
||||||
@@ -134,6 +207,6 @@ fn create_development_profile(conn: &Connection) -> Result<()> {
|
|||||||
params![profile_id, op_type, pattern_type, pattern_value, enabled, platforms],
|
params![profile_id, op_type, pattern_type, pattern_value, enabled, platforms],
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use gaol::sandbox::{ChildSandbox, ChildSandboxMethods, Command as GaolCommand, Sandbox, SandboxMethods};
|
use gaol::sandbox::{
|
||||||
use log::{info, warn, error, debug};
|
ChildSandbox, ChildSandboxMethods, Command as GaolCommand, Sandbox, SandboxMethods,
|
||||||
|
};
|
||||||
|
use log::{debug, error, info, warn};
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::process::Stdio;
|
use std::process::Stdio;
|
||||||
@@ -25,12 +27,12 @@ impl SandboxExecutor {
|
|||||||
serialized_profile: None,
|
serialized_profile: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new sandbox executor with serialized profile for child process communication
|
/// Create a new sandbox executor with serialized profile for child process communication
|
||||||
pub fn new_with_serialization(
|
pub fn new_with_serialization(
|
||||||
profile: gaol::profile::Profile,
|
profile: gaol::profile::Profile,
|
||||||
project_path: PathBuf,
|
project_path: PathBuf,
|
||||||
serialized_profile: SerializedProfile
|
serialized_profile: SerializedProfile,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
profile,
|
profile,
|
||||||
@@ -41,15 +43,23 @@ impl SandboxExecutor {
|
|||||||
|
|
||||||
/// Execute a command in the sandbox (for the parent process)
|
/// Execute a command in the sandbox (for the parent process)
|
||||||
/// This is used when we need to spawn a child process with sandbox
|
/// This is used when we need to spawn a child process with sandbox
|
||||||
pub fn execute_sandboxed_spawn(&self, command: &str, args: &[&str], cwd: &Path) -> Result<std::process::Child> {
|
pub fn execute_sandboxed_spawn(
|
||||||
|
&self,
|
||||||
|
command: &str,
|
||||||
|
args: &[&str],
|
||||||
|
cwd: &Path,
|
||||||
|
) -> Result<std::process::Child> {
|
||||||
info!("Executing sandboxed command: {} {:?}", command, args);
|
info!("Executing sandboxed command: {} {:?}", command, args);
|
||||||
|
|
||||||
// On macOS, we need to check if the command is allowed by the system
|
// On macOS, we need to check if the command is allowed by the system
|
||||||
#[cfg(target_os = "macos")]
|
#[cfg(target_os = "macos")]
|
||||||
{
|
{
|
||||||
// For testing purposes, we'll skip actual sandboxing for simple commands like echo
|
// For testing purposes, we'll skip actual sandboxing for simple commands like echo
|
||||||
if command == "echo" || command == "/bin/echo" {
|
if command == "echo" || command == "/bin/echo" {
|
||||||
debug!("Using direct execution for simple test command: {}", command);
|
debug!(
|
||||||
|
"Using direct execution for simple test command: {}",
|
||||||
|
command
|
||||||
|
);
|
||||||
return std::process::Command::new(command)
|
return std::process::Command::new(command)
|
||||||
.args(args)
|
.args(args)
|
||||||
.current_dir(cwd)
|
.current_dir(cwd)
|
||||||
@@ -60,44 +70,55 @@ impl SandboxExecutor {
|
|||||||
.context("Failed to spawn test command");
|
.context("Failed to spawn test command");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the sandbox
|
// Create the sandbox
|
||||||
let sandbox = Sandbox::new(self.profile.clone());
|
let sandbox = Sandbox::new(self.profile.clone());
|
||||||
|
|
||||||
// Create the command
|
// Create the command
|
||||||
let mut gaol_command = GaolCommand::new(command);
|
let mut gaol_command = GaolCommand::new(command);
|
||||||
for arg in args {
|
for arg in args {
|
||||||
gaol_command.arg(arg);
|
gaol_command.arg(arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set environment variables
|
// Set environment variables
|
||||||
gaol_command.env("GAOL_CHILD_PROCESS", "1");
|
gaol_command.env("GAOL_CHILD_PROCESS", "1");
|
||||||
gaol_command.env("GAOL_SANDBOX_ACTIVE", "1");
|
gaol_command.env("GAOL_SANDBOX_ACTIVE", "1");
|
||||||
gaol_command.env("GAOL_PROJECT_PATH", self.project_path.to_string_lossy().as_ref());
|
gaol_command.env(
|
||||||
|
"GAOL_PROJECT_PATH",
|
||||||
|
self.project_path.to_string_lossy().as_ref(),
|
||||||
|
);
|
||||||
|
|
||||||
// Inherit specific parent environment variables that are safe
|
// Inherit specific parent environment variables that are safe
|
||||||
for (key, value) in env::vars() {
|
for (key, value) in env::vars() {
|
||||||
// Only pass through safe environment variables
|
// Only pass through safe environment variables
|
||||||
if key.starts_with("PATH") || key.starts_with("HOME") || key.starts_with("USER")
|
if key.starts_with("PATH")
|
||||||
|| key == "SHELL" || key == "LANG" || key == "LC_ALL" || key.starts_with("LC_") {
|
|| key.starts_with("HOME")
|
||||||
|
|| key.starts_with("USER")
|
||||||
|
|| key == "SHELL"
|
||||||
|
|| key == "LANG"
|
||||||
|
|| key == "LC_ALL"
|
||||||
|
|| key.starts_with("LC_")
|
||||||
|
{
|
||||||
gaol_command.env(&key, &value);
|
gaol_command.env(&key, &value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to start the sandboxed process using gaol
|
// Try to start the sandboxed process using gaol
|
||||||
match sandbox.start(&mut gaol_command) {
|
match sandbox.start(&mut gaol_command) {
|
||||||
Ok(process) => {
|
Ok(process) => {
|
||||||
debug!("Successfully started sandboxed process using gaol");
|
debug!("Successfully started sandboxed process using gaol");
|
||||||
// Unfortunately, gaol doesn't expose the underlying Child process
|
// Unfortunately, gaol doesn't expose the underlying Child process
|
||||||
// So we need to use a different approach for now
|
// So we need to use a different approach for now
|
||||||
|
|
||||||
// This is a limitation of the gaol library - we can't get the Child back
|
// This is a limitation of the gaol library - we can't get the Child back
|
||||||
// For now, we'll have to use the fallback approach
|
// For now, we'll have to use the fallback approach
|
||||||
warn!("Gaol started the process but we can't get the Child handle - using fallback");
|
warn!(
|
||||||
|
"Gaol started the process but we can't get the Child handle - using fallback"
|
||||||
|
);
|
||||||
|
|
||||||
// Drop the process to avoid zombie
|
// Drop the process to avoid zombie
|
||||||
drop(process);
|
drop(process);
|
||||||
|
|
||||||
// Fall through to fallback
|
// Fall through to fallback
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -105,10 +126,10 @@ impl SandboxExecutor {
|
|||||||
debug!("Gaol error details: {:?}", e);
|
debug!("Gaol error details: {:?}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback: Use regular process spawn with sandbox activation in child
|
// Fallback: Use regular process spawn with sandbox activation in child
|
||||||
info!("Using child-side sandbox activation as fallback");
|
info!("Using child-side sandbox activation as fallback");
|
||||||
|
|
||||||
// Serialize the sandbox rules for the child process
|
// Serialize the sandbox rules for the child process
|
||||||
let rules_json = if let Some(ref serialized) = self.serialized_profile {
|
let rules_json = if let Some(ref serialized) = self.serialized_profile {
|
||||||
serde_json::to_string(serialized)?
|
serde_json::to_string(serialized)?
|
||||||
@@ -116,50 +137,70 @@ impl SandboxExecutor {
|
|||||||
let serialized_rules = self.extract_sandbox_rules()?;
|
let serialized_rules = self.extract_sandbox_rules()?;
|
||||||
serde_json::to_string(&serialized_rules)?
|
serde_json::to_string(&serialized_rules)?
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut std_command = std::process::Command::new(command);
|
let mut std_command = std::process::Command::new(command);
|
||||||
std_command.args(args)
|
std_command
|
||||||
|
.args(args)
|
||||||
.current_dir(cwd)
|
.current_dir(cwd)
|
||||||
.env("GAOL_SANDBOX_ACTIVE", "1")
|
.env("GAOL_SANDBOX_ACTIVE", "1")
|
||||||
.env("GAOL_PROJECT_PATH", self.project_path.to_string_lossy().as_ref())
|
.env(
|
||||||
|
"GAOL_PROJECT_PATH",
|
||||||
|
self.project_path.to_string_lossy().as_ref(),
|
||||||
|
)
|
||||||
.env("GAOL_SANDBOX_RULES", rules_json)
|
.env("GAOL_SANDBOX_RULES", rules_json)
|
||||||
.stdin(Stdio::piped())
|
.stdin(Stdio::piped())
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped());
|
.stderr(Stdio::piped());
|
||||||
|
|
||||||
std_command.spawn()
|
std_command
|
||||||
|
.spawn()
|
||||||
.context("Failed to spawn process with sandbox environment")
|
.context("Failed to spawn process with sandbox environment")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Prepare a tokio Command for sandboxed execution
|
/// Prepare a tokio Command for sandboxed execution
|
||||||
/// The sandbox will be activated in the child process
|
/// The sandbox will be activated in the child process
|
||||||
pub fn prepare_sandboxed_command(&self, command: &str, args: &[&str], cwd: &Path) -> Command {
|
pub fn prepare_sandboxed_command(&self, command: &str, args: &[&str], cwd: &Path) -> Command {
|
||||||
info!("Preparing sandboxed command: {} {:?}", command, args);
|
info!("Preparing sandboxed command: {} {:?}", command, args);
|
||||||
|
|
||||||
let mut cmd = Command::new(command);
|
let mut cmd = Command::new(command);
|
||||||
cmd.args(args)
|
cmd.args(args).current_dir(cwd);
|
||||||
.current_dir(cwd);
|
|
||||||
|
|
||||||
// Inherit essential environment variables from parent process
|
// Inherit essential environment variables from parent process
|
||||||
// This is crucial for commands like Claude that need to find Node.js
|
// This is crucial for commands like Claude that need to find Node.js
|
||||||
for (key, value) in env::vars() {
|
for (key, value) in env::vars() {
|
||||||
// Pass through PATH and other essential environment variables
|
// Pass through PATH and other essential environment variables
|
||||||
if key == "PATH" || key == "HOME" || key == "USER"
|
if key == "PATH"
|
||||||
|| key == "SHELL" || key == "LANG" || key == "LC_ALL" || key.starts_with("LC_")
|
|| key == "HOME"
|
||||||
|| key == "NODE_PATH" || key == "NVM_DIR" || key == "NVM_BIN" {
|
|| key == "USER"
|
||||||
|
|| key == "SHELL"
|
||||||
|
|| key == "LANG"
|
||||||
|
|| key == "LC_ALL"
|
||||||
|
|| key.starts_with("LC_")
|
||||||
|
|| key == "NODE_PATH"
|
||||||
|
|| key == "NVM_DIR"
|
||||||
|
|| key == "NVM_BIN"
|
||||||
|
{
|
||||||
debug!("Inheriting env var: {}={}", key, value);
|
debug!("Inheriting env var: {}={}", key, value);
|
||||||
cmd.env(&key, &value);
|
cmd.env(&key, &value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialize the sandbox rules for the child process
|
// Serialize the sandbox rules for the child process
|
||||||
let rules_json = if let Some(ref serialized) = self.serialized_profile {
|
let rules_json = if let Some(ref serialized) = self.serialized_profile {
|
||||||
let json = serde_json::to_string(serialized).ok();
|
let json = serde_json::to_string(serialized).ok();
|
||||||
info!("🔧 Using serialized sandbox profile with {} operations", serialized.operations.len());
|
info!(
|
||||||
|
"🔧 Using serialized sandbox profile with {} operations",
|
||||||
|
serialized.operations.len()
|
||||||
|
);
|
||||||
for (i, op) in serialized.operations.iter().enumerate() {
|
for (i, op) in serialized.operations.iter().enumerate() {
|
||||||
match op {
|
match op {
|
||||||
SerializedOperation::FileReadAll { path, is_subpath } => {
|
SerializedOperation::FileReadAll { path, is_subpath } => {
|
||||||
info!(" Rule {}: FileReadAll {} (subpath: {})", i, path.display(), is_subpath);
|
info!(
|
||||||
|
" Rule {}: FileReadAll {} (subpath: {})",
|
||||||
|
i,
|
||||||
|
path.display(),
|
||||||
|
is_subpath
|
||||||
|
);
|
||||||
}
|
}
|
||||||
SerializedOperation::NetworkOutbound { pattern } => {
|
SerializedOperation::NetworkOutbound { pattern } => {
|
||||||
info!(" Rule {}: NetworkOutbound {}", i, pattern);
|
info!(" Rule {}: NetworkOutbound {}", i, pattern);
|
||||||
@@ -179,7 +220,7 @@ impl SandboxExecutor {
|
|||||||
.ok()
|
.ok()
|
||||||
.and_then(|r| serde_json::to_string(&r).ok())
|
.and_then(|r| serde_json::to_string(&r).ok())
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(json) = rules_json {
|
if let Some(json) = rules_json {
|
||||||
// TEMPORARILY DISABLED: Claude Code might not understand these env vars and could hang
|
// TEMPORARILY DISABLED: Claude Code might not understand these env vars and could hang
|
||||||
// cmd.env("GAOL_SANDBOX_ACTIVE", "1");
|
// cmd.env("GAOL_SANDBOX_ACTIVE", "1");
|
||||||
@@ -188,19 +229,22 @@ impl SandboxExecutor {
|
|||||||
warn!("🚨 TEMPORARILY DISABLED sandbox environment variables for debugging");
|
warn!("🚨 TEMPORARILY DISABLED sandbox environment variables for debugging");
|
||||||
info!("🔧 Would have set sandbox environment variables for child process");
|
info!("🔧 Would have set sandbox environment variables for child process");
|
||||||
info!(" GAOL_SANDBOX_ACTIVE=1 (disabled)");
|
info!(" GAOL_SANDBOX_ACTIVE=1 (disabled)");
|
||||||
info!(" GAOL_PROJECT_PATH={} (disabled)", self.project_path.display());
|
info!(
|
||||||
|
" GAOL_PROJECT_PATH={} (disabled)",
|
||||||
|
self.project_path.display()
|
||||||
|
);
|
||||||
info!(" GAOL_SANDBOX_RULES={} chars (disabled)", json.len());
|
info!(" GAOL_SANDBOX_RULES={} chars (disabled)", json.len());
|
||||||
} else {
|
} else {
|
||||||
warn!("🚨 Failed to serialize sandbox rules - running without sandbox!");
|
warn!("🚨 Failed to serialize sandbox rules - running without sandbox!");
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.stdin(Stdio::null()) // Don't pipe stdin - we have no input to send
|
cmd.stdin(Stdio::null()) // Don't pipe stdin - we have no input to send
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped());
|
.stderr(Stdio::piped());
|
||||||
|
|
||||||
cmd
|
cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract sandbox rules from the profile
|
/// Extract sandbox rules from the profile
|
||||||
/// This is a workaround since gaol doesn't expose the operations
|
/// This is a workaround since gaol doesn't expose the operations
|
||||||
fn extract_sandbox_rules(&self) -> Result<SerializedProfile> {
|
fn extract_sandbox_rules(&self) -> Result<SerializedProfile> {
|
||||||
@@ -208,18 +252,18 @@ impl SandboxExecutor {
|
|||||||
// For now, return a default set based on what we know
|
// For now, return a default set based on what we know
|
||||||
// This should be improved by tracking rules during profile creation
|
// This should be improved by tracking rules during profile creation
|
||||||
let operations = vec![
|
let operations = vec![
|
||||||
SerializedOperation::FileReadAll {
|
SerializedOperation::FileReadAll {
|
||||||
path: self.project_path.clone(),
|
path: self.project_path.clone(),
|
||||||
is_subpath: true
|
is_subpath: true,
|
||||||
},
|
},
|
||||||
SerializedOperation::NetworkOutbound {
|
SerializedOperation::NetworkOutbound {
|
||||||
pattern: "all".to_string()
|
pattern: "all".to_string(),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
Ok(SerializedProfile { operations })
|
Ok(SerializedProfile { operations })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Activate sandbox in the current process (for child processes)
|
/// Activate sandbox in the current process (for child processes)
|
||||||
/// This should be called early in the child process
|
/// This should be called early in the child process
|
||||||
pub fn activate_sandbox_in_child() -> Result<()> {
|
pub fn activate_sandbox_in_child() -> Result<()> {
|
||||||
@@ -227,21 +271,23 @@ impl SandboxExecutor {
|
|||||||
if !should_activate_sandbox() {
|
if !should_activate_sandbox() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("Activating sandbox in child process");
|
info!("Activating sandbox in child process");
|
||||||
|
|
||||||
// Get project path
|
// Get project path
|
||||||
let project_path = env::var("GAOL_PROJECT_PATH")
|
let project_path = env::var("GAOL_PROJECT_PATH").context("GAOL_PROJECT_PATH not set")?;
|
||||||
.context("GAOL_PROJECT_PATH not set")?;
|
|
||||||
let project_path = PathBuf::from(project_path);
|
let project_path = PathBuf::from(project_path);
|
||||||
|
|
||||||
// Try to deserialize the sandbox rules from environment
|
// Try to deserialize the sandbox rules from environment
|
||||||
let profile = if let Ok(rules_json) = env::var("GAOL_SANDBOX_RULES") {
|
let profile = if let Ok(rules_json) = env::var("GAOL_SANDBOX_RULES") {
|
||||||
match serde_json::from_str::<SerializedProfile>(&rules_json) {
|
match serde_json::from_str::<SerializedProfile>(&rules_json) {
|
||||||
Ok(serialized) => {
|
Ok(serialized) => {
|
||||||
debug!("Deserializing {} sandbox rules", serialized.operations.len());
|
debug!(
|
||||||
|
"Deserializing {} sandbox rules",
|
||||||
|
serialized.operations.len()
|
||||||
|
);
|
||||||
deserialize_profile(serialized, &project_path)?
|
deserialize_profile(serialized, &project_path)?
|
||||||
},
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Failed to deserialize sandbox rules: {}", e);
|
warn!("Failed to deserialize sandbox rules: {}", e);
|
||||||
// Fallback to minimal profile
|
// Fallback to minimal profile
|
||||||
@@ -253,10 +299,10 @@ impl SandboxExecutor {
|
|||||||
// Fallback to minimal profile
|
// Fallback to minimal profile
|
||||||
create_minimal_profile(project_path)?
|
create_minimal_profile(project_path)?
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create and activate the child sandbox
|
// Create and activate the child sandbox
|
||||||
let sandbox = ChildSandbox::new(profile);
|
let sandbox = ChildSandbox::new(profile);
|
||||||
|
|
||||||
match sandbox.activate() {
|
match sandbox.activate() {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
info!("Sandbox activated successfully");
|
info!("Sandbox activated successfully");
|
||||||
@@ -280,12 +326,12 @@ impl SandboxExecutor {
|
|||||||
serialized_profile: None,
|
serialized_profile: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new sandbox executor with serialized profile (no-op on Windows)
|
/// Create a new sandbox executor with serialized profile (no-op on Windows)
|
||||||
pub fn new_with_serialization(
|
pub fn new_with_serialization(
|
||||||
_profile: (),
|
_profile: (),
|
||||||
project_path: PathBuf,
|
project_path: PathBuf,
|
||||||
serialized_profile: SerializedProfile
|
serialized_profile: SerializedProfile,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
project_path,
|
project_path,
|
||||||
@@ -294,9 +340,17 @@ impl SandboxExecutor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Execute a command in the sandbox (Windows - no sandboxing)
|
/// Execute a command in the sandbox (Windows - no sandboxing)
|
||||||
pub fn execute_sandboxed_spawn(&self, command: &str, args: &[&str], cwd: &Path) -> Result<std::process::Child> {
|
pub fn execute_sandboxed_spawn(
|
||||||
info!("Executing command without sandbox on Windows: {} {:?}", command, args);
|
&self,
|
||||||
|
command: &str,
|
||||||
|
args: &[&str],
|
||||||
|
cwd: &Path,
|
||||||
|
) -> Result<std::process::Child> {
|
||||||
|
info!(
|
||||||
|
"Executing command without sandbox on Windows: {} {:?}",
|
||||||
|
command, args
|
||||||
|
);
|
||||||
|
|
||||||
std::process::Command::new(command)
|
std::process::Command::new(command)
|
||||||
.args(args)
|
.args(args)
|
||||||
.current_dir(cwd)
|
.current_dir(cwd)
|
||||||
@@ -309,23 +363,26 @@ impl SandboxExecutor {
|
|||||||
|
|
||||||
/// Prepare a sandboxed tokio Command (Windows - no sandboxing)
|
/// Prepare a sandboxed tokio Command (Windows - no sandboxing)
|
||||||
pub fn prepare_sandboxed_command(&self, command: &str, args: &[&str], cwd: &Path) -> Command {
|
pub fn prepare_sandboxed_command(&self, command: &str, args: &[&str], cwd: &Path) -> Command {
|
||||||
info!("Preparing command without sandbox on Windows: {} {:?}", command, args);
|
info!(
|
||||||
|
"Preparing command without sandbox on Windows: {} {:?}",
|
||||||
|
command, args
|
||||||
|
);
|
||||||
|
|
||||||
let mut cmd = Command::new(command);
|
let mut cmd = Command::new(command);
|
||||||
cmd.args(args)
|
cmd.args(args)
|
||||||
.current_dir(cwd)
|
.current_dir(cwd)
|
||||||
.stdin(Stdio::null())
|
.stdin(Stdio::null())
|
||||||
.stdout(Stdio::piped())
|
.stdout(Stdio::piped())
|
||||||
.stderr(Stdio::piped());
|
.stderr(Stdio::piped());
|
||||||
|
|
||||||
cmd
|
cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract sandbox rules (no-op on Windows)
|
/// Extract sandbox rules (no-op on Windows)
|
||||||
fn extract_sandbox_rules(&self) -> Result<SerializedProfile> {
|
fn extract_sandbox_rules(&self) -> Result<SerializedProfile> {
|
||||||
Ok(SerializedProfile { operations: vec![] })
|
Ok(SerializedProfile { operations: vec![] })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Activate sandbox in child process (no-op on Windows)
|
/// Activate sandbox in child process (no-op on Windows)
|
||||||
pub fn activate_sandbox_in_child() -> Result<()> {
|
pub fn activate_sandbox_in_child() -> Result<()> {
|
||||||
debug!("Sandbox activation skipped on Windows");
|
debug!("Sandbox activation skipped on Windows");
|
||||||
@@ -341,11 +398,11 @@ pub fn should_activate_sandbox() -> bool {
|
|||||||
/// Helper to create a sandboxed tokio Command
|
/// Helper to create a sandboxed tokio Command
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
pub fn create_sandboxed_command(
|
pub fn create_sandboxed_command(
|
||||||
command: &str,
|
command: &str,
|
||||||
args: &[&str],
|
args: &[&str],
|
||||||
cwd: &Path,
|
cwd: &Path,
|
||||||
profile: gaol::profile::Profile,
|
profile: gaol::profile::Profile,
|
||||||
project_path: PathBuf
|
project_path: PathBuf,
|
||||||
) -> Command {
|
) -> Command {
|
||||||
let executor = SandboxExecutor::new(profile, project_path);
|
let executor = SandboxExecutor::new(profile, project_path);
|
||||||
executor.prepare_sandboxed_command(command, args, cwd)
|
executor.prepare_sandboxed_command(command, args, cwd)
|
||||||
@@ -368,9 +425,12 @@ pub enum SerializedOperation {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn deserialize_profile(serialized: SerializedProfile, project_path: &Path) -> Result<gaol::profile::Profile> {
|
fn deserialize_profile(
|
||||||
|
serialized: SerializedProfile,
|
||||||
|
project_path: &Path,
|
||||||
|
) -> Result<gaol::profile::Profile> {
|
||||||
let mut operations = Vec::new();
|
let mut operations = Vec::new();
|
||||||
|
|
||||||
for op in serialized.operations {
|
for op in serialized.operations {
|
||||||
match op {
|
match op {
|
||||||
SerializedOperation::FileReadAll { path, is_subpath } => {
|
SerializedOperation::FileReadAll { path, is_subpath } => {
|
||||||
@@ -401,12 +461,12 @@ fn deserialize_profile(serialized: SerializedProfile, project_path: &Path) -> Re
|
|||||||
}
|
}
|
||||||
SerializedOperation::NetworkTcp { port } => {
|
SerializedOperation::NetworkTcp { port } => {
|
||||||
operations.push(gaol::profile::Operation::NetworkOutbound(
|
operations.push(gaol::profile::Operation::NetworkOutbound(
|
||||||
gaol::profile::AddressPattern::Tcp(port)
|
gaol::profile::AddressPattern::Tcp(port),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
SerializedOperation::NetworkLocalSocket { path } => {
|
SerializedOperation::NetworkLocalSocket { path } => {
|
||||||
operations.push(gaol::profile::Operation::NetworkOutbound(
|
operations.push(gaol::profile::Operation::NetworkOutbound(
|
||||||
gaol::profile::AddressPattern::LocalSocket(path)
|
gaol::profile::AddressPattern::LocalSocket(path),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
SerializedOperation::SystemInfoRead => {
|
SerializedOperation::SystemInfoRead => {
|
||||||
@@ -414,40 +474,38 @@ fn deserialize_profile(serialized: SerializedProfile, project_path: &Path) -> Re
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always ensure project path access
|
// Always ensure project path access
|
||||||
let has_project_access = operations.iter().any(|op| {
|
let has_project_access = operations.iter().any(|op| {
|
||||||
matches!(op, gaol::profile::Operation::FileReadAll(gaol::profile::PathPattern::Subpath(p)) if p == project_path)
|
matches!(op, gaol::profile::Operation::FileReadAll(gaol::profile::PathPattern::Subpath(p)) if p == project_path)
|
||||||
});
|
});
|
||||||
|
|
||||||
if !has_project_access {
|
if !has_project_access {
|
||||||
operations.push(gaol::profile::Operation::FileReadAll(
|
operations.push(gaol::profile::Operation::FileReadAll(
|
||||||
gaol::profile::PathPattern::Subpath(project_path.to_path_buf())
|
gaol::profile::PathPattern::Subpath(project_path.to_path_buf()),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let op_count = operations.len();
|
let op_count = operations.len();
|
||||||
gaol::profile::Profile::new(operations)
|
gaol::profile::Profile::new(operations).map_err(|e| {
|
||||||
.map_err(|e| {
|
error!("Failed to create profile: {:?}", e);
|
||||||
error!("Failed to create profile: {:?}", e);
|
anyhow::anyhow!(
|
||||||
anyhow::anyhow!("Failed to create profile from {} operations: {:?}", op_count, e)
|
"Failed to create profile from {} operations: {:?}",
|
||||||
})
|
op_count,
|
||||||
|
e
|
||||||
|
)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn create_minimal_profile(project_path: PathBuf) -> Result<gaol::profile::Profile> {
|
fn create_minimal_profile(project_path: PathBuf) -> Result<gaol::profile::Profile> {
|
||||||
let operations = vec![
|
let operations = vec![
|
||||||
gaol::profile::Operation::FileReadAll(
|
gaol::profile::Operation::FileReadAll(gaol::profile::PathPattern::Subpath(project_path)),
|
||||||
gaol::profile::PathPattern::Subpath(project_path)
|
gaol::profile::Operation::NetworkOutbound(gaol::profile::AddressPattern::All),
|
||||||
),
|
|
||||||
gaol::profile::Operation::NetworkOutbound(
|
|
||||||
gaol::profile::AddressPattern::All
|
|
||||||
),
|
|
||||||
];
|
];
|
||||||
|
|
||||||
gaol::profile::Profile::new(operations)
|
gaol::profile::Profile::new(operations).map_err(|e| {
|
||||||
.map_err(|e| {
|
error!("Failed to create minimal profile: {:?}", e);
|
||||||
error!("Failed to create minimal profile: {:?}", e);
|
anyhow::anyhow!("Failed to create minimal sandbox profile: {:?}", e)
|
||||||
anyhow::anyhow!("Failed to create minimal sandbox profile: {:?}", e)
|
})
|
||||||
})
|
}
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub mod profile;
|
pub mod defaults;
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub mod executor;
|
pub mod executor;
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub mod platform;
|
pub mod platform;
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub mod defaults;
|
pub mod profile;
|
||||||
|
|
||||||
// These are used in agents.rs and claude.rs via direct module paths
|
// These are used in agents.rs and claude.rs via direct module paths
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub use profile::{SandboxProfile, SandboxRule, ProfileBuilder};
|
pub use profile::{ProfileBuilder, SandboxProfile, SandboxRule};
|
||||||
// These are used in main.rs and sandbox.rs
|
// These are used in main.rs and sandbox.rs
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub use executor::{SandboxExecutor, should_activate_sandbox};
|
pub use executor::{should_activate_sandbox, SandboxExecutor};
|
||||||
// These are used in sandbox.rs
|
// These are used in sandbox.rs
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub use platform::{PlatformCapabilities, get_platform_capabilities};
|
pub use platform::{get_platform_capabilities, PlatformCapabilities};
|
||||||
// Used for initial setup
|
// Used for initial setup
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub use defaults::create_default_profiles;
|
pub use defaults::create_default_profiles;
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ pub struct OperationSupport {
|
|||||||
/// Get the platform capabilities for sandboxing
|
/// Get the platform capabilities for sandboxing
|
||||||
pub fn get_platform_capabilities() -> PlatformCapabilities {
|
pub fn get_platform_capabilities() -> PlatformCapabilities {
|
||||||
let os = env::consts::OS;
|
let os = env::consts::OS;
|
||||||
|
|
||||||
match os {
|
match os {
|
||||||
"linux" => get_linux_capabilities(),
|
"linux" => get_linux_capabilities(),
|
||||||
"macos" => get_macos_capabilities(),
|
"macos" => get_macos_capabilities(),
|
||||||
@@ -176,4 +176,4 @@ fn get_unsupported_capabilities(os: &str) -> PlatformCapabilities {
|
|||||||
/// Check if sandboxing is available on the current platform
|
/// Check if sandboxing is available on the current platform
|
||||||
pub fn is_sandboxing_available() -> bool {
|
pub fn is_sandboxing_available() -> bool {
|
||||||
matches!(env::consts::OS, "linux" | "macos" | "freebsd")
|
matches!(env::consts::OS, "linux" | "macos" | "freebsd")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use crate::sandbox::executor::{SerializedOperation, SerializedProfile};
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use gaol::profile::{AddressPattern, Operation, OperationSupport, PathPattern, Profile};
|
use gaol::profile::{AddressPattern, Operation, OperationSupport, PathPattern, Profile};
|
||||||
@@ -5,7 +6,6 @@ use log::{debug, info, warn};
|
|||||||
use rusqlite::{params, Connection};
|
use rusqlite::{params, Connection};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use crate::sandbox::executor::{SerializedOperation, SerializedProfile};
|
|
||||||
|
|
||||||
/// Represents a sandbox profile from the database
|
/// Represents a sandbox profile from the database
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -37,7 +37,7 @@ pub struct ProfileBuildResult {
|
|||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
pub profile: Profile,
|
pub profile: Profile,
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
pub profile: (), // Placeholder for Windows
|
pub profile: (), // Placeholder for Windows
|
||||||
pub serialized: SerializedProfile,
|
pub serialized: SerializedProfile,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,56 +50,63 @@ pub struct ProfileBuilder {
|
|||||||
impl ProfileBuilder {
|
impl ProfileBuilder {
|
||||||
/// Create a new profile builder
|
/// Create a new profile builder
|
||||||
pub fn new(project_path: PathBuf) -> Result<Self> {
|
pub fn new(project_path: PathBuf) -> Result<Self> {
|
||||||
let home_dir = dirs::home_dir()
|
let home_dir = dirs::home_dir().context("Could not determine home directory")?;
|
||||||
.context("Could not determine home directory")?;
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
project_path,
|
project_path,
|
||||||
home_dir,
|
home_dir,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a gaol Profile from database rules filtered by agent permissions
|
/// Build a gaol Profile from database rules filtered by agent permissions
|
||||||
pub fn build_agent_profile(&self, rules: Vec<SandboxRule>, sandbox_enabled: bool, enable_file_read: bool, enable_file_write: bool, enable_network: bool) -> Result<ProfileBuildResult> {
|
pub fn build_agent_profile(
|
||||||
|
&self,
|
||||||
|
rules: Vec<SandboxRule>,
|
||||||
|
sandbox_enabled: bool,
|
||||||
|
enable_file_read: bool,
|
||||||
|
enable_file_write: bool,
|
||||||
|
enable_network: bool,
|
||||||
|
) -> Result<ProfileBuildResult> {
|
||||||
// If sandbox is completely disabled, return an empty profile
|
// If sandbox is completely disabled, return an empty profile
|
||||||
if !sandbox_enabled {
|
if !sandbox_enabled {
|
||||||
return Ok(ProfileBuildResult {
|
return Ok(ProfileBuildResult {
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
profile: Profile::new(vec![]).map_err(|_| anyhow::anyhow!("Failed to create empty profile"))?,
|
profile: Profile::new(vec![])
|
||||||
|
.map_err(|_| anyhow::anyhow!("Failed to create empty profile"))?,
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
profile: (),
|
profile: (),
|
||||||
serialized: SerializedProfile { operations: vec![] },
|
serialized: SerializedProfile { operations: vec![] },
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut filtered_rules = Vec::new();
|
let mut filtered_rules = Vec::new();
|
||||||
|
|
||||||
for rule in rules {
|
for rule in rules {
|
||||||
if !rule.enabled {
|
if !rule.enabled {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter rules based on agent permissions
|
// Filter rules based on agent permissions
|
||||||
let include_rule = match rule.operation_type.as_str() {
|
let include_rule = match rule.operation_type.as_str() {
|
||||||
"file_read_all" | "file_read_metadata" => enable_file_read,
|
"file_read_all" | "file_read_metadata" => enable_file_read,
|
||||||
"network_outbound" => enable_network,
|
"network_outbound" => enable_network,
|
||||||
"system_info_read" => true, // Always allow system info reading
|
"system_info_read" => true, // Always allow system info reading
|
||||||
_ => true // Include unknown rule types by default
|
_ => true, // Include unknown rule types by default
|
||||||
};
|
};
|
||||||
|
|
||||||
if include_rule {
|
if include_rule {
|
||||||
filtered_rules.push(rule);
|
filtered_rules.push(rule);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always ensure project path access if file reading is enabled
|
// Always ensure project path access if file reading is enabled
|
||||||
if enable_file_read {
|
if enable_file_read {
|
||||||
let has_project_access = filtered_rules.iter().any(|rule| {
|
let has_project_access = filtered_rules.iter().any(|rule| {
|
||||||
rule.operation_type == "file_read_all" &&
|
rule.operation_type == "file_read_all"
|
||||||
rule.pattern_type == "subpath" &&
|
&& rule.pattern_type == "subpath"
|
||||||
rule.pattern_value.contains("{{PROJECT_PATH}}")
|
&& rule.pattern_value.contains("{{PROJECT_PATH}}")
|
||||||
});
|
});
|
||||||
|
|
||||||
if !has_project_access {
|
if !has_project_access {
|
||||||
// Add a default project access rule
|
// Add a default project access rule
|
||||||
filtered_rules.push(SandboxRule {
|
filtered_rules.push(SandboxRule {
|
||||||
@@ -114,78 +121,99 @@ impl ProfileBuilder {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.build_profile_with_serialization(filtered_rules)
|
self.build_profile_with_serialization(filtered_rules)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a gaol Profile from database rules
|
/// Build a gaol Profile from database rules
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
pub fn build_profile(&self, rules: Vec<SandboxRule>) -> Result<Profile> {
|
pub fn build_profile(&self, rules: Vec<SandboxRule>) -> Result<Profile> {
|
||||||
let result = self.build_profile_with_serialization(rules)?;
|
let result = self.build_profile_with_serialization(rules)?;
|
||||||
Ok(result.profile)
|
Ok(result.profile)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a gaol Profile from database rules (Windows stub)
|
/// Build a gaol Profile from database rules (Windows stub)
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
pub fn build_profile(&self, _rules: Vec<SandboxRule>) -> Result<()> {
|
pub fn build_profile(&self, _rules: Vec<SandboxRule>) -> Result<()> {
|
||||||
warn!("Sandbox profiles are not supported on Windows");
|
warn!("Sandbox profiles are not supported on Windows");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a gaol Profile from database rules and return serialized operations
|
/// Build a gaol Profile from database rules and return serialized operations
|
||||||
pub fn build_profile_with_serialization(&self, rules: Vec<SandboxRule>) -> Result<ProfileBuildResult> {
|
pub fn build_profile_with_serialization(
|
||||||
|
&self,
|
||||||
|
rules: Vec<SandboxRule>,
|
||||||
|
) -> Result<ProfileBuildResult> {
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
{
|
{
|
||||||
let mut operations = Vec::new();
|
let mut operations = Vec::new();
|
||||||
let mut serialized_operations = Vec::new();
|
let mut serialized_operations = Vec::new();
|
||||||
|
|
||||||
for rule in rules {
|
for rule in rules {
|
||||||
if !rule.enabled {
|
if !rule.enabled {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check platform support
|
// Check platform support
|
||||||
if !self.is_rule_supported_on_platform(&rule) {
|
if !self.is_rule_supported_on_platform(&rule) {
|
||||||
debug!("Skipping rule {} - not supported on current platform", rule.operation_type);
|
debug!(
|
||||||
|
"Skipping rule {} - not supported on current platform",
|
||||||
|
rule.operation_type
|
||||||
|
);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
match self.build_operation_with_serialization(&rule) {
|
match self.build_operation_with_serialization(&rule) {
|
||||||
Ok(Some((op, serialized))) => {
|
Ok(Some((op, serialized))) => {
|
||||||
// Check if operation is supported on current platform
|
// Check if operation is supported on current platform
|
||||||
if matches!(op.support(), gaol::profile::OperationSupportLevel::CanBeAllowed) {
|
if matches!(
|
||||||
|
op.support(),
|
||||||
|
gaol::profile::OperationSupportLevel::CanBeAllowed
|
||||||
|
) {
|
||||||
operations.push(op);
|
operations.push(op);
|
||||||
serialized_operations.push(serialized);
|
serialized_operations.push(serialized);
|
||||||
} else {
|
} else {
|
||||||
warn!("Operation {:?} not supported at desired level on current platform", rule.operation_type);
|
warn!(
|
||||||
|
"Operation {:?} not supported at desired level on current platform",
|
||||||
|
rule.operation_type
|
||||||
|
);
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
debug!("Skipping unsupported operation type: {}", rule.operation_type);
|
debug!(
|
||||||
},
|
"Skipping unsupported operation type: {}",
|
||||||
|
rule.operation_type
|
||||||
|
);
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Failed to build operation for rule {}: {}", rule.id.unwrap_or(0), e);
|
warn!(
|
||||||
|
"Failed to build operation for rule {}: {}",
|
||||||
|
rule.id.unwrap_or(0),
|
||||||
|
e
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure project path access is included
|
// Ensure project path access is included
|
||||||
let has_project_access = serialized_operations.iter().any(|op| {
|
let has_project_access = serialized_operations.iter().any(|op| {
|
||||||
matches!(op, SerializedOperation::FileReadAll { path, is_subpath: true } if path == &self.project_path)
|
matches!(op, SerializedOperation::FileReadAll { path, is_subpath: true } if path == &self.project_path)
|
||||||
});
|
});
|
||||||
|
|
||||||
if !has_project_access {
|
if !has_project_access {
|
||||||
operations.push(Operation::FileReadAll(PathPattern::Subpath(self.project_path.clone())));
|
operations.push(Operation::FileReadAll(PathPattern::Subpath(
|
||||||
|
self.project_path.clone(),
|
||||||
|
)));
|
||||||
serialized_operations.push(SerializedOperation::FileReadAll {
|
serialized_operations.push(SerializedOperation::FileReadAll {
|
||||||
path: self.project_path.clone(),
|
path: self.project_path.clone(),
|
||||||
is_subpath: true,
|
is_subpath: true,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the profile
|
// Create the profile
|
||||||
let profile = Profile::new(operations)
|
let profile = Profile::new(operations)
|
||||||
.map_err(|_| anyhow::anyhow!("Failed to create sandbox profile - some operations may not be supported on this platform"))?;
|
.map_err(|_| anyhow::anyhow!("Failed to create sandbox profile - some operations may not be supported on this platform"))?;
|
||||||
|
|
||||||
Ok(ProfileBuildResult {
|
Ok(ProfileBuildResult {
|
||||||
profile,
|
profile,
|
||||||
serialized: SerializedProfile {
|
serialized: SerializedProfile {
|
||||||
@@ -193,22 +221,22 @@ impl ProfileBuilder {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
{
|
{
|
||||||
// On Windows, we just create a serialized profile without actual sandboxing
|
// On Windows, we just create a serialized profile without actual sandboxing
|
||||||
let mut serialized_operations = Vec::new();
|
let mut serialized_operations = Vec::new();
|
||||||
|
|
||||||
for rule in rules {
|
for rule in rules {
|
||||||
if !rule.enabled {
|
if !rule.enabled {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(Some(serialized)) = self.build_serialized_operation(&rule) {
|
if let Ok(Some(serialized)) = self.build_serialized_operation(&rule) {
|
||||||
serialized_operations.push(serialized);
|
serialized_operations.push(serialized);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(ProfileBuildResult {
|
Ok(ProfileBuildResult {
|
||||||
profile: (),
|
profile: (),
|
||||||
serialized: SerializedProfile {
|
serialized: SerializedProfile {
|
||||||
@@ -217,7 +245,7 @@ impl ProfileBuilder {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a gaol Operation from a database rule
|
/// Build a gaol Operation from a database rule
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn build_operation(&self, rule: &SandboxRule) -> Result<Option<Operation>> {
|
fn build_operation(&self, rule: &SandboxRule) -> Result<Option<Operation>> {
|
||||||
@@ -227,97 +255,125 @@ impl ProfileBuilder {
|
|||||||
Err(e) => Err(e),
|
Err(e) => Err(e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a gaol Operation and its serialized form from a database rule
|
/// Build a gaol Operation and its serialized form from a database rule
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn build_operation_with_serialization(&self, rule: &SandboxRule) -> Result<Option<(Operation, SerializedOperation)>> {
|
fn build_operation_with_serialization(
|
||||||
|
&self,
|
||||||
|
rule: &SandboxRule,
|
||||||
|
) -> Result<Option<(Operation, SerializedOperation)>> {
|
||||||
match rule.operation_type.as_str() {
|
match rule.operation_type.as_str() {
|
||||||
"file_read_all" => {
|
"file_read_all" => {
|
||||||
let (pattern, path, is_subpath) = self.build_path_pattern_with_info(&rule.pattern_type, &rule.pattern_value)?;
|
let (pattern, path, is_subpath) =
|
||||||
|
self.build_path_pattern_with_info(&rule.pattern_type, &rule.pattern_value)?;
|
||||||
Ok(Some((
|
Ok(Some((
|
||||||
Operation::FileReadAll(pattern),
|
Operation::FileReadAll(pattern),
|
||||||
SerializedOperation::FileReadAll { path, is_subpath }
|
SerializedOperation::FileReadAll { path, is_subpath },
|
||||||
)))
|
)))
|
||||||
},
|
}
|
||||||
"file_read_metadata" => {
|
"file_read_metadata" => {
|
||||||
let (pattern, path, is_subpath) = self.build_path_pattern_with_info(&rule.pattern_type, &rule.pattern_value)?;
|
let (pattern, path, is_subpath) =
|
||||||
|
self.build_path_pattern_with_info(&rule.pattern_type, &rule.pattern_value)?;
|
||||||
Ok(Some((
|
Ok(Some((
|
||||||
Operation::FileReadMetadata(pattern),
|
Operation::FileReadMetadata(pattern),
|
||||||
SerializedOperation::FileReadMetadata { path, is_subpath }
|
SerializedOperation::FileReadMetadata { path, is_subpath },
|
||||||
)))
|
)))
|
||||||
},
|
}
|
||||||
"network_outbound" => {
|
"network_outbound" => {
|
||||||
let (pattern, serialized) = self.build_address_pattern_with_serialization(&rule.pattern_type, &rule.pattern_value)?;
|
let (pattern, serialized) = self.build_address_pattern_with_serialization(
|
||||||
|
&rule.pattern_type,
|
||||||
|
&rule.pattern_value,
|
||||||
|
)?;
|
||||||
Ok(Some((Operation::NetworkOutbound(pattern), serialized)))
|
Ok(Some((Operation::NetworkOutbound(pattern), serialized)))
|
||||||
},
|
}
|
||||||
"system_info_read" => {
|
"system_info_read" => Ok(Some((
|
||||||
Ok(Some((
|
Operation::SystemInfoRead,
|
||||||
Operation::SystemInfoRead,
|
SerializedOperation::SystemInfoRead,
|
||||||
SerializedOperation::SystemInfoRead
|
))),
|
||||||
)))
|
_ => Ok(None),
|
||||||
},
|
|
||||||
_ => Ok(None)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a PathPattern from pattern type and value
|
/// Build a PathPattern from pattern type and value
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn build_path_pattern(&self, pattern_type: &str, pattern_value: &str) -> Result<PathPattern> {
|
fn build_path_pattern(&self, pattern_type: &str, pattern_value: &str) -> Result<PathPattern> {
|
||||||
let (pattern, _, _) = self.build_path_pattern_with_info(pattern_type, pattern_value)?;
|
let (pattern, _, _) = self.build_path_pattern_with_info(pattern_type, pattern_value)?;
|
||||||
Ok(pattern)
|
Ok(pattern)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build a PathPattern and return additional info for serialization
|
/// Build a PathPattern and return additional info for serialization
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn build_path_pattern_with_info(&self, pattern_type: &str, pattern_value: &str) -> Result<(PathPattern, PathBuf, bool)> {
|
fn build_path_pattern_with_info(
|
||||||
|
&self,
|
||||||
|
pattern_type: &str,
|
||||||
|
pattern_value: &str,
|
||||||
|
) -> Result<(PathPattern, PathBuf, bool)> {
|
||||||
// Replace template variables
|
// Replace template variables
|
||||||
let expanded_value = pattern_value
|
let expanded_value = pattern_value
|
||||||
.replace("{{PROJECT_PATH}}", &self.project_path.to_string_lossy())
|
.replace("{{PROJECT_PATH}}", &self.project_path.to_string_lossy())
|
||||||
.replace("{{HOME}}", &self.home_dir.to_string_lossy());
|
.replace("{{HOME}}", &self.home_dir.to_string_lossy());
|
||||||
|
|
||||||
let path = PathBuf::from(expanded_value);
|
let path = PathBuf::from(expanded_value);
|
||||||
|
|
||||||
match pattern_type {
|
match pattern_type {
|
||||||
"literal" => Ok((PathPattern::Literal(path.clone()), path, false)),
|
"literal" => Ok((PathPattern::Literal(path.clone()), path, false)),
|
||||||
"subpath" => Ok((PathPattern::Subpath(path.clone()), path, true)),
|
"subpath" => Ok((PathPattern::Subpath(path.clone()), path, true)),
|
||||||
_ => Err(anyhow::anyhow!("Unknown path pattern type: {}", pattern_type))
|
_ => Err(anyhow::anyhow!(
|
||||||
|
"Unknown path pattern type: {}",
|
||||||
|
pattern_type
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build an AddressPattern from pattern type and value
|
/// Build an AddressPattern from pattern type and value
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn build_address_pattern(&self, pattern_type: &str, pattern_value: &str) -> Result<AddressPattern> {
|
fn build_address_pattern(
|
||||||
let (pattern, _) = self.build_address_pattern_with_serialization(pattern_type, pattern_value)?;
|
&self,
|
||||||
|
pattern_type: &str,
|
||||||
|
pattern_value: &str,
|
||||||
|
) -> Result<AddressPattern> {
|
||||||
|
let (pattern, _) =
|
||||||
|
self.build_address_pattern_with_serialization(pattern_type, pattern_value)?;
|
||||||
Ok(pattern)
|
Ok(pattern)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build an AddressPattern and its serialized form
|
/// Build an AddressPattern and its serialized form
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn build_address_pattern_with_serialization(&self, pattern_type: &str, pattern_value: &str) -> Result<(AddressPattern, SerializedOperation)> {
|
fn build_address_pattern_with_serialization(
|
||||||
|
&self,
|
||||||
|
pattern_type: &str,
|
||||||
|
pattern_value: &str,
|
||||||
|
) -> Result<(AddressPattern, SerializedOperation)> {
|
||||||
match pattern_type {
|
match pattern_type {
|
||||||
"all" => Ok((
|
"all" => Ok((
|
||||||
AddressPattern::All,
|
AddressPattern::All,
|
||||||
SerializedOperation::NetworkOutbound { pattern: "all".to_string() }
|
SerializedOperation::NetworkOutbound {
|
||||||
|
pattern: "all".to_string(),
|
||||||
|
},
|
||||||
)),
|
)),
|
||||||
"tcp" => {
|
"tcp" => {
|
||||||
let port = pattern_value.parse::<u16>()
|
let port = pattern_value
|
||||||
|
.parse::<u16>()
|
||||||
.context("Invalid TCP port number")?;
|
.context("Invalid TCP port number")?;
|
||||||
Ok((
|
Ok((
|
||||||
AddressPattern::Tcp(port),
|
AddressPattern::Tcp(port),
|
||||||
SerializedOperation::NetworkTcp { port }
|
SerializedOperation::NetworkTcp { port },
|
||||||
))
|
))
|
||||||
},
|
}
|
||||||
"local_socket" => {
|
"local_socket" => {
|
||||||
let path = PathBuf::from(pattern_value);
|
let path = PathBuf::from(pattern_value);
|
||||||
Ok((
|
Ok((
|
||||||
AddressPattern::LocalSocket(path.clone()),
|
AddressPattern::LocalSocket(path.clone()),
|
||||||
SerializedOperation::NetworkLocalSocket { path }
|
SerializedOperation::NetworkLocalSocket { path },
|
||||||
))
|
))
|
||||||
},
|
}
|
||||||
_ => Err(anyhow::anyhow!("Unknown address pattern type: {}", pattern_type))
|
_ => Err(anyhow::anyhow!(
|
||||||
|
"Unknown address pattern type: {}",
|
||||||
|
pattern_type
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if a rule is supported on the current platform
|
/// Check if a rule is supported on the current platform
|
||||||
fn is_rule_supported_on_platform(&self, rule: &SandboxRule) -> bool {
|
fn is_rule_supported_on_platform(&self, rule: &SandboxRule) -> bool {
|
||||||
if let Some(platforms_json) = &rule.platform_support {
|
if let Some(platforms_json) = &rule.platform_support {
|
||||||
@@ -332,37 +388,42 @@ impl ProfileBuilder {
|
|||||||
|
|
||||||
/// Build only the serialized operation (for Windows)
|
/// Build only the serialized operation (for Windows)
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
fn build_serialized_operation(&self, rule: &SandboxRule) -> Result<Option<SerializedOperation>> {
|
fn build_serialized_operation(
|
||||||
|
&self,
|
||||||
|
rule: &SandboxRule,
|
||||||
|
) -> Result<Option<SerializedOperation>> {
|
||||||
let pattern_value = self.expand_pattern_value(&rule.pattern_value);
|
let pattern_value = self.expand_pattern_value(&rule.pattern_value);
|
||||||
|
|
||||||
match rule.operation_type.as_str() {
|
match rule.operation_type.as_str() {
|
||||||
"file_read_all" => {
|
"file_read_all" => {
|
||||||
let (path, is_subpath) = self.parse_path_pattern(&rule.pattern_type, &pattern_value)?;
|
let (path, is_subpath) =
|
||||||
|
self.parse_path_pattern(&rule.pattern_type, &pattern_value)?;
|
||||||
Ok(Some(SerializedOperation::FileReadAll { path, is_subpath }))
|
Ok(Some(SerializedOperation::FileReadAll { path, is_subpath }))
|
||||||
}
|
}
|
||||||
"file_read_metadata" => {
|
"file_read_metadata" => {
|
||||||
let (path, is_subpath) = self.parse_path_pattern(&rule.pattern_type, &pattern_value)?;
|
let (path, is_subpath) =
|
||||||
Ok(Some(SerializedOperation::FileReadMetadata { path, is_subpath }))
|
self.parse_path_pattern(&rule.pattern_type, &pattern_value)?;
|
||||||
}
|
Ok(Some(SerializedOperation::FileReadMetadata {
|
||||||
"network_outbound" => {
|
path,
|
||||||
Ok(Some(SerializedOperation::NetworkOutbound { pattern: pattern_value }))
|
is_subpath,
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
|
"network_outbound" => Ok(Some(SerializedOperation::NetworkOutbound {
|
||||||
|
pattern: pattern_value,
|
||||||
|
})),
|
||||||
"network_tcp" => {
|
"network_tcp" => {
|
||||||
let port = pattern_value.parse::<u16>()
|
let port = pattern_value.parse::<u16>().context("Invalid TCP port")?;
|
||||||
.context("Invalid TCP port")?;
|
|
||||||
Ok(Some(SerializedOperation::NetworkTcp { port }))
|
Ok(Some(SerializedOperation::NetworkTcp { port }))
|
||||||
}
|
}
|
||||||
"network_local_socket" => {
|
"network_local_socket" => {
|
||||||
let path = PathBuf::from(pattern_value);
|
let path = PathBuf::from(pattern_value);
|
||||||
Ok(Some(SerializedOperation::NetworkLocalSocket { path }))
|
Ok(Some(SerializedOperation::NetworkLocalSocket { path }))
|
||||||
}
|
}
|
||||||
"system_info_read" => {
|
"system_info_read" => Ok(Some(SerializedOperation::SystemInfoRead)),
|
||||||
Ok(Some(SerializedOperation::SystemInfoRead))
|
|
||||||
}
|
|
||||||
_ => Ok(None),
|
_ => Ok(None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper method to expand pattern values (Windows version)
|
/// Helper method to expand pattern values (Windows version)
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
fn expand_pattern_value(&self, pattern_value: &str) -> String {
|
fn expand_pattern_value(&self, pattern_value: &str) -> String {
|
||||||
@@ -370,16 +431,23 @@ impl ProfileBuilder {
|
|||||||
.replace("{{PROJECT_PATH}}", &self.project_path.to_string_lossy())
|
.replace("{{PROJECT_PATH}}", &self.project_path.to_string_lossy())
|
||||||
.replace("{{HOME}}", &self.home_dir.to_string_lossy())
|
.replace("{{HOME}}", &self.home_dir.to_string_lossy())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper method to parse path patterns (Windows version)
|
/// Helper method to parse path patterns (Windows version)
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
fn parse_path_pattern(&self, pattern_type: &str, pattern_value: &str) -> Result<(PathBuf, bool)> {
|
fn parse_path_pattern(
|
||||||
|
&self,
|
||||||
|
pattern_type: &str,
|
||||||
|
pattern_value: &str,
|
||||||
|
) -> Result<(PathBuf, bool)> {
|
||||||
let path = PathBuf::from(pattern_value);
|
let path = PathBuf::from(pattern_value);
|
||||||
|
|
||||||
match pattern_type {
|
match pattern_type {
|
||||||
"literal" => Ok((path, false)),
|
"literal" => Ok((path, false)),
|
||||||
"subpath" => Ok((path, true)),
|
"subpath" => Ok((path, true)),
|
||||||
_ => Err(anyhow::anyhow!("Unknown path pattern type: {}", pattern_type))
|
_ => Err(anyhow::anyhow!(
|
||||||
|
"Unknown path pattern type: {}",
|
||||||
|
pattern_type
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -400,7 +468,7 @@ pub fn load_profile(conn: &Connection, profile_id: i64) -> Result<SandboxProfile
|
|||||||
created_at: row.get(5)?,
|
created_at: row.get(5)?,
|
||||||
updated_at: row.get(6)?,
|
updated_at: row.get(6)?,
|
||||||
})
|
})
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
.context("Failed to load sandbox profile")
|
.context("Failed to load sandbox profile")
|
||||||
}
|
}
|
||||||
@@ -421,7 +489,7 @@ pub fn load_default_profile(conn: &Connection) -> Result<SandboxProfile> {
|
|||||||
created_at: row.get(5)?,
|
created_at: row.get(5)?,
|
||||||
updated_at: row.get(6)?,
|
updated_at: row.get(6)?,
|
||||||
})
|
})
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
.context("Failed to load default sandbox profile")
|
.context("Failed to load default sandbox profile")
|
||||||
}
|
}
|
||||||
@@ -432,40 +500,45 @@ pub fn load_profile_rules(conn: &Connection, profile_id: i64) -> Result<Vec<Sand
|
|||||||
"SELECT id, profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support, created_at
|
"SELECT id, profile_id, operation_type, pattern_type, pattern_value, enabled, platform_support, created_at
|
||||||
FROM sandbox_rules WHERE profile_id = ?1 AND enabled = 1"
|
FROM sandbox_rules WHERE profile_id = ?1 AND enabled = 1"
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let rules = stmt.query_map(params![profile_id], |row| {
|
let rules = stmt
|
||||||
Ok(SandboxRule {
|
.query_map(params![profile_id], |row| {
|
||||||
id: Some(row.get(0)?),
|
Ok(SandboxRule {
|
||||||
profile_id: row.get(1)?,
|
id: Some(row.get(0)?),
|
||||||
operation_type: row.get(2)?,
|
profile_id: row.get(1)?,
|
||||||
pattern_type: row.get(3)?,
|
operation_type: row.get(2)?,
|
||||||
pattern_value: row.get(4)?,
|
pattern_type: row.get(3)?,
|
||||||
enabled: row.get(5)?,
|
pattern_value: row.get(4)?,
|
||||||
platform_support: row.get(6)?,
|
enabled: row.get(5)?,
|
||||||
created_at: row.get(7)?,
|
platform_support: row.get(6)?,
|
||||||
})
|
created_at: row.get(7)?,
|
||||||
})?
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
})?
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
Ok(rules)
|
Ok(rules)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get or create the gaol Profile for execution
|
/// Get or create the gaol Profile for execution
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
pub fn get_gaol_profile(conn: &Connection, profile_id: Option<i64>, project_path: PathBuf) -> Result<Profile> {
|
pub fn get_gaol_profile(
|
||||||
|
conn: &Connection,
|
||||||
|
profile_id: Option<i64>,
|
||||||
|
project_path: PathBuf,
|
||||||
|
) -> Result<Profile> {
|
||||||
// Load the profile
|
// Load the profile
|
||||||
let profile = if let Some(id) = profile_id {
|
let profile = if let Some(id) = profile_id {
|
||||||
load_profile(conn, id)?
|
load_profile(conn, id)?
|
||||||
} else {
|
} else {
|
||||||
load_default_profile(conn)?
|
load_default_profile(conn)?
|
||||||
};
|
};
|
||||||
|
|
||||||
info!("Using sandbox profile: {}", profile.name);
|
info!("Using sandbox profile: {}", profile.name);
|
||||||
|
|
||||||
// Load the rules
|
// Load the rules
|
||||||
let rules = load_profile_rules(conn, profile.id.unwrap())?;
|
let rules = load_profile_rules(conn, profile.id.unwrap())?;
|
||||||
info!("Loaded {} sandbox rules", rules.len());
|
info!("Loaded {} sandbox rules", rules.len());
|
||||||
|
|
||||||
// Build the gaol profile
|
// Build the gaol profile
|
||||||
let builder = ProfileBuilder::new(project_path)?;
|
let builder = ProfileBuilder::new(project_path)?;
|
||||||
builder.build_profile(rules)
|
builder.build_profile(rules)
|
||||||
@@ -473,7 +546,11 @@ pub fn get_gaol_profile(conn: &Connection, profile_id: Option<i64>, project_path
|
|||||||
|
|
||||||
/// Get or create the gaol Profile for execution (Windows stub)
|
/// Get or create the gaol Profile for execution (Windows stub)
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
pub fn get_gaol_profile(_conn: &Connection, _profile_id: Option<i64>, _project_path: PathBuf) -> Result<()> {
|
pub fn get_gaol_profile(
|
||||||
|
_conn: &Connection,
|
||||||
|
_profile_id: Option<i64>,
|
||||||
|
_project_path: PathBuf,
|
||||||
|
) -> Result<()> {
|
||||||
warn!("Sandbox profiles are not supported on Windows");
|
warn!("Sandbox profiles are not supported on Windows");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,36 +14,37 @@ pub fn execute_claude_task(
|
|||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
) -> Result<ClaudeOutput> {
|
) -> Result<ClaudeOutput> {
|
||||||
let mut cmd = Command::new("claude");
|
let mut cmd = Command::new("claude");
|
||||||
|
|
||||||
// Add task
|
// Add task
|
||||||
cmd.arg("-p").arg(task);
|
cmd.arg("-p").arg(task);
|
||||||
|
|
||||||
// Add system prompt if provided
|
// Add system prompt if provided
|
||||||
if let Some(prompt) = system_prompt {
|
if let Some(prompt) = system_prompt {
|
||||||
cmd.arg("--system-prompt").arg(prompt);
|
cmd.arg("--system-prompt").arg(prompt);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add model if provided
|
// Add model if provided
|
||||||
if let Some(m) = model {
|
if let Some(m) = model {
|
||||||
cmd.arg("--model").arg(m);
|
cmd.arg("--model").arg(m);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Always add these flags for testing
|
// Always add these flags for testing
|
||||||
cmd.arg("--output-format").arg("stream-json")
|
cmd.arg("--output-format")
|
||||||
.arg("--verbose")
|
.arg("stream-json")
|
||||||
.arg("--dangerously-skip-permissions")
|
.arg("--verbose")
|
||||||
.current_dir(project_path)
|
.arg("--dangerously-skip-permissions")
|
||||||
.stdout(Stdio::piped())
|
.current_dir(project_path)
|
||||||
.stderr(Stdio::piped());
|
.stdout(Stdio::piped())
|
||||||
|
.stderr(Stdio::piped());
|
||||||
|
|
||||||
// Add sandbox profile ID if provided
|
// Add sandbox profile ID if provided
|
||||||
if let Some(profile_id) = sandbox_profile_id {
|
if let Some(profile_id) = sandbox_profile_id {
|
||||||
cmd.env("CLAUDIA_SANDBOX_PROFILE_ID", profile_id.to_string());
|
cmd.env("CLAUDIA_SANDBOX_PROFILE_ID", profile_id.to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute with timeout (use gtimeout on macOS, timeout on Linux)
|
// Execute with timeout (use gtimeout on macOS, timeout on Linux)
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
let timeout_cmd = if cfg!(target_os = "macos") {
|
let timeout_cmd = if cfg!(target_os = "macos") {
|
||||||
// On macOS, try gtimeout (from GNU coreutils) first, fallback to direct execution
|
// On macOS, try gtimeout (from GNU coreutils) first, fallback to direct execution
|
||||||
if std::process::Command::new("which")
|
if std::process::Command::new("which")
|
||||||
@@ -60,15 +61,15 @@ pub fn execute_claude_task(
|
|||||||
} else {
|
} else {
|
||||||
"timeout"
|
"timeout"
|
||||||
};
|
};
|
||||||
|
|
||||||
let output = if timeout_cmd.is_empty() {
|
let output = if timeout_cmd.is_empty() {
|
||||||
// Run without timeout wrapper
|
// Run without timeout wrapper
|
||||||
cmd.output()
|
cmd.output().context("Failed to execute Claude command")?
|
||||||
.context("Failed to execute Claude command")?
|
|
||||||
} else {
|
} else {
|
||||||
// Run with timeout wrapper
|
// Run with timeout wrapper
|
||||||
let mut timeout_cmd = Command::new(timeout_cmd);
|
let mut timeout_cmd = Command::new(timeout_cmd);
|
||||||
timeout_cmd.arg(timeout_secs.to_string())
|
timeout_cmd
|
||||||
|
.arg(timeout_secs.to_string())
|
||||||
.arg("claude")
|
.arg("claude")
|
||||||
.args(cmd.get_args())
|
.args(cmd.get_args())
|
||||||
.current_dir(project_path)
|
.current_dir(project_path)
|
||||||
@@ -78,9 +79,9 @@ pub fn execute_claude_task(
|
|||||||
.output()
|
.output()
|
||||||
.context("Failed to execute Claude command with timeout")?
|
.context("Failed to execute Claude command with timeout")?
|
||||||
};
|
};
|
||||||
|
|
||||||
let duration = start.elapsed();
|
let duration = start.elapsed();
|
||||||
|
|
||||||
Ok(ClaudeOutput {
|
Ok(ClaudeOutput {
|
||||||
stdout: String::from_utf8_lossy(&output.stdout).to_string(),
|
stdout: String::from_utf8_lossy(&output.stdout).to_string(),
|
||||||
stderr: String::from_utf8_lossy(&output.stderr).to_string(),
|
stderr: String::from_utf8_lossy(&output.stderr).to_string(),
|
||||||
@@ -103,7 +104,7 @@ impl ClaudeOutput {
|
|||||||
pub fn contains_operation(&self, operation: &str) -> bool {
|
pub fn contains_operation(&self, operation: &str) -> bool {
|
||||||
self.stdout.contains(operation) || self.stderr.contains(operation)
|
self.stdout.contains(operation) || self.stderr.contains(operation)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if operation was blocked (look for permission denied, sandbox violation, etc)
|
/// Check if operation was blocked (look for permission denied, sandbox violation, etc)
|
||||||
pub fn operation_was_blocked(&self, operation: &str) -> bool {
|
pub fn operation_was_blocked(&self, operation: &str) -> bool {
|
||||||
let blocked_patterns = [
|
let blocked_patterns = [
|
||||||
@@ -114,16 +115,16 @@ impl ClaudeOutput {
|
|||||||
"access denied",
|
"access denied",
|
||||||
"sandbox violation",
|
"sandbox violation",
|
||||||
];
|
];
|
||||||
|
|
||||||
let output = format!("{}\n{}", self.stdout, self.stderr).to_lowercase();
|
let output = format!("{}\n{}", self.stdout, self.stderr).to_lowercase();
|
||||||
let op_lower = operation.to_lowercase();
|
let op_lower = operation.to_lowercase();
|
||||||
|
|
||||||
// Check if operation was mentioned along with a block pattern
|
// Check if operation was mentioned along with a block pattern
|
||||||
blocked_patterns.iter().any(|pattern| {
|
blocked_patterns
|
||||||
output.contains(&op_lower) && output.contains(pattern)
|
.iter()
|
||||||
})
|
.any(|pattern| output.contains(&op_lower) && output.contains(pattern))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if file read was successful
|
/// Check if file read was successful
|
||||||
pub fn file_read_succeeded(&self, filename: &str) -> bool {
|
pub fn file_read_succeeded(&self, filename: &str) -> bool {
|
||||||
// Look for patterns indicating successful file read
|
// Look for patterns indicating successful file read
|
||||||
@@ -133,10 +134,12 @@ impl ClaudeOutput {
|
|||||||
&format!("Contents of {}", filename),
|
&format!("Contents of {}", filename),
|
||||||
"test content", // Our test files contain this
|
"test content", // Our test files contain this
|
||||||
];
|
];
|
||||||
|
|
||||||
patterns.iter().any(|pattern| self.contains_operation(pattern))
|
patterns
|
||||||
|
.iter()
|
||||||
|
.any(|pattern| self.contains_operation(pattern))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if network connection was attempted
|
/// Check if network connection was attempted
|
||||||
pub fn network_attempted(&self, host: &str) -> bool {
|
pub fn network_attempted(&self, host: &str) -> bool {
|
||||||
let patterns = [
|
let patterns = [
|
||||||
@@ -145,8 +148,10 @@ impl ClaudeOutput {
|
|||||||
&format!("connect to {}", host),
|
&format!("connect to {}", host),
|
||||||
host,
|
host,
|
||||||
];
|
];
|
||||||
|
|
||||||
patterns.iter().any(|pattern| self.contains_operation(pattern))
|
patterns
|
||||||
|
.iter()
|
||||||
|
.any(|pattern| self.contains_operation(pattern))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,24 +161,27 @@ pub mod tasks {
|
|||||||
pub fn read_file(filename: &str) -> String {
|
pub fn read_file(filename: &str) -> String {
|
||||||
format!("Read the file {} and show me its contents", filename)
|
format!("Read the file {} and show me its contents", filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Task to attempt network connection
|
/// Task to attempt network connection
|
||||||
pub fn connect_network(host: &str) -> String {
|
pub fn connect_network(host: &str) -> String {
|
||||||
format!("Try to connect to {} and tell me if it works", host)
|
format!("Try to connect to {} and tell me if it works", host)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Task to do multiple operations
|
/// Task to do multiple operations
|
||||||
pub fn multi_operation() -> String {
|
pub fn multi_operation() -> String {
|
||||||
"Read the file ./test.txt in the current directory and show its contents".to_string()
|
"Read the file ./test.txt in the current directory and show its contents".to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Task to test file write
|
/// Task to test file write
|
||||||
pub fn write_file(filename: &str, content: &str) -> String {
|
pub fn write_file(filename: &str, content: &str) -> String {
|
||||||
format!("Create a file called {} with the content '{}'", filename, content)
|
format!(
|
||||||
|
"Create a file called {} with the content '{}'",
|
||||||
|
filename, content
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Task to test process spawning
|
/// Task to test process spawning
|
||||||
pub fn spawn_process(command: &str) -> String {
|
pub fn spawn_process(command: &str) -> String {
|
||||||
format!("Run the command '{}' and show me the output", command)
|
format!("Run the command '{}' and show me the output", command)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,9 +10,8 @@ use tempfile::{tempdir, TempDir};
|
|||||||
/// Using parking_lot::Mutex which doesn't poison on panic
|
/// Using parking_lot::Mutex which doesn't poison on panic
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
|
||||||
pub static TEST_DB: Lazy<Mutex<TestDatabase>> = Lazy::new(|| {
|
pub static TEST_DB: Lazy<Mutex<TestDatabase>> =
|
||||||
Mutex::new(TestDatabase::new().expect("Failed to create test database"))
|
Lazy::new(|| Mutex::new(TestDatabase::new().expect("Failed to create test database")));
|
||||||
});
|
|
||||||
|
|
||||||
/// Test database manager
|
/// Test database manager
|
||||||
pub struct TestDatabase {
|
pub struct TestDatabase {
|
||||||
@@ -26,13 +25,13 @@ impl TestDatabase {
|
|||||||
let temp_dir = tempdir()?;
|
let temp_dir = tempdir()?;
|
||||||
let db_path = temp_dir.path().join("test_sandbox.db");
|
let db_path = temp_dir.path().join("test_sandbox.db");
|
||||||
let conn = Connection::open(&db_path)?;
|
let conn = Connection::open(&db_path)?;
|
||||||
|
|
||||||
// Initialize schema
|
// Initialize schema
|
||||||
Self::init_schema(&conn)?;
|
Self::init_schema(&conn)?;
|
||||||
|
|
||||||
Ok(Self { conn, temp_dir })
|
Ok(Self { conn, temp_dir })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initialize database schema
|
/// Initialize database schema
|
||||||
fn init_schema(conn: &Connection) -> Result<()> {
|
fn init_schema(conn: &Connection) -> Result<()> {
|
||||||
// Create sandbox profiles table
|
// Create sandbox profiles table
|
||||||
@@ -48,7 +47,7 @@ impl TestDatabase {
|
|||||||
)",
|
)",
|
||||||
[],
|
[],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// Create sandbox rules table
|
// Create sandbox rules table
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"CREATE TABLE IF NOT EXISTS sandbox_rules (
|
"CREATE TABLE IF NOT EXISTS sandbox_rules (
|
||||||
@@ -64,7 +63,7 @@ impl TestDatabase {
|
|||||||
)",
|
)",
|
||||||
[],
|
[],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// Create agents table
|
// Create agents table
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"CREATE TABLE IF NOT EXISTS agents (
|
"CREATE TABLE IF NOT EXISTS agents (
|
||||||
@@ -80,7 +79,7 @@ impl TestDatabase {
|
|||||||
)",
|
)",
|
||||||
[],
|
[],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// Create agent_runs table
|
// Create agent_runs table
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"CREATE TABLE IF NOT EXISTS agent_runs (
|
"CREATE TABLE IF NOT EXISTS agent_runs (
|
||||||
@@ -101,7 +100,7 @@ impl TestDatabase {
|
|||||||
)",
|
)",
|
||||||
[],
|
[],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// Create sandbox violations table
|
// Create sandbox violations table
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"CREATE TABLE IF NOT EXISTS sandbox_violations (
|
"CREATE TABLE IF NOT EXISTS sandbox_violations (
|
||||||
@@ -120,7 +119,7 @@ impl TestDatabase {
|
|||||||
)",
|
)",
|
||||||
[],
|
[],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// Create trigger to update the updated_at timestamp for agents
|
// Create trigger to update the updated_at timestamp for agents
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"CREATE TRIGGER IF NOT EXISTS update_agent_timestamp
|
"CREATE TRIGGER IF NOT EXISTS update_agent_timestamp
|
||||||
@@ -131,7 +130,7 @@ impl TestDatabase {
|
|||||||
END",
|
END",
|
||||||
[],
|
[],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// Create trigger to update sandbox profile timestamp
|
// Create trigger to update sandbox profile timestamp
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"CREATE TRIGGER IF NOT EXISTS update_sandbox_profile_timestamp
|
"CREATE TRIGGER IF NOT EXISTS update_sandbox_profile_timestamp
|
||||||
@@ -142,10 +141,10 @@ impl TestDatabase {
|
|||||||
END",
|
END",
|
||||||
[],
|
[],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a test profile with rules
|
/// Create a test profile with rules
|
||||||
pub fn create_test_profile(&self, name: &str, rules: Vec<TestRule>) -> Result<i64> {
|
pub fn create_test_profile(&self, name: &str, rules: Vec<TestRule>) -> Result<i64> {
|
||||||
// Insert profile
|
// Insert profile
|
||||||
@@ -153,9 +152,9 @@ impl TestDatabase {
|
|||||||
"INSERT INTO sandbox_profiles (name, description, is_active, is_default) VALUES (?1, ?2, ?3, ?4)",
|
"INSERT INTO sandbox_profiles (name, description, is_active, is_default) VALUES (?1, ?2, ?3, ?4)",
|
||||||
params![name, format!("Test profile: {name}"), true, false],
|
params![name, format!("Test profile: {name}"), true, false],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let profile_id = self.conn.last_insert_rowid();
|
let profile_id = self.conn.last_insert_rowid();
|
||||||
|
|
||||||
// Insert rules
|
// Insert rules
|
||||||
for rule in rules {
|
for rule in rules {
|
||||||
self.conn.execute(
|
self.conn.execute(
|
||||||
@@ -171,10 +170,10 @@ impl TestDatabase {
|
|||||||
],
|
],
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(profile_id)
|
Ok(profile_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Reset database to clean state
|
/// Reset database to clean state
|
||||||
pub fn reset(&self) -> Result<()> {
|
pub fn reset(&self) -> Result<()> {
|
||||||
// Delete in the correct order to respect foreign key constraints
|
// Delete in the correct order to respect foreign key constraints
|
||||||
@@ -208,7 +207,7 @@ impl TestRule {
|
|||||||
platform_support: Some(r#"["linux", "macos"]"#.to_string()),
|
platform_support: Some(r#"["linux", "macos"]"#.to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a network rule
|
/// Create a network rule
|
||||||
pub fn network_all() -> Self {
|
pub fn network_all() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -219,7 +218,7 @@ impl TestRule {
|
|||||||
platform_support: Some(r#"["linux", "macos"]"#.to_string()),
|
platform_support: Some(r#"["linux", "macos"]"#.to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a network TCP rule
|
/// Create a network TCP rule
|
||||||
pub fn network_tcp(port: u16) -> Self {
|
pub fn network_tcp(port: u16) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -230,7 +229,7 @@ impl TestRule {
|
|||||||
platform_support: Some(r#"["macos"]"#.to_string()),
|
platform_support: Some(r#"["macos"]"#.to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a system info read rule
|
/// Create a system info read rule
|
||||||
pub fn system_info_read() -> Self {
|
pub fn system_info_read() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -256,25 +255,28 @@ impl TestFileSystem {
|
|||||||
pub fn new() -> Result<Self> {
|
pub fn new() -> Result<Self> {
|
||||||
let root = tempdir()?;
|
let root = tempdir()?;
|
||||||
let root_path = root.path();
|
let root_path = root.path();
|
||||||
|
|
||||||
// Create project directory
|
// Create project directory
|
||||||
let project_path = root_path.join("test_project");
|
let project_path = root_path.join("test_project");
|
||||||
std::fs::create_dir_all(&project_path)?;
|
std::fs::create_dir_all(&project_path)?;
|
||||||
|
|
||||||
// Create allowed directory
|
// Create allowed directory
|
||||||
let allowed_path = root_path.join("allowed");
|
let allowed_path = root_path.join("allowed");
|
||||||
std::fs::create_dir_all(&allowed_path)?;
|
std::fs::create_dir_all(&allowed_path)?;
|
||||||
std::fs::write(allowed_path.join("test.txt"), "allowed content")?;
|
std::fs::write(allowed_path.join("test.txt"), "allowed content")?;
|
||||||
|
|
||||||
// Create forbidden directory
|
// Create forbidden directory
|
||||||
let forbidden_path = root_path.join("forbidden");
|
let forbidden_path = root_path.join("forbidden");
|
||||||
std::fs::create_dir_all(&forbidden_path)?;
|
std::fs::create_dir_all(&forbidden_path)?;
|
||||||
std::fs::write(forbidden_path.join("secret.txt"), "forbidden content")?;
|
std::fs::write(forbidden_path.join("secret.txt"), "forbidden content")?;
|
||||||
|
|
||||||
// Create project files
|
// Create project files
|
||||||
std::fs::write(project_path.join("main.rs"), "fn main() {}")?;
|
std::fs::write(project_path.join("main.rs"), "fn main() {}")?;
|
||||||
std::fs::write(project_path.join("Cargo.toml"), "[package]\nname = \"test\"")?;
|
std::fs::write(
|
||||||
|
project_path.join("Cargo.toml"),
|
||||||
|
"[package]\nname = \"test\"",
|
||||||
|
)?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
root,
|
root,
|
||||||
project_path,
|
project_path,
|
||||||
@@ -287,14 +289,12 @@ impl TestFileSystem {
|
|||||||
/// Standard test profiles
|
/// Standard test profiles
|
||||||
pub mod profiles {
|
pub mod profiles {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
/// Minimal profile - only project access
|
/// Minimal profile - only project access
|
||||||
pub fn minimal(project_path: &str) -> Vec<TestRule> {
|
pub fn minimal(project_path: &str) -> Vec<TestRule> {
|
||||||
vec![
|
vec![TestRule::file_read(project_path, true)]
|
||||||
TestRule::file_read(project_path, true),
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Standard profile - project + system libraries
|
/// Standard profile - project + system libraries
|
||||||
pub fn standard(project_path: &str) -> Vec<TestRule> {
|
pub fn standard(project_path: &str) -> Vec<TestRule> {
|
||||||
vec![
|
vec![
|
||||||
@@ -304,7 +304,7 @@ pub mod profiles {
|
|||||||
TestRule::network_all(),
|
TestRule::network_all(),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Development profile - more permissive
|
/// Development profile - more permissive
|
||||||
pub fn development(project_path: &str, home_dir: &str) -> Vec<TestRule> {
|
pub fn development(project_path: &str, home_dir: &str) -> Vec<TestRule> {
|
||||||
vec![
|
vec![
|
||||||
@@ -316,18 +316,17 @@ pub mod profiles {
|
|||||||
TestRule::system_info_read(),
|
TestRule::system_info_read(),
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Network-only profile
|
/// Network-only profile
|
||||||
pub fn network_only() -> Vec<TestRule> {
|
pub fn network_only() -> Vec<TestRule> {
|
||||||
vec![
|
vec![TestRule::network_all()]
|
||||||
TestRule::network_all(),
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// File-only profile
|
/// File-only profile
|
||||||
pub fn file_only(paths: Vec<&str>) -> Vec<TestRule> {
|
pub fn file_only(paths: Vec<&str>) -> Vec<TestRule> {
|
||||||
paths.into_iter()
|
paths
|
||||||
|
.into_iter()
|
||||||
.map(|path| TestRule::file_read(path, true))
|
.map(|path| TestRule::file_read(path, true))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,10 @@ pub fn is_sandboxing_supported() -> bool {
|
|||||||
macro_rules! skip_if_unsupported {
|
macro_rules! skip_if_unsupported {
|
||||||
() => {
|
() => {
|
||||||
if !$crate::sandbox::common::is_sandboxing_supported() {
|
if !$crate::sandbox::common::is_sandboxing_supported() {
|
||||||
eprintln!("Skipping test: sandboxing not supported on {}", std::env::consts::OS);
|
eprintln!(
|
||||||
|
"Skipping test: sandboxing not supported on {}",
|
||||||
|
std::env::consts::OS
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -39,7 +42,7 @@ impl PlatformConfig {
|
|||||||
supports_file_read: true,
|
supports_file_read: true,
|
||||||
supports_metadata_read: false, // Cannot be precisely controlled
|
supports_metadata_read: false, // Cannot be precisely controlled
|
||||||
supports_network_all: true,
|
supports_network_all: true,
|
||||||
supports_network_tcp: false, // Cannot filter by port
|
supports_network_tcp: false, // Cannot filter by port
|
||||||
supports_network_local: false, // Cannot filter by path
|
supports_network_local: false, // Cannot filter by path
|
||||||
supports_system_info: false,
|
supports_system_info: false,
|
||||||
},
|
},
|
||||||
@@ -89,54 +92,53 @@ impl TestCommand {
|
|||||||
working_dir: None,
|
working_dir: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add an argument
|
/// Add an argument
|
||||||
pub fn arg(mut self, arg: &str) -> Self {
|
pub fn arg(mut self, arg: &str) -> Self {
|
||||||
self.args.push(arg.to_string());
|
self.args.push(arg.to_string());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add multiple arguments
|
/// Add multiple arguments
|
||||||
pub fn args(mut self, args: &[&str]) -> Self {
|
pub fn args(mut self, args: &[&str]) -> Self {
|
||||||
self.args.extend(args.iter().map(|s| s.to_string()));
|
self.args.extend(args.iter().map(|s| s.to_string()));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set an environment variable
|
/// Set an environment variable
|
||||||
pub fn env(mut self, key: &str, value: &str) -> Self {
|
pub fn env(mut self, key: &str, value: &str) -> Self {
|
||||||
self.env_vars.push((key.to_string(), value.to_string()));
|
self.env_vars.push((key.to_string(), value.to_string()));
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set working directory
|
/// Set working directory
|
||||||
pub fn current_dir(mut self, dir: &Path) -> Self {
|
pub fn current_dir(mut self, dir: &Path) -> Self {
|
||||||
self.working_dir = Some(dir.to_path_buf());
|
self.working_dir = Some(dir.to_path_buf());
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute the command with timeout
|
/// Execute the command with timeout
|
||||||
pub fn execute_with_timeout(&self, timeout: Duration) -> Result<Output> {
|
pub fn execute_with_timeout(&self, timeout: Duration) -> Result<Output> {
|
||||||
let mut cmd = Command::new(&self.command);
|
let mut cmd = Command::new(&self.command);
|
||||||
|
|
||||||
cmd.args(&self.args);
|
cmd.args(&self.args);
|
||||||
|
|
||||||
for (key, value) in &self.env_vars {
|
for (key, value) in &self.env_vars {
|
||||||
cmd.env(key, value);
|
cmd.env(key, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(dir) = &self.working_dir {
|
if let Some(dir) = &self.working_dir {
|
||||||
cmd.current_dir(dir);
|
cmd.current_dir(dir);
|
||||||
}
|
}
|
||||||
|
|
||||||
// On Unix, we can use a timeout mechanism
|
// On Unix, we can use a timeout mechanism
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
{
|
{
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let mut child = cmd.spawn()
|
let mut child = cmd.spawn().context("Failed to spawn command")?;
|
||||||
.context("Failed to spawn command")?;
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match child.try_wait() {
|
match child.try_wait() {
|
||||||
Ok(Some(status)) => {
|
Ok(Some(status)) => {
|
||||||
@@ -158,19 +160,18 @@ impl TestCommand {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(unix))]
|
#[cfg(not(unix))]
|
||||||
{
|
{
|
||||||
// Fallback for non-Unix platforms
|
// Fallback for non-Unix platforms
|
||||||
cmd.output()
|
cmd.output().context("Failed to execute command")
|
||||||
.context("Failed to execute command")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute and expect success
|
/// Execute and expect success
|
||||||
pub fn execute_expect_success(&self) -> Result<String> {
|
pub fn execute_expect_success(&self) -> Result<String> {
|
||||||
let output = self.execute_with_timeout(Duration::from_secs(10))?;
|
let output = self.execute_with_timeout(Duration::from_secs(10))?;
|
||||||
|
|
||||||
if !output.status.success() {
|
if !output.status.success() {
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
@@ -178,31 +179,27 @@ impl TestCommand {
|
|||||||
output.status.code()
|
output.status.code()
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute and expect failure
|
/// Execute and expect failure
|
||||||
pub fn execute_expect_failure(&self) -> Result<String> {
|
pub fn execute_expect_failure(&self) -> Result<String> {
|
||||||
let output = self.execute_with_timeout(Duration::from_secs(10))?;
|
let output = self.execute_with_timeout(Duration::from_secs(10))?;
|
||||||
|
|
||||||
if output.status.success() {
|
if output.status.success() {
|
||||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||||
return Err(anyhow::anyhow!(
|
return Err(anyhow::anyhow!(
|
||||||
"Command unexpectedly succeeded. Stdout: {stdout}"
|
"Command unexpectedly succeeded. Stdout: {stdout}"
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(String::from_utf8_lossy(&output.stderr).to_string())
|
Ok(String::from_utf8_lossy(&output.stderr).to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a simple test binary that attempts an operation
|
/// Create a simple test binary that attempts an operation
|
||||||
pub fn create_test_binary(
|
pub fn create_test_binary(name: &str, code: &str, test_dir: &Path) -> Result<PathBuf> {
|
||||||
name: &str,
|
|
||||||
code: &str,
|
|
||||||
test_dir: &Path,
|
|
||||||
) -> Result<PathBuf> {
|
|
||||||
create_test_binary_with_deps(name, code, test_dir, &[])
|
create_test_binary_with_deps(name, code, test_dir, &[])
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -215,7 +212,7 @@ pub fn create_test_binary_with_deps(
|
|||||||
) -> Result<PathBuf> {
|
) -> Result<PathBuf> {
|
||||||
let src_dir = test_dir.join("src");
|
let src_dir = test_dir.join("src");
|
||||||
std::fs::create_dir_all(&src_dir)?;
|
std::fs::create_dir_all(&src_dir)?;
|
||||||
|
|
||||||
// Build dependencies section
|
// Build dependencies section
|
||||||
let deps_section = if dependencies.is_empty() {
|
let deps_section = if dependencies.is_empty() {
|
||||||
String::new()
|
String::new()
|
||||||
@@ -226,7 +223,7 @@ pub fn create_test_binary_with_deps(
|
|||||||
}
|
}
|
||||||
deps
|
deps
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create Cargo.toml
|
// Create Cargo.toml
|
||||||
let cargo_toml = format!(
|
let cargo_toml = format!(
|
||||||
r#"[package]
|
r#"[package]
|
||||||
@@ -240,10 +237,10 @@ path = "src/main.rs"
|
|||||||
{deps_section}"#
|
{deps_section}"#
|
||||||
);
|
);
|
||||||
std::fs::write(test_dir.join("Cargo.toml"), cargo_toml)?;
|
std::fs::write(test_dir.join("Cargo.toml"), cargo_toml)?;
|
||||||
|
|
||||||
// Create main.rs
|
// Create main.rs
|
||||||
std::fs::write(src_dir.join("main.rs"), code)?;
|
std::fs::write(src_dir.join("main.rs"), code)?;
|
||||||
|
|
||||||
// Build the binary
|
// Build the binary
|
||||||
let output = Command::new("cargo")
|
let output = Command::new("cargo")
|
||||||
.arg("build")
|
.arg("build")
|
||||||
@@ -251,12 +248,12 @@ path = "src/main.rs"
|
|||||||
.current_dir(test_dir)
|
.current_dir(test_dir)
|
||||||
.output()
|
.output()
|
||||||
.context("Failed to build test binary")?;
|
.context("Failed to build test binary")?;
|
||||||
|
|
||||||
if !output.status.success() {
|
if !output.status.success() {
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
return Err(anyhow::anyhow!("Failed to build test binary: {stderr}"));
|
return Err(anyhow::anyhow!("Failed to build test binary: {stderr}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
let binary_path = test_dir.join("target/release").join(name);
|
let binary_path = test_dir.join("target/release").join(name);
|
||||||
Ok(binary_path)
|
Ok(binary_path)
|
||||||
}
|
}
|
||||||
@@ -281,7 +278,7 @@ fn main() {{
|
|||||||
"#
|
"#
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Code that reads file metadata
|
/// Code that reads file metadata
|
||||||
pub fn file_metadata(path: &str) -> String {
|
pub fn file_metadata(path: &str) -> String {
|
||||||
format!(
|
format!(
|
||||||
@@ -300,7 +297,7 @@ fn main() {{
|
|||||||
"#
|
"#
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Code that makes a network connection
|
/// Code that makes a network connection
|
||||||
pub fn network_connect(addr: &str) -> String {
|
pub fn network_connect(addr: &str) -> String {
|
||||||
format!(
|
format!(
|
||||||
@@ -321,7 +318,7 @@ fn main() {{
|
|||||||
"#
|
"#
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Code that reads system information
|
/// Code that reads system information
|
||||||
pub fn system_info() -> &'static str {
|
pub fn system_info() -> &'static str {
|
||||||
r#"
|
r#"
|
||||||
@@ -368,7 +365,7 @@ fn main() {
|
|||||||
}
|
}
|
||||||
"#
|
"#
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Code that tries to spawn a process
|
/// Code that tries to spawn a process
|
||||||
pub fn spawn_process() -> &'static str {
|
pub fn spawn_process() -> &'static str {
|
||||||
r#"
|
r#"
|
||||||
@@ -387,7 +384,7 @@ fn main() {
|
|||||||
}
|
}
|
||||||
"#
|
"#
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Code that uses fork (requires libc)
|
/// Code that uses fork (requires libc)
|
||||||
pub fn fork_process() -> &'static str {
|
pub fn fork_process() -> &'static str {
|
||||||
r#"
|
r#"
|
||||||
@@ -418,7 +415,7 @@ fn main() {
|
|||||||
}
|
}
|
||||||
"#
|
"#
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Code that uses exec (requires libc)
|
/// Code that uses exec (requires libc)
|
||||||
pub fn exec_process() -> &'static str {
|
pub fn exec_process() -> &'static str {
|
||||||
r#"
|
r#"
|
||||||
@@ -446,7 +443,7 @@ fn main() {
|
|||||||
}
|
}
|
||||||
"#
|
"#
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Code that tries to write a file
|
/// Code that tries to write a file
|
||||||
pub fn file_write(path: &str) -> String {
|
pub fn file_write(path: &str) -> String {
|
||||||
format!(
|
format!(
|
||||||
@@ -483,4 +480,4 @@ pub fn assert_sandbox_success(output: &str) {
|
|||||||
/// Assert that a command output indicates failure
|
/// Assert that a command output indicates failure
|
||||||
pub fn assert_sandbox_failure(output: &str) {
|
pub fn assert_sandbox_failure(output: &str) {
|
||||||
assert_output_contains(output, "FAILURE:");
|
assert_output_contains(output, "FAILURE:");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
//! Common test utilities and helpers for sandbox testing
|
//! Common test utilities and helpers for sandbox testing
|
||||||
|
pub mod claude_real;
|
||||||
pub mod fixtures;
|
pub mod fixtures;
|
||||||
pub mod helpers;
|
pub mod helpers;
|
||||||
pub mod claude_real;
|
|
||||||
|
|
||||||
|
pub use claude_real::*;
|
||||||
pub use fixtures::*;
|
pub use fixtures::*;
|
||||||
pub use helpers::*;
|
pub use helpers::*;
|
||||||
pub use claude_real::*;
|
|
||||||
@@ -8,17 +8,18 @@ use serial_test::serial;
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_agent_with_minimal_profile() {
|
fn test_agent_with_minimal_profile() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test environment
|
// Create test environment
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
let test_db = TEST_DB.lock();
|
let test_db = TEST_DB.lock();
|
||||||
test_db.reset().expect("Failed to reset database");
|
test_db.reset().expect("Failed to reset database");
|
||||||
|
|
||||||
// Create minimal sandbox profile
|
// Create minimal sandbox profile
|
||||||
let rules = profiles::minimal(&test_fs.project_path.to_string_lossy());
|
let rules = profiles::minimal(&test_fs.project_path.to_string_lossy());
|
||||||
let profile_id = test_db.create_test_profile("minimal_agent_test", rules)
|
let profile_id = test_db
|
||||||
|
.create_test_profile("minimal_agent_test", rules)
|
||||||
.expect("Failed to create test profile");
|
.expect("Failed to create test profile");
|
||||||
|
|
||||||
// Create test agent
|
// Create test agent
|
||||||
test_db.conn.execute(
|
test_db.conn.execute(
|
||||||
"INSERT INTO agents (name, icon, system_prompt, model, sandbox_profile_id) VALUES (?1, ?2, ?3, ?4, ?5)",
|
"INSERT INTO agents (name, icon, system_prompt, model, sandbox_profile_id) VALUES (?1, ?2, ?3, ?4, ?5)",
|
||||||
@@ -30,9 +31,9 @@ fn test_agent_with_minimal_profile() {
|
|||||||
profile_id
|
profile_id
|
||||||
],
|
],
|
||||||
).expect("Failed to create agent");
|
).expect("Failed to create agent");
|
||||||
|
|
||||||
let _agent_id = test_db.conn.last_insert_rowid();
|
let _agent_id = test_db.conn.last_insert_rowid();
|
||||||
|
|
||||||
// Execute real Claude command with minimal profile
|
// Execute real Claude command with minimal profile
|
||||||
let result = execute_claude_task(
|
let result = execute_claude_task(
|
||||||
&test_fs.project_path,
|
&test_fs.project_path,
|
||||||
@@ -41,8 +42,9 @@ fn test_agent_with_minimal_profile() {
|
|||||||
Some("sonnet"),
|
Some("sonnet"),
|
||||||
Some(profile_id),
|
Some(profile_id),
|
||||||
20, // 20 second timeout
|
20, // 20 second timeout
|
||||||
).expect("Failed to execute Claude command");
|
)
|
||||||
|
.expect("Failed to execute Claude command");
|
||||||
|
|
||||||
// Debug output
|
// Debug output
|
||||||
eprintln!("=== Claude Output ===");
|
eprintln!("=== Claude Output ===");
|
||||||
eprintln!("Exit code: {}", result.exit_code);
|
eprintln!("Exit code: {}", result.exit_code);
|
||||||
@@ -50,10 +52,13 @@ fn test_agent_with_minimal_profile() {
|
|||||||
eprintln!("STDERR:\n{}", result.stderr);
|
eprintln!("STDERR:\n{}", result.stderr);
|
||||||
eprintln!("Duration: {:?}", result.duration);
|
eprintln!("Duration: {:?}", result.duration);
|
||||||
eprintln!("===================");
|
eprintln!("===================");
|
||||||
|
|
||||||
// Basic verification - just check Claude ran
|
// Basic verification - just check Claude ran
|
||||||
assert!(result.exit_code == 0 || result.exit_code == 124, // 0 = success, 124 = timeout
|
assert!(
|
||||||
"Claude should execute (exit code: {})", result.exit_code);
|
result.exit_code == 0 || result.exit_code == 124, // 0 = success, 124 = timeout
|
||||||
|
"Claude should execute (exit code: {})",
|
||||||
|
result.exit_code
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test agent execution with standard sandbox profile
|
/// Test agent execution with standard sandbox profile
|
||||||
@@ -61,17 +66,18 @@ fn test_agent_with_minimal_profile() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_agent_with_standard_profile() {
|
fn test_agent_with_standard_profile() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test environment
|
// Create test environment
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
let test_db = TEST_DB.lock();
|
let test_db = TEST_DB.lock();
|
||||||
test_db.reset().expect("Failed to reset database");
|
test_db.reset().expect("Failed to reset database");
|
||||||
|
|
||||||
// Create standard sandbox profile
|
// Create standard sandbox profile
|
||||||
let rules = profiles::standard(&test_fs.project_path.to_string_lossy());
|
let rules = profiles::standard(&test_fs.project_path.to_string_lossy());
|
||||||
let profile_id = test_db.create_test_profile("standard_agent_test", rules)
|
let profile_id = test_db
|
||||||
|
.create_test_profile("standard_agent_test", rules)
|
||||||
.expect("Failed to create test profile");
|
.expect("Failed to create test profile");
|
||||||
|
|
||||||
// Create test agent
|
// Create test agent
|
||||||
test_db.conn.execute(
|
test_db.conn.execute(
|
||||||
"INSERT INTO agents (name, icon, system_prompt, model, sandbox_profile_id) VALUES (?1, ?2, ?3, ?4, ?5)",
|
"INSERT INTO agents (name, icon, system_prompt, model, sandbox_profile_id) VALUES (?1, ?2, ?3, ?4, ?5)",
|
||||||
@@ -83,9 +89,9 @@ fn test_agent_with_standard_profile() {
|
|||||||
profile_id
|
profile_id
|
||||||
],
|
],
|
||||||
).expect("Failed to create agent");
|
).expect("Failed to create agent");
|
||||||
|
|
||||||
let _agent_id = test_db.conn.last_insert_rowid();
|
let _agent_id = test_db.conn.last_insert_rowid();
|
||||||
|
|
||||||
// Execute real Claude command with standard profile
|
// Execute real Claude command with standard profile
|
||||||
let result = execute_claude_task(
|
let result = execute_claude_task(
|
||||||
&test_fs.project_path,
|
&test_fs.project_path,
|
||||||
@@ -94,18 +100,22 @@ fn test_agent_with_standard_profile() {
|
|||||||
Some("sonnet"),
|
Some("sonnet"),
|
||||||
Some(profile_id),
|
Some(profile_id),
|
||||||
20, // 20 second timeout
|
20, // 20 second timeout
|
||||||
).expect("Failed to execute Claude command");
|
)
|
||||||
|
.expect("Failed to execute Claude command");
|
||||||
|
|
||||||
// Debug output
|
// Debug output
|
||||||
eprintln!("=== Claude Output (Standard Profile) ===");
|
eprintln!("=== Claude Output (Standard Profile) ===");
|
||||||
eprintln!("Exit code: {}", result.exit_code);
|
eprintln!("Exit code: {}", result.exit_code);
|
||||||
eprintln!("STDOUT:\n{}", result.stdout);
|
eprintln!("STDOUT:\n{}", result.stdout);
|
||||||
eprintln!("STDERR:\n{}", result.stderr);
|
eprintln!("STDERR:\n{}", result.stderr);
|
||||||
eprintln!("===================");
|
eprintln!("===================");
|
||||||
|
|
||||||
// Basic verification
|
// Basic verification
|
||||||
assert!(result.exit_code == 0 || result.exit_code == 124,
|
assert!(
|
||||||
"Claude should execute with standard profile (exit code: {})", result.exit_code);
|
result.exit_code == 0 || result.exit_code == 124,
|
||||||
|
"Claude should execute with standard profile (exit code: {})",
|
||||||
|
result.exit_code
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test agent execution without sandbox (control test)
|
/// Test agent execution without sandbox (control test)
|
||||||
@@ -113,25 +123,28 @@ fn test_agent_with_standard_profile() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_agent_without_sandbox() {
|
fn test_agent_without_sandbox() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test environment
|
// Create test environment
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
let test_db = TEST_DB.lock();
|
let test_db = TEST_DB.lock();
|
||||||
test_db.reset().expect("Failed to reset database");
|
test_db.reset().expect("Failed to reset database");
|
||||||
|
|
||||||
// Create agent without sandbox profile
|
// Create agent without sandbox profile
|
||||||
test_db.conn.execute(
|
test_db
|
||||||
"INSERT INTO agents (name, icon, system_prompt, model) VALUES (?1, ?2, ?3, ?4)",
|
.conn
|
||||||
rusqlite::params![
|
.execute(
|
||||||
"Unsandboxed Agent",
|
"INSERT INTO agents (name, icon, system_prompt, model) VALUES (?1, ?2, ?3, ?4)",
|
||||||
"⚠️",
|
rusqlite::params![
|
||||||
"You are a test agent without sandbox restrictions.",
|
"Unsandboxed Agent",
|
||||||
"sonnet"
|
"⚠️",
|
||||||
],
|
"You are a test agent without sandbox restrictions.",
|
||||||
).expect("Failed to create agent");
|
"sonnet"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.expect("Failed to create agent");
|
||||||
|
|
||||||
let _agent_id = test_db.conn.last_insert_rowid();
|
let _agent_id = test_db.conn.last_insert_rowid();
|
||||||
|
|
||||||
// Execute real Claude command without sandbox profile
|
// Execute real Claude command without sandbox profile
|
||||||
let result = execute_claude_task(
|
let result = execute_claude_task(
|
||||||
&test_fs.project_path,
|
&test_fs.project_path,
|
||||||
@@ -139,19 +152,23 @@ fn test_agent_without_sandbox() {
|
|||||||
Some("You are a test agent without sandbox restrictions."),
|
Some("You are a test agent without sandbox restrictions."),
|
||||||
Some("sonnet"),
|
Some("sonnet"),
|
||||||
None, // No sandbox profile
|
None, // No sandbox profile
|
||||||
20, // 20 second timeout
|
20, // 20 second timeout
|
||||||
).expect("Failed to execute Claude command");
|
)
|
||||||
|
.expect("Failed to execute Claude command");
|
||||||
|
|
||||||
// Debug output
|
// Debug output
|
||||||
eprintln!("=== Claude Output (No Sandbox) ===");
|
eprintln!("=== Claude Output (No Sandbox) ===");
|
||||||
eprintln!("Exit code: {}", result.exit_code);
|
eprintln!("Exit code: {}", result.exit_code);
|
||||||
eprintln!("STDOUT:\n{}", result.stdout);
|
eprintln!("STDOUT:\n{}", result.stdout);
|
||||||
eprintln!("STDERR:\n{}", result.stderr);
|
eprintln!("STDERR:\n{}", result.stderr);
|
||||||
eprintln!("===================");
|
eprintln!("===================");
|
||||||
|
|
||||||
// Basic verification
|
// Basic verification
|
||||||
assert!(result.exit_code == 0 || result.exit_code == 124,
|
assert!(
|
||||||
"Claude should execute without sandbox (exit code: {})", result.exit_code);
|
result.exit_code == 0 || result.exit_code == 124,
|
||||||
|
"Claude should execute without sandbox (exit code: {})",
|
||||||
|
result.exit_code
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test agent run violation logging
|
/// Test agent run violation logging
|
||||||
@@ -159,15 +176,16 @@ fn test_agent_without_sandbox() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_agent_run_violation_logging() {
|
fn test_agent_run_violation_logging() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test environment
|
// Create test environment
|
||||||
let test_db = TEST_DB.lock();
|
let test_db = TEST_DB.lock();
|
||||||
test_db.reset().expect("Failed to reset database");
|
test_db.reset().expect("Failed to reset database");
|
||||||
|
|
||||||
// Create a test profile first
|
// Create a test profile first
|
||||||
let profile_id = test_db.create_test_profile("violation_test", vec![])
|
let profile_id = test_db
|
||||||
|
.create_test_profile("violation_test", vec![])
|
||||||
.expect("Failed to create test profile");
|
.expect("Failed to create test profile");
|
||||||
|
|
||||||
// Create a test agent
|
// Create a test agent
|
||||||
test_db.conn.execute(
|
test_db.conn.execute(
|
||||||
"INSERT INTO agents (name, icon, system_prompt, model, sandbox_profile_id) VALUES (?1, ?2, ?3, ?4, ?5)",
|
"INSERT INTO agents (name, icon, system_prompt, model, sandbox_profile_id) VALUES (?1, ?2, ?3, ?4, ?5)",
|
||||||
@@ -179,9 +197,9 @@ fn test_agent_run_violation_logging() {
|
|||||||
profile_id
|
profile_id
|
||||||
],
|
],
|
||||||
).expect("Failed to create agent");
|
).expect("Failed to create agent");
|
||||||
|
|
||||||
let agent_id = test_db.conn.last_insert_rowid();
|
let agent_id = test_db.conn.last_insert_rowid();
|
||||||
|
|
||||||
// Create a test agent run
|
// Create a test agent run
|
||||||
test_db.conn.execute(
|
test_db.conn.execute(
|
||||||
"INSERT INTO agent_runs (agent_id, agent_name, agent_icon, task, model, project_path) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
"INSERT INTO agent_runs (agent_id, agent_name, agent_icon, task, model, project_path) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
|
||||||
@@ -194,23 +212,26 @@ fn test_agent_run_violation_logging() {
|
|||||||
"/test/path"
|
"/test/path"
|
||||||
],
|
],
|
||||||
).expect("Failed to create agent run");
|
).expect("Failed to create agent run");
|
||||||
|
|
||||||
let agent_run_id = test_db.conn.last_insert_rowid();
|
let agent_run_id = test_db.conn.last_insert_rowid();
|
||||||
|
|
||||||
// Insert test violations
|
// Insert test violations
|
||||||
test_db.conn.execute(
|
test_db.conn.execute(
|
||||||
"INSERT INTO sandbox_violations (profile_id, agent_id, agent_run_id, operation_type, pattern_value)
|
"INSERT INTO sandbox_violations (profile_id, agent_id, agent_run_id, operation_type, pattern_value)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5)",
|
VALUES (?1, ?2, ?3, ?4, ?5)",
|
||||||
rusqlite::params![profile_id, agent_id, agent_run_id, "file_read_all", "/etc/passwd"],
|
rusqlite::params![profile_id, agent_id, agent_run_id, "file_read_all", "/etc/passwd"],
|
||||||
).expect("Failed to insert violation");
|
).expect("Failed to insert violation");
|
||||||
|
|
||||||
// Query violations
|
// Query violations
|
||||||
let count: i64 = test_db.conn.query_row(
|
let count: i64 = test_db
|
||||||
"SELECT COUNT(*) FROM sandbox_violations WHERE agent_id = ?1",
|
.conn
|
||||||
rusqlite::params![agent_id],
|
.query_row(
|
||||||
|row| row.get(0),
|
"SELECT COUNT(*) FROM sandbox_violations WHERE agent_id = ?1",
|
||||||
).expect("Failed to query violations");
|
rusqlite::params![agent_id],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.expect("Failed to query violations");
|
||||||
|
|
||||||
assert_eq!(count, 1, "Should have recorded one violation");
|
assert_eq!(count, 1, "Should have recorded one violation");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -219,21 +240,23 @@ fn test_agent_run_violation_logging() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_profile_switching() {
|
fn test_profile_switching() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test environment
|
// Create test environment
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
let test_db = TEST_DB.lock();
|
let test_db = TEST_DB.lock();
|
||||||
test_db.reset().expect("Failed to reset database");
|
test_db.reset().expect("Failed to reset database");
|
||||||
|
|
||||||
// Create two different profiles
|
// Create two different profiles
|
||||||
let minimal_rules = profiles::minimal(&test_fs.project_path.to_string_lossy());
|
let minimal_rules = profiles::minimal(&test_fs.project_path.to_string_lossy());
|
||||||
let minimal_id = test_db.create_test_profile("minimal_switch", minimal_rules)
|
let minimal_id = test_db
|
||||||
|
.create_test_profile("minimal_switch", minimal_rules)
|
||||||
.expect("Failed to create minimal profile");
|
.expect("Failed to create minimal profile");
|
||||||
|
|
||||||
let standard_rules = profiles::standard(&test_fs.project_path.to_string_lossy());
|
let standard_rules = profiles::standard(&test_fs.project_path.to_string_lossy());
|
||||||
let standard_id = test_db.create_test_profile("standard_switch", standard_rules)
|
let standard_id = test_db
|
||||||
|
.create_test_profile("standard_switch", standard_rules)
|
||||||
.expect("Failed to create standard profile");
|
.expect("Failed to create standard profile");
|
||||||
|
|
||||||
// Create agent initially with minimal profile
|
// Create agent initially with minimal profile
|
||||||
test_db.conn.execute(
|
test_db.conn.execute(
|
||||||
"INSERT INTO agents (name, icon, system_prompt, model, sandbox_profile_id) VALUES (?1, ?2, ?3, ?4, ?5)",
|
"INSERT INTO agents (name, icon, system_prompt, model, sandbox_profile_id) VALUES (?1, ?2, ?3, ?4, ?5)",
|
||||||
@@ -245,21 +268,27 @@ fn test_profile_switching() {
|
|||||||
minimal_id
|
minimal_id
|
||||||
],
|
],
|
||||||
).expect("Failed to create agent");
|
).expect("Failed to create agent");
|
||||||
|
|
||||||
let agent_id = test_db.conn.last_insert_rowid();
|
let agent_id = test_db.conn.last_insert_rowid();
|
||||||
|
|
||||||
// Update agent to use standard profile
|
// Update agent to use standard profile
|
||||||
test_db.conn.execute(
|
test_db
|
||||||
"UPDATE agents SET sandbox_profile_id = ?1 WHERE id = ?2",
|
.conn
|
||||||
rusqlite::params![standard_id, agent_id],
|
.execute(
|
||||||
).expect("Failed to update agent profile");
|
"UPDATE agents SET sandbox_profile_id = ?1 WHERE id = ?2",
|
||||||
|
rusqlite::params![standard_id, agent_id],
|
||||||
|
)
|
||||||
|
.expect("Failed to update agent profile");
|
||||||
|
|
||||||
// Verify profile was updated
|
// Verify profile was updated
|
||||||
let current_profile: i64 = test_db.conn.query_row(
|
let current_profile: i64 = test_db
|
||||||
"SELECT sandbox_profile_id FROM agents WHERE id = ?1",
|
.conn
|
||||||
rusqlite::params![agent_id],
|
.query_row(
|
||||||
|row| row.get(0),
|
"SELECT sandbox_profile_id FROM agents WHERE id = ?1",
|
||||||
).expect("Failed to query agent profile");
|
rusqlite::params![agent_id],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.expect("Failed to query agent profile");
|
||||||
|
|
||||||
assert_eq!(current_profile, standard_id, "Profile should be updated");
|
assert_eq!(current_profile, standard_id, "Profile should be updated");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,23 +8,27 @@ use serial_test::serial;
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_claude_with_default_sandbox() {
|
fn test_claude_with_default_sandbox() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test environment
|
// Create test environment
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
let test_db = TEST_DB.lock();
|
let test_db = TEST_DB.lock();
|
||||||
test_db.reset().expect("Failed to reset database");
|
test_db.reset().expect("Failed to reset database");
|
||||||
|
|
||||||
// Create default sandbox profile
|
// Create default sandbox profile
|
||||||
let rules = profiles::standard(&test_fs.project_path.to_string_lossy());
|
let rules = profiles::standard(&test_fs.project_path.to_string_lossy());
|
||||||
let profile_id = test_db.create_test_profile("claude_default", rules)
|
let profile_id = test_db
|
||||||
|
.create_test_profile("claude_default", rules)
|
||||||
.expect("Failed to create test profile");
|
.expect("Failed to create test profile");
|
||||||
|
|
||||||
// Set as default and active
|
// Set as default and active
|
||||||
test_db.conn.execute(
|
test_db
|
||||||
"UPDATE sandbox_profiles SET is_default = 1, is_active = 1 WHERE id = ?1",
|
.conn
|
||||||
rusqlite::params![profile_id],
|
.execute(
|
||||||
).expect("Failed to set default profile");
|
"UPDATE sandbox_profiles SET is_default = 1, is_active = 1 WHERE id = ?1",
|
||||||
|
rusqlite::params![profile_id],
|
||||||
|
)
|
||||||
|
.expect("Failed to set default profile");
|
||||||
|
|
||||||
// Execute real Claude command with default sandbox profile
|
// Execute real Claude command with default sandbox profile
|
||||||
let result = execute_claude_task(
|
let result = execute_claude_task(
|
||||||
&test_fs.project_path,
|
&test_fs.project_path,
|
||||||
@@ -33,18 +37,22 @@ fn test_claude_with_default_sandbox() {
|
|||||||
Some("sonnet"),
|
Some("sonnet"),
|
||||||
Some(profile_id),
|
Some(profile_id),
|
||||||
20, // 20 second timeout
|
20, // 20 second timeout
|
||||||
).expect("Failed to execute Claude command");
|
)
|
||||||
|
.expect("Failed to execute Claude command");
|
||||||
|
|
||||||
// Debug output
|
// Debug output
|
||||||
eprintln!("=== Claude Output (Default Sandbox) ===");
|
eprintln!("=== Claude Output (Default Sandbox) ===");
|
||||||
eprintln!("Exit code: {}", result.exit_code);
|
eprintln!("Exit code: {}", result.exit_code);
|
||||||
eprintln!("STDOUT:\n{}", result.stdout);
|
eprintln!("STDOUT:\n{}", result.stdout);
|
||||||
eprintln!("STDERR:\n{}", result.stderr);
|
eprintln!("STDERR:\n{}", result.stderr);
|
||||||
eprintln!("===================");
|
eprintln!("===================");
|
||||||
|
|
||||||
// Basic verification
|
// Basic verification
|
||||||
assert!(result.exit_code == 0 || result.exit_code == 124,
|
assert!(
|
||||||
"Claude should execute with default sandbox (exit code: {})", result.exit_code);
|
result.exit_code == 0 || result.exit_code == 124,
|
||||||
|
"Claude should execute with default sandbox (exit code: {})",
|
||||||
|
result.exit_code
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test Claude Code with sandboxing disabled
|
/// Test Claude Code with sandboxing disabled
|
||||||
@@ -52,23 +60,27 @@ fn test_claude_with_default_sandbox() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_claude_sandbox_disabled() {
|
fn test_claude_sandbox_disabled() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test environment
|
// Create test environment
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
let test_db = TEST_DB.lock();
|
let test_db = TEST_DB.lock();
|
||||||
test_db.reset().expect("Failed to reset database");
|
test_db.reset().expect("Failed to reset database");
|
||||||
|
|
||||||
// Create profile but mark as inactive
|
// Create profile but mark as inactive
|
||||||
let rules = profiles::standard(&test_fs.project_path.to_string_lossy());
|
let rules = profiles::standard(&test_fs.project_path.to_string_lossy());
|
||||||
let profile_id = test_db.create_test_profile("claude_inactive", rules)
|
let profile_id = test_db
|
||||||
|
.create_test_profile("claude_inactive", rules)
|
||||||
.expect("Failed to create test profile");
|
.expect("Failed to create test profile");
|
||||||
|
|
||||||
// Set as default but inactive
|
// Set as default but inactive
|
||||||
test_db.conn.execute(
|
test_db
|
||||||
"UPDATE sandbox_profiles SET is_default = 1, is_active = 0 WHERE id = ?1",
|
.conn
|
||||||
rusqlite::params![profile_id],
|
.execute(
|
||||||
).expect("Failed to set inactive profile");
|
"UPDATE sandbox_profiles SET is_default = 1, is_active = 0 WHERE id = ?1",
|
||||||
|
rusqlite::params![profile_id],
|
||||||
|
)
|
||||||
|
.expect("Failed to set inactive profile");
|
||||||
|
|
||||||
// Execute real Claude command without active sandbox
|
// Execute real Claude command without active sandbox
|
||||||
let result = execute_claude_task(
|
let result = execute_claude_task(
|
||||||
&test_fs.project_path,
|
&test_fs.project_path,
|
||||||
@@ -76,19 +88,23 @@ fn test_claude_sandbox_disabled() {
|
|||||||
Some("You are Claude. Only perform the requested task."),
|
Some("You are Claude. Only perform the requested task."),
|
||||||
Some("sonnet"),
|
Some("sonnet"),
|
||||||
None, // No sandbox since profile is inactive
|
None, // No sandbox since profile is inactive
|
||||||
20, // 20 second timeout
|
20, // 20 second timeout
|
||||||
).expect("Failed to execute Claude command");
|
)
|
||||||
|
.expect("Failed to execute Claude command");
|
||||||
|
|
||||||
// Debug output
|
// Debug output
|
||||||
eprintln!("=== Claude Output (Inactive Sandbox) ===");
|
eprintln!("=== Claude Output (Inactive Sandbox) ===");
|
||||||
eprintln!("Exit code: {}", result.exit_code);
|
eprintln!("Exit code: {}", result.exit_code);
|
||||||
eprintln!("STDOUT:\n{}", result.stdout);
|
eprintln!("STDOUT:\n{}", result.stdout);
|
||||||
eprintln!("STDERR:\n{}", result.stderr);
|
eprintln!("STDERR:\n{}", result.stderr);
|
||||||
eprintln!("===================");
|
eprintln!("===================");
|
||||||
|
|
||||||
// Basic verification
|
// Basic verification
|
||||||
assert!(result.exit_code == 0 || result.exit_code == 124,
|
assert!(
|
||||||
"Claude should execute without active sandbox (exit code: {})", result.exit_code);
|
result.exit_code == 0 || result.exit_code == 124,
|
||||||
|
"Claude should execute without active sandbox (exit code: {})",
|
||||||
|
result.exit_code
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test Claude Code session operations
|
/// Test Claude Code session operations
|
||||||
@@ -96,31 +112,31 @@ fn test_claude_sandbox_disabled() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_claude_session_operations() {
|
fn test_claude_session_operations() {
|
||||||
// This test doesn't require actual Claude execution
|
// This test doesn't require actual Claude execution
|
||||||
|
|
||||||
// Create test environment
|
// Create test environment
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create mock session structure
|
// Create mock session structure
|
||||||
let claude_dir = test_fs.root.path().join(".claude");
|
let claude_dir = test_fs.root.path().join(".claude");
|
||||||
let projects_dir = claude_dir.join("projects");
|
let projects_dir = claude_dir.join("projects");
|
||||||
let project_id = test_fs.project_path.to_string_lossy().replace('/', "-");
|
let project_id = test_fs.project_path.to_string_lossy().replace('/', "-");
|
||||||
let session_dir = projects_dir.join(&project_id);
|
let session_dir = projects_dir.join(&project_id);
|
||||||
|
|
||||||
std::fs::create_dir_all(&session_dir).expect("Failed to create session dir");
|
std::fs::create_dir_all(&session_dir).expect("Failed to create session dir");
|
||||||
|
|
||||||
// Create mock session file
|
// Create mock session file
|
||||||
let session_id = "test-session-123";
|
let session_id = "test-session-123";
|
||||||
let session_file = session_dir.join(format!("{}.jsonl", session_id));
|
let session_file = session_dir.join(format!("{}.jsonl", session_id));
|
||||||
|
|
||||||
let session_data = serde_json::json!({
|
let session_data = serde_json::json!({
|
||||||
"type": "session_start",
|
"type": "session_start",
|
||||||
"cwd": test_fs.project_path.to_string_lossy(),
|
"cwd": test_fs.project_path.to_string_lossy(),
|
||||||
"timestamp": "2024-01-01T00:00:00Z"
|
"timestamp": "2024-01-01T00:00:00Z"
|
||||||
});
|
});
|
||||||
|
|
||||||
std::fs::write(&session_file, format!("{}\n", session_data))
|
std::fs::write(&session_file, format!("{}\n", session_data))
|
||||||
.expect("Failed to write session file");
|
.expect("Failed to write session file");
|
||||||
|
|
||||||
// Verify session file exists
|
// Verify session file exists
|
||||||
assert!(session_file.exists(), "Session file should exist");
|
assert!(session_file.exists(), "Session file should exist");
|
||||||
}
|
}
|
||||||
@@ -131,11 +147,11 @@ fn test_claude_session_operations() {
|
|||||||
fn test_claude_settings_sandbox_config() {
|
fn test_claude_settings_sandbox_config() {
|
||||||
// Create test environment
|
// Create test environment
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create mock settings
|
// Create mock settings
|
||||||
let claude_dir = test_fs.root.path().join(".claude");
|
let claude_dir = test_fs.root.path().join(".claude");
|
||||||
std::fs::create_dir_all(&claude_dir).expect("Failed to create claude dir");
|
std::fs::create_dir_all(&claude_dir).expect("Failed to create claude dir");
|
||||||
|
|
||||||
let settings_file = claude_dir.join("settings.json");
|
let settings_file = claude_dir.join("settings.json");
|
||||||
let settings = serde_json::json!({
|
let settings = serde_json::json!({
|
||||||
"sandboxEnabled": true,
|
"sandboxEnabled": true,
|
||||||
@@ -143,18 +159,23 @@ fn test_claude_settings_sandbox_config() {
|
|||||||
"theme": "dark",
|
"theme": "dark",
|
||||||
"model": "sonnet"
|
"model": "sonnet"
|
||||||
});
|
});
|
||||||
|
|
||||||
std::fs::write(&settings_file, serde_json::to_string_pretty(&settings).unwrap())
|
std::fs::write(
|
||||||
.expect("Failed to write settings");
|
&settings_file,
|
||||||
|
serde_json::to_string_pretty(&settings).unwrap(),
|
||||||
|
)
|
||||||
|
.expect("Failed to write settings");
|
||||||
|
|
||||||
// Read and verify settings
|
// Read and verify settings
|
||||||
let content = std::fs::read_to_string(&settings_file)
|
let content = std::fs::read_to_string(&settings_file).expect("Failed to read settings");
|
||||||
.expect("Failed to read settings");
|
let parsed: serde_json::Value =
|
||||||
let parsed: serde_json::Value = serde_json::from_str(&content)
|
serde_json::from_str(&content).expect("Failed to parse settings");
|
||||||
.expect("Failed to parse settings");
|
|
||||||
|
|
||||||
assert_eq!(parsed["sandboxEnabled"], true, "Sandbox should be enabled");
|
assert_eq!(parsed["sandboxEnabled"], true, "Sandbox should be enabled");
|
||||||
assert_eq!(parsed["defaultSandboxProfile"], "standard", "Default profile should be standard");
|
assert_eq!(
|
||||||
|
parsed["defaultSandboxProfile"], "standard",
|
||||||
|
"Default profile should be standard"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test profile-based file access restrictions
|
/// Test profile-based file access restrictions
|
||||||
@@ -162,22 +183,23 @@ fn test_claude_settings_sandbox_config() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_profile_file_access_simulation() {
|
fn test_profile_file_access_simulation() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test environment
|
// Create test environment
|
||||||
let _test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let _test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
let test_db = TEST_DB.lock();
|
let test_db = TEST_DB.lock();
|
||||||
test_db.reset().expect("Failed to reset database");
|
test_db.reset().expect("Failed to reset database");
|
||||||
|
|
||||||
// Create a custom profile with specific file access
|
// Create a custom profile with specific file access
|
||||||
let custom_rules = vec![
|
let custom_rules = vec![
|
||||||
TestRule::file_read("{{PROJECT_PATH}}", true),
|
TestRule::file_read("{{PROJECT_PATH}}", true),
|
||||||
TestRule::file_read("/usr/local/bin", true),
|
TestRule::file_read("/usr/local/bin", true),
|
||||||
TestRule::file_read("/etc/hosts", false), // Literal file
|
TestRule::file_read("/etc/hosts", false), // Literal file
|
||||||
];
|
];
|
||||||
|
|
||||||
let profile_id = test_db.create_test_profile("file_access_test", custom_rules)
|
let profile_id = test_db
|
||||||
|
.create_test_profile("file_access_test", custom_rules)
|
||||||
.expect("Failed to create test profile");
|
.expect("Failed to create test profile");
|
||||||
|
|
||||||
// Load the profile rules
|
// Load the profile rules
|
||||||
let loaded_rules: Vec<(String, String, String)> = test_db.conn
|
let loaded_rules: Vec<(String, String, String)> = test_db.conn
|
||||||
.prepare("SELECT operation_type, pattern_type, pattern_value FROM sandbox_rules WHERE profile_id = ?1")
|
.prepare("SELECT operation_type, pattern_type, pattern_value FROM sandbox_rules WHERE profile_id = ?1")
|
||||||
@@ -188,9 +210,11 @@ fn test_profile_file_access_simulation() {
|
|||||||
.expect("Failed to query rules")
|
.expect("Failed to query rules")
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
.expect("Failed to collect rules");
|
.expect("Failed to collect rules");
|
||||||
|
|
||||||
// Verify rules were created correctly
|
// Verify rules were created correctly
|
||||||
assert_eq!(loaded_rules.len(), 3, "Should have 3 rules");
|
assert_eq!(loaded_rules.len(), 3, "Should have 3 rules");
|
||||||
assert!(loaded_rules.iter().any(|(op, _, _)| op == "file_read_all"),
|
assert!(
|
||||||
"Should have file_read_all operation");
|
loaded_rules.iter().any(|(op, _, _)| op == "file_read_all"),
|
||||||
}
|
"Should have file_read_all operation"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,4 +2,4 @@
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod agent_sandbox;
|
mod agent_sandbox;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod claude_sandbox;
|
mod claude_sandbox;
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use crate::sandbox::common::*;
|
|||||||
use crate::skip_if_unsupported;
|
use crate::skip_if_unsupported;
|
||||||
use claudia_lib::sandbox::executor::SandboxExecutor;
|
use claudia_lib::sandbox::executor::SandboxExecutor;
|
||||||
use claudia_lib::sandbox::profile::ProfileBuilder;
|
use claudia_lib::sandbox::profile::ProfileBuilder;
|
||||||
use gaol::profile::{Profile, Operation, PathPattern};
|
use gaol::profile::{Operation, PathPattern, Profile};
|
||||||
use serial_test::serial;
|
use serial_test::serial;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
@@ -12,21 +12,21 @@ use tempfile::TempDir;
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_allowed_file_read() {
|
fn test_allowed_file_read() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
let platform = PlatformConfig::current();
|
let platform = PlatformConfig::current();
|
||||||
if !platform.supports_file_read {
|
if !platform.supports_file_read {
|
||||||
eprintln!("Skipping test: file read not supported on this platform");
|
eprintln!("Skipping test: file read not supported on this platform");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test file system
|
// Create test file system
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create profile allowing project path access
|
// Create profile allowing project path access
|
||||||
let operations = vec![
|
let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
|
||||||
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())),
|
test_fs.project_path.clone(),
|
||||||
];
|
))];
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -34,13 +34,13 @@ fn test_allowed_file_read() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that reads from allowed path
|
// Create test binary that reads from allowed path
|
||||||
let test_code = test_code::file_read(&test_fs.project_path.join("main.rs").to_string_lossy());
|
let test_code = test_code::file_read(&test_fs.project_path.join("main.rs").to_string_lossy());
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_file_read", &test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_file_read", &test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -63,21 +63,21 @@ fn test_allowed_file_read() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_forbidden_file_read() {
|
fn test_forbidden_file_read() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
let platform = PlatformConfig::current();
|
let platform = PlatformConfig::current();
|
||||||
if !platform.supports_file_read {
|
if !platform.supports_file_read {
|
||||||
eprintln!("Skipping test: file read not supported on this platform");
|
eprintln!("Skipping test: file read not supported on this platform");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test file system
|
// Create test file system
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create profile allowing only project path (not forbidden path)
|
// Create profile allowing only project path (not forbidden path)
|
||||||
let operations = vec![
|
let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
|
||||||
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())),
|
test_fs.project_path.clone(),
|
||||||
];
|
))];
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -85,14 +85,14 @@ fn test_forbidden_file_read() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that reads from forbidden path
|
// Create test binary that reads from forbidden path
|
||||||
let forbidden_file = test_fs.forbidden_path.join("secret.txt");
|
let forbidden_file = test_fs.forbidden_path.join("secret.txt");
|
||||||
let test_code = test_code::file_read(&forbidden_file.to_string_lossy());
|
let test_code = test_code::file_read(&forbidden_file.to_string_lossy());
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_forbidden_read", &test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_forbidden_read", &test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -105,7 +105,9 @@ fn test_forbidden_file_read() {
|
|||||||
// On some platforms (like macOS), gaol might not block all file reads
|
// On some platforms (like macOS), gaol might not block all file reads
|
||||||
// so we check if the operation failed OR if it's a platform limitation
|
// so we check if the operation failed OR if it's a platform limitation
|
||||||
if status.success() {
|
if status.success() {
|
||||||
eprintln!("WARNING: File read was not blocked - this might be a platform limitation");
|
eprintln!(
|
||||||
|
"WARNING: File read was not blocked - this might be a platform limitation"
|
||||||
|
);
|
||||||
// Check if we're on a platform where this is expected
|
// Check if we're on a platform where this is expected
|
||||||
let platform_config = PlatformConfig::current();
|
let platform_config = PlatformConfig::current();
|
||||||
if !platform_config.supports_file_read {
|
if !platform_config.supports_file_read {
|
||||||
@@ -124,15 +126,15 @@ fn test_forbidden_file_read() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_file_write_always_forbidden() {
|
fn test_file_write_always_forbidden() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test file system
|
// Create test file system
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create profile with file read permissions (write should still be blocked)
|
// Create profile with file read permissions (write should still be blocked)
|
||||||
let operations = vec![
|
let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
|
||||||
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())),
|
test_fs.project_path.clone(),
|
||||||
];
|
))];
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -140,14 +142,14 @@ fn test_file_write_always_forbidden() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that tries to write a file
|
// Create test binary that tries to write a file
|
||||||
let write_path = test_fs.project_path.join("test_write.txt");
|
let write_path = test_fs.project_path.join("test_write.txt");
|
||||||
let test_code = test_code::file_write(&write_path.to_string_lossy());
|
let test_code = test_code::file_write(&write_path.to_string_lossy());
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_file_write", &test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_file_write", &test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -177,28 +179,28 @@ fn test_file_write_always_forbidden() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_file_metadata_operations() {
|
fn test_file_metadata_operations() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
let platform = PlatformConfig::current();
|
let platform = PlatformConfig::current();
|
||||||
if !platform.supports_metadata_read && !platform.supports_file_read {
|
if !platform.supports_metadata_read && !platform.supports_file_read {
|
||||||
eprintln!("Skipping test: metadata read not supported on this platform");
|
eprintln!("Skipping test: metadata read not supported on this platform");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test file system
|
// Create test file system
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create profile with metadata read permission
|
// Create profile with metadata read permission
|
||||||
let operations = if platform.supports_metadata_read {
|
let operations = if platform.supports_metadata_read {
|
||||||
vec![
|
vec![Operation::FileReadMetadata(PathPattern::Subpath(
|
||||||
Operation::FileReadMetadata(PathPattern::Subpath(test_fs.project_path.clone())),
|
test_fs.project_path.clone(),
|
||||||
]
|
))]
|
||||||
} else {
|
} else {
|
||||||
// On Linux, metadata is allowed if file read is allowed
|
// On Linux, metadata is allowed if file read is allowed
|
||||||
vec![
|
vec![Operation::FileReadAll(PathPattern::Subpath(
|
||||||
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())),
|
test_fs.project_path.clone(),
|
||||||
]
|
))]
|
||||||
};
|
};
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -206,14 +208,14 @@ fn test_file_metadata_operations() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that reads file metadata
|
// Create test binary that reads file metadata
|
||||||
let test_file = test_fs.project_path.join("main.rs");
|
let test_file = test_fs.project_path.join("main.rs");
|
||||||
let test_code = test_code::file_metadata(&test_file.to_string_lossy());
|
let test_code = test_code::file_metadata(&test_file.to_string_lossy());
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_metadata", &test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_metadata", &test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -224,7 +226,10 @@ fn test_file_metadata_operations() {
|
|||||||
Ok(mut child) => {
|
Ok(mut child) => {
|
||||||
let status = child.wait().expect("Failed to wait for child");
|
let status = child.wait().expect("Failed to wait for child");
|
||||||
if platform.supports_metadata_read || platform.supports_file_read {
|
if platform.supports_metadata_read || platform.supports_file_read {
|
||||||
assert!(status.success(), "Metadata read should succeed when allowed");
|
assert!(
|
||||||
|
status.success(),
|
||||||
|
"Metadata read should succeed when allowed"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -238,33 +243,32 @@ fn test_file_metadata_operations() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_template_variable_expansion() {
|
fn test_template_variable_expansion() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
let platform = PlatformConfig::current();
|
let platform = PlatformConfig::current();
|
||||||
if !platform.supports_file_read {
|
if !platform.supports_file_read {
|
||||||
eprintln!("Skipping test: file read not supported on this platform");
|
eprintln!("Skipping test: file read not supported on this platform");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test database and profile
|
// Create test database and profile
|
||||||
let test_db = TEST_DB.lock();
|
let test_db = TEST_DB.lock();
|
||||||
test_db.reset().expect("Failed to reset database");
|
test_db.reset().expect("Failed to reset database");
|
||||||
|
|
||||||
// Create a profile with template variables
|
// Create a profile with template variables
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
let rules = vec![
|
let rules = vec![TestRule::file_read("{{PROJECT_PATH}}", true)];
|
||||||
TestRule::file_read("{{PROJECT_PATH}}", true),
|
|
||||||
];
|
let profile_id = test_db
|
||||||
|
.create_test_profile("template_test", rules)
|
||||||
let profile_id = test_db.create_test_profile("template_test", rules)
|
|
||||||
.expect("Failed to create test profile");
|
.expect("Failed to create test profile");
|
||||||
|
|
||||||
// Load and build the profile
|
// Load and build the profile
|
||||||
let db_rules = claudia_lib::sandbox::profile::load_profile_rules(&test_db.conn, profile_id)
|
let db_rules = claudia_lib::sandbox::profile::load_profile_rules(&test_db.conn, profile_id)
|
||||||
.expect("Failed to load profile rules");
|
.expect("Failed to load profile rules");
|
||||||
|
|
||||||
let builder = ProfileBuilder::new(test_fs.project_path.clone())
|
let builder = ProfileBuilder::new(test_fs.project_path.clone())
|
||||||
.expect("Failed to create profile builder");
|
.expect("Failed to create profile builder");
|
||||||
|
|
||||||
let profile = match builder.build_profile(db_rules) {
|
let profile = match builder.build_profile(db_rules) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -272,13 +276,13 @@ fn test_template_variable_expansion() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that reads from project path
|
// Create test binary that reads from project path
|
||||||
let test_code = test_code::file_read(&test_fs.project_path.join("main.rs").to_string_lossy());
|
let test_code = test_code::file_read(&test_fs.project_path.join("main.rs").to_string_lossy());
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_template", &test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_template", &test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -294,4 +298,4 @@ fn test_template_variable_expansion() {
|
|||||||
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ mod file_operations;
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod network_operations;
|
mod network_operations;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod system_info;
|
|
||||||
#[cfg(test)]
|
|
||||||
mod process_isolation;
|
mod process_isolation;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod violations;
|
mod system_info;
|
||||||
|
#[cfg(test)]
|
||||||
|
mod violations;
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
use crate::sandbox::common::*;
|
use crate::sandbox::common::*;
|
||||||
use crate::skip_if_unsupported;
|
use crate::skip_if_unsupported;
|
||||||
use claudia_lib::sandbox::executor::SandboxExecutor;
|
use claudia_lib::sandbox::executor::SandboxExecutor;
|
||||||
use gaol::profile::{Profile, Operation, AddressPattern};
|
use gaol::profile::{AddressPattern, Operation, Profile};
|
||||||
use serial_test::serial;
|
use serial_test::serial;
|
||||||
use std::net::TcpListener;
|
use std::net::TcpListener;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
@@ -10,7 +10,10 @@ use tempfile::TempDir;
|
|||||||
/// Get an available port for testing
|
/// Get an available port for testing
|
||||||
fn get_available_port() -> u16 {
|
fn get_available_port() -> u16 {
|
||||||
let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind to 0");
|
let listener = TcpListener::bind("127.0.0.1:0").expect("Failed to bind to 0");
|
||||||
let port = listener.local_addr().expect("Failed to get local addr").port();
|
let port = listener
|
||||||
|
.local_addr()
|
||||||
|
.expect("Failed to get local addr")
|
||||||
|
.port();
|
||||||
drop(listener); // Release the port
|
drop(listener); // Release the port
|
||||||
port
|
port
|
||||||
}
|
}
|
||||||
@@ -20,21 +23,19 @@ fn get_available_port() -> u16 {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_allowed_network_all() {
|
fn test_allowed_network_all() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
let platform = PlatformConfig::current();
|
let platform = PlatformConfig::current();
|
||||||
if !platform.supports_network_all {
|
if !platform.supports_network_all {
|
||||||
eprintln!("Skipping test: network all not supported on this platform");
|
eprintln!("Skipping test: network all not supported on this platform");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test project
|
// Create test project
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create profile allowing all network access
|
// Create profile allowing all network access
|
||||||
let operations = vec![
|
let operations = vec![Operation::NetworkOutbound(AddressPattern::All)];
|
||||||
Operation::NetworkOutbound(AddressPattern::All),
|
|
||||||
];
|
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -42,18 +43,18 @@ fn test_allowed_network_all() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that connects to localhost
|
// Create test binary that connects to localhost
|
||||||
let port = get_available_port();
|
let port = get_available_port();
|
||||||
let test_code = test_code::network_connect(&format!("127.0.0.1:{}", port));
|
let test_code = test_code::network_connect(&format!("127.0.0.1:{}", port));
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_network", &test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_network", &test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Start a listener on the port
|
// Start a listener on the port
|
||||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
|
let listener =
|
||||||
.expect("Failed to bind listener");
|
TcpListener::bind(format!("127.0.0.1:{}", port)).expect("Failed to bind listener");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -66,9 +67,12 @@ fn test_allowed_network_all() {
|
|||||||
std::thread::spawn(move || {
|
std::thread::spawn(move || {
|
||||||
let _ = listener.accept();
|
let _ = listener.accept();
|
||||||
});
|
});
|
||||||
|
|
||||||
let status = child.wait().expect("Failed to wait for child");
|
let status = child.wait().expect("Failed to wait for child");
|
||||||
assert!(status.success(), "Network connection should succeed when allowed");
|
assert!(
|
||||||
|
status.success(),
|
||||||
|
"Network connection should succeed when allowed"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
||||||
@@ -81,15 +85,15 @@ fn test_allowed_network_all() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_forbidden_network() {
|
fn test_forbidden_network() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test project
|
// Create test project
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create profile without network permissions
|
// Create profile without network permissions
|
||||||
let operations = vec![
|
let operations = vec![Operation::FileReadAll(gaol::profile::PathPattern::Subpath(
|
||||||
Operation::FileReadAll(gaol::profile::PathPattern::Subpath(test_fs.project_path.clone())),
|
test_fs.project_path.clone(),
|
||||||
];
|
))];
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -97,13 +101,13 @@ fn test_forbidden_network() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that tries to connect
|
// Create test binary that tries to connect
|
||||||
let test_code = test_code::network_connect("google.com:80");
|
let test_code = test_code::network_connect("google.com:80");
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_no_network", &test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_no_network", &test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -137,19 +141,19 @@ fn test_network_tcp_port_specific() {
|
|||||||
eprintln!("Skipping test: TCP port filtering not supported");
|
eprintln!("Skipping test: TCP port filtering not supported");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test project
|
// Create test project
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Get two ports - one allowed, one forbidden
|
// Get two ports - one allowed, one forbidden
|
||||||
let allowed_port = get_available_port();
|
let allowed_port = get_available_port();
|
||||||
let forbidden_port = get_available_port();
|
let forbidden_port = get_available_port();
|
||||||
|
|
||||||
// Create profile allowing only specific port
|
// Create profile allowing only specific port
|
||||||
let operations = vec![
|
let operations = vec![Operation::NetworkOutbound(AddressPattern::Tcp(
|
||||||
Operation::NetworkOutbound(AddressPattern::Tcp(allowed_port)),
|
allowed_port,
|
||||||
];
|
))];
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -157,17 +161,17 @@ fn test_network_tcp_port_specific() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test 1: Allowed port
|
// Test 1: Allowed port
|
||||||
{
|
{
|
||||||
let test_code = test_code::network_connect(&format!("127.0.0.1:{}", allowed_port));
|
let test_code = test_code::network_connect(&format!("127.0.0.1:{}", allowed_port));
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_allowed_port", &test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_allowed_port", &test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", allowed_port))
|
let listener = TcpListener::bind(format!("127.0.0.1:{}", allowed_port))
|
||||||
.expect("Failed to bind listener");
|
.expect("Failed to bind listener");
|
||||||
|
|
||||||
let executor = SandboxExecutor::new(profile.clone(), test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile.clone(), test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
&binary_path.to_string_lossy(),
|
&binary_path.to_string_lossy(),
|
||||||
@@ -178,23 +182,26 @@ fn test_network_tcp_port_specific() {
|
|||||||
std::thread::spawn(move || {
|
std::thread::spawn(move || {
|
||||||
let _ = listener.accept();
|
let _ = listener.accept();
|
||||||
});
|
});
|
||||||
|
|
||||||
let status = child.wait().expect("Failed to wait for child");
|
let status = child.wait().expect("Failed to wait for child");
|
||||||
assert!(status.success(), "Connection to allowed port should succeed");
|
assert!(
|
||||||
|
status.success(),
|
||||||
|
"Connection to allowed port should succeed"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test 2: Forbidden port
|
// Test 2: Forbidden port
|
||||||
{
|
{
|
||||||
let test_code = test_code::network_connect(&format!("127.0.0.1:{}", forbidden_port));
|
let test_code = test_code::network_connect(&format!("127.0.0.1:{}", forbidden_port));
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_forbidden_port", &test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_forbidden_port", &test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
&binary_path.to_string_lossy(),
|
&binary_path.to_string_lossy(),
|
||||||
@@ -203,7 +210,10 @@ fn test_network_tcp_port_specific() {
|
|||||||
) {
|
) {
|
||||||
Ok(mut child) => {
|
Ok(mut child) => {
|
||||||
let status = child.wait().expect("Failed to wait for child");
|
let status = child.wait().expect("Failed to wait for child");
|
||||||
assert!(!status.success(), "Connection to forbidden port should fail");
|
assert!(
|
||||||
|
!status.success(),
|
||||||
|
"Connection to forbidden port should fail"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
||||||
@@ -218,28 +228,26 @@ fn test_network_tcp_port_specific() {
|
|||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn test_local_socket_connections() {
|
fn test_local_socket_connections() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
let platform = PlatformConfig::current();
|
let platform = PlatformConfig::current();
|
||||||
|
|
||||||
// Create test project
|
// Create test project
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
let socket_path = test_fs.project_path.join("test.sock");
|
let socket_path = test_fs.project_path.join("test.sock");
|
||||||
|
|
||||||
// Create appropriate profile based on platform
|
// Create appropriate profile based on platform
|
||||||
let operations = if platform.supports_network_local {
|
let operations = if platform.supports_network_local {
|
||||||
vec![
|
vec![Operation::NetworkOutbound(AddressPattern::LocalSocket(
|
||||||
Operation::NetworkOutbound(AddressPattern::LocalSocket(socket_path.clone())),
|
socket_path.clone(),
|
||||||
]
|
))]
|
||||||
} else if platform.supports_network_all {
|
} else if platform.supports_network_all {
|
||||||
// Fallback to allowing all network
|
// Fallback to allowing all network
|
||||||
vec![
|
vec![Operation::NetworkOutbound(AddressPattern::All)]
|
||||||
Operation::NetworkOutbound(AddressPattern::All),
|
|
||||||
]
|
|
||||||
} else {
|
} else {
|
||||||
eprintln!("Skipping test: no network support on this platform");
|
eprintln!("Skipping test: no network support on this platform");
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -247,7 +255,7 @@ fn test_local_socket_connections() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that connects to local socket
|
// Create test binary that connects to local socket
|
||||||
let test_code = format!(
|
let test_code = format!(
|
||||||
r#"
|
r#"
|
||||||
@@ -267,15 +275,15 @@ fn main() {{
|
|||||||
"#,
|
"#,
|
||||||
socket_path.to_string_lossy()
|
socket_path.to_string_lossy()
|
||||||
);
|
);
|
||||||
|
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_local_socket", &test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_local_socket", &test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Create Unix socket listener
|
// Create Unix socket listener
|
||||||
use std::os::unix::net::UnixListener;
|
use std::os::unix::net::UnixListener;
|
||||||
let listener = UnixListener::bind(&socket_path).expect("Failed to bind Unix socket");
|
let listener = UnixListener::bind(&socket_path).expect("Failed to bind Unix socket");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -287,15 +295,18 @@ fn main() {{
|
|||||||
std::thread::spawn(move || {
|
std::thread::spawn(move || {
|
||||||
let _ = listener.accept();
|
let _ = listener.accept();
|
||||||
});
|
});
|
||||||
|
|
||||||
let status = child.wait().expect("Failed to wait for child");
|
let status = child.wait().expect("Failed to wait for child");
|
||||||
assert!(status.success(), "Local socket connection should succeed when allowed");
|
assert!(
|
||||||
|
status.success(),
|
||||||
|
"Local socket connection should succeed when allowed"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up socket file
|
// Clean up socket file
|
||||||
let _ = std::fs::remove_file(&socket_path);
|
let _ = std::fs::remove_file(&socket_path);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
use crate::sandbox::common::*;
|
use crate::sandbox::common::*;
|
||||||
use crate::skip_if_unsupported;
|
use crate::skip_if_unsupported;
|
||||||
use claudia_lib::sandbox::executor::SandboxExecutor;
|
use claudia_lib::sandbox::executor::SandboxExecutor;
|
||||||
use gaol::profile::{Profile, Operation, PathPattern, AddressPattern};
|
use gaol::profile::{AddressPattern, Operation, PathPattern, Profile};
|
||||||
use serial_test::serial;
|
use serial_test::serial;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
@@ -11,16 +11,16 @@ use tempfile::TempDir;
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_process_spawn_forbidden() {
|
fn test_process_spawn_forbidden() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test project
|
// Create test project
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create profile with various permissions (process spawn should still be blocked)
|
// Create profile with various permissions (process spawn should still be blocked)
|
||||||
let operations = vec![
|
let operations = vec![
|
||||||
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())),
|
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())),
|
||||||
Operation::NetworkOutbound(AddressPattern::All),
|
Operation::NetworkOutbound(AddressPattern::All),
|
||||||
];
|
];
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -28,13 +28,13 @@ fn test_process_spawn_forbidden() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that tries to spawn a process
|
// Create test binary that tries to spawn a process
|
||||||
let test_code = test_code::spawn_process();
|
let test_code = test_code::spawn_process();
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_spawn", test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_spawn", test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -49,7 +49,10 @@ fn test_process_spawn_forbidden() {
|
|||||||
eprintln!("WARNING: Process spawning was not blocked");
|
eprintln!("WARNING: Process spawning was not blocked");
|
||||||
// macOS sandbox might have limitations
|
// macOS sandbox might have limitations
|
||||||
if std::env::consts::OS != "linux" {
|
if std::env::consts::OS != "linux" {
|
||||||
eprintln!("Process spawning might not be fully blocked on {}", std::env::consts::OS);
|
eprintln!(
|
||||||
|
"Process spawning might not be fully blocked on {}",
|
||||||
|
std::env::consts::OS
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
panic!("Process spawning should be blocked on Linux");
|
panic!("Process spawning should be blocked on Linux");
|
||||||
}
|
}
|
||||||
@@ -67,15 +70,15 @@ fn test_process_spawn_forbidden() {
|
|||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn test_fork_forbidden() {
|
fn test_fork_forbidden() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test project
|
// Create test project
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create minimal profile
|
// Create minimal profile
|
||||||
let operations = vec![
|
let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
|
||||||
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())),
|
test_fs.project_path.clone(),
|
||||||
];
|
))];
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -83,14 +86,19 @@ fn test_fork_forbidden() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that tries to fork
|
// Create test binary that tries to fork
|
||||||
let test_code = test_code::fork_process();
|
let test_code = test_code::fork_process();
|
||||||
|
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary_with_deps("test_fork", test_code, binary_dir.path(), &[("libc", "0.2")])
|
let binary_path = create_test_binary_with_deps(
|
||||||
.expect("Failed to create test binary");
|
"test_fork",
|
||||||
|
test_code,
|
||||||
|
binary_dir.path(),
|
||||||
|
&[("libc", "0.2")],
|
||||||
|
)
|
||||||
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -120,15 +128,15 @@ fn test_fork_forbidden() {
|
|||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn test_exec_forbidden() {
|
fn test_exec_forbidden() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test project
|
// Create test project
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create minimal profile
|
// Create minimal profile
|
||||||
let operations = vec![
|
let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
|
||||||
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())),
|
test_fs.project_path.clone(),
|
||||||
];
|
))];
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -136,14 +144,19 @@ fn test_exec_forbidden() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that tries to exec
|
// Create test binary that tries to exec
|
||||||
let test_code = test_code::exec_process();
|
let test_code = test_code::exec_process();
|
||||||
|
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary_with_deps("test_exec", test_code, binary_dir.path(), &[("libc", "0.2")])
|
let binary_path = create_test_binary_with_deps(
|
||||||
.expect("Failed to create test binary");
|
"test_exec",
|
||||||
|
test_code,
|
||||||
|
binary_dir.path(),
|
||||||
|
&[("libc", "0.2")],
|
||||||
|
)
|
||||||
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -172,15 +185,15 @@ fn test_exec_forbidden() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_thread_creation_allowed() {
|
fn test_thread_creation_allowed() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test project
|
// Create test project
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create minimal profile
|
// Create minimal profile
|
||||||
let operations = vec![
|
let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
|
||||||
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())),
|
test_fs.project_path.clone(),
|
||||||
];
|
))];
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -188,7 +201,7 @@ fn test_thread_creation_allowed() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that creates threads
|
// Create test binary that creates threads
|
||||||
let test_code = r#"
|
let test_code = r#"
|
||||||
use std::thread;
|
use std::thread;
|
||||||
@@ -211,11 +224,11 @@ fn main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_thread", test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_thread", test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -231,4 +244,4 @@ fn main() {
|
|||||||
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
use crate::sandbox::common::*;
|
use crate::sandbox::common::*;
|
||||||
use crate::skip_if_unsupported;
|
use crate::skip_if_unsupported;
|
||||||
use claudia_lib::sandbox::executor::SandboxExecutor;
|
use claudia_lib::sandbox::executor::SandboxExecutor;
|
||||||
use gaol::profile::{Profile, Operation};
|
use gaol::profile::{Operation, Profile};
|
||||||
use serial_test::serial;
|
use serial_test::serial;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
@@ -11,21 +11,19 @@ use tempfile::TempDir;
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_system_info_read() {
|
fn test_system_info_read() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
let platform = PlatformConfig::current();
|
let platform = PlatformConfig::current();
|
||||||
if !platform.supports_system_info {
|
if !platform.supports_system_info {
|
||||||
eprintln!("Skipping test: system info read not supported on this platform");
|
eprintln!("Skipping test: system info read not supported on this platform");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test project
|
// Create test project
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create profile allowing system info read
|
// Create profile allowing system info read
|
||||||
let operations = vec![
|
let operations = vec![Operation::SystemInfoRead];
|
||||||
Operation::SystemInfoRead,
|
|
||||||
];
|
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -33,13 +31,13 @@ fn test_system_info_read() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that reads system info
|
// Create test binary that reads system info
|
||||||
let test_code = test_code::system_info();
|
let test_code = test_code::system_info();
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_sysinfo", test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_sysinfo", test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -49,7 +47,10 @@ fn test_system_info_read() {
|
|||||||
) {
|
) {
|
||||||
Ok(mut child) => {
|
Ok(mut child) => {
|
||||||
let status = child.wait().expect("Failed to wait for child");
|
let status = child.wait().expect("Failed to wait for child");
|
||||||
assert!(status.success(), "System info read should succeed when allowed");
|
assert!(
|
||||||
|
status.success(),
|
||||||
|
"System info read should succeed when allowed"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
||||||
@@ -64,12 +65,12 @@ fn test_system_info_read() {
|
|||||||
fn test_forbidden_system_info() {
|
fn test_forbidden_system_info() {
|
||||||
// Create test project
|
// Create test project
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create profile without system info permission
|
// Create profile without system info permission
|
||||||
let operations = vec![
|
let operations = vec![Operation::FileReadAll(gaol::profile::PathPattern::Subpath(
|
||||||
Operation::FileReadAll(gaol::profile::PathPattern::Subpath(test_fs.project_path.clone())),
|
test_fs.project_path.clone(),
|
||||||
];
|
))];
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -77,13 +78,13 @@ fn test_forbidden_system_info() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that reads system info
|
// Create test binary that reads system info
|
||||||
let test_code = test_code::system_info();
|
let test_code = test_code::system_info();
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_no_sysinfo", test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_no_sysinfo", test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -118,27 +119,33 @@ fn test_forbidden_system_info() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_platform_specific_system_info() {
|
fn test_platform_specific_system_info() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
let platform = PlatformConfig::current();
|
let platform = PlatformConfig::current();
|
||||||
|
|
||||||
match std::env::consts::OS {
|
match std::env::consts::OS {
|
||||||
"linux" => {
|
"linux" => {
|
||||||
// On Linux, system info is never allowed
|
// On Linux, system info is never allowed
|
||||||
assert!(!platform.supports_system_info,
|
assert!(
|
||||||
"Linux should not support system info read");
|
!platform.supports_system_info,
|
||||||
|
"Linux should not support system info read"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
"macos" => {
|
"macos" => {
|
||||||
// On macOS, system info can be allowed
|
// On macOS, system info can be allowed
|
||||||
assert!(platform.supports_system_info,
|
assert!(
|
||||||
"macOS should support system info read");
|
platform.supports_system_info,
|
||||||
|
"macOS should support system info read"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
"freebsd" => {
|
"freebsd" => {
|
||||||
// On FreeBSD, system info is always allowed (can't be restricted)
|
// On FreeBSD, system info is always allowed (can't be restricted)
|
||||||
assert!(platform.supports_system_info,
|
assert!(
|
||||||
"FreeBSD always allows system info read");
|
platform.supports_system_info,
|
||||||
|
"FreeBSD always allows system info read"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
eprintln!("Unknown platform behavior for system info");
|
eprintln!("Unknown platform behavior for system info");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
use crate::sandbox::common::*;
|
use crate::sandbox::common::*;
|
||||||
use crate::skip_if_unsupported;
|
use crate::skip_if_unsupported;
|
||||||
use claudia_lib::sandbox::executor::SandboxExecutor;
|
use claudia_lib::sandbox::executor::SandboxExecutor;
|
||||||
use gaol::profile::{Profile, Operation, PathPattern};
|
use gaol::profile::{Operation, PathPattern, Profile};
|
||||||
use serial_test::serial;
|
use serial_test::serial;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
@@ -27,19 +27,19 @@ impl ViolationCollector {
|
|||||||
violations: Arc::new(Mutex::new(Vec::new())),
|
violations: Arc::new(Mutex::new(Vec::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn record(&self, operation_type: &str, pattern_value: Option<&str>, process_name: &str) {
|
fn record(&self, operation_type: &str, pattern_value: Option<&str>, process_name: &str) {
|
||||||
let event = ViolationEvent {
|
let event = ViolationEvent {
|
||||||
operation_type: operation_type.to_string(),
|
operation_type: operation_type.to_string(),
|
||||||
pattern_value: pattern_value.map(|s| s.to_string()),
|
pattern_value: pattern_value.map(|s| s.to_string()),
|
||||||
process_name: process_name.to_string(),
|
process_name: process_name.to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Ok(mut violations) = self.violations.lock() {
|
if let Ok(mut violations) = self.violations.lock() {
|
||||||
violations.push(event);
|
violations.push(event);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_violations(&self) -> Vec<ViolationEvent> {
|
fn get_violations(&self) -> Vec<ViolationEvent> {
|
||||||
self.violations.lock().unwrap().clone()
|
self.violations.lock().unwrap().clone()
|
||||||
}
|
}
|
||||||
@@ -50,22 +50,22 @@ impl ViolationCollector {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_violation_detection() {
|
fn test_violation_detection() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
let platform = PlatformConfig::current();
|
let platform = PlatformConfig::current();
|
||||||
if !platform.supports_file_read {
|
if !platform.supports_file_read {
|
||||||
eprintln!("Skipping test: file read not supported on this platform");
|
eprintln!("Skipping test: file read not supported on this platform");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test file system
|
// Create test file system
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
let collector = ViolationCollector::new();
|
let collector = ViolationCollector::new();
|
||||||
|
|
||||||
// Create profile allowing only project path
|
// Create profile allowing only project path
|
||||||
let operations = vec![
|
let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
|
||||||
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())),
|
test_fs.project_path.clone(),
|
||||||
];
|
))];
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -73,19 +73,31 @@ fn test_violation_detection() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test various forbidden operations
|
// Test various forbidden operations
|
||||||
let test_cases = vec![
|
let test_cases = vec![
|
||||||
("file_read", test_code::file_read(&test_fs.forbidden_path.join("secret.txt").to_string_lossy()), "file_read_forbidden"),
|
(
|
||||||
("file_write", test_code::file_write(&test_fs.project_path.join("new.txt").to_string_lossy()), "file_write_forbidden"),
|
"file_read",
|
||||||
("process_spawn", test_code::spawn_process().to_string(), "process_spawn_forbidden"),
|
test_code::file_read(&test_fs.forbidden_path.join("secret.txt").to_string_lossy()),
|
||||||
|
"file_read_forbidden",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"file_write",
|
||||||
|
test_code::file_write(&test_fs.project_path.join("new.txt").to_string_lossy()),
|
||||||
|
"file_write_forbidden",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"process_spawn",
|
||||||
|
test_code::spawn_process().to_string(),
|
||||||
|
"process_spawn_forbidden",
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
for (op_type, test_code, binary_name) in test_cases {
|
for (op_type, test_code, binary_name) in test_cases {
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary(binary_name, &test_code, binary_dir.path())
|
let binary_path = create_test_binary(binary_name, &test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
let executor = SandboxExecutor::new(profile.clone(), test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile.clone(), test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
&binary_path.to_string_lossy(),
|
&binary_path.to_string_lossy(),
|
||||||
@@ -104,7 +116,7 @@ fn test_violation_detection() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify violations were detected
|
// Verify violations were detected
|
||||||
let violations = collector.get_violations();
|
let violations = collector.get_violations();
|
||||||
// On some platforms (like macOS), sandbox might not block all operations
|
// On some platforms (like macOS), sandbox might not block all operations
|
||||||
@@ -122,25 +134,25 @@ fn test_violation_detection() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_violation_patterns() {
|
fn test_violation_patterns() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
let platform = PlatformConfig::current();
|
let platform = PlatformConfig::current();
|
||||||
if !platform.supports_file_read {
|
if !platform.supports_file_read {
|
||||||
eprintln!("Skipping test: file read not supported on this platform");
|
eprintln!("Skipping test: file read not supported on this platform");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test file system
|
// Create test file system
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create profile with specific allowed paths
|
// Create profile with specific allowed paths
|
||||||
let allowed_dir = test_fs.root.path().join("allowed_specific");
|
let allowed_dir = test_fs.root.path().join("allowed_specific");
|
||||||
std::fs::create_dir_all(&allowed_dir).expect("Failed to create allowed dir");
|
std::fs::create_dir_all(&allowed_dir).expect("Failed to create allowed dir");
|
||||||
|
|
||||||
let operations = vec![
|
let operations = vec![
|
||||||
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())),
|
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())),
|
||||||
Operation::FileReadAll(PathPattern::Literal(allowed_dir.join("file.txt"))),
|
Operation::FileReadAll(PathPattern::Literal(allowed_dir.join("file.txt"))),
|
||||||
];
|
];
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -148,21 +160,25 @@ fn test_violation_patterns() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test accessing different forbidden paths
|
// Test accessing different forbidden paths
|
||||||
let forbidden_db_path = test_fs.forbidden_path.join("data.db").to_string_lossy().to_string();
|
let forbidden_db_path = test_fs
|
||||||
|
.forbidden_path
|
||||||
|
.join("data.db")
|
||||||
|
.to_string_lossy()
|
||||||
|
.to_string();
|
||||||
let forbidden_paths = vec![
|
let forbidden_paths = vec![
|
||||||
("/etc/passwd", "system_file"),
|
("/etc/passwd", "system_file"),
|
||||||
("/tmp/test.txt", "temp_file"),
|
("/tmp/test.txt", "temp_file"),
|
||||||
(forbidden_db_path.as_str(), "forbidden_db"),
|
(forbidden_db_path.as_str(), "forbidden_db"),
|
||||||
];
|
];
|
||||||
|
|
||||||
for (path, test_name) in forbidden_paths {
|
for (path, test_name) in forbidden_paths {
|
||||||
let test_code = test_code::file_read(path);
|
let test_code = test_code::file_read(path);
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary(test_name, &test_code, binary_dir.path())
|
let binary_path = create_test_binary(test_name, &test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
let executor = SandboxExecutor::new(profile.clone(), test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile.clone(), test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
&binary_path.to_string_lossy(),
|
&binary_path.to_string_lossy(),
|
||||||
@@ -173,7 +189,10 @@ fn test_violation_patterns() {
|
|||||||
let status = child.wait().expect("Failed to wait for child");
|
let status = child.wait().expect("Failed to wait for child");
|
||||||
// Some platforms might not block all file access
|
// Some platforms might not block all file access
|
||||||
if status.success() {
|
if status.success() {
|
||||||
eprintln!("WARNING: Access to {} was allowed (possible platform limitation)", path);
|
eprintln!(
|
||||||
|
"WARNING: Access to {} was allowed (possible platform limitation)",
|
||||||
|
path
|
||||||
|
);
|
||||||
if std::env::consts::OS == "linux" && path.starts_with("/etc") {
|
if std::env::consts::OS == "linux" && path.starts_with("/etc") {
|
||||||
panic!("Access to {} should be denied on Linux", path);
|
panic!("Access to {} should be denied on Linux", path);
|
||||||
}
|
}
|
||||||
@@ -191,15 +210,15 @@ fn test_violation_patterns() {
|
|||||||
#[serial]
|
#[serial]
|
||||||
fn test_multiple_violations_sequence() {
|
fn test_multiple_violations_sequence() {
|
||||||
skip_if_unsupported!();
|
skip_if_unsupported!();
|
||||||
|
|
||||||
// Create test file system
|
// Create test file system
|
||||||
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
let test_fs = TestFileSystem::new().expect("Failed to create test filesystem");
|
||||||
|
|
||||||
// Create minimal profile
|
// Create minimal profile
|
||||||
let operations = vec![
|
let operations = vec![Operation::FileReadAll(PathPattern::Subpath(
|
||||||
Operation::FileReadAll(PathPattern::Subpath(test_fs.project_path.clone())),
|
test_fs.project_path.clone(),
|
||||||
];
|
))];
|
||||||
|
|
||||||
let profile = match Profile::new(operations) {
|
let profile = match Profile::new(operations) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
@@ -207,7 +226,7 @@ fn test_multiple_violations_sequence() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create test binary that attempts multiple forbidden operations
|
// Create test binary that attempts multiple forbidden operations
|
||||||
let test_code = r#"
|
let test_code = r#"
|
||||||
use std::fs;
|
use std::fs;
|
||||||
@@ -249,11 +268,11 @@ fn main() {{
|
|||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
let binary_dir = TempDir::new().expect("Failed to create temp dir");
|
||||||
let binary_path = create_test_binary("test_multi_violations", test_code, binary_dir.path())
|
let binary_path = create_test_binary("test_multi_violations", test_code, binary_dir.path())
|
||||||
.expect("Failed to create test binary");
|
.expect("Failed to create test binary");
|
||||||
|
|
||||||
// Execute in sandbox
|
// Execute in sandbox
|
||||||
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
let executor = SandboxExecutor::new(profile, test_fs.project_path.clone());
|
||||||
match executor.execute_sandboxed_spawn(
|
match executor.execute_sandboxed_spawn(
|
||||||
@@ -275,4 +294,4 @@ fn main() {{
|
|||||||
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
eprintln!("Sandbox execution failed: {} (may be expected in CI)", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
//! Comprehensive test suite for sandbox functionality
|
//! Comprehensive test suite for sandbox functionality
|
||||||
//!
|
//!
|
||||||
//! This test suite validates the sandboxing capabilities across different platforms,
|
//! This test suite validates the sandboxing capabilities across different platforms,
|
||||||
//! ensuring that security policies are correctly enforced.
|
//! ensuring that security policies are correctly enforced.
|
||||||
|
|
||||||
@@ -14,4 +14,4 @@ pub mod unit;
|
|||||||
pub mod integration;
|
pub mod integration;
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
pub mod e2e;
|
pub mod e2e;
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
//! Unit tests for SandboxExecutor
|
//! Unit tests for SandboxExecutor
|
||||||
use claudia_lib::sandbox::executor::{SandboxExecutor, should_activate_sandbox};
|
use claudia_lib::sandbox::executor::{should_activate_sandbox, SandboxExecutor};
|
||||||
use gaol::profile::{Profile, Operation, PathPattern, AddressPattern};
|
use gaol::profile::{AddressPattern, Operation, PathPattern, Profile};
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
@@ -10,7 +10,7 @@ fn create_test_profile(project_path: PathBuf) -> Profile {
|
|||||||
Operation::FileReadAll(PathPattern::Subpath(project_path)),
|
Operation::FileReadAll(PathPattern::Subpath(project_path)),
|
||||||
Operation::NetworkOutbound(AddressPattern::All),
|
Operation::NetworkOutbound(AddressPattern::All),
|
||||||
];
|
];
|
||||||
|
|
||||||
Profile::new(operations).expect("Failed to create test profile")
|
Profile::new(operations).expect("Failed to create test profile")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -18,7 +18,7 @@ fn create_test_profile(project_path: PathBuf) -> Profile {
|
|||||||
fn test_executor_creation() {
|
fn test_executor_creation() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let profile = create_test_profile(project_path.clone());
|
let profile = create_test_profile(project_path.clone());
|
||||||
|
|
||||||
let _executor = SandboxExecutor::new(profile, project_path);
|
let _executor = SandboxExecutor::new(profile, project_path);
|
||||||
// Executor should be created successfully
|
// Executor should be created successfully
|
||||||
}
|
}
|
||||||
@@ -27,16 +27,25 @@ fn test_executor_creation() {
|
|||||||
fn test_should_activate_sandbox_env_var() {
|
fn test_should_activate_sandbox_env_var() {
|
||||||
// Test when env var is not set
|
// Test when env var is not set
|
||||||
env::remove_var("GAOL_SANDBOX_ACTIVE");
|
env::remove_var("GAOL_SANDBOX_ACTIVE");
|
||||||
assert!(!should_activate_sandbox(), "Should not activate when env var is not set");
|
assert!(
|
||||||
|
!should_activate_sandbox(),
|
||||||
|
"Should not activate when env var is not set"
|
||||||
|
);
|
||||||
|
|
||||||
// Test when env var is set to "1"
|
// Test when env var is set to "1"
|
||||||
env::set_var("GAOL_SANDBOX_ACTIVE", "1");
|
env::set_var("GAOL_SANDBOX_ACTIVE", "1");
|
||||||
assert!(should_activate_sandbox(), "Should activate when env var is '1'");
|
assert!(
|
||||||
|
should_activate_sandbox(),
|
||||||
|
"Should activate when env var is '1'"
|
||||||
|
);
|
||||||
|
|
||||||
// Test when env var is set to other value
|
// Test when env var is set to other value
|
||||||
env::set_var("GAOL_SANDBOX_ACTIVE", "0");
|
env::set_var("GAOL_SANDBOX_ACTIVE", "0");
|
||||||
assert!(!should_activate_sandbox(), "Should not activate when env var is not '1'");
|
assert!(
|
||||||
|
!should_activate_sandbox(),
|
||||||
|
"Should not activate when env var is not '1'"
|
||||||
|
);
|
||||||
|
|
||||||
// Clean up
|
// Clean up
|
||||||
env::remove_var("GAOL_SANDBOX_ACTIVE");
|
env::remove_var("GAOL_SANDBOX_ACTIVE");
|
||||||
}
|
}
|
||||||
@@ -46,9 +55,9 @@ fn test_prepare_sandboxed_command() {
|
|||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let profile = create_test_profile(project_path.clone());
|
let profile = create_test_profile(project_path.clone());
|
||||||
let executor = SandboxExecutor::new(profile, project_path.clone());
|
let executor = SandboxExecutor::new(profile, project_path.clone());
|
||||||
|
|
||||||
let _cmd = executor.prepare_sandboxed_command("echo", &["hello"], &project_path);
|
let _cmd = executor.prepare_sandboxed_command("echo", &["hello"], &project_path);
|
||||||
|
|
||||||
// The command should have sandbox environment variables set
|
// The command should have sandbox environment variables set
|
||||||
// Note: We can't easily test Command internals, but we can verify it doesn't panic
|
// Note: We can't easily test Command internals, but we can verify it doesn't panic
|
||||||
}
|
}
|
||||||
@@ -57,10 +66,10 @@ fn test_prepare_sandboxed_command() {
|
|||||||
fn test_executor_with_empty_profile() {
|
fn test_executor_with_empty_profile() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let profile = Profile::new(vec![]).expect("Failed to create empty profile");
|
let profile = Profile::new(vec![]).expect("Failed to create empty profile");
|
||||||
|
|
||||||
let executor = SandboxExecutor::new(profile, project_path.clone());
|
let executor = SandboxExecutor::new(profile, project_path.clone());
|
||||||
let _cmd = executor.prepare_sandboxed_command("echo", &["test"], &project_path);
|
let _cmd = executor.prepare_sandboxed_command("echo", &["test"], &project_path);
|
||||||
|
|
||||||
// Should handle empty profile gracefully
|
// Should handle empty profile gracefully
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,15 +85,16 @@ fn test_executor_with_complex_profile() {
|
|||||||
Operation::NetworkOutbound(AddressPattern::Tcp(443)),
|
Operation::NetworkOutbound(AddressPattern::Tcp(443)),
|
||||||
Operation::SystemInfoRead,
|
Operation::SystemInfoRead,
|
||||||
];
|
];
|
||||||
|
|
||||||
// Only create profile with supported operations
|
// Only create profile with supported operations
|
||||||
let filtered_ops: Vec<_> = operations.into_iter()
|
let filtered_ops: Vec<_> = operations
|
||||||
|
.into_iter()
|
||||||
.filter(|op| {
|
.filter(|op| {
|
||||||
use gaol::profile::{OperationSupport, OperationSupportLevel};
|
use gaol::profile::{OperationSupport, OperationSupportLevel};
|
||||||
matches!(op.support(), OperationSupportLevel::CanBeAllowed)
|
matches!(op.support(), OperationSupportLevel::CanBeAllowed)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if !filtered_ops.is_empty() {
|
if !filtered_ops.is_empty() {
|
||||||
let profile = Profile::new(filtered_ops).expect("Failed to create complex profile");
|
let profile = Profile::new(filtered_ops).expect("Failed to create complex profile");
|
||||||
let executor = SandboxExecutor::new(profile, project_path.clone());
|
let executor = SandboxExecutor::new(profile, project_path.clone());
|
||||||
@@ -97,12 +107,12 @@ fn test_command_environment_setup() {
|
|||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let profile = create_test_profile(project_path.clone());
|
let profile = create_test_profile(project_path.clone());
|
||||||
let executor = SandboxExecutor::new(profile, project_path.clone());
|
let executor = SandboxExecutor::new(profile, project_path.clone());
|
||||||
|
|
||||||
// Test with various arguments
|
// Test with various arguments
|
||||||
let _cmd1 = executor.prepare_sandboxed_command("ls", &[], &project_path);
|
let _cmd1 = executor.prepare_sandboxed_command("ls", &[], &project_path);
|
||||||
let _cmd2 = executor.prepare_sandboxed_command("cat", &["file.txt"], &project_path);
|
let _cmd2 = executor.prepare_sandboxed_command("cat", &["file.txt"], &project_path);
|
||||||
let _cmd3 = executor.prepare_sandboxed_command("grep", &["-r", "pattern", "."], &project_path);
|
let _cmd3 = executor.prepare_sandboxed_command("grep", &["-r", "pattern", "."], &project_path);
|
||||||
|
|
||||||
// Commands should be prepared without panic
|
// Commands should be prepared without panic
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,18 +120,18 @@ fn test_command_environment_setup() {
|
|||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
fn test_spawn_sandboxed_process() {
|
fn test_spawn_sandboxed_process() {
|
||||||
use crate::sandbox::common::is_sandboxing_supported;
|
use crate::sandbox::common::is_sandboxing_supported;
|
||||||
|
|
||||||
if !is_sandboxing_supported() {
|
if !is_sandboxing_supported() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let project_path = env::current_dir().unwrap_or_else(|_| PathBuf::from("/tmp"));
|
let project_path = env::current_dir().unwrap_or_else(|_| PathBuf::from("/tmp"));
|
||||||
let profile = create_test_profile(project_path.clone());
|
let profile = create_test_profile(project_path.clone());
|
||||||
let executor = SandboxExecutor::new(profile, project_path.clone());
|
let executor = SandboxExecutor::new(profile, project_path.clone());
|
||||||
|
|
||||||
// Try to spawn a simple command
|
// Try to spawn a simple command
|
||||||
let result = executor.execute_sandboxed_spawn("echo", &["sandbox test"], &project_path);
|
let result = executor.execute_sandboxed_spawn("echo", &["sandbox test"], &project_path);
|
||||||
|
|
||||||
// On supported platforms, this should either succeed or fail gracefully
|
// On supported platforms, this should either succeed or fail gracefully
|
||||||
match result {
|
match result {
|
||||||
Ok(mut child) => {
|
Ok(mut child) => {
|
||||||
@@ -133,4 +143,4 @@ fn test_spawn_sandboxed_process() {
|
|||||||
println!("Sandbox spawn failed (expected in some environments): {e}");
|
println!("Sandbox spawn failed (expected in some environments): {e}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
//! Unit tests for sandbox components
|
//! Unit tests for sandbox components
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod profile_builder;
|
mod executor;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod platform;
|
mod platform;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod executor;
|
mod profile_builder;
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
//! Unit tests for platform capabilities
|
//! Unit tests for platform capabilities
|
||||||
use claudia_lib::sandbox::platform::{get_platform_capabilities, is_sandboxing_available};
|
use claudia_lib::sandbox::platform::{get_platform_capabilities, is_sandboxing_available};
|
||||||
use std::env;
|
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
|
use std::env;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_sandboxing_availability() {
|
fn test_sandboxing_availability() {
|
||||||
let is_available = is_sandboxing_available();
|
let is_available = is_sandboxing_available();
|
||||||
let expected = matches!(env::consts::OS, "linux" | "macos" | "freebsd");
|
let expected = matches!(env::consts::OS, "linux" | "macos" | "freebsd");
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
is_available, expected,
|
is_available, expected,
|
||||||
"Sandboxing availability should match platform support"
|
"Sandboxing availability should match platform support"
|
||||||
@@ -17,44 +17,59 @@ fn test_sandboxing_availability() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_platform_capabilities_structure() {
|
fn test_platform_capabilities_structure() {
|
||||||
let caps = get_platform_capabilities();
|
let caps = get_platform_capabilities();
|
||||||
|
|
||||||
// Verify basic structure
|
// Verify basic structure
|
||||||
assert_eq!(caps.os, env::consts::OS, "OS should match current platform");
|
assert_eq!(caps.os, env::consts::OS, "OS should match current platform");
|
||||||
assert!(!caps.operations.is_empty() || !caps.sandboxing_supported,
|
assert!(
|
||||||
"Should have operations if sandboxing is supported");
|
!caps.operations.is_empty() || !caps.sandboxing_supported,
|
||||||
assert!(!caps.notes.is_empty(), "Should have platform-specific notes");
|
"Should have operations if sandboxing is supported"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!caps.notes.is_empty(),
|
||||||
|
"Should have platform-specific notes"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
fn test_linux_capabilities() {
|
fn test_linux_capabilities() {
|
||||||
let caps = get_platform_capabilities();
|
let caps = get_platform_capabilities();
|
||||||
|
|
||||||
assert_eq!(caps.os, "linux");
|
assert_eq!(caps.os, "linux");
|
||||||
assert!(caps.sandboxing_supported);
|
assert!(caps.sandboxing_supported);
|
||||||
|
|
||||||
// Verify Linux-specific capabilities
|
// Verify Linux-specific capabilities
|
||||||
let file_read = caps.operations.iter()
|
let file_read = caps
|
||||||
|
.operations
|
||||||
|
.iter()
|
||||||
.find(|op| op.operation == "file_read_all")
|
.find(|op| op.operation == "file_read_all")
|
||||||
.expect("file_read_all should be present");
|
.expect("file_read_all should be present");
|
||||||
assert_eq!(file_read.support_level, "can_be_allowed");
|
assert_eq!(file_read.support_level, "can_be_allowed");
|
||||||
|
|
||||||
let metadata_read = caps.operations.iter()
|
let metadata_read = caps
|
||||||
|
.operations
|
||||||
|
.iter()
|
||||||
.find(|op| op.operation == "file_read_metadata")
|
.find(|op| op.operation == "file_read_metadata")
|
||||||
.expect("file_read_metadata should be present");
|
.expect("file_read_metadata should be present");
|
||||||
assert_eq!(metadata_read.support_level, "cannot_be_precisely");
|
assert_eq!(metadata_read.support_level, "cannot_be_precisely");
|
||||||
|
|
||||||
let network_all = caps.operations.iter()
|
let network_all = caps
|
||||||
|
.operations
|
||||||
|
.iter()
|
||||||
.find(|op| op.operation == "network_outbound_all")
|
.find(|op| op.operation == "network_outbound_all")
|
||||||
.expect("network_outbound_all should be present");
|
.expect("network_outbound_all should be present");
|
||||||
assert_eq!(network_all.support_level, "can_be_allowed");
|
assert_eq!(network_all.support_level, "can_be_allowed");
|
||||||
|
|
||||||
let network_tcp = caps.operations.iter()
|
let network_tcp = caps
|
||||||
|
.operations
|
||||||
|
.iter()
|
||||||
.find(|op| op.operation == "network_outbound_tcp")
|
.find(|op| op.operation == "network_outbound_tcp")
|
||||||
.expect("network_outbound_tcp should be present");
|
.expect("network_outbound_tcp should be present");
|
||||||
assert_eq!(network_tcp.support_level, "cannot_be_precisely");
|
assert_eq!(network_tcp.support_level, "cannot_be_precisely");
|
||||||
|
|
||||||
let system_info = caps.operations.iter()
|
let system_info = caps
|
||||||
|
.operations
|
||||||
|
.iter()
|
||||||
.find(|op| op.operation == "system_info_read")
|
.find(|op| op.operation == "system_info_read")
|
||||||
.expect("system_info_read should be present");
|
.expect("system_info_read should be present");
|
||||||
assert_eq!(system_info.support_level, "never");
|
assert_eq!(system_info.support_level, "never");
|
||||||
@@ -64,27 +79,35 @@ fn test_linux_capabilities() {
|
|||||||
#[cfg(target_os = "macos")]
|
#[cfg(target_os = "macos")]
|
||||||
fn test_macos_capabilities() {
|
fn test_macos_capabilities() {
|
||||||
let caps = get_platform_capabilities();
|
let caps = get_platform_capabilities();
|
||||||
|
|
||||||
assert_eq!(caps.os, "macos");
|
assert_eq!(caps.os, "macos");
|
||||||
assert!(caps.sandboxing_supported);
|
assert!(caps.sandboxing_supported);
|
||||||
|
|
||||||
// Verify macOS-specific capabilities
|
// Verify macOS-specific capabilities
|
||||||
let file_read = caps.operations.iter()
|
let file_read = caps
|
||||||
|
.operations
|
||||||
|
.iter()
|
||||||
.find(|op| op.operation == "file_read_all")
|
.find(|op| op.operation == "file_read_all")
|
||||||
.expect("file_read_all should be present");
|
.expect("file_read_all should be present");
|
||||||
assert_eq!(file_read.support_level, "can_be_allowed");
|
assert_eq!(file_read.support_level, "can_be_allowed");
|
||||||
|
|
||||||
let metadata_read = caps.operations.iter()
|
let metadata_read = caps
|
||||||
|
.operations
|
||||||
|
.iter()
|
||||||
.find(|op| op.operation == "file_read_metadata")
|
.find(|op| op.operation == "file_read_metadata")
|
||||||
.expect("file_read_metadata should be present");
|
.expect("file_read_metadata should be present");
|
||||||
assert_eq!(metadata_read.support_level, "can_be_allowed");
|
assert_eq!(metadata_read.support_level, "can_be_allowed");
|
||||||
|
|
||||||
let network_tcp = caps.operations.iter()
|
let network_tcp = caps
|
||||||
|
.operations
|
||||||
|
.iter()
|
||||||
.find(|op| op.operation == "network_outbound_tcp")
|
.find(|op| op.operation == "network_outbound_tcp")
|
||||||
.expect("network_outbound_tcp should be present");
|
.expect("network_outbound_tcp should be present");
|
||||||
assert_eq!(network_tcp.support_level, "can_be_allowed");
|
assert_eq!(network_tcp.support_level, "can_be_allowed");
|
||||||
|
|
||||||
let system_info = caps.operations.iter()
|
let system_info = caps
|
||||||
|
.operations
|
||||||
|
.iter()
|
||||||
.find(|op| op.operation == "system_info_read")
|
.find(|op| op.operation == "system_info_read")
|
||||||
.expect("system_info_read should be present");
|
.expect("system_info_read should be present");
|
||||||
assert_eq!(system_info.support_level, "can_be_allowed");
|
assert_eq!(system_info.support_level, "can_be_allowed");
|
||||||
@@ -94,17 +117,21 @@ fn test_macos_capabilities() {
|
|||||||
#[cfg(target_os = "freebsd")]
|
#[cfg(target_os = "freebsd")]
|
||||||
fn test_freebsd_capabilities() {
|
fn test_freebsd_capabilities() {
|
||||||
let caps = get_platform_capabilities();
|
let caps = get_platform_capabilities();
|
||||||
|
|
||||||
assert_eq!(caps.os, "freebsd");
|
assert_eq!(caps.os, "freebsd");
|
||||||
assert!(caps.sandboxing_supported);
|
assert!(caps.sandboxing_supported);
|
||||||
|
|
||||||
// Verify FreeBSD-specific capabilities
|
// Verify FreeBSD-specific capabilities
|
||||||
let file_read = caps.operations.iter()
|
let file_read = caps
|
||||||
|
.operations
|
||||||
|
.iter()
|
||||||
.find(|op| op.operation == "file_read_all")
|
.find(|op| op.operation == "file_read_all")
|
||||||
.expect("file_read_all should be present");
|
.expect("file_read_all should be present");
|
||||||
assert_eq!(file_read.support_level, "never");
|
assert_eq!(file_read.support_level, "never");
|
||||||
|
|
||||||
let system_info = caps.operations.iter()
|
let system_info = caps
|
||||||
|
.operations
|
||||||
|
.iter()
|
||||||
.find(|op| op.operation == "system_info_read")
|
.find(|op| op.operation == "system_info_read")
|
||||||
.expect("system_info_read should be present");
|
.expect("system_info_read should be present");
|
||||||
assert_eq!(system_info.support_level, "always");
|
assert_eq!(system_info.support_level, "always");
|
||||||
@@ -114,7 +141,7 @@ fn test_freebsd_capabilities() {
|
|||||||
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "freebsd")))]
|
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "freebsd")))]
|
||||||
fn test_unsupported_platform_capabilities() {
|
fn test_unsupported_platform_capabilities() {
|
||||||
let caps = get_platform_capabilities();
|
let caps = get_platform_capabilities();
|
||||||
|
|
||||||
assert!(!caps.sandboxing_supported);
|
assert!(!caps.sandboxing_supported);
|
||||||
assert_eq!(caps.operations.len(), 0);
|
assert_eq!(caps.operations.len(), 0);
|
||||||
assert!(caps.notes.iter().any(|note| note.contains("not supported")));
|
assert!(caps.notes.iter().any(|note| note.contains("not supported")));
|
||||||
@@ -123,12 +150,18 @@ fn test_unsupported_platform_capabilities() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_all_operations_have_descriptions() {
|
fn test_all_operations_have_descriptions() {
|
||||||
let caps = get_platform_capabilities();
|
let caps = get_platform_capabilities();
|
||||||
|
|
||||||
for op in &caps.operations {
|
for op in &caps.operations {
|
||||||
assert!(!op.description.is_empty(),
|
assert!(
|
||||||
"Operation {} should have a description", op.operation);
|
!op.description.is_empty(),
|
||||||
assert!(!op.support_level.is_empty(),
|
"Operation {} should have a description",
|
||||||
"Operation {} should have a support level", op.operation);
|
op.operation
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!op.support_level.is_empty(),
|
||||||
|
"Operation {} should have a support level",
|
||||||
|
op.operation
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,7 +169,7 @@ fn test_all_operations_have_descriptions() {
|
|||||||
fn test_support_level_values() {
|
fn test_support_level_values() {
|
||||||
let caps = get_platform_capabilities();
|
let caps = get_platform_capabilities();
|
||||||
let valid_levels = ["never", "can_be_allowed", "cannot_be_precisely", "always"];
|
let valid_levels = ["never", "can_be_allowed", "cannot_be_precisely", "always"];
|
||||||
|
|
||||||
for op in &caps.operations {
|
for op in &caps.operations {
|
||||||
assert!(
|
assert!(
|
||||||
valid_levels.contains(&op.support_level.as_str()),
|
valid_levels.contains(&op.support_level.as_str()),
|
||||||
@@ -145,4 +178,4 @@ fn test_support_level_values() {
|
|||||||
op.support_level
|
op.support_level
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ fn make_rule(
|
|||||||
pattern_value: pattern_value.to_string(),
|
pattern_value: pattern_value.to_string(),
|
||||||
enabled: true,
|
enabled: true,
|
||||||
platform_support: platforms.map(|p| {
|
platform_support: platforms.map(|p| {
|
||||||
serde_json::to_string(&p.iter().map(|s| s.to_string()).collect::<Vec<_>>())
|
serde_json::to_string(&p.iter().map(|s| s.to_string()).collect::<Vec<_>>()).unwrap()
|
||||||
.unwrap()
|
|
||||||
}),
|
}),
|
||||||
created_at: String::new(),
|
created_at: String::new(),
|
||||||
}
|
}
|
||||||
@@ -29,34 +28,53 @@ fn make_rule(
|
|||||||
fn test_profile_builder_creation() {
|
fn test_profile_builder_creation() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path.clone());
|
let builder = ProfileBuilder::new(project_path.clone());
|
||||||
|
|
||||||
assert!(builder.is_ok(), "ProfileBuilder should be created successfully");
|
assert!(
|
||||||
|
builder.is_ok(),
|
||||||
|
"ProfileBuilder should be created successfully"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_empty_rules_creates_empty_profile() {
|
fn test_empty_rules_creates_empty_profile() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path).unwrap();
|
let builder = ProfileBuilder::new(project_path).unwrap();
|
||||||
|
|
||||||
let profile = builder.build_profile(vec![]);
|
let profile = builder.build_profile(vec![]);
|
||||||
assert!(profile.is_ok(), "Empty rules should create valid empty profile");
|
assert!(
|
||||||
|
profile.is_ok(),
|
||||||
|
"Empty rules should create valid empty profile"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_file_read_rule_parsing() {
|
fn test_file_read_rule_parsing() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path.clone()).unwrap();
|
let builder = ProfileBuilder::new(project_path.clone()).unwrap();
|
||||||
|
|
||||||
let rules = vec![
|
let rules = vec![
|
||||||
make_rule("file_read_all", "literal", "/usr/lib/test.so", Some(&["linux", "macos"])),
|
make_rule(
|
||||||
make_rule("file_read_all", "subpath", "/usr/lib", Some(&["linux", "macos"])),
|
"file_read_all",
|
||||||
|
"literal",
|
||||||
|
"/usr/lib/test.so",
|
||||||
|
Some(&["linux", "macos"]),
|
||||||
|
),
|
||||||
|
make_rule(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"/usr/lib",
|
||||||
|
Some(&["linux", "macos"]),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
let _profile = builder.build_profile(rules);
|
let _profile = builder.build_profile(rules);
|
||||||
|
|
||||||
// Profile creation might fail on unsupported platforms, but parsing should work
|
// Profile creation might fail on unsupported platforms, but parsing should work
|
||||||
if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" {
|
if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" {
|
||||||
assert!(_profile.is_ok(), "File read rules should be parsed on supported platforms");
|
assert!(
|
||||||
|
_profile.is_ok(),
|
||||||
|
"File read rules should be parsed on supported platforms"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,17 +82,25 @@ fn test_file_read_rule_parsing() {
|
|||||||
fn test_network_rule_parsing() {
|
fn test_network_rule_parsing() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path).unwrap();
|
let builder = ProfileBuilder::new(project_path).unwrap();
|
||||||
|
|
||||||
let rules = vec![
|
let rules = vec![
|
||||||
make_rule("network_outbound", "all", "", Some(&["linux", "macos"])),
|
make_rule("network_outbound", "all", "", Some(&["linux", "macos"])),
|
||||||
make_rule("network_outbound", "tcp", "8080", Some(&["macos"])),
|
make_rule("network_outbound", "tcp", "8080", Some(&["macos"])),
|
||||||
make_rule("network_outbound", "local_socket", "/tmp/socket", Some(&["macos"])),
|
make_rule(
|
||||||
|
"network_outbound",
|
||||||
|
"local_socket",
|
||||||
|
"/tmp/socket",
|
||||||
|
Some(&["macos"]),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
let _profile = builder.build_profile(rules);
|
let _profile = builder.build_profile(rules);
|
||||||
|
|
||||||
if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" {
|
if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" {
|
||||||
assert!(_profile.is_ok(), "Network rules should be parsed on supported platforms");
|
assert!(
|
||||||
|
_profile.is_ok(),
|
||||||
|
"Network rules should be parsed on supported platforms"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,15 +108,16 @@ fn test_network_rule_parsing() {
|
|||||||
fn test_system_info_rule_parsing() {
|
fn test_system_info_rule_parsing() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path).unwrap();
|
let builder = ProfileBuilder::new(project_path).unwrap();
|
||||||
|
|
||||||
let rules = vec![
|
let rules = vec![make_rule("system_info_read", "all", "", Some(&["macos"]))];
|
||||||
make_rule("system_info_read", "all", "", Some(&["macos"])),
|
|
||||||
];
|
|
||||||
|
|
||||||
let _profile = builder.build_profile(rules);
|
let _profile = builder.build_profile(rules);
|
||||||
|
|
||||||
if std::env::consts::OS == "macos" {
|
if std::env::consts::OS == "macos" {
|
||||||
assert!(_profile.is_ok(), "System info rule should be parsed on macOS");
|
assert!(
|
||||||
|
_profile.is_ok(),
|
||||||
|
"System info rule should be parsed on macOS"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,12 +125,22 @@ fn test_system_info_rule_parsing() {
|
|||||||
fn test_template_variable_replacement() {
|
fn test_template_variable_replacement() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path.clone()).unwrap();
|
let builder = ProfileBuilder::new(project_path.clone()).unwrap();
|
||||||
|
|
||||||
let rules = vec![
|
let rules = vec![
|
||||||
make_rule("file_read_all", "subpath", "{{PROJECT_PATH}}/src", Some(&["linux", "macos"])),
|
make_rule(
|
||||||
make_rule("file_read_all", "subpath", "{{HOME}}/.config", Some(&["linux", "macos"])),
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"{{PROJECT_PATH}}/src",
|
||||||
|
Some(&["linux", "macos"]),
|
||||||
|
),
|
||||||
|
make_rule(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"{{HOME}}/.config",
|
||||||
|
Some(&["linux", "macos"]),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
let _profile = builder.build_profile(rules);
|
let _profile = builder.build_profile(rules);
|
||||||
// We can't easily verify the exact paths without inspecting the Profile internals,
|
// We can't easily verify the exact paths without inspecting the Profile internals,
|
||||||
// but this test ensures template replacement doesn't panic
|
// but this test ensures template replacement doesn't panic
|
||||||
@@ -113,10 +150,15 @@ fn test_template_variable_replacement() {
|
|||||||
fn test_disabled_rules_are_ignored() {
|
fn test_disabled_rules_are_ignored() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path).unwrap();
|
let builder = ProfileBuilder::new(project_path).unwrap();
|
||||||
|
|
||||||
let mut rule = make_rule("file_read_all", "subpath", "/usr/lib", Some(&["linux", "macos"]));
|
let mut rule = make_rule(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"/usr/lib",
|
||||||
|
Some(&["linux", "macos"]),
|
||||||
|
);
|
||||||
rule.enabled = false;
|
rule.enabled = false;
|
||||||
|
|
||||||
let profile = builder.build_profile(vec![rule]);
|
let profile = builder.build_profile(vec![rule]);
|
||||||
assert!(profile.is_ok(), "Disabled rules should be ignored");
|
assert!(profile.is_ok(), "Disabled rules should be ignored");
|
||||||
}
|
}
|
||||||
@@ -125,21 +167,30 @@ fn test_disabled_rules_are_ignored() {
|
|||||||
fn test_platform_filtering() {
|
fn test_platform_filtering() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path).unwrap();
|
let builder = ProfileBuilder::new(project_path).unwrap();
|
||||||
|
|
||||||
let current_os = std::env::consts::OS;
|
let current_os = std::env::consts::OS;
|
||||||
let other_os = if current_os == "linux" { "macos" } else { "linux" };
|
let other_os = if current_os == "linux" {
|
||||||
|
"macos"
|
||||||
|
} else {
|
||||||
|
"linux"
|
||||||
|
};
|
||||||
|
|
||||||
let rules = vec![
|
let rules = vec![
|
||||||
// Rule for current platform
|
// Rule for current platform
|
||||||
make_rule("file_read_all", "subpath", "/test1", Some(&[current_os])),
|
make_rule("file_read_all", "subpath", "/test1", Some(&[current_os])),
|
||||||
// Rule for other platform
|
// Rule for other platform
|
||||||
make_rule("file_read_all", "subpath", "/test2", Some(&[other_os])),
|
make_rule("file_read_all", "subpath", "/test2", Some(&[other_os])),
|
||||||
// Rule for both platforms
|
// Rule for both platforms
|
||||||
make_rule("file_read_all", "subpath", "/test3", Some(&["linux", "macos"])),
|
make_rule(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"/test3",
|
||||||
|
Some(&["linux", "macos"]),
|
||||||
|
),
|
||||||
// Rule with no platform specification (should be included)
|
// Rule with no platform specification (should be included)
|
||||||
make_rule("file_read_all", "subpath", "/test4", None),
|
make_rule("file_read_all", "subpath", "/test4", None),
|
||||||
];
|
];
|
||||||
|
|
||||||
let _profile = builder.build_profile(rules);
|
let _profile = builder.build_profile(rules);
|
||||||
// Rules for other platforms should be filtered out
|
// Rules for other platforms should be filtered out
|
||||||
}
|
}
|
||||||
@@ -148,11 +199,14 @@ fn test_platform_filtering() {
|
|||||||
fn test_invalid_operation_type() {
|
fn test_invalid_operation_type() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path).unwrap();
|
let builder = ProfileBuilder::new(project_path).unwrap();
|
||||||
|
|
||||||
let rules = vec![
|
let rules = vec![make_rule(
|
||||||
make_rule("invalid_operation", "subpath", "/test", Some(&["linux", "macos"])),
|
"invalid_operation",
|
||||||
];
|
"subpath",
|
||||||
|
"/test",
|
||||||
|
Some(&["linux", "macos"]),
|
||||||
|
)];
|
||||||
|
|
||||||
let _profile = builder.build_profile(rules);
|
let _profile = builder.build_profile(rules);
|
||||||
assert!(_profile.is_ok(), "Invalid operations should be skipped");
|
assert!(_profile.is_ok(), "Invalid operations should be skipped");
|
||||||
}
|
}
|
||||||
@@ -161,11 +215,14 @@ fn test_invalid_operation_type() {
|
|||||||
fn test_invalid_pattern_type() {
|
fn test_invalid_pattern_type() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path).unwrap();
|
let builder = ProfileBuilder::new(project_path).unwrap();
|
||||||
|
|
||||||
let rules = vec![
|
let rules = vec![make_rule(
|
||||||
make_rule("file_read_all", "invalid_pattern", "/test", Some(&["linux", "macos"])),
|
"file_read_all",
|
||||||
];
|
"invalid_pattern",
|
||||||
|
"/test",
|
||||||
|
Some(&["linux", "macos"]),
|
||||||
|
)];
|
||||||
|
|
||||||
let _profile = builder.build_profile(rules);
|
let _profile = builder.build_profile(rules);
|
||||||
// Should either skip the rule or fail gracefully
|
// Should either skip the rule or fail gracefully
|
||||||
}
|
}
|
||||||
@@ -174,11 +231,14 @@ fn test_invalid_pattern_type() {
|
|||||||
fn test_invalid_tcp_port() {
|
fn test_invalid_tcp_port() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path).unwrap();
|
let builder = ProfileBuilder::new(project_path).unwrap();
|
||||||
|
|
||||||
let rules = vec![
|
let rules = vec![make_rule(
|
||||||
make_rule("network_outbound", "tcp", "not_a_number", Some(&["macos"])),
|
"network_outbound",
|
||||||
];
|
"tcp",
|
||||||
|
"not_a_number",
|
||||||
|
Some(&["macos"]),
|
||||||
|
)];
|
||||||
|
|
||||||
let _profile = builder.build_profile(rules);
|
let _profile = builder.build_profile(rules);
|
||||||
// Should handle invalid port gracefully
|
// Should handle invalid port gracefully
|
||||||
}
|
}
|
||||||
@@ -188,13 +248,12 @@ fn test_invalid_tcp_port() {
|
|||||||
#[test_case("network_outbound", "all", "" ; "network all operation")]
|
#[test_case("network_outbound", "all", "" ; "network all operation")]
|
||||||
#[test_case("system_info_read", "all", "" ; "system info operation")]
|
#[test_case("system_info_read", "all", "" ; "system info operation")]
|
||||||
fn test_operation_support_level(operation_type: &str, pattern_type: &str, pattern_value: &str) {
|
fn test_operation_support_level(operation_type: &str, pattern_type: &str, pattern_value: &str) {
|
||||||
|
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path).unwrap();
|
let builder = ProfileBuilder::new(project_path).unwrap();
|
||||||
|
|
||||||
let rule = make_rule(operation_type, pattern_type, pattern_value, None);
|
let rule = make_rule(operation_type, pattern_type, pattern_value, None);
|
||||||
let rules = vec![rule];
|
let rules = vec![rule];
|
||||||
|
|
||||||
match builder.build_profile(rules) {
|
match builder.build_profile(rules) {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
// Profile created successfully - operation is supported
|
// Profile created successfully - operation is supported
|
||||||
@@ -211,27 +270,43 @@ fn test_operation_support_level(operation_type: &str, pattern_type: &str, patter
|
|||||||
fn test_complex_profile_with_multiple_rules() {
|
fn test_complex_profile_with_multiple_rules() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path.clone()).unwrap();
|
let builder = ProfileBuilder::new(project_path.clone()).unwrap();
|
||||||
|
|
||||||
let rules = vec![
|
let rules = vec![
|
||||||
// File operations
|
// File operations
|
||||||
make_rule("file_read_all", "subpath", "{{PROJECT_PATH}}", Some(&["linux", "macos"])),
|
make_rule(
|
||||||
make_rule("file_read_all", "subpath", "/usr/lib", Some(&["linux", "macos"])),
|
"file_read_all",
|
||||||
make_rule("file_read_all", "literal", "/etc/hosts", Some(&["linux", "macos"])),
|
"subpath",
|
||||||
|
"{{PROJECT_PATH}}",
|
||||||
|
Some(&["linux", "macos"]),
|
||||||
|
),
|
||||||
|
make_rule(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"/usr/lib",
|
||||||
|
Some(&["linux", "macos"]),
|
||||||
|
),
|
||||||
|
make_rule(
|
||||||
|
"file_read_all",
|
||||||
|
"literal",
|
||||||
|
"/etc/hosts",
|
||||||
|
Some(&["linux", "macos"]),
|
||||||
|
),
|
||||||
make_rule("file_read_metadata", "subpath", "/", Some(&["macos"])),
|
make_rule("file_read_metadata", "subpath", "/", Some(&["macos"])),
|
||||||
|
|
||||||
// Network operations
|
// Network operations
|
||||||
make_rule("network_outbound", "all", "", Some(&["linux", "macos"])),
|
make_rule("network_outbound", "all", "", Some(&["linux", "macos"])),
|
||||||
make_rule("network_outbound", "tcp", "443", Some(&["macos"])),
|
make_rule("network_outbound", "tcp", "443", Some(&["macos"])),
|
||||||
make_rule("network_outbound", "tcp", "80", Some(&["macos"])),
|
make_rule("network_outbound", "tcp", "80", Some(&["macos"])),
|
||||||
|
|
||||||
// System info
|
// System info
|
||||||
make_rule("system_info_read", "all", "", Some(&["macos"])),
|
make_rule("system_info_read", "all", "", Some(&["macos"])),
|
||||||
];
|
];
|
||||||
|
|
||||||
let _profile = builder.build_profile(rules);
|
let _profile = builder.build_profile(rules);
|
||||||
|
|
||||||
if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" {
|
if std::env::consts::OS == "linux" || std::env::consts::OS == "macos" {
|
||||||
assert!(_profile.is_ok(), "Complex profile should be created on supported platforms");
|
assert!(
|
||||||
|
_profile.is_ok(),
|
||||||
|
"Complex profile should be created on supported platforms"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,14 +314,24 @@ fn test_complex_profile_with_multiple_rules() {
|
|||||||
fn test_rule_order_preservation() {
|
fn test_rule_order_preservation() {
|
||||||
let project_path = PathBuf::from("/test/project");
|
let project_path = PathBuf::from("/test/project");
|
||||||
let builder = ProfileBuilder::new(project_path).unwrap();
|
let builder = ProfileBuilder::new(project_path).unwrap();
|
||||||
|
|
||||||
// Create rules with specific order
|
// Create rules with specific order
|
||||||
let rules = vec![
|
let rules = vec![
|
||||||
make_rule("file_read_all", "subpath", "/first", Some(&["linux", "macos"])),
|
make_rule(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"/first",
|
||||||
|
Some(&["linux", "macos"]),
|
||||||
|
),
|
||||||
make_rule("network_outbound", "all", "", Some(&["linux", "macos"])),
|
make_rule("network_outbound", "all", "", Some(&["linux", "macos"])),
|
||||||
make_rule("file_read_all", "subpath", "/second", Some(&["linux", "macos"])),
|
make_rule(
|
||||||
|
"file_read_all",
|
||||||
|
"subpath",
|
||||||
|
"/second",
|
||||||
|
Some(&["linux", "macos"]),
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
let _profile = builder.build_profile(rules);
|
let _profile = builder.build_profile(rules);
|
||||||
// Order should be preserved in the resulting profile
|
// Order should be preserved in the resulting profile
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
//! Main entry point for sandbox tests
|
//! Main entry point for sandbox tests
|
||||||
//!
|
//!
|
||||||
//! This file integrates all the sandbox test modules and provides
|
//! This file integrates all the sandbox test modules and provides
|
||||||
//! a central location for running the comprehensive test suite.
|
//! a central location for running the comprehensive test suite.
|
||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
@@ -8,4 +8,4 @@
|
|||||||
mod sandbox;
|
mod sandbox;
|
||||||
|
|
||||||
// Re-export test modules to make them discoverable
|
// Re-export test modules to make them discoverable
|
||||||
pub use sandbox::*;
|
pub use sandbox::*;
|
||||||
|
|||||||
Reference in New Issue
Block a user