style: apply cargo fmt across entire Rust codebase

- Remove Rust formatting check from CI workflow since formatting is now applied
- Standardize import ordering and organization throughout codebase
- Fix indentation, spacing, and line breaks for consistency
- Clean up trailing whitespace and formatting inconsistencies
- Apply rustfmt to all Rust source files including checkpoint, sandbox, commands, and test modules

This establishes a consistent code style baseline for the project.
This commit is contained in:
Mufeed VH
2025-06-25 03:45:59 +05:30
parent bb48a32784
commit bcffce0a08
41 changed files with 3617 additions and 2662 deletions

View File

@@ -1,16 +1,16 @@
use anyhow::{Context, Result};
use chrono::{DateTime, TimeZone, Utc};
use log;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::sync::Arc;
use chrono::{Utc, TimeZone, DateTime};
use tokio::sync::RwLock;
use log;
use super::{
Checkpoint, CheckpointMetadata, FileSnapshot, FileTracker, FileState,
CheckpointResult, SessionTimeline, CheckpointStrategy, CheckpointPaths,
storage::{CheckpointStorage, self},
storage::{self, CheckpointStorage},
Checkpoint, CheckpointMetadata, CheckpointPaths, CheckpointResult, CheckpointStrategy,
FileSnapshot, FileState, FileTracker, SessionTimeline,
};
/// Manages checkpoint operations for a session
@@ -33,10 +33,10 @@ impl CheckpointManager {
claude_dir: PathBuf,
) -> Result<Self> {
let storage = Arc::new(CheckpointStorage::new(claude_dir.clone()));
// Initialize storage
storage.init_storage(&project_id, &session_id)?;
// Load or create timeline
let paths = CheckpointPaths::new(&claude_dir, &project_id, &session_id);
let timeline = if paths.timeline_file.exists() {
@@ -44,11 +44,11 @@ impl CheckpointManager {
} else {
SessionTimeline::new(session_id.clone())
};
let file_tracker = FileTracker {
tracked_files: HashMap::new(),
};
Ok(Self {
project_id,
session_id,
@@ -59,12 +59,12 @@ impl CheckpointManager {
current_messages: Arc::new(RwLock::new(Vec::new())),
})
}
/// Track a new message in the session
pub async fn track_message(&self, jsonl_message: String) -> Result<()> {
let mut messages = self.current_messages.write().await;
messages.push(jsonl_message.clone());
// Parse message to check for tool usage
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&jsonl_message) {
if let Some(content) = msg.get("message").and_then(|m| m.get("content")) {
@@ -81,10 +81,10 @@ impl CheckpointManager {
}
}
}
Ok(())
}
/// Track file operations from tool usage
async fn track_tool_operation(&self, tool: &str, input: &serde_json::Value) -> Result<()> {
match tool.to_lowercase().as_str() {
@@ -103,47 +103,51 @@ impl CheckpointManager {
}
Ok(())
}
/// Track a file modification
pub async fn track_file_modification(&self, file_path: &str) -> Result<()> {
let mut tracker = self.file_tracker.write().await;
let full_path = self.project_path.join(file_path);
// Read current file state
let (hash, exists, _size, modified) = if full_path.exists() {
let content = fs::read_to_string(&full_path)
.unwrap_or_default();
let content = fs::read_to_string(&full_path).unwrap_or_default();
let metadata = fs::metadata(&full_path)?;
let modified = metadata.modified()
let modified = metadata
.modified()
.ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
.map(|d| Utc.timestamp_opt(d.as_secs() as i64, d.subsec_nanos()).unwrap())
.map(|d| {
Utc.timestamp_opt(d.as_secs() as i64, d.subsec_nanos())
.unwrap()
})
.unwrap_or_else(Utc::now);
(
storage::CheckpointStorage::calculate_file_hash(&content),
true,
metadata.len(),
modified
modified,
)
} else {
(String::new(), false, 0, Utc::now())
};
// Check if file has actually changed
let is_modified = if let Some(existing_state) = tracker.tracked_files.get(&PathBuf::from(file_path)) {
// File is modified if:
// 1. Hash has changed
// 2. Existence state has changed
// 3. It was already marked as modified
existing_state.last_hash != hash ||
existing_state.exists != exists ||
existing_state.is_modified
} else {
// New file is always considered modified
true
};
let is_modified =
if let Some(existing_state) = tracker.tracked_files.get(&PathBuf::from(file_path)) {
// File is modified if:
// 1. Hash has changed
// 2. Existence state has changed
// 3. It was already marked as modified
existing_state.last_hash != hash
|| existing_state.exists != exists
|| existing_state.is_modified
} else {
// New file is always considered modified
true
};
tracker.tracked_files.insert(
PathBuf::from(file_path),
FileState {
@@ -153,18 +157,18 @@ impl CheckpointManager {
exists,
},
);
Ok(())
}
/// Track potential file changes from bash commands
async fn track_bash_side_effects(&self, command: &str) -> Result<()> {
// Common file-modifying commands
let file_commands = [
"echo", "cat", "cp", "mv", "rm", "touch", "sed", "awk",
"npm", "yarn", "pnpm", "bun", "cargo", "make", "gcc", "g++",
"echo", "cat", "cp", "mv", "rm", "touch", "sed", "awk", "npm", "yarn", "pnpm", "bun",
"cargo", "make", "gcc", "g++",
];
// Simple heuristic: if command contains file-modifying operations
for cmd in &file_commands {
if command.contains(cmd) {
@@ -176,10 +180,10 @@ impl CheckpointManager {
break;
}
}
Ok(())
}
/// Create a checkpoint
pub async fn create_checkpoint(
&self,
@@ -188,13 +192,18 @@ impl CheckpointManager {
) -> Result<CheckpointResult> {
let messages = self.current_messages.read().await;
let message_index = messages.len().saturating_sub(1);
// Extract metadata from the last user message
let (user_prompt, model_used, total_tokens) = self.extract_checkpoint_metadata(&messages).await?;
let (user_prompt, model_used, total_tokens) =
self.extract_checkpoint_metadata(&messages).await?;
// Ensure every file in the project is tracked so new checkpoints include all files
// Recursively walk the project directory and track each file
fn collect_files(dir: &std::path::Path, base: &std::path::Path, files: &mut Vec<std::path::PathBuf>) -> Result<(), std::io::Error> {
fn collect_files(
dir: &std::path::Path,
base: &std::path::Path,
files: &mut Vec<std::path::PathBuf>,
) -> Result<(), std::io::Error> {
for entry in std::fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
@@ -224,13 +233,13 @@ impl CheckpointManager {
let _ = self.track_file_modification(p).await;
}
}
// Generate checkpoint ID early so snapshots reference it
let checkpoint_id = storage::CheckpointStorage::generate_checkpoint_id();
// Create file snapshots
let file_snapshots = self.create_file_snapshots(&checkpoint_id).await?;
// Generate checkpoint struct
let checkpoint = Checkpoint {
id: checkpoint_id.clone(),
@@ -259,7 +268,7 @@ impl CheckpointManager {
),
},
};
// Save checkpoint
let messages_content = messages.join("\n");
let result = self.storage.save_checkpoint(
@@ -269,7 +278,7 @@ impl CheckpointManager {
file_snapshots,
&messages_content,
)?;
// Reload timeline from disk so in-memory timeline has updated nodes and total_checkpoints
let claude_dir = self.storage.claude_dir.clone();
let paths = CheckpointPaths::new(&claude_dir, &self.project_id, &self.session_id);
@@ -278,20 +287,20 @@ impl CheckpointManager {
let mut timeline_lock = self.timeline.write().await;
*timeline_lock = updated_timeline;
}
// Update timeline (current checkpoint only)
let mut timeline = self.timeline.write().await;
timeline.current_checkpoint_id = Some(checkpoint_id);
// Reset file tracker
let mut tracker = self.file_tracker.write().await;
for (_, state) in tracker.tracked_files.iter_mut() {
state.is_modified = false;
}
Ok(result)
}
/// Extract metadata from messages for checkpoint
async fn extract_checkpoint_metadata(
&self,
@@ -300,13 +309,14 @@ impl CheckpointManager {
let mut user_prompt = String::new();
let mut model_used = String::from("unknown");
let mut total_tokens = 0u64;
// Iterate through messages in reverse to find the last user prompt
for msg_str in messages.iter().rev() {
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(msg_str) {
// Check for user message
if msg.get("type").and_then(|t| t.as_str()) == Some("user") {
if let Some(content) = msg.get("message")
if let Some(content) = msg
.get("message")
.and_then(|m| m.get("content"))
.and_then(|c| c.as_array())
{
@@ -320,19 +330,19 @@ impl CheckpointManager {
}
}
}
// Extract model info
if let Some(model) = msg.get("model").and_then(|m| m.as_str()) {
model_used = model.to_string();
}
// Also check for model in message.model (assistant messages)
if let Some(message) = msg.get("message") {
if let Some(model) = message.get("model").and_then(|m| m.as_str()) {
model_used = model.to_string();
}
}
// Count tokens - check both top-level and nested usage
// First check for usage in message.usage (assistant messages)
if let Some(message) = msg.get("message") {
@@ -344,15 +354,21 @@ impl CheckpointManager {
total_tokens += output;
}
// Also count cache tokens
if let Some(cache_creation) = usage.get("cache_creation_input_tokens").and_then(|t| t.as_u64()) {
if let Some(cache_creation) = usage
.get("cache_creation_input_tokens")
.and_then(|t| t.as_u64())
{
total_tokens += cache_creation;
}
if let Some(cache_read) = usage.get("cache_read_input_tokens").and_then(|t| t.as_u64()) {
if let Some(cache_read) = usage
.get("cache_read_input_tokens")
.and_then(|t| t.as_u64())
{
total_tokens += cache_read;
}
}
}
// Then check for top-level usage (result messages)
if let Some(usage) = msg.get("usage") {
if let Some(input) = usage.get("input_tokens").and_then(|t| t.as_u64()) {
@@ -362,40 +378,45 @@ impl CheckpointManager {
total_tokens += output;
}
// Also count cache tokens
if let Some(cache_creation) = usage.get("cache_creation_input_tokens").and_then(|t| t.as_u64()) {
if let Some(cache_creation) = usage
.get("cache_creation_input_tokens")
.and_then(|t| t.as_u64())
{
total_tokens += cache_creation;
}
if let Some(cache_read) = usage.get("cache_read_input_tokens").and_then(|t| t.as_u64()) {
if let Some(cache_read) = usage
.get("cache_read_input_tokens")
.and_then(|t| t.as_u64())
{
total_tokens += cache_read;
}
}
}
}
Ok((user_prompt, model_used, total_tokens))
}
/// Create file snapshots for all tracked modified files
async fn create_file_snapshots(&self, checkpoint_id: &str) -> Result<Vec<FileSnapshot>> {
let tracker = self.file_tracker.read().await;
let mut snapshots = Vec::new();
for (rel_path, state) in &tracker.tracked_files {
// Skip files that haven't been modified
if !state.is_modified {
continue;
}
let full_path = self.project_path.join(rel_path);
let (content, exists, permissions, size, current_hash) = if full_path.exists() {
let content = fs::read_to_string(&full_path)
.unwrap_or_default();
let content = fs::read_to_string(&full_path).unwrap_or_default();
let current_hash = storage::CheckpointStorage::calculate_file_hash(&content);
// Don't skip based on hash - if is_modified is true, we should snapshot it
// The hash check in track_file_modification already determined if it changed
let metadata = fs::metadata(&full_path)?;
let permissions = {
#[cfg(unix)]
@@ -412,7 +433,7 @@ impl CheckpointManager {
} else {
(String::new(), false, None, 0, String::new())
};
snapshots.push(FileSnapshot {
checkpoint_id: checkpoint_id.to_string(),
file_path: rel_path.clone(),
@@ -423,21 +444,23 @@ impl CheckpointManager {
size,
});
}
Ok(snapshots)
}
/// Restore a checkpoint
pub async fn restore_checkpoint(&self, checkpoint_id: &str) -> Result<CheckpointResult> {
// Load checkpoint data
let (checkpoint, file_snapshots, messages) = self.storage.load_checkpoint(
&self.project_id,
&self.session_id,
checkpoint_id,
)?;
let (checkpoint, file_snapshots, messages) =
self.storage
.load_checkpoint(&self.project_id, &self.session_id, checkpoint_id)?;
// First, collect all files currently in the project to handle deletions
fn collect_all_project_files(dir: &std::path::Path, base: &std::path::Path, files: &mut Vec<std::path::PathBuf>) -> Result<(), std::io::Error> {
fn collect_all_project_files(
dir: &std::path::Path,
base: &std::path::Path,
files: &mut Vec<std::path::PathBuf>,
) -> Result<(), std::io::Error> {
for entry in std::fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
@@ -458,10 +481,11 @@ impl CheckpointManager {
}
Ok(())
}
let mut current_files = Vec::new();
let _ = collect_all_project_files(&self.project_path, &self.project_path, &mut current_files);
let _ =
collect_all_project_files(&self.project_path, &self.project_path, &mut current_files);
// Create a set of files that should exist after restore
let mut checkpoint_files = std::collections::HashSet::new();
for snapshot in &file_snapshots {
@@ -469,11 +493,11 @@ impl CheckpointManager {
checkpoint_files.insert(snapshot.file_path.clone());
}
}
// Delete files that exist now but shouldn't exist in the checkpoint
let mut warnings = Vec::new();
let mut files_processed = 0;
for current_file in current_files {
if !checkpoint_files.contains(&current_file) {
// This file exists now but not in the checkpoint, so delete it
@@ -484,18 +508,25 @@ impl CheckpointManager {
log::info!("Deleted file not in checkpoint: {:?}", current_file);
}
Err(e) => {
warnings.push(format!("Failed to delete {}: {}", current_file.display(), e));
warnings.push(format!(
"Failed to delete {}: {}",
current_file.display(),
e
));
}
}
}
}
// Clean up empty directories
fn remove_empty_dirs(dir: &std::path::Path, base: &std::path::Path) -> Result<bool, std::io::Error> {
fn remove_empty_dirs(
dir: &std::path::Path,
base: &std::path::Path,
) -> Result<bool, std::io::Error> {
if dir == base {
return Ok(false); // Don't remove the base directory
}
let mut is_empty = true;
for entry in fs::read_dir(dir)? {
let entry = entry?;
@@ -508,7 +539,7 @@ impl CheckpointManager {
is_empty = false;
}
}
if is_empty {
fs::remove_dir(dir)?;
Ok(true)
@@ -516,30 +547,33 @@ impl CheckpointManager {
Ok(false)
}
}
// Clean up any empty directories left after file deletion
let _ = remove_empty_dirs(&self.project_path, &self.project_path);
// Restore files from checkpoint
for snapshot in &file_snapshots {
match self.restore_file_snapshot(snapshot).await {
Ok(_) => files_processed += 1,
Err(e) => warnings.push(format!("Failed to restore {}: {}",
snapshot.file_path.display(), e)),
Err(e) => warnings.push(format!(
"Failed to restore {}: {}",
snapshot.file_path.display(),
e
)),
}
}
// Update current messages
let mut current_messages = self.current_messages.write().await;
current_messages.clear();
for line in messages.lines() {
current_messages.push(line.to_string());
}
// Update timeline
let mut timeline = self.timeline.write().await;
timeline.current_checkpoint_id = Some(checkpoint_id.to_string());
// Update file tracker
let mut tracker = self.file_tracker.write().await;
tracker.tracked_files.clear();
@@ -556,35 +590,32 @@ impl CheckpointManager {
);
}
}
Ok(CheckpointResult {
checkpoint: checkpoint.clone(),
files_processed,
warnings,
})
}
/// Restore a single file from snapshot
async fn restore_file_snapshot(&self, snapshot: &FileSnapshot) -> Result<()> {
let full_path = self.project_path.join(&snapshot.file_path);
if snapshot.is_deleted {
// Delete the file if it exists
if full_path.exists() {
fs::remove_file(&full_path)
.context("Failed to delete file")?;
fs::remove_file(&full_path).context("Failed to delete file")?;
}
} else {
// Create parent directories if needed
if let Some(parent) = full_path.parent() {
fs::create_dir_all(parent)
.context("Failed to create parent directories")?;
fs::create_dir_all(parent).context("Failed to create parent directories")?;
}
// Write file content
fs::write(&full_path, &snapshot.content)
.context("Failed to write file")?;
fs::write(&full_path, &snapshot.content).context("Failed to write file")?;
// Restore permissions if available
#[cfg(unix)]
if let Some(mode) = snapshot.permissions {
@@ -594,35 +625,38 @@ impl CheckpointManager {
.context("Failed to set file permissions")?;
}
}
Ok(())
}
/// Get the current timeline
pub async fn get_timeline(&self) -> SessionTimeline {
self.timeline.read().await.clone()
}
/// List all checkpoints
pub async fn list_checkpoints(&self) -> Vec<Checkpoint> {
let timeline = self.timeline.read().await;
let mut checkpoints = Vec::new();
if let Some(root) = &timeline.root_node {
Self::collect_checkpoints_from_node(root, &mut checkpoints);
}
checkpoints
}
/// Recursively collect checkpoints from timeline tree
fn collect_checkpoints_from_node(node: &super::TimelineNode, checkpoints: &mut Vec<Checkpoint>) {
fn collect_checkpoints_from_node(
node: &super::TimelineNode,
checkpoints: &mut Vec<Checkpoint>,
) {
checkpoints.push(node.checkpoint.clone());
for child in &node.children {
Self::collect_checkpoints_from_node(child, checkpoints);
}
}
/// Fork from a checkpoint
pub async fn fork_from_checkpoint(
&self,
@@ -630,31 +664,29 @@ impl CheckpointManager {
description: Option<String>,
) -> Result<CheckpointResult> {
// Load the checkpoint to fork from
let (_base_checkpoint, _, _) = self.storage.load_checkpoint(
&self.project_id,
&self.session_id,
checkpoint_id,
)?;
let (_base_checkpoint, _, _) =
self.storage
.load_checkpoint(&self.project_id, &self.session_id, checkpoint_id)?;
// Restore to that checkpoint first
self.restore_checkpoint(checkpoint_id).await?;
// Create a new checkpoint with the fork
let fork_description = description.unwrap_or_else(|| {
format!("Fork from checkpoint {}", &checkpoint_id[..8])
});
self.create_checkpoint(Some(fork_description), Some(checkpoint_id.to_string())).await
let fork_description =
description.unwrap_or_else(|| format!("Fork from checkpoint {}", &checkpoint_id[..8]));
self.create_checkpoint(Some(fork_description), Some(checkpoint_id.to_string()))
.await
}
/// Check if auto-checkpoint should be triggered
pub async fn should_auto_checkpoint(&self, message: &str) -> bool {
let timeline = self.timeline.read().await;
if !timeline.auto_checkpoint_enabled {
return false;
}
match timeline.checkpoint_strategy {
CheckpointStrategy::Manual => false,
CheckpointStrategy::PerPrompt => {
@@ -668,7 +700,11 @@ impl CheckpointManager {
CheckpointStrategy::PerToolUse => {
// Check if message contains tool use
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(message) {
if let Some(content) = msg.get("message").and_then(|m| m.get("content")).and_then(|c| c.as_array()) {
if let Some(content) = msg
.get("message")
.and_then(|m| m.get("content"))
.and_then(|c| c.as_array())
{
content.iter().any(|item| {
item.get("type").and_then(|t| t.as_str()) == Some("tool_use")
})
@@ -682,12 +718,19 @@ impl CheckpointManager {
CheckpointStrategy::Smart => {
// Smart strategy: checkpoint after destructive operations
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(message) {
if let Some(content) = msg.get("message").and_then(|m| m.get("content")).and_then(|c| c.as_array()) {
if let Some(content) = msg
.get("message")
.and_then(|m| m.get("content"))
.and_then(|c| c.as_array())
{
content.iter().any(|item| {
if item.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
let tool_name = item.get("name").and_then(|n| n.as_str()).unwrap_or("");
matches!(tool_name.to_lowercase().as_str(),
"write" | "edit" | "multiedit" | "bash" | "rm" | "delete")
let tool_name =
item.get("name").and_then(|n| n.as_str()).unwrap_or("");
matches!(
tool_name.to_lowercase().as_str(),
"write" | "edit" | "multiedit" | "bash" | "rm" | "delete"
)
} else {
false
}
@@ -701,7 +744,7 @@ impl CheckpointManager {
}
}
}
/// Update checkpoint settings
pub async fn update_settings(
&self,
@@ -711,31 +754,34 @@ impl CheckpointManager {
let mut timeline = self.timeline.write().await;
timeline.auto_checkpoint_enabled = auto_checkpoint_enabled;
timeline.checkpoint_strategy = checkpoint_strategy;
// Save updated timeline
let claude_dir = self.storage.claude_dir.clone();
let paths = CheckpointPaths::new(&claude_dir, &self.project_id, &self.session_id);
self.storage.save_timeline(&paths.timeline_file, &timeline)?;
self.storage
.save_timeline(&paths.timeline_file, &timeline)?;
Ok(())
}
/// Get files modified since a given timestamp
pub async fn get_files_modified_since(&self, since: DateTime<Utc>) -> Vec<PathBuf> {
let tracker = self.file_tracker.read().await;
tracker.tracked_files
tracker
.tracked_files
.iter()
.filter(|(_, state)| state.last_modified > since && state.is_modified)
.map(|(path, _)| path.clone())
.collect()
}
/// Get the last modification time of any tracked file
pub async fn get_last_modification_time(&self) -> Option<DateTime<Utc>> {
let tracker = self.file_tracker.read().await;
tracker.tracked_files
tracker
.tracked_files
.values()
.map(|state| state.last_modified)
.max()
}
}
}

View File

@@ -1,11 +1,11 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use chrono::{DateTime, Utc};
pub mod manager;
pub mod storage;
pub mod state;
pub mod storage;
/// Represents a checkpoint in the session timeline
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -188,24 +188,25 @@ impl SessionTimeline {
total_checkpoints: 0,
}
}
/// Find a checkpoint by ID in the timeline tree
pub fn find_checkpoint(&self, checkpoint_id: &str) -> Option<&TimelineNode> {
self.root_node.as_ref()
self.root_node
.as_ref()
.and_then(|root| Self::find_in_tree(root, checkpoint_id))
}
fn find_in_tree<'a>(node: &'a TimelineNode, checkpoint_id: &str) -> Option<&'a TimelineNode> {
if node.checkpoint.id == checkpoint_id {
return Some(node);
}
for child in &node.children {
if let Some(found) = Self::find_in_tree(child, checkpoint_id) {
return Some(found);
}
}
None
}
}
@@ -224,35 +225,38 @@ impl CheckpointPaths {
.join(project_id)
.join(".timelines")
.join(session_id);
Self {
timeline_file: base_dir.join("timeline.json"),
checkpoints_dir: base_dir.join("checkpoints"),
files_dir: base_dir.join("files"),
}
}
pub fn checkpoint_dir(&self, checkpoint_id: &str) -> PathBuf {
self.checkpoints_dir.join(checkpoint_id)
}
pub fn checkpoint_metadata_file(&self, checkpoint_id: &str) -> PathBuf {
self.checkpoint_dir(checkpoint_id).join("metadata.json")
}
pub fn checkpoint_messages_file(&self, checkpoint_id: &str) -> PathBuf {
self.checkpoint_dir(checkpoint_id).join("messages.jsonl")
}
#[allow(dead_code)]
pub fn file_snapshot_path(&self, _checkpoint_id: &str, file_hash: &str) -> PathBuf {
// In content-addressable storage, files are stored by hash in the content pool
self.files_dir.join("content_pool").join(file_hash)
}
#[allow(dead_code)]
pub fn file_reference_path(&self, checkpoint_id: &str, safe_filename: &str) -> PathBuf {
// References are stored per checkpoint
self.files_dir.join("refs").join(checkpoint_id).join(format!("{}.json", safe_filename))
self.files_dir
.join("refs")
.join(checkpoint_id)
.join(format!("{}.json", safe_filename))
}
}
}

View File

@@ -1,13 +1,13 @@
use anyhow::Result;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
use anyhow::Result;
use super::manager::CheckpointManager;
/// Manages checkpoint managers for active sessions
///
///
/// This struct maintains a stateful collection of CheckpointManager instances,
/// one per active session, to avoid recreating them on every command invocation.
/// It provides thread-safe access to managers and handles their lifecycle.
@@ -28,25 +28,25 @@ impl CheckpointState {
claude_dir: Arc::new(RwLock::new(None)),
}
}
/// Sets the Claude directory path
///
///
/// This should be called once during application initialization
pub async fn set_claude_dir(&self, claude_dir: PathBuf) {
let mut dir = self.claude_dir.write().await;
*dir = Some(claude_dir);
}
/// Gets or creates a CheckpointManager for a session
///
///
/// If a manager already exists for the session, it returns the existing one.
/// Otherwise, it creates a new manager and stores it for future use.
///
///
/// # Arguments
/// * `session_id` - The session identifier
/// * `project_id` - The project identifier
/// * `project_path` - The path to the project directory
///
///
/// # Returns
/// An Arc reference to the CheckpointManager for thread-safe sharing
pub async fn get_or_create_manager(
@@ -56,12 +56,12 @@ impl CheckpointState {
project_path: PathBuf,
) -> Result<Arc<CheckpointManager>> {
let mut managers = self.managers.write().await;
// Check if manager already exists
if let Some(manager) = managers.get(&session_id) {
return Ok(Arc::clone(manager));
}
// Get Claude directory
let claude_dir = {
let dir = self.claude_dir.read().await;
@@ -69,65 +69,62 @@ impl CheckpointState {
.ok_or_else(|| anyhow::anyhow!("Claude directory not set"))?
.clone()
};
// Create new manager
let manager = CheckpointManager::new(
project_id,
session_id.clone(),
project_path,
claude_dir,
).await?;
let manager =
CheckpointManager::new(project_id, session_id.clone(), project_path, claude_dir)
.await?;
let manager_arc = Arc::new(manager);
managers.insert(session_id, Arc::clone(&manager_arc));
Ok(manager_arc)
}
/// Gets an existing CheckpointManager for a session
///
///
/// Returns None if no manager exists for the session
#[allow(dead_code)]
pub async fn get_manager(&self, session_id: &str) -> Option<Arc<CheckpointManager>> {
let managers = self.managers.read().await;
managers.get(session_id).map(Arc::clone)
}
/// Removes a CheckpointManager for a session
///
///
/// This should be called when a session ends to free resources
pub async fn remove_manager(&self, session_id: &str) -> Option<Arc<CheckpointManager>> {
let mut managers = self.managers.write().await;
managers.remove(session_id)
}
/// Clears all managers
///
///
/// This is useful for cleanup during application shutdown
#[allow(dead_code)]
pub async fn clear_all(&self) {
let mut managers = self.managers.write().await;
managers.clear();
}
/// Gets the number of active managers
pub async fn active_count(&self) -> usize {
let managers = self.managers.read().await;
managers.len()
}
/// Lists all active session IDs
pub async fn list_active_sessions(&self) -> Vec<String> {
let managers = self.managers.read().await;
managers.keys().cloned().collect()
}
/// Checks if a session has an active manager
#[allow(dead_code)]
pub async fn has_active_manager(&self, session_id: &str) -> bool {
self.get_manager(session_id).await.is_some()
}
/// Clears all managers and returns the count that were cleared
#[allow(dead_code)]
pub async fn clear_all_and_count(&self) -> usize {
@@ -141,50 +138,47 @@ impl CheckpointState {
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_checkpoint_state_lifecycle() {
let state = CheckpointState::new();
let temp_dir = TempDir::new().unwrap();
let claude_dir = temp_dir.path().to_path_buf();
// Set Claude directory
state.set_claude_dir(claude_dir.clone()).await;
// Create a manager
let session_id = "test-session-123".to_string();
let project_id = "test-project".to_string();
let project_path = temp_dir.path().join("project");
std::fs::create_dir_all(&project_path).unwrap();
let manager1 = state.get_or_create_manager(
session_id.clone(),
project_id.clone(),
project_path.clone(),
).await.unwrap();
let manager1 = state
.get_or_create_manager(session_id.clone(), project_id.clone(), project_path.clone())
.await
.unwrap();
// Getting the same session should return the same manager
let manager2 = state.get_or_create_manager(
session_id.clone(),
project_id.clone(),
project_path.clone(),
).await.unwrap();
let manager2 = state
.get_or_create_manager(session_id.clone(), project_id.clone(), project_path.clone())
.await
.unwrap();
assert!(Arc::ptr_eq(&manager1, &manager2));
assert_eq!(state.active_count().await, 1);
// Remove the manager
let removed = state.remove_manager(&session_id).await;
assert!(removed.is_some());
assert_eq!(state.active_count().await, 0);
// Getting after removal should create a new one
let manager3 = state.get_or_create_manager(
session_id.clone(),
project_id,
project_path,
).await.unwrap();
let manager3 = state
.get_or_create_manager(session_id.clone(), project_id, project_path)
.await
.unwrap();
assert!(!Arc::ptr_eq(&manager1, &manager3));
}
}
}

View File

@@ -1,13 +1,12 @@
use anyhow::{Context, Result};
use sha2::{Digest, Sha256};
use std::fs;
use std::path::{Path, PathBuf};
use sha2::{Sha256, Digest};
use zstd::stream::{encode_all, decode_all};
use uuid::Uuid;
use zstd::stream::{decode_all, encode_all};
use super::{
Checkpoint, FileSnapshot, SessionTimeline,
TimelineNode, CheckpointPaths, CheckpointResult
Checkpoint, CheckpointPaths, CheckpointResult, FileSnapshot, SessionTimeline, TimelineNode,
};
/// Manages checkpoint storage operations
@@ -24,26 +23,25 @@ impl CheckpointStorage {
compression_level: 3, // Default zstd compression level
}
}
/// Initialize checkpoint storage for a session
pub fn init_storage(&self, project_id: &str, session_id: &str) -> Result<()> {
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
// Create directory structure
fs::create_dir_all(&paths.checkpoints_dir)
.context("Failed to create checkpoints directory")?;
fs::create_dir_all(&paths.files_dir)
.context("Failed to create files directory")?;
fs::create_dir_all(&paths.files_dir).context("Failed to create files directory")?;
// Initialize empty timeline if it doesn't exist
if !paths.timeline_file.exists() {
let timeline = SessionTimeline::new(session_id.to_string());
self.save_timeline(&paths.timeline_file, &timeline)?;
}
Ok(())
}
/// Save a checkpoint to disk
pub fn save_checkpoint(
&self,
@@ -55,76 +53,73 @@ impl CheckpointStorage {
) -> Result<CheckpointResult> {
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
let checkpoint_dir = paths.checkpoint_dir(&checkpoint.id);
// Create checkpoint directory
fs::create_dir_all(&checkpoint_dir)
.context("Failed to create checkpoint directory")?;
fs::create_dir_all(&checkpoint_dir).context("Failed to create checkpoint directory")?;
// Save checkpoint metadata
let metadata_path = paths.checkpoint_metadata_file(&checkpoint.id);
let metadata_json = serde_json::to_string_pretty(checkpoint)
.context("Failed to serialize checkpoint metadata")?;
fs::write(&metadata_path, metadata_json)
.context("Failed to write checkpoint metadata")?;
fs::write(&metadata_path, metadata_json).context("Failed to write checkpoint metadata")?;
// Save messages (compressed)
let messages_path = paths.checkpoint_messages_file(&checkpoint.id);
let compressed_messages = encode_all(messages.as_bytes(), self.compression_level)
.context("Failed to compress messages")?;
fs::write(&messages_path, compressed_messages)
.context("Failed to write compressed messages")?;
// Save file snapshots
let mut warnings = Vec::new();
let mut files_processed = 0;
for snapshot in &file_snapshots {
match self.save_file_snapshot(&paths, snapshot) {
Ok(_) => files_processed += 1,
Err(e) => warnings.push(format!("Failed to save {}: {}",
snapshot.file_path.display(), e)),
Err(e) => warnings.push(format!(
"Failed to save {}: {}",
snapshot.file_path.display(),
e
)),
}
}
// Update timeline
self.update_timeline_with_checkpoint(
&paths.timeline_file,
checkpoint,
&file_snapshots
)?;
self.update_timeline_with_checkpoint(&paths.timeline_file, checkpoint, &file_snapshots)?;
Ok(CheckpointResult {
checkpoint: checkpoint.clone(),
files_processed,
warnings,
})
}
/// Save a single file snapshot
fn save_file_snapshot(&self, paths: &CheckpointPaths, snapshot: &FileSnapshot) -> Result<()> {
// Use content-addressable storage: store files by their hash
// This prevents duplication of identical file content across checkpoints
let content_pool_dir = paths.files_dir.join("content_pool");
fs::create_dir_all(&content_pool_dir)
.context("Failed to create content pool directory")?;
fs::create_dir_all(&content_pool_dir).context("Failed to create content pool directory")?;
// Store the actual content in the content pool
let content_file = content_pool_dir.join(&snapshot.hash);
// Only write the content if it doesn't already exist
if !content_file.exists() {
// Compress and save file content
let compressed_content = encode_all(snapshot.content.as_bytes(), self.compression_level)
.context("Failed to compress file content")?;
let compressed_content =
encode_all(snapshot.content.as_bytes(), self.compression_level)
.context("Failed to compress file content")?;
fs::write(&content_file, compressed_content)
.context("Failed to write file content to pool")?;
}
// Create a reference in the checkpoint-specific directory
let checkpoint_refs_dir = paths.files_dir.join("refs").join(&snapshot.checkpoint_id);
fs::create_dir_all(&checkpoint_refs_dir)
.context("Failed to create checkpoint refs directory")?;
// Save file metadata with reference to content
let ref_metadata = serde_json::json!({
"path": snapshot.file_path,
@@ -133,20 +128,21 @@ impl CheckpointStorage {
"permissions": snapshot.permissions,
"size": snapshot.size,
});
// Use a sanitized filename for the reference
let safe_filename = snapshot.file_path
let safe_filename = snapshot
.file_path
.to_string_lossy()
.replace('/', "_")
.replace('\\', "_");
let ref_path = checkpoint_refs_dir.join(format!("{}.json", safe_filename));
fs::write(&ref_path, serde_json::to_string_pretty(&ref_metadata)?)
.context("Failed to write file reference")?;
Ok(())
}
/// Load a checkpoint from disk
pub fn load_checkpoint(
&self,
@@ -155,75 +151,78 @@ impl CheckpointStorage {
checkpoint_id: &str,
) -> Result<(Checkpoint, Vec<FileSnapshot>, String)> {
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
// Load checkpoint metadata
let metadata_path = paths.checkpoint_metadata_file(checkpoint_id);
let metadata_json = fs::read_to_string(&metadata_path)
.context("Failed to read checkpoint metadata")?;
let checkpoint: Checkpoint = serde_json::from_str(&metadata_json)
.context("Failed to parse checkpoint metadata")?;
let metadata_json =
fs::read_to_string(&metadata_path).context("Failed to read checkpoint metadata")?;
let checkpoint: Checkpoint =
serde_json::from_str(&metadata_json).context("Failed to parse checkpoint metadata")?;
// Load messages
let messages_path = paths.checkpoint_messages_file(checkpoint_id);
let compressed_messages = fs::read(&messages_path)
.context("Failed to read compressed messages")?;
let messages = String::from_utf8(decode_all(&compressed_messages[..])
.context("Failed to decompress messages")?)
.context("Invalid UTF-8 in messages")?;
let compressed_messages =
fs::read(&messages_path).context("Failed to read compressed messages")?;
let messages = String::from_utf8(
decode_all(&compressed_messages[..]).context("Failed to decompress messages")?,
)
.context("Invalid UTF-8 in messages")?;
// Load file snapshots
let file_snapshots = self.load_file_snapshots(&paths, checkpoint_id)?;
Ok((checkpoint, file_snapshots, messages))
}
/// Load all file snapshots for a checkpoint
fn load_file_snapshots(
&self,
paths: &CheckpointPaths,
checkpoint_id: &str
&self,
paths: &CheckpointPaths,
checkpoint_id: &str,
) -> Result<Vec<FileSnapshot>> {
let refs_dir = paths.files_dir.join("refs").join(checkpoint_id);
if !refs_dir.exists() {
return Ok(Vec::new());
}
let content_pool_dir = paths.files_dir.join("content_pool");
let mut snapshots = Vec::new();
// Read all reference files
for entry in fs::read_dir(&refs_dir)? {
let entry = entry?;
let path = entry.path();
// Skip non-JSON files
if path.extension().and_then(|e| e.to_str()) != Some("json") {
continue;
}
// Load reference metadata
let ref_json = fs::read_to_string(&path)
.context("Failed to read file reference")?;
let ref_metadata: serde_json::Value = serde_json::from_str(&ref_json)
.context("Failed to parse file reference")?;
let hash = ref_metadata["hash"].as_str()
let ref_json = fs::read_to_string(&path).context("Failed to read file reference")?;
let ref_metadata: serde_json::Value =
serde_json::from_str(&ref_json).context("Failed to parse file reference")?;
let hash = ref_metadata["hash"]
.as_str()
.ok_or_else(|| anyhow::anyhow!("Missing hash in reference"))?;
// Load content from pool
let content_file = content_pool_dir.join(hash);
let content = if content_file.exists() {
let compressed_content = fs::read(&content_file)
.context("Failed to read file content from pool")?;
String::from_utf8(decode_all(&compressed_content[..])
.context("Failed to decompress file content")?)
.context("Invalid UTF-8 in file content")?
let compressed_content =
fs::read(&content_file).context("Failed to read file content from pool")?;
String::from_utf8(
decode_all(&compressed_content[..])
.context("Failed to decompress file content")?,
)
.context("Invalid UTF-8 in file content")?
} else {
// Handle missing content gracefully
log::warn!("Content file missing for hash: {}", hash);
String::new()
};
snapshots.push(FileSnapshot {
checkpoint_id: checkpoint_id.to_string(),
file_path: PathBuf::from(ref_metadata["path"].as_str().unwrap_or("")),
@@ -234,28 +233,26 @@ impl CheckpointStorage {
size: ref_metadata["size"].as_u64().unwrap_or(0),
});
}
Ok(snapshots)
}
/// Save timeline to disk
pub fn save_timeline(&self, timeline_path: &Path, timeline: &SessionTimeline) -> Result<()> {
let timeline_json = serde_json::to_string_pretty(timeline)
.context("Failed to serialize timeline")?;
fs::write(timeline_path, timeline_json)
.context("Failed to write timeline")?;
let timeline_json =
serde_json::to_string_pretty(timeline).context("Failed to serialize timeline")?;
fs::write(timeline_path, timeline_json).context("Failed to write timeline")?;
Ok(())
}
/// Load timeline from disk
pub fn load_timeline(&self, timeline_path: &Path) -> Result<SessionTimeline> {
let timeline_json = fs::read_to_string(timeline_path)
.context("Failed to read timeline")?;
let timeline: SessionTimeline = serde_json::from_str(&timeline_json)
.context("Failed to parse timeline")?;
let timeline_json = fs::read_to_string(timeline_path).context("Failed to read timeline")?;
let timeline: SessionTimeline =
serde_json::from_str(&timeline_json).context("Failed to parse timeline")?;
Ok(timeline)
}
/// Update timeline with a new checkpoint
fn update_timeline_with_checkpoint(
&self,
@@ -264,15 +261,13 @@ impl CheckpointStorage {
file_snapshots: &[FileSnapshot],
) -> Result<()> {
let mut timeline = self.load_timeline(timeline_path)?;
let new_node = TimelineNode {
checkpoint: checkpoint.clone(),
children: Vec::new(),
file_snapshot_ids: file_snapshots.iter()
.map(|s| s.hash.clone())
.collect(),
file_snapshot_ids: file_snapshots.iter().map(|s| s.hash.clone()).collect(),
};
// If this is the first checkpoint
if timeline.root_node.is_none() {
timeline.root_node = Some(new_node);
@@ -280,7 +275,7 @@ impl CheckpointStorage {
} else if let Some(parent_id) = &checkpoint.parent_checkpoint_id {
// Check if parent exists before modifying
let parent_exists = timeline.find_checkpoint(parent_id).is_some();
if parent_exists {
if let Some(root) = &mut timeline.root_node {
Self::add_child_to_node(root, parent_id, new_node)?;
@@ -290,59 +285,54 @@ impl CheckpointStorage {
anyhow::bail!("Parent checkpoint not found: {}", parent_id);
}
}
timeline.total_checkpoints += 1;
self.save_timeline(timeline_path, &timeline)?;
Ok(())
}
/// Recursively add a child node to the timeline tree
fn add_child_to_node(
node: &mut TimelineNode,
parent_id: &str,
child: TimelineNode
node: &mut TimelineNode,
parent_id: &str,
child: TimelineNode,
) -> Result<()> {
if node.checkpoint.id == parent_id {
node.children.push(child);
return Ok(());
}
for child_node in &mut node.children {
if Self::add_child_to_node(child_node, parent_id, child.clone()).is_ok() {
return Ok(());
}
}
anyhow::bail!("Parent checkpoint not found: {}", parent_id)
}
/// Calculate hash of file content
pub fn calculate_file_hash(content: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(content.as_bytes());
format!("{:x}", hasher.finalize())
}
/// Generate a new checkpoint ID
pub fn generate_checkpoint_id() -> String {
Uuid::new_v4().to_string()
}
/// Estimate storage size for a checkpoint
pub fn estimate_checkpoint_size(
messages: &str,
file_snapshots: &[FileSnapshot],
) -> u64 {
pub fn estimate_checkpoint_size(messages: &str, file_snapshots: &[FileSnapshot]) -> u64 {
let messages_size = messages.len() as u64;
let files_size: u64 = file_snapshots.iter()
.map(|s| s.content.len() as u64)
.sum();
let files_size: u64 = file_snapshots.iter().map(|s| s.content.len() as u64).sum();
// Estimate compressed size (typically 20-30% of original for text)
(messages_size + files_size) / 4
}
/// Clean up old checkpoints based on retention policy
pub fn cleanup_old_checkpoints(
&self,
@@ -352,26 +342,26 @@ impl CheckpointStorage {
) -> Result<usize> {
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
let timeline = self.load_timeline(&paths.timeline_file)?;
// Collect all checkpoint IDs in chronological order
let mut all_checkpoints = Vec::new();
if let Some(root) = &timeline.root_node {
Self::collect_checkpoints(root, &mut all_checkpoints);
}
// Sort by timestamp (oldest first)
all_checkpoints.sort_by(|a, b| a.timestamp.cmp(&b.timestamp));
// Keep only the most recent checkpoints
let to_remove = all_checkpoints.len().saturating_sub(keep_count);
let mut removed_count = 0;
for checkpoint in all_checkpoints.into_iter().take(to_remove) {
if self.remove_checkpoint(&paths, &checkpoint.id).is_ok() {
removed_count += 1;
}
}
// Run garbage collection to clean up orphaned content
if removed_count > 0 {
match self.garbage_collect_content(project_id, session_id) {
@@ -383,10 +373,10 @@ impl CheckpointStorage {
}
}
}
Ok(removed_count)
}
/// Collect all checkpoints from the tree in order
fn collect_checkpoints(node: &TimelineNode, checkpoints: &mut Vec<Checkpoint>) {
checkpoints.push(node.checkpoint.clone());
@@ -394,46 +384,40 @@ impl CheckpointStorage {
Self::collect_checkpoints(child, checkpoints);
}
}
/// Remove a checkpoint and its associated files
fn remove_checkpoint(&self, paths: &CheckpointPaths, checkpoint_id: &str) -> Result<()> {
// Remove checkpoint metadata directory
let checkpoint_dir = paths.checkpoint_dir(checkpoint_id);
if checkpoint_dir.exists() {
fs::remove_dir_all(&checkpoint_dir)
.context("Failed to remove checkpoint directory")?;
fs::remove_dir_all(&checkpoint_dir).context("Failed to remove checkpoint directory")?;
}
// Remove file references for this checkpoint
let refs_dir = paths.files_dir.join("refs").join(checkpoint_id);
if refs_dir.exists() {
fs::remove_dir_all(&refs_dir)
.context("Failed to remove file references")?;
fs::remove_dir_all(&refs_dir).context("Failed to remove file references")?;
}
// Note: We don't remove content from the pool here as it might be
// referenced by other checkpoints. Use garbage_collect_content() for that.
Ok(())
}
/// Garbage collect unreferenced content from the content pool
pub fn garbage_collect_content(
&self,
project_id: &str,
session_id: &str,
) -> Result<usize> {
pub fn garbage_collect_content(&self, project_id: &str, session_id: &str) -> Result<usize> {
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
let content_pool_dir = paths.files_dir.join("content_pool");
let refs_dir = paths.files_dir.join("refs");
if !content_pool_dir.exists() {
return Ok(0);
}
// Collect all referenced hashes
let mut referenced_hashes = std::collections::HashSet::new();
if refs_dir.exists() {
for checkpoint_entry in fs::read_dir(&refs_dir)? {
let checkpoint_dir = checkpoint_entry?.path();
@@ -442,7 +426,9 @@ impl CheckpointStorage {
let ref_path = ref_entry?.path();
if ref_path.extension().and_then(|e| e.to_str()) == Some("json") {
if let Ok(ref_json) = fs::read_to_string(&ref_path) {
if let Ok(ref_metadata) = serde_json::from_str::<serde_json::Value>(&ref_json) {
if let Ok(ref_metadata) =
serde_json::from_str::<serde_json::Value>(&ref_json)
{
if let Some(hash) = ref_metadata["hash"].as_str() {
referenced_hashes.insert(hash.to_string());
}
@@ -453,7 +439,7 @@ impl CheckpointStorage {
}
}
}
// Remove unreferenced content
let mut removed_count = 0;
for entry in fs::read_dir(&content_pool_dir)? {
@@ -468,7 +454,7 @@ impl CheckpointStorage {
}
}
}
Ok(removed_count)
}
}
}