init: push source
This commit is contained in:
741
src-tauri/src/checkpoint/manager.rs
Normal file
741
src-tauri/src/checkpoint/manager.rs
Normal file
@@ -0,0 +1,741 @@
|
||||
use anyhow::{Context, Result};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use chrono::{Utc, TimeZone, DateTime};
|
||||
use tokio::sync::RwLock;
|
||||
use log;
|
||||
|
||||
use super::{
|
||||
Checkpoint, CheckpointMetadata, FileSnapshot, FileTracker, FileState,
|
||||
CheckpointResult, SessionTimeline, CheckpointStrategy, CheckpointPaths,
|
||||
storage::{CheckpointStorage, self},
|
||||
};
|
||||
|
||||
/// Manages checkpoint operations for a session
|
||||
pub struct CheckpointManager {
|
||||
project_id: String,
|
||||
session_id: String,
|
||||
project_path: PathBuf,
|
||||
file_tracker: Arc<RwLock<FileTracker>>,
|
||||
pub storage: Arc<CheckpointStorage>,
|
||||
timeline: Arc<RwLock<SessionTimeline>>,
|
||||
current_messages: Arc<RwLock<Vec<String>>>, // JSONL messages
|
||||
}
|
||||
|
||||
impl CheckpointManager {
|
||||
/// Create a new checkpoint manager
|
||||
pub async fn new(
|
||||
project_id: String,
|
||||
session_id: String,
|
||||
project_path: PathBuf,
|
||||
claude_dir: PathBuf,
|
||||
) -> Result<Self> {
|
||||
let storage = Arc::new(CheckpointStorage::new(claude_dir.clone()));
|
||||
|
||||
// Initialize storage
|
||||
storage.init_storage(&project_id, &session_id)?;
|
||||
|
||||
// Load or create timeline
|
||||
let paths = CheckpointPaths::new(&claude_dir, &project_id, &session_id);
|
||||
let timeline = if paths.timeline_file.exists() {
|
||||
storage.load_timeline(&paths.timeline_file)?
|
||||
} else {
|
||||
SessionTimeline::new(session_id.clone())
|
||||
};
|
||||
|
||||
let file_tracker = FileTracker {
|
||||
tracked_files: HashMap::new(),
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
project_id,
|
||||
session_id,
|
||||
project_path,
|
||||
file_tracker: Arc::new(RwLock::new(file_tracker)),
|
||||
storage,
|
||||
timeline: Arc::new(RwLock::new(timeline)),
|
||||
current_messages: Arc::new(RwLock::new(Vec::new())),
|
||||
})
|
||||
}
|
||||
|
||||
/// Track a new message in the session
|
||||
pub async fn track_message(&self, jsonl_message: String) -> Result<()> {
|
||||
let mut messages = self.current_messages.write().await;
|
||||
messages.push(jsonl_message.clone());
|
||||
|
||||
// Parse message to check for tool usage
|
||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(&jsonl_message) {
|
||||
if let Some(content) = msg.get("message").and_then(|m| m.get("content")) {
|
||||
if let Some(content_array) = content.as_array() {
|
||||
for item in content_array {
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
|
||||
if let Some(tool_name) = item.get("name").and_then(|n| n.as_str()) {
|
||||
if let Some(input) = item.get("input") {
|
||||
self.track_tool_operation(tool_name, input).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Track file operations from tool usage
|
||||
async fn track_tool_operation(&self, tool: &str, input: &serde_json::Value) -> Result<()> {
|
||||
match tool.to_lowercase().as_str() {
|
||||
"edit" | "write" | "multiedit" => {
|
||||
if let Some(file_path) = input.get("file_path").and_then(|p| p.as_str()) {
|
||||
self.track_file_modification(file_path).await?;
|
||||
}
|
||||
}
|
||||
"bash" => {
|
||||
// Try to detect file modifications from bash commands
|
||||
if let Some(command) = input.get("command").and_then(|c| c.as_str()) {
|
||||
self.track_bash_side_effects(command).await?;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Track a file modification
|
||||
pub async fn track_file_modification(&self, file_path: &str) -> Result<()> {
|
||||
let mut tracker = self.file_tracker.write().await;
|
||||
let full_path = self.project_path.join(file_path);
|
||||
|
||||
// Read current file state
|
||||
let (hash, exists, _size, modified) = if full_path.exists() {
|
||||
let content = fs::read_to_string(&full_path)
|
||||
.unwrap_or_default();
|
||||
let metadata = fs::metadata(&full_path)?;
|
||||
let modified = metadata.modified()
|
||||
.ok()
|
||||
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
|
||||
.map(|d| Utc.timestamp_opt(d.as_secs() as i64, d.subsec_nanos()).unwrap())
|
||||
.unwrap_or_else(Utc::now);
|
||||
|
||||
(
|
||||
storage::CheckpointStorage::calculate_file_hash(&content),
|
||||
true,
|
||||
metadata.len(),
|
||||
modified
|
||||
)
|
||||
} else {
|
||||
(String::new(), false, 0, Utc::now())
|
||||
};
|
||||
|
||||
// Check if file has actually changed
|
||||
let is_modified = if let Some(existing_state) = tracker.tracked_files.get(&PathBuf::from(file_path)) {
|
||||
// File is modified if:
|
||||
// 1. Hash has changed
|
||||
// 2. Existence state has changed
|
||||
// 3. It was already marked as modified
|
||||
existing_state.last_hash != hash ||
|
||||
existing_state.exists != exists ||
|
||||
existing_state.is_modified
|
||||
} else {
|
||||
// New file is always considered modified
|
||||
true
|
||||
};
|
||||
|
||||
tracker.tracked_files.insert(
|
||||
PathBuf::from(file_path),
|
||||
FileState {
|
||||
last_hash: hash,
|
||||
is_modified,
|
||||
last_modified: modified,
|
||||
exists,
|
||||
},
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Track potential file changes from bash commands
|
||||
async fn track_bash_side_effects(&self, command: &str) -> Result<()> {
|
||||
// Common file-modifying commands
|
||||
let file_commands = [
|
||||
"echo", "cat", "cp", "mv", "rm", "touch", "sed", "awk",
|
||||
"npm", "yarn", "pnpm", "bun", "cargo", "make", "gcc", "g++",
|
||||
];
|
||||
|
||||
// Simple heuristic: if command contains file-modifying operations
|
||||
for cmd in &file_commands {
|
||||
if command.contains(cmd) {
|
||||
// Mark all tracked files as potentially modified
|
||||
let mut tracker = self.file_tracker.write().await;
|
||||
for (_, state) in tracker.tracked_files.iter_mut() {
|
||||
state.is_modified = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a checkpoint
|
||||
pub async fn create_checkpoint(
|
||||
&self,
|
||||
description: Option<String>,
|
||||
parent_checkpoint_id: Option<String>,
|
||||
) -> Result<CheckpointResult> {
|
||||
let messages = self.current_messages.read().await;
|
||||
let message_index = messages.len().saturating_sub(1);
|
||||
|
||||
// Extract metadata from the last user message
|
||||
let (user_prompt, model_used, total_tokens) = self.extract_checkpoint_metadata(&messages).await?;
|
||||
|
||||
// Ensure every file in the project is tracked so new checkpoints include all files
|
||||
// Recursively walk the project directory and track each file
|
||||
fn collect_files(dir: &std::path::Path, base: &std::path::Path, files: &mut Vec<std::path::PathBuf>) -> Result<(), std::io::Error> {
|
||||
for entry in std::fs::read_dir(dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
// Skip hidden directories like .git
|
||||
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
|
||||
if name.starts_with('.') {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
collect_files(&path, base, files)?;
|
||||
} else if path.is_file() {
|
||||
// Compute relative path from project root
|
||||
if let Ok(rel) = path.strip_prefix(base) {
|
||||
files.push(rel.to_path_buf());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
let mut all_files = Vec::new();
|
||||
let project_dir = &self.project_path;
|
||||
let _ = collect_files(project_dir.as_path(), project_dir.as_path(), &mut all_files);
|
||||
for rel in all_files {
|
||||
if let Some(p) = rel.to_str() {
|
||||
// Track each file for snapshot
|
||||
let _ = self.track_file_modification(p).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Generate checkpoint ID early so snapshots reference it
|
||||
let checkpoint_id = storage::CheckpointStorage::generate_checkpoint_id();
|
||||
|
||||
// Create file snapshots
|
||||
let file_snapshots = self.create_file_snapshots(&checkpoint_id).await?;
|
||||
|
||||
// Generate checkpoint struct
|
||||
let checkpoint = Checkpoint {
|
||||
id: checkpoint_id.clone(),
|
||||
session_id: self.session_id.clone(),
|
||||
project_id: self.project_id.clone(),
|
||||
message_index,
|
||||
timestamp: Utc::now(),
|
||||
description,
|
||||
parent_checkpoint_id: {
|
||||
if let Some(parent_id) = parent_checkpoint_id {
|
||||
Some(parent_id)
|
||||
} else {
|
||||
// Perform an asynchronous read to avoid blocking within the runtime
|
||||
let timeline = self.timeline.read().await;
|
||||
timeline.current_checkpoint_id.clone()
|
||||
}
|
||||
},
|
||||
metadata: CheckpointMetadata {
|
||||
total_tokens,
|
||||
model_used,
|
||||
user_prompt,
|
||||
file_changes: file_snapshots.len(),
|
||||
snapshot_size: storage::CheckpointStorage::estimate_checkpoint_size(
|
||||
&messages.join("\n"),
|
||||
&file_snapshots,
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
// Save checkpoint
|
||||
let messages_content = messages.join("\n");
|
||||
let result = self.storage.save_checkpoint(
|
||||
&self.project_id,
|
||||
&self.session_id,
|
||||
&checkpoint,
|
||||
file_snapshots,
|
||||
&messages_content,
|
||||
)?;
|
||||
|
||||
// Reload timeline from disk so in-memory timeline has updated nodes and total_checkpoints
|
||||
let claude_dir = self.storage.claude_dir.clone();
|
||||
let paths = CheckpointPaths::new(&claude_dir, &self.project_id, &self.session_id);
|
||||
let updated_timeline = self.storage.load_timeline(&paths.timeline_file)?;
|
||||
{
|
||||
let mut timeline_lock = self.timeline.write().await;
|
||||
*timeline_lock = updated_timeline;
|
||||
}
|
||||
|
||||
// Update timeline (current checkpoint only)
|
||||
let mut timeline = self.timeline.write().await;
|
||||
timeline.current_checkpoint_id = Some(checkpoint_id);
|
||||
|
||||
// Reset file tracker
|
||||
let mut tracker = self.file_tracker.write().await;
|
||||
for (_, state) in tracker.tracked_files.iter_mut() {
|
||||
state.is_modified = false;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Extract metadata from messages for checkpoint
|
||||
async fn extract_checkpoint_metadata(
|
||||
&self,
|
||||
messages: &[String],
|
||||
) -> Result<(String, String, u64)> {
|
||||
let mut user_prompt = String::new();
|
||||
let mut model_used = String::from("unknown");
|
||||
let mut total_tokens = 0u64;
|
||||
|
||||
// Iterate through messages in reverse to find the last user prompt
|
||||
for msg_str in messages.iter().rev() {
|
||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(msg_str) {
|
||||
// Check for user message
|
||||
if msg.get("type").and_then(|t| t.as_str()) == Some("user") {
|
||||
if let Some(content) = msg.get("message")
|
||||
.and_then(|m| m.get("content"))
|
||||
.and_then(|c| c.as_array())
|
||||
{
|
||||
for item in content {
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("text") {
|
||||
if let Some(text) = item.get("text").and_then(|t| t.as_str()) {
|
||||
user_prompt = text.to_string();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract model info
|
||||
if let Some(model) = msg.get("model").and_then(|m| m.as_str()) {
|
||||
model_used = model.to_string();
|
||||
}
|
||||
|
||||
// Also check for model in message.model (assistant messages)
|
||||
if let Some(message) = msg.get("message") {
|
||||
if let Some(model) = message.get("model").and_then(|m| m.as_str()) {
|
||||
model_used = model.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Count tokens - check both top-level and nested usage
|
||||
// First check for usage in message.usage (assistant messages)
|
||||
if let Some(message) = msg.get("message") {
|
||||
if let Some(usage) = message.get("usage") {
|
||||
if let Some(input) = usage.get("input_tokens").and_then(|t| t.as_u64()) {
|
||||
total_tokens += input;
|
||||
}
|
||||
if let Some(output) = usage.get("output_tokens").and_then(|t| t.as_u64()) {
|
||||
total_tokens += output;
|
||||
}
|
||||
// Also count cache tokens
|
||||
if let Some(cache_creation) = usage.get("cache_creation_input_tokens").and_then(|t| t.as_u64()) {
|
||||
total_tokens += cache_creation;
|
||||
}
|
||||
if let Some(cache_read) = usage.get("cache_read_input_tokens").and_then(|t| t.as_u64()) {
|
||||
total_tokens += cache_read;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Then check for top-level usage (result messages)
|
||||
if let Some(usage) = msg.get("usage") {
|
||||
if let Some(input) = usage.get("input_tokens").and_then(|t| t.as_u64()) {
|
||||
total_tokens += input;
|
||||
}
|
||||
if let Some(output) = usage.get("output_tokens").and_then(|t| t.as_u64()) {
|
||||
total_tokens += output;
|
||||
}
|
||||
// Also count cache tokens
|
||||
if let Some(cache_creation) = usage.get("cache_creation_input_tokens").and_then(|t| t.as_u64()) {
|
||||
total_tokens += cache_creation;
|
||||
}
|
||||
if let Some(cache_read) = usage.get("cache_read_input_tokens").and_then(|t| t.as_u64()) {
|
||||
total_tokens += cache_read;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((user_prompt, model_used, total_tokens))
|
||||
}
|
||||
|
||||
/// Create file snapshots for all tracked modified files
|
||||
async fn create_file_snapshots(&self, checkpoint_id: &str) -> Result<Vec<FileSnapshot>> {
|
||||
let tracker = self.file_tracker.read().await;
|
||||
let mut snapshots = Vec::new();
|
||||
|
||||
for (rel_path, state) in &tracker.tracked_files {
|
||||
// Skip files that haven't been modified
|
||||
if !state.is_modified {
|
||||
continue;
|
||||
}
|
||||
|
||||
let full_path = self.project_path.join(rel_path);
|
||||
|
||||
let (content, exists, permissions, size, current_hash) = if full_path.exists() {
|
||||
let content = fs::read_to_string(&full_path)
|
||||
.unwrap_or_default();
|
||||
let current_hash = storage::CheckpointStorage::calculate_file_hash(&content);
|
||||
|
||||
// Don't skip based on hash - if is_modified is true, we should snapshot it
|
||||
// The hash check in track_file_modification already determined if it changed
|
||||
|
||||
let metadata = fs::metadata(&full_path)?;
|
||||
let permissions = {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
Some(metadata.permissions().mode())
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
None
|
||||
}
|
||||
};
|
||||
(content, true, permissions, metadata.len(), current_hash)
|
||||
} else {
|
||||
(String::new(), false, None, 0, String::new())
|
||||
};
|
||||
|
||||
snapshots.push(FileSnapshot {
|
||||
checkpoint_id: checkpoint_id.to_string(),
|
||||
file_path: rel_path.clone(),
|
||||
content,
|
||||
hash: current_hash,
|
||||
is_deleted: !exists,
|
||||
permissions,
|
||||
size,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(snapshots)
|
||||
}
|
||||
|
||||
/// Restore a checkpoint
|
||||
pub async fn restore_checkpoint(&self, checkpoint_id: &str) -> Result<CheckpointResult> {
|
||||
// Load checkpoint data
|
||||
let (checkpoint, file_snapshots, messages) = self.storage.load_checkpoint(
|
||||
&self.project_id,
|
||||
&self.session_id,
|
||||
checkpoint_id,
|
||||
)?;
|
||||
|
||||
// First, collect all files currently in the project to handle deletions
|
||||
fn collect_all_project_files(dir: &std::path::Path, base: &std::path::Path, files: &mut Vec<std::path::PathBuf>) -> Result<(), std::io::Error> {
|
||||
for entry in std::fs::read_dir(dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
// Skip hidden directories like .git
|
||||
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
|
||||
if name.starts_with('.') {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
collect_all_project_files(&path, base, files)?;
|
||||
} else if path.is_file() {
|
||||
// Compute relative path from project root
|
||||
if let Ok(rel) = path.strip_prefix(base) {
|
||||
files.push(rel.to_path_buf());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let mut current_files = Vec::new();
|
||||
let _ = collect_all_project_files(&self.project_path, &self.project_path, &mut current_files);
|
||||
|
||||
// Create a set of files that should exist after restore
|
||||
let mut checkpoint_files = std::collections::HashSet::new();
|
||||
for snapshot in &file_snapshots {
|
||||
if !snapshot.is_deleted {
|
||||
checkpoint_files.insert(snapshot.file_path.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Delete files that exist now but shouldn't exist in the checkpoint
|
||||
let mut warnings = Vec::new();
|
||||
let mut files_processed = 0;
|
||||
|
||||
for current_file in current_files {
|
||||
if !checkpoint_files.contains(¤t_file) {
|
||||
// This file exists now but not in the checkpoint, so delete it
|
||||
let full_path = self.project_path.join(¤t_file);
|
||||
match fs::remove_file(&full_path) {
|
||||
Ok(_) => {
|
||||
files_processed += 1;
|
||||
log::info!("Deleted file not in checkpoint: {:?}", current_file);
|
||||
}
|
||||
Err(e) => {
|
||||
warnings.push(format!("Failed to delete {}: {}", current_file.display(), e));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up empty directories
|
||||
fn remove_empty_dirs(dir: &std::path::Path, base: &std::path::Path) -> Result<bool, std::io::Error> {
|
||||
if dir == base {
|
||||
return Ok(false); // Don't remove the base directory
|
||||
}
|
||||
|
||||
let mut is_empty = true;
|
||||
for entry in fs::read_dir(dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
if !remove_empty_dirs(&path, base)? {
|
||||
is_empty = false;
|
||||
}
|
||||
} else {
|
||||
is_empty = false;
|
||||
}
|
||||
}
|
||||
|
||||
if is_empty {
|
||||
fs::remove_dir(dir)?;
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up any empty directories left after file deletion
|
||||
let _ = remove_empty_dirs(&self.project_path, &self.project_path);
|
||||
|
||||
// Restore files from checkpoint
|
||||
for snapshot in &file_snapshots {
|
||||
match self.restore_file_snapshot(snapshot).await {
|
||||
Ok(_) => files_processed += 1,
|
||||
Err(e) => warnings.push(format!("Failed to restore {}: {}",
|
||||
snapshot.file_path.display(), e)),
|
||||
}
|
||||
}
|
||||
|
||||
// Update current messages
|
||||
let mut current_messages = self.current_messages.write().await;
|
||||
current_messages.clear();
|
||||
for line in messages.lines() {
|
||||
current_messages.push(line.to_string());
|
||||
}
|
||||
|
||||
// Update timeline
|
||||
let mut timeline = self.timeline.write().await;
|
||||
timeline.current_checkpoint_id = Some(checkpoint_id.to_string());
|
||||
|
||||
// Update file tracker
|
||||
let mut tracker = self.file_tracker.write().await;
|
||||
tracker.tracked_files.clear();
|
||||
for snapshot in &file_snapshots {
|
||||
if !snapshot.is_deleted {
|
||||
tracker.tracked_files.insert(
|
||||
snapshot.file_path.clone(),
|
||||
FileState {
|
||||
last_hash: snapshot.hash.clone(),
|
||||
is_modified: false,
|
||||
last_modified: Utc::now(),
|
||||
exists: true,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(CheckpointResult {
|
||||
checkpoint: checkpoint.clone(),
|
||||
files_processed,
|
||||
warnings,
|
||||
})
|
||||
}
|
||||
|
||||
/// Restore a single file from snapshot
|
||||
async fn restore_file_snapshot(&self, snapshot: &FileSnapshot) -> Result<()> {
|
||||
let full_path = self.project_path.join(&snapshot.file_path);
|
||||
|
||||
if snapshot.is_deleted {
|
||||
// Delete the file if it exists
|
||||
if full_path.exists() {
|
||||
fs::remove_file(&full_path)
|
||||
.context("Failed to delete file")?;
|
||||
}
|
||||
} else {
|
||||
// Create parent directories if needed
|
||||
if let Some(parent) = full_path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.context("Failed to create parent directories")?;
|
||||
}
|
||||
|
||||
// Write file content
|
||||
fs::write(&full_path, &snapshot.content)
|
||||
.context("Failed to write file")?;
|
||||
|
||||
// Restore permissions if available
|
||||
#[cfg(unix)]
|
||||
if let Some(mode) = snapshot.permissions {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let permissions = std::fs::Permissions::from_mode(mode);
|
||||
fs::set_permissions(&full_path, permissions)
|
||||
.context("Failed to set file permissions")?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the current timeline
|
||||
pub async fn get_timeline(&self) -> SessionTimeline {
|
||||
self.timeline.read().await.clone()
|
||||
}
|
||||
|
||||
/// List all checkpoints
|
||||
pub async fn list_checkpoints(&self) -> Vec<Checkpoint> {
|
||||
let timeline = self.timeline.read().await;
|
||||
let mut checkpoints = Vec::new();
|
||||
|
||||
if let Some(root) = &timeline.root_node {
|
||||
Self::collect_checkpoints_from_node(root, &mut checkpoints);
|
||||
}
|
||||
|
||||
checkpoints
|
||||
}
|
||||
|
||||
/// Recursively collect checkpoints from timeline tree
|
||||
fn collect_checkpoints_from_node(node: &super::TimelineNode, checkpoints: &mut Vec<Checkpoint>) {
|
||||
checkpoints.push(node.checkpoint.clone());
|
||||
for child in &node.children {
|
||||
Self::collect_checkpoints_from_node(child, checkpoints);
|
||||
}
|
||||
}
|
||||
|
||||
/// Fork from a checkpoint
|
||||
pub async fn fork_from_checkpoint(
|
||||
&self,
|
||||
checkpoint_id: &str,
|
||||
description: Option<String>,
|
||||
) -> Result<CheckpointResult> {
|
||||
// Load the checkpoint to fork from
|
||||
let (_base_checkpoint, _, _) = self.storage.load_checkpoint(
|
||||
&self.project_id,
|
||||
&self.session_id,
|
||||
checkpoint_id,
|
||||
)?;
|
||||
|
||||
// Restore to that checkpoint first
|
||||
self.restore_checkpoint(checkpoint_id).await?;
|
||||
|
||||
// Create a new checkpoint with the fork
|
||||
let fork_description = description.unwrap_or_else(|| {
|
||||
format!("Fork from checkpoint {}", &checkpoint_id[..8])
|
||||
});
|
||||
|
||||
self.create_checkpoint(Some(fork_description), Some(checkpoint_id.to_string())).await
|
||||
}
|
||||
|
||||
/// Check if auto-checkpoint should be triggered
|
||||
pub async fn should_auto_checkpoint(&self, message: &str) -> bool {
|
||||
let timeline = self.timeline.read().await;
|
||||
|
||||
if !timeline.auto_checkpoint_enabled {
|
||||
return false;
|
||||
}
|
||||
|
||||
match timeline.checkpoint_strategy {
|
||||
CheckpointStrategy::Manual => false,
|
||||
CheckpointStrategy::PerPrompt => {
|
||||
// Check if message is a user prompt
|
||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(message) {
|
||||
msg.get("type").and_then(|t| t.as_str()) == Some("user")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
CheckpointStrategy::PerToolUse => {
|
||||
// Check if message contains tool use
|
||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(message) {
|
||||
if let Some(content) = msg.get("message").and_then(|m| m.get("content")).and_then(|c| c.as_array()) {
|
||||
content.iter().any(|item| {
|
||||
item.get("type").and_then(|t| t.as_str()) == Some("tool_use")
|
||||
})
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
CheckpointStrategy::Smart => {
|
||||
// Smart strategy: checkpoint after destructive operations
|
||||
if let Ok(msg) = serde_json::from_str::<serde_json::Value>(message) {
|
||||
if let Some(content) = msg.get("message").and_then(|m| m.get("content")).and_then(|c| c.as_array()) {
|
||||
content.iter().any(|item| {
|
||||
if item.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
|
||||
let tool_name = item.get("name").and_then(|n| n.as_str()).unwrap_or("");
|
||||
matches!(tool_name.to_lowercase().as_str(),
|
||||
"write" | "edit" | "multiedit" | "bash" | "rm" | "delete")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Update checkpoint settings
|
||||
pub async fn update_settings(
|
||||
&self,
|
||||
auto_checkpoint_enabled: bool,
|
||||
checkpoint_strategy: CheckpointStrategy,
|
||||
) -> Result<()> {
|
||||
let mut timeline = self.timeline.write().await;
|
||||
timeline.auto_checkpoint_enabled = auto_checkpoint_enabled;
|
||||
timeline.checkpoint_strategy = checkpoint_strategy;
|
||||
|
||||
// Save updated timeline
|
||||
let claude_dir = self.storage.claude_dir.clone();
|
||||
let paths = CheckpointPaths::new(&claude_dir, &self.project_id, &self.session_id);
|
||||
self.storage.save_timeline(&paths.timeline_file, &timeline)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get files modified since a given timestamp
|
||||
pub async fn get_files_modified_since(&self, since: DateTime<Utc>) -> Vec<PathBuf> {
|
||||
let tracker = self.file_tracker.read().await;
|
||||
tracker.tracked_files
|
||||
.iter()
|
||||
.filter(|(_, state)| state.last_modified > since && state.is_modified)
|
||||
.map(|(path, _)| path.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get the last modification time of any tracked file
|
||||
pub async fn get_last_modification_time(&self) -> Option<DateTime<Utc>> {
|
||||
let tracker = self.file_tracker.read().await;
|
||||
tracker.tracked_files
|
||||
.values()
|
||||
.map(|state| state.last_modified)
|
||||
.max()
|
||||
}
|
||||
}
|
256
src-tauri/src/checkpoint/mod.rs
Normal file
256
src-tauri/src/checkpoint/mod.rs
Normal file
@@ -0,0 +1,256 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use chrono::{DateTime, Utc};
|
||||
|
||||
pub mod manager;
|
||||
pub mod storage;
|
||||
pub mod state;
|
||||
|
||||
/// Represents a checkpoint in the session timeline
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Checkpoint {
|
||||
/// Unique identifier for the checkpoint
|
||||
pub id: String,
|
||||
/// Session ID this checkpoint belongs to
|
||||
pub session_id: String,
|
||||
/// Project ID for the session
|
||||
pub project_id: String,
|
||||
/// Index of the last message in this checkpoint
|
||||
pub message_index: usize,
|
||||
/// Timestamp when checkpoint was created
|
||||
pub timestamp: DateTime<Utc>,
|
||||
/// User-provided description
|
||||
pub description: Option<String>,
|
||||
/// Parent checkpoint ID for fork tracking
|
||||
pub parent_checkpoint_id: Option<String>,
|
||||
/// Metadata about the checkpoint
|
||||
pub metadata: CheckpointMetadata,
|
||||
}
|
||||
|
||||
/// Metadata associated with a checkpoint
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CheckpointMetadata {
|
||||
/// Total tokens used up to this point
|
||||
pub total_tokens: u64,
|
||||
/// Model used for the last operation
|
||||
pub model_used: String,
|
||||
/// The user prompt that led to this state
|
||||
pub user_prompt: String,
|
||||
/// Number of file changes in this checkpoint
|
||||
pub file_changes: usize,
|
||||
/// Size of all file snapshots in bytes
|
||||
pub snapshot_size: u64,
|
||||
}
|
||||
|
||||
/// Represents a snapshot of a file at a checkpoint
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FileSnapshot {
|
||||
/// Checkpoint this snapshot belongs to
|
||||
pub checkpoint_id: String,
|
||||
/// Relative path from project root
|
||||
pub file_path: PathBuf,
|
||||
/// Full content of the file (will be compressed)
|
||||
pub content: String,
|
||||
/// SHA-256 hash for integrity verification
|
||||
pub hash: String,
|
||||
/// Whether this file was deleted at this checkpoint
|
||||
pub is_deleted: bool,
|
||||
/// File permissions (Unix mode)
|
||||
pub permissions: Option<u32>,
|
||||
/// File size in bytes
|
||||
pub size: u64,
|
||||
}
|
||||
|
||||
/// Represents a node in the timeline tree
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TimelineNode {
|
||||
/// The checkpoint at this node
|
||||
pub checkpoint: Checkpoint,
|
||||
/// Child nodes (for branches/forks)
|
||||
pub children: Vec<TimelineNode>,
|
||||
/// IDs of file snapshots associated with this checkpoint
|
||||
pub file_snapshot_ids: Vec<String>,
|
||||
}
|
||||
|
||||
/// The complete timeline for a session
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionTimeline {
|
||||
/// Session ID this timeline belongs to
|
||||
pub session_id: String,
|
||||
/// Root node of the timeline tree
|
||||
pub root_node: Option<TimelineNode>,
|
||||
/// ID of the current active checkpoint
|
||||
pub current_checkpoint_id: Option<String>,
|
||||
/// Whether auto-checkpointing is enabled
|
||||
pub auto_checkpoint_enabled: bool,
|
||||
/// Strategy for automatic checkpoints
|
||||
pub checkpoint_strategy: CheckpointStrategy,
|
||||
/// Total number of checkpoints in timeline
|
||||
pub total_checkpoints: usize,
|
||||
}
|
||||
|
||||
/// Strategy for automatic checkpoint creation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CheckpointStrategy {
|
||||
/// Only create checkpoints manually
|
||||
Manual,
|
||||
/// Create checkpoint after each user prompt
|
||||
PerPrompt,
|
||||
/// Create checkpoint after each tool use
|
||||
PerToolUse,
|
||||
/// Create checkpoint after destructive operations
|
||||
Smart,
|
||||
}
|
||||
|
||||
/// Tracks the state of files for checkpointing
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileTracker {
|
||||
/// Map of file paths to their current state
|
||||
pub tracked_files: HashMap<PathBuf, FileState>,
|
||||
}
|
||||
|
||||
/// State of a tracked file
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileState {
|
||||
/// Last known hash of the file
|
||||
pub last_hash: String,
|
||||
/// Whether the file has been modified since last checkpoint
|
||||
pub is_modified: bool,
|
||||
/// Last modification timestamp
|
||||
pub last_modified: DateTime<Utc>,
|
||||
/// Whether the file currently exists
|
||||
pub exists: bool,
|
||||
}
|
||||
|
||||
/// Result of a checkpoint operation
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CheckpointResult {
|
||||
/// The created/restored checkpoint
|
||||
pub checkpoint: Checkpoint,
|
||||
/// Number of files snapshot/restored
|
||||
pub files_processed: usize,
|
||||
/// Any warnings during the operation
|
||||
pub warnings: Vec<String>,
|
||||
}
|
||||
|
||||
/// Diff between two checkpoints
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CheckpointDiff {
|
||||
/// Source checkpoint ID
|
||||
pub from_checkpoint_id: String,
|
||||
/// Target checkpoint ID
|
||||
pub to_checkpoint_id: String,
|
||||
/// Files that were modified
|
||||
pub modified_files: Vec<FileDiff>,
|
||||
/// Files that were added
|
||||
pub added_files: Vec<PathBuf>,
|
||||
/// Files that were deleted
|
||||
pub deleted_files: Vec<PathBuf>,
|
||||
/// Token usage difference
|
||||
pub token_delta: i64,
|
||||
}
|
||||
|
||||
/// Diff for a single file
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct FileDiff {
|
||||
/// File path
|
||||
pub path: PathBuf,
|
||||
/// Number of additions
|
||||
pub additions: usize,
|
||||
/// Number of deletions
|
||||
pub deletions: usize,
|
||||
/// Unified diff content (optional)
|
||||
pub diff_content: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for CheckpointStrategy {
|
||||
fn default() -> Self {
|
||||
CheckpointStrategy::Smart
|
||||
}
|
||||
}
|
||||
|
||||
impl SessionTimeline {
|
||||
/// Create a new empty timeline
|
||||
pub fn new(session_id: String) -> Self {
|
||||
Self {
|
||||
session_id,
|
||||
root_node: None,
|
||||
current_checkpoint_id: None,
|
||||
auto_checkpoint_enabled: false,
|
||||
checkpoint_strategy: CheckpointStrategy::default(),
|
||||
total_checkpoints: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Find a checkpoint by ID in the timeline tree
|
||||
pub fn find_checkpoint(&self, checkpoint_id: &str) -> Option<&TimelineNode> {
|
||||
self.root_node.as_ref()
|
||||
.and_then(|root| Self::find_in_tree(root, checkpoint_id))
|
||||
}
|
||||
|
||||
fn find_in_tree<'a>(node: &'a TimelineNode, checkpoint_id: &str) -> Option<&'a TimelineNode> {
|
||||
if node.checkpoint.id == checkpoint_id {
|
||||
return Some(node);
|
||||
}
|
||||
|
||||
for child in &node.children {
|
||||
if let Some(found) = Self::find_in_tree(child, checkpoint_id) {
|
||||
return Some(found);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Checkpoint storage paths
|
||||
pub struct CheckpointPaths {
|
||||
pub timeline_file: PathBuf,
|
||||
pub checkpoints_dir: PathBuf,
|
||||
pub files_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl CheckpointPaths {
|
||||
pub fn new(claude_dir: &PathBuf, project_id: &str, session_id: &str) -> Self {
|
||||
let base_dir = claude_dir
|
||||
.join("projects")
|
||||
.join(project_id)
|
||||
.join(".timelines")
|
||||
.join(session_id);
|
||||
|
||||
Self {
|
||||
timeline_file: base_dir.join("timeline.json"),
|
||||
checkpoints_dir: base_dir.join("checkpoints"),
|
||||
files_dir: base_dir.join("files"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn checkpoint_dir(&self, checkpoint_id: &str) -> PathBuf {
|
||||
self.checkpoints_dir.join(checkpoint_id)
|
||||
}
|
||||
|
||||
pub fn checkpoint_metadata_file(&self, checkpoint_id: &str) -> PathBuf {
|
||||
self.checkpoint_dir(checkpoint_id).join("metadata.json")
|
||||
}
|
||||
|
||||
pub fn checkpoint_messages_file(&self, checkpoint_id: &str) -> PathBuf {
|
||||
self.checkpoint_dir(checkpoint_id).join("messages.jsonl")
|
||||
}
|
||||
|
||||
pub fn file_snapshot_path(&self, _checkpoint_id: &str, file_hash: &str) -> PathBuf {
|
||||
// In content-addressable storage, files are stored by hash in the content pool
|
||||
self.files_dir.join("content_pool").join(file_hash)
|
||||
}
|
||||
|
||||
pub fn file_reference_path(&self, checkpoint_id: &str, safe_filename: &str) -> PathBuf {
|
||||
// References are stored per checkpoint
|
||||
self.files_dir.join("refs").join(checkpoint_id).join(format!("{}.json", safe_filename))
|
||||
}
|
||||
}
|
186
src-tauri/src/checkpoint/state.rs
Normal file
186
src-tauri/src/checkpoint/state.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use anyhow::Result;
|
||||
|
||||
use super::manager::CheckpointManager;
|
||||
|
||||
/// Manages checkpoint managers for active sessions
|
||||
///
|
||||
/// This struct maintains a stateful collection of CheckpointManager instances,
|
||||
/// one per active session, to avoid recreating them on every command invocation.
|
||||
/// It provides thread-safe access to managers and handles their lifecycle.
|
||||
#[derive(Default, Clone)]
|
||||
pub struct CheckpointState {
|
||||
/// Map of session_id to CheckpointManager
|
||||
/// Uses Arc<CheckpointManager> to allow sharing across async boundaries
|
||||
managers: Arc<RwLock<HashMap<String, Arc<CheckpointManager>>>>,
|
||||
/// The Claude directory path for consistent access
|
||||
claude_dir: Arc<RwLock<Option<PathBuf>>>,
|
||||
}
|
||||
|
||||
impl CheckpointState {
|
||||
/// Creates a new CheckpointState instance
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
managers: Arc::new(RwLock::new(HashMap::new())),
|
||||
claude_dir: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the Claude directory path
|
||||
///
|
||||
/// This should be called once during application initialization
|
||||
pub async fn set_claude_dir(&self, claude_dir: PathBuf) {
|
||||
let mut dir = self.claude_dir.write().await;
|
||||
*dir = Some(claude_dir);
|
||||
}
|
||||
|
||||
/// Gets or creates a CheckpointManager for a session
|
||||
///
|
||||
/// If a manager already exists for the session, it returns the existing one.
|
||||
/// Otherwise, it creates a new manager and stores it for future use.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `session_id` - The session identifier
|
||||
/// * `project_id` - The project identifier
|
||||
/// * `project_path` - The path to the project directory
|
||||
///
|
||||
/// # Returns
|
||||
/// An Arc reference to the CheckpointManager for thread-safe sharing
|
||||
pub async fn get_or_create_manager(
|
||||
&self,
|
||||
session_id: String,
|
||||
project_id: String,
|
||||
project_path: PathBuf,
|
||||
) -> Result<Arc<CheckpointManager>> {
|
||||
let mut managers = self.managers.write().await;
|
||||
|
||||
// Check if manager already exists
|
||||
if let Some(manager) = managers.get(&session_id) {
|
||||
return Ok(Arc::clone(manager));
|
||||
}
|
||||
|
||||
// Get Claude directory
|
||||
let claude_dir = {
|
||||
let dir = self.claude_dir.read().await;
|
||||
dir.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Claude directory not set"))?
|
||||
.clone()
|
||||
};
|
||||
|
||||
// Create new manager
|
||||
let manager = CheckpointManager::new(
|
||||
project_id,
|
||||
session_id.clone(),
|
||||
project_path,
|
||||
claude_dir,
|
||||
).await?;
|
||||
|
||||
let manager_arc = Arc::new(manager);
|
||||
managers.insert(session_id, Arc::clone(&manager_arc));
|
||||
|
||||
Ok(manager_arc)
|
||||
}
|
||||
|
||||
/// Gets an existing CheckpointManager for a session
|
||||
///
|
||||
/// Returns None if no manager exists for the session
|
||||
pub async fn get_manager(&self, session_id: &str) -> Option<Arc<CheckpointManager>> {
|
||||
let managers = self.managers.read().await;
|
||||
managers.get(session_id).map(Arc::clone)
|
||||
}
|
||||
|
||||
/// Removes a CheckpointManager for a session
|
||||
///
|
||||
/// This should be called when a session ends to free resources
|
||||
pub async fn remove_manager(&self, session_id: &str) -> Option<Arc<CheckpointManager>> {
|
||||
let mut managers = self.managers.write().await;
|
||||
managers.remove(session_id)
|
||||
}
|
||||
|
||||
/// Clears all managers
|
||||
///
|
||||
/// This is useful for cleanup during application shutdown
|
||||
pub async fn clear_all(&self) {
|
||||
let mut managers = self.managers.write().await;
|
||||
managers.clear();
|
||||
}
|
||||
|
||||
/// Gets the number of active managers
|
||||
pub async fn active_count(&self) -> usize {
|
||||
let managers = self.managers.read().await;
|
||||
managers.len()
|
||||
}
|
||||
|
||||
/// Lists all active session IDs
|
||||
pub async fn list_active_sessions(&self) -> Vec<String> {
|
||||
let managers = self.managers.read().await;
|
||||
managers.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Checks if a session has an active manager
|
||||
pub async fn has_active_manager(&self, session_id: &str) -> bool {
|
||||
self.get_manager(session_id).await.is_some()
|
||||
}
|
||||
|
||||
/// Clears all managers and returns the count that were cleared
|
||||
pub async fn clear_all_and_count(&self) -> usize {
|
||||
let count = self.active_count().await;
|
||||
self.clear_all().await;
|
||||
count
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_checkpoint_state_lifecycle() {
|
||||
let state = CheckpointState::new();
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let claude_dir = temp_dir.path().to_path_buf();
|
||||
|
||||
// Set Claude directory
|
||||
state.set_claude_dir(claude_dir.clone()).await;
|
||||
|
||||
// Create a manager
|
||||
let session_id = "test-session-123".to_string();
|
||||
let project_id = "test-project".to_string();
|
||||
let project_path = temp_dir.path().join("project");
|
||||
std::fs::create_dir_all(&project_path).unwrap();
|
||||
|
||||
let manager1 = state.get_or_create_manager(
|
||||
session_id.clone(),
|
||||
project_id.clone(),
|
||||
project_path.clone(),
|
||||
).await.unwrap();
|
||||
|
||||
// Getting the same session should return the same manager
|
||||
let manager2 = state.get_or_create_manager(
|
||||
session_id.clone(),
|
||||
project_id.clone(),
|
||||
project_path.clone(),
|
||||
).await.unwrap();
|
||||
|
||||
assert!(Arc::ptr_eq(&manager1, &manager2));
|
||||
assert_eq!(state.active_count().await, 1);
|
||||
|
||||
// Remove the manager
|
||||
let removed = state.remove_manager(&session_id).await;
|
||||
assert!(removed.is_some());
|
||||
assert_eq!(state.active_count().await, 0);
|
||||
|
||||
// Getting after removal should create a new one
|
||||
let manager3 = state.get_or_create_manager(
|
||||
session_id.clone(),
|
||||
project_id,
|
||||
project_path,
|
||||
).await.unwrap();
|
||||
|
||||
assert!(!Arc::ptr_eq(&manager1, &manager3));
|
||||
}
|
||||
}
|
474
src-tauri/src/checkpoint/storage.rs
Normal file
474
src-tauri/src/checkpoint/storage.rs
Normal file
@@ -0,0 +1,474 @@
|
||||
use anyhow::{Context, Result};
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use sha2::{Sha256, Digest};
|
||||
use zstd::stream::{encode_all, decode_all};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{
|
||||
Checkpoint, FileSnapshot, SessionTimeline,
|
||||
TimelineNode, CheckpointPaths, CheckpointResult
|
||||
};
|
||||
|
||||
/// Manages checkpoint storage operations
|
||||
pub struct CheckpointStorage {
|
||||
pub claude_dir: PathBuf,
|
||||
compression_level: i32,
|
||||
}
|
||||
|
||||
impl CheckpointStorage {
|
||||
/// Create a new checkpoint storage instance
|
||||
pub fn new(claude_dir: PathBuf) -> Self {
|
||||
Self {
|
||||
claude_dir,
|
||||
compression_level: 3, // Default zstd compression level
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize checkpoint storage for a session
|
||||
pub fn init_storage(&self, project_id: &str, session_id: &str) -> Result<()> {
|
||||
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
||||
|
||||
// Create directory structure
|
||||
fs::create_dir_all(&paths.checkpoints_dir)
|
||||
.context("Failed to create checkpoints directory")?;
|
||||
fs::create_dir_all(&paths.files_dir)
|
||||
.context("Failed to create files directory")?;
|
||||
|
||||
// Initialize empty timeline if it doesn't exist
|
||||
if !paths.timeline_file.exists() {
|
||||
let timeline = SessionTimeline::new(session_id.to_string());
|
||||
self.save_timeline(&paths.timeline_file, &timeline)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Save a checkpoint to disk
|
||||
pub fn save_checkpoint(
|
||||
&self,
|
||||
project_id: &str,
|
||||
session_id: &str,
|
||||
checkpoint: &Checkpoint,
|
||||
file_snapshots: Vec<FileSnapshot>,
|
||||
messages: &str, // JSONL content up to checkpoint
|
||||
) -> Result<CheckpointResult> {
|
||||
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
||||
let checkpoint_dir = paths.checkpoint_dir(&checkpoint.id);
|
||||
|
||||
// Create checkpoint directory
|
||||
fs::create_dir_all(&checkpoint_dir)
|
||||
.context("Failed to create checkpoint directory")?;
|
||||
|
||||
// Save checkpoint metadata
|
||||
let metadata_path = paths.checkpoint_metadata_file(&checkpoint.id);
|
||||
let metadata_json = serde_json::to_string_pretty(checkpoint)
|
||||
.context("Failed to serialize checkpoint metadata")?;
|
||||
fs::write(&metadata_path, metadata_json)
|
||||
.context("Failed to write checkpoint metadata")?;
|
||||
|
||||
// Save messages (compressed)
|
||||
let messages_path = paths.checkpoint_messages_file(&checkpoint.id);
|
||||
let compressed_messages = encode_all(messages.as_bytes(), self.compression_level)
|
||||
.context("Failed to compress messages")?;
|
||||
fs::write(&messages_path, compressed_messages)
|
||||
.context("Failed to write compressed messages")?;
|
||||
|
||||
// Save file snapshots
|
||||
let mut warnings = Vec::new();
|
||||
let mut files_processed = 0;
|
||||
|
||||
for snapshot in &file_snapshots {
|
||||
match self.save_file_snapshot(&paths, snapshot) {
|
||||
Ok(_) => files_processed += 1,
|
||||
Err(e) => warnings.push(format!("Failed to save {}: {}",
|
||||
snapshot.file_path.display(), e)),
|
||||
}
|
||||
}
|
||||
|
||||
// Update timeline
|
||||
self.update_timeline_with_checkpoint(
|
||||
&paths.timeline_file,
|
||||
checkpoint,
|
||||
&file_snapshots
|
||||
)?;
|
||||
|
||||
Ok(CheckpointResult {
|
||||
checkpoint: checkpoint.clone(),
|
||||
files_processed,
|
||||
warnings,
|
||||
})
|
||||
}
|
||||
|
||||
/// Save a single file snapshot
|
||||
fn save_file_snapshot(&self, paths: &CheckpointPaths, snapshot: &FileSnapshot) -> Result<()> {
|
||||
// Use content-addressable storage: store files by their hash
|
||||
// This prevents duplication of identical file content across checkpoints
|
||||
let content_pool_dir = paths.files_dir.join("content_pool");
|
||||
fs::create_dir_all(&content_pool_dir)
|
||||
.context("Failed to create content pool directory")?;
|
||||
|
||||
// Store the actual content in the content pool
|
||||
let content_file = content_pool_dir.join(&snapshot.hash);
|
||||
|
||||
// Only write the content if it doesn't already exist
|
||||
if !content_file.exists() {
|
||||
// Compress and save file content
|
||||
let compressed_content = encode_all(snapshot.content.as_bytes(), self.compression_level)
|
||||
.context("Failed to compress file content")?;
|
||||
fs::write(&content_file, compressed_content)
|
||||
.context("Failed to write file content to pool")?;
|
||||
}
|
||||
|
||||
// Create a reference in the checkpoint-specific directory
|
||||
let checkpoint_refs_dir = paths.files_dir.join("refs").join(&snapshot.checkpoint_id);
|
||||
fs::create_dir_all(&checkpoint_refs_dir)
|
||||
.context("Failed to create checkpoint refs directory")?;
|
||||
|
||||
// Save file metadata with reference to content
|
||||
let ref_metadata = serde_json::json!({
|
||||
"path": snapshot.file_path,
|
||||
"hash": snapshot.hash,
|
||||
"is_deleted": snapshot.is_deleted,
|
||||
"permissions": snapshot.permissions,
|
||||
"size": snapshot.size,
|
||||
});
|
||||
|
||||
// Use a sanitized filename for the reference
|
||||
let safe_filename = snapshot.file_path
|
||||
.to_string_lossy()
|
||||
.replace('/', "_")
|
||||
.replace('\\', "_");
|
||||
let ref_path = checkpoint_refs_dir.join(format!("{}.json", safe_filename));
|
||||
|
||||
fs::write(&ref_path, serde_json::to_string_pretty(&ref_metadata)?)
|
||||
.context("Failed to write file reference")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a checkpoint from disk
|
||||
pub fn load_checkpoint(
|
||||
&self,
|
||||
project_id: &str,
|
||||
session_id: &str,
|
||||
checkpoint_id: &str,
|
||||
) -> Result<(Checkpoint, Vec<FileSnapshot>, String)> {
|
||||
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
||||
|
||||
// Load checkpoint metadata
|
||||
let metadata_path = paths.checkpoint_metadata_file(checkpoint_id);
|
||||
let metadata_json = fs::read_to_string(&metadata_path)
|
||||
.context("Failed to read checkpoint metadata")?;
|
||||
let checkpoint: Checkpoint = serde_json::from_str(&metadata_json)
|
||||
.context("Failed to parse checkpoint metadata")?;
|
||||
|
||||
// Load messages
|
||||
let messages_path = paths.checkpoint_messages_file(checkpoint_id);
|
||||
let compressed_messages = fs::read(&messages_path)
|
||||
.context("Failed to read compressed messages")?;
|
||||
let messages = String::from_utf8(decode_all(&compressed_messages[..])
|
||||
.context("Failed to decompress messages")?)
|
||||
.context("Invalid UTF-8 in messages")?;
|
||||
|
||||
// Load file snapshots
|
||||
let file_snapshots = self.load_file_snapshots(&paths, checkpoint_id)?;
|
||||
|
||||
Ok((checkpoint, file_snapshots, messages))
|
||||
}
|
||||
|
||||
/// Load all file snapshots for a checkpoint
|
||||
fn load_file_snapshots(
|
||||
&self,
|
||||
paths: &CheckpointPaths,
|
||||
checkpoint_id: &str
|
||||
) -> Result<Vec<FileSnapshot>> {
|
||||
let refs_dir = paths.files_dir.join("refs").join(checkpoint_id);
|
||||
if !refs_dir.exists() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let content_pool_dir = paths.files_dir.join("content_pool");
|
||||
let mut snapshots = Vec::new();
|
||||
|
||||
// Read all reference files
|
||||
for entry in fs::read_dir(&refs_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
// Skip non-JSON files
|
||||
if path.extension().and_then(|e| e.to_str()) != Some("json") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Load reference metadata
|
||||
let ref_json = fs::read_to_string(&path)
|
||||
.context("Failed to read file reference")?;
|
||||
let ref_metadata: serde_json::Value = serde_json::from_str(&ref_json)
|
||||
.context("Failed to parse file reference")?;
|
||||
|
||||
let hash = ref_metadata["hash"].as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing hash in reference"))?;
|
||||
|
||||
// Load content from pool
|
||||
let content_file = content_pool_dir.join(hash);
|
||||
let content = if content_file.exists() {
|
||||
let compressed_content = fs::read(&content_file)
|
||||
.context("Failed to read file content from pool")?;
|
||||
String::from_utf8(decode_all(&compressed_content[..])
|
||||
.context("Failed to decompress file content")?)
|
||||
.context("Invalid UTF-8 in file content")?
|
||||
} else {
|
||||
// Handle missing content gracefully
|
||||
log::warn!("Content file missing for hash: {}", hash);
|
||||
String::new()
|
||||
};
|
||||
|
||||
snapshots.push(FileSnapshot {
|
||||
checkpoint_id: checkpoint_id.to_string(),
|
||||
file_path: PathBuf::from(ref_metadata["path"].as_str().unwrap_or("")),
|
||||
content,
|
||||
hash: hash.to_string(),
|
||||
is_deleted: ref_metadata["is_deleted"].as_bool().unwrap_or(false),
|
||||
permissions: ref_metadata["permissions"].as_u64().map(|p| p as u32),
|
||||
size: ref_metadata["size"].as_u64().unwrap_or(0),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(snapshots)
|
||||
}
|
||||
|
||||
/// Save timeline to disk
|
||||
pub fn save_timeline(&self, timeline_path: &Path, timeline: &SessionTimeline) -> Result<()> {
|
||||
let timeline_json = serde_json::to_string_pretty(timeline)
|
||||
.context("Failed to serialize timeline")?;
|
||||
fs::write(timeline_path, timeline_json)
|
||||
.context("Failed to write timeline")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load timeline from disk
|
||||
pub fn load_timeline(&self, timeline_path: &Path) -> Result<SessionTimeline> {
|
||||
let timeline_json = fs::read_to_string(timeline_path)
|
||||
.context("Failed to read timeline")?;
|
||||
let timeline: SessionTimeline = serde_json::from_str(&timeline_json)
|
||||
.context("Failed to parse timeline")?;
|
||||
Ok(timeline)
|
||||
}
|
||||
|
||||
/// Update timeline with a new checkpoint
|
||||
fn update_timeline_with_checkpoint(
|
||||
&self,
|
||||
timeline_path: &Path,
|
||||
checkpoint: &Checkpoint,
|
||||
file_snapshots: &[FileSnapshot],
|
||||
) -> Result<()> {
|
||||
let mut timeline = self.load_timeline(timeline_path)?;
|
||||
|
||||
let new_node = TimelineNode {
|
||||
checkpoint: checkpoint.clone(),
|
||||
children: Vec::new(),
|
||||
file_snapshot_ids: file_snapshots.iter()
|
||||
.map(|s| s.hash.clone())
|
||||
.collect(),
|
||||
};
|
||||
|
||||
// If this is the first checkpoint
|
||||
if timeline.root_node.is_none() {
|
||||
timeline.root_node = Some(new_node);
|
||||
timeline.current_checkpoint_id = Some(checkpoint.id.clone());
|
||||
} else if let Some(parent_id) = &checkpoint.parent_checkpoint_id {
|
||||
// Check if parent exists before modifying
|
||||
let parent_exists = timeline.find_checkpoint(parent_id).is_some();
|
||||
|
||||
if parent_exists {
|
||||
if let Some(root) = &mut timeline.root_node {
|
||||
Self::add_child_to_node(root, parent_id, new_node)?;
|
||||
timeline.current_checkpoint_id = Some(checkpoint.id.clone());
|
||||
}
|
||||
} else {
|
||||
anyhow::bail!("Parent checkpoint not found: {}", parent_id);
|
||||
}
|
||||
}
|
||||
|
||||
timeline.total_checkpoints += 1;
|
||||
self.save_timeline(timeline_path, &timeline)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Recursively add a child node to the timeline tree
|
||||
fn add_child_to_node(
|
||||
node: &mut TimelineNode,
|
||||
parent_id: &str,
|
||||
child: TimelineNode
|
||||
) -> Result<()> {
|
||||
if node.checkpoint.id == parent_id {
|
||||
node.children.push(child);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for child_node in &mut node.children {
|
||||
if Self::add_child_to_node(child_node, parent_id, child.clone()).is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!("Parent checkpoint not found: {}", parent_id)
|
||||
}
|
||||
|
||||
/// Calculate hash of file content
|
||||
pub fn calculate_file_hash(content: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(content.as_bytes());
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
/// Generate a new checkpoint ID
|
||||
pub fn generate_checkpoint_id() -> String {
|
||||
Uuid::new_v4().to_string()
|
||||
}
|
||||
|
||||
/// Estimate storage size for a checkpoint
|
||||
pub fn estimate_checkpoint_size(
|
||||
messages: &str,
|
||||
file_snapshots: &[FileSnapshot],
|
||||
) -> u64 {
|
||||
let messages_size = messages.len() as u64;
|
||||
let files_size: u64 = file_snapshots.iter()
|
||||
.map(|s| s.content.len() as u64)
|
||||
.sum();
|
||||
|
||||
// Estimate compressed size (typically 20-30% of original for text)
|
||||
(messages_size + files_size) / 4
|
||||
}
|
||||
|
||||
/// Clean up old checkpoints based on retention policy
|
||||
pub fn cleanup_old_checkpoints(
|
||||
&self,
|
||||
project_id: &str,
|
||||
session_id: &str,
|
||||
keep_count: usize,
|
||||
) -> Result<usize> {
|
||||
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
||||
let timeline = self.load_timeline(&paths.timeline_file)?;
|
||||
|
||||
// Collect all checkpoint IDs in chronological order
|
||||
let mut all_checkpoints = Vec::new();
|
||||
if let Some(root) = &timeline.root_node {
|
||||
Self::collect_checkpoints(root, &mut all_checkpoints);
|
||||
}
|
||||
|
||||
// Sort by timestamp (oldest first)
|
||||
all_checkpoints.sort_by(|a, b| a.timestamp.cmp(&b.timestamp));
|
||||
|
||||
// Keep only the most recent checkpoints
|
||||
let to_remove = all_checkpoints.len().saturating_sub(keep_count);
|
||||
let mut removed_count = 0;
|
||||
|
||||
for checkpoint in all_checkpoints.into_iter().take(to_remove) {
|
||||
if self.remove_checkpoint(&paths, &checkpoint.id).is_ok() {
|
||||
removed_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Run garbage collection to clean up orphaned content
|
||||
if removed_count > 0 {
|
||||
match self.garbage_collect_content(project_id, session_id) {
|
||||
Ok(gc_count) => {
|
||||
log::info!("Garbage collected {} orphaned content files", gc_count);
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Failed to garbage collect content: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(removed_count)
|
||||
}
|
||||
|
||||
/// Collect all checkpoints from the tree in order
|
||||
fn collect_checkpoints(node: &TimelineNode, checkpoints: &mut Vec<Checkpoint>) {
|
||||
checkpoints.push(node.checkpoint.clone());
|
||||
for child in &node.children {
|
||||
Self::collect_checkpoints(child, checkpoints);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a checkpoint and its associated files
|
||||
fn remove_checkpoint(&self, paths: &CheckpointPaths, checkpoint_id: &str) -> Result<()> {
|
||||
// Remove checkpoint metadata directory
|
||||
let checkpoint_dir = paths.checkpoint_dir(checkpoint_id);
|
||||
if checkpoint_dir.exists() {
|
||||
fs::remove_dir_all(&checkpoint_dir)
|
||||
.context("Failed to remove checkpoint directory")?;
|
||||
}
|
||||
|
||||
// Remove file references for this checkpoint
|
||||
let refs_dir = paths.files_dir.join("refs").join(checkpoint_id);
|
||||
if refs_dir.exists() {
|
||||
fs::remove_dir_all(&refs_dir)
|
||||
.context("Failed to remove file references")?;
|
||||
}
|
||||
|
||||
// Note: We don't remove content from the pool here as it might be
|
||||
// referenced by other checkpoints. Use garbage_collect_content() for that.
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Garbage collect unreferenced content from the content pool
|
||||
pub fn garbage_collect_content(
|
||||
&self,
|
||||
project_id: &str,
|
||||
session_id: &str,
|
||||
) -> Result<usize> {
|
||||
let paths = CheckpointPaths::new(&self.claude_dir, project_id, session_id);
|
||||
let content_pool_dir = paths.files_dir.join("content_pool");
|
||||
let refs_dir = paths.files_dir.join("refs");
|
||||
|
||||
if !content_pool_dir.exists() {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
// Collect all referenced hashes
|
||||
let mut referenced_hashes = std::collections::HashSet::new();
|
||||
|
||||
if refs_dir.exists() {
|
||||
for checkpoint_entry in fs::read_dir(&refs_dir)? {
|
||||
let checkpoint_dir = checkpoint_entry?.path();
|
||||
if checkpoint_dir.is_dir() {
|
||||
for ref_entry in fs::read_dir(&checkpoint_dir)? {
|
||||
let ref_path = ref_entry?.path();
|
||||
if ref_path.extension().and_then(|e| e.to_str()) == Some("json") {
|
||||
if let Ok(ref_json) = fs::read_to_string(&ref_path) {
|
||||
if let Ok(ref_metadata) = serde_json::from_str::<serde_json::Value>(&ref_json) {
|
||||
if let Some(hash) = ref_metadata["hash"].as_str() {
|
||||
referenced_hashes.insert(hash.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove unreferenced content
|
||||
let mut removed_count = 0;
|
||||
for entry in fs::read_dir(&content_pool_dir)? {
|
||||
let content_file = entry?.path();
|
||||
if content_file.is_file() {
|
||||
if let Some(hash) = content_file.file_name().and_then(|n| n.to_str()) {
|
||||
if !referenced_hashes.contains(hash) {
|
||||
if fs::remove_file(&content_file).is_ok() {
|
||||
removed_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(removed_count)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user