From 3dc741fd6b950f52185b85d08c3fbe0452f63cab Mon Sep 17 00:00:00 2001 From: Vivek R <123vivekr@gmail.com> Date: Mon, 23 Jun 2025 00:45:58 +0530 Subject: [PATCH] feat: add ability to stop Claude execution mid-way using loading icon as cancel button --- src-tauri/src/commands/claude.rs | 96 +++++++++++++-- src-tauri/src/main.rs | 6 +- src/components/ClaudeCodeSession.tsx | 86 +++++++++++-- src/components/FloatingPromptInput.tsx | 159 +++++++++++++------------ src/lib/api.ts | 7 ++ 5 files changed, 257 insertions(+), 97 deletions(-) diff --git a/src-tauri/src/commands/claude.rs b/src-tauri/src/commands/claude.rs index 638a368..4aa3e66 100644 --- a/src-tauri/src/commands/claude.rs +++ b/src-tauri/src/commands/claude.rs @@ -6,10 +6,25 @@ use std::time::SystemTime; use std::io::{BufRead, BufReader}; use std::process::Stdio; use tauri::{AppHandle, Emitter, Manager}; -use tokio::process::Command; +use tokio::process::{Command, Child}; +use tokio::sync::Mutex; +use std::sync::Arc; use crate::process::ProcessHandle; use crate::checkpoint::{CheckpointResult, CheckpointDiff, SessionTimeline, Checkpoint}; +/// Global state to track current Claude process +pub struct ClaudeProcessState { + pub current_process: Arc>>, +} + +impl Default for ClaudeProcessState { + fn default() -> Self { + Self { + current_process: Arc::new(Mutex::new(None)), + } + } +} + /// Represents a project in the ~/.claude/projects directory #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Project { @@ -925,6 +940,41 @@ pub async fn resume_claude_code( spawn_claude_process(app, cmd).await } +/// Cancel the currently running Claude Code execution +#[tauri::command] +pub async fn cancel_claude_execution(app: AppHandle) -> Result<(), String> { + log::info!("Cancelling Claude Code execution"); + + let claude_state = app.state::(); + let mut current_process = claude_state.current_process.lock().await; + + if let Some(mut child) = current_process.take() { + // Try to get the PID before killing + let pid = child.id(); + log::info!("Attempting to kill Claude process with PID: {:?}", pid); + + // Kill the process + match child.kill().await { + Ok(_) => { + log::info!("Successfully killed Claude process"); + // Emit cancellation event + let _ = app.emit("claude-cancelled", true); + // Also emit complete with false to indicate failure + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + let _ = app.emit("claude-complete", false); + Ok(()) + } + Err(e) => { + log::error!("Failed to kill Claude process: {}", e); + Err(format!("Failed to kill Claude process: {}", e)) + } + } + } else { + log::warn!("No active Claude process to cancel"); + Ok(()) + } +} + /// Helper function to check if sandboxing should be used based on settings fn should_use_sandbox(app: &AppHandle) -> Result { // First check if sandboxing is even available on this platform @@ -1097,10 +1147,21 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St let stdout = child.stdout.take().ok_or("Failed to get stdout")?; let stderr = child.stderr.take().ok_or("Failed to get stderr")?; + // Get the child PID for logging + let pid = child.id(); + log::info!("Spawned Claude process with PID: {:?}", pid); + // Create readers let stdout_reader = BufReader::new(stdout); let stderr_reader = BufReader::new(stderr); + // Store the child process in the global state + let claude_state = app.state::(); + { + let mut current_process = claude_state.current_process.lock().await; + *current_process = Some(child); + } + // Spawn tasks to read stdout and stderr let app_handle = app.clone(); let stdout_task = tokio::spawn(async move { @@ -1123,24 +1184,33 @@ async fn spawn_claude_process(app: AppHandle, mut cmd: Command) -> Result<(), St }); // Wait for the process to complete + let app_handle_wait = app.clone(); + let claude_state_wait = claude_state.current_process.clone(); tokio::spawn(async move { let _ = stdout_task.await; let _ = stderr_task.await; - match child.wait().await { - Ok(status) => { - log::info!("Claude process exited with status: {}", status); - // Add a small delay to ensure all messages are processed - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let _ = app.emit("claude-complete", status.success()); - } - Err(e) => { - log::error!("Failed to wait for Claude process: {}", e); - // Add a small delay to ensure all messages are processed - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let _ = app.emit("claude-complete", false); + // Get the child from the state to wait on it + let mut current_process = claude_state_wait.lock().await; + if let Some(mut child) = current_process.take() { + match child.wait().await { + Ok(status) => { + log::info!("Claude process exited with status: {}", status); + // Add a small delay to ensure all messages are processed + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + let _ = app_handle_wait.emit("claude-complete", status.success()); + } + Err(e) => { + log::error!("Failed to wait for Claude process: {}", e); + // Add a small delay to ensure all messages are processed + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + let _ = app_handle_wait.emit("claude-complete", false); + } } } + + // Clear the process from state + *current_process = None; }); Ok(()) diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 0f7b3c3..6c61732 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -17,7 +17,7 @@ use commands::claude::{ get_session_timeline, update_checkpoint_settings, get_checkpoint_diff, track_checkpoint_message, track_session_messages, check_auto_checkpoint, cleanup_old_checkpoints, get_checkpoint_settings, clear_checkpoint_manager, get_checkpoint_state_stats, - get_recently_modified_files, + get_recently_modified_files, cancel_claude_execution, ClaudeProcessState, }; use commands::agents::{ init_database, list_agents, create_agent, update_agent, delete_agent, @@ -93,6 +93,9 @@ fn main() { // Initialize process registry app.manage(ProcessRegistryState::default()); + // Initialize Claude process state + app.manage(ClaudeProcessState::default()); + Ok(()) }) .invoke_handler(tauri::generate_handler![ @@ -111,6 +114,7 @@ fn main() { execute_claude_code, continue_claude_code, resume_claude_code, + cancel_claude_execution, list_directory_contents, search_files, create_checkpoint, diff --git a/src/components/ClaudeCodeSession.tsx b/src/components/ClaudeCodeSession.tsx index b5956ba..6f17ba5 100644 --- a/src/components/ClaudeCodeSession.tsx +++ b/src/components/ClaudeCodeSession.tsx @@ -9,7 +9,8 @@ import { ChevronDown, GitBranch, Settings, - Globe + Globe, + Square } from "lucide-react"; import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; @@ -84,6 +85,7 @@ export const ClaudeCodeSession: React.FC = ({ const [showForkDialog, setShowForkDialog] = useState(false); const [forkCheckpointId, setForkCheckpointId] = useState(null); const [forkSessionName, setForkSessionName] = useState(""); + const [isCancelling, setIsCancelling] = useState(false); // New state for preview feature const [showPreview, setShowPreview] = useState(false); @@ -278,6 +280,7 @@ export const ClaudeCodeSession: React.FC = ({ const completeUnlisten = await listen("claude-complete", async (event) => { console.log('[ClaudeCodeSession] Received claude-complete:', event.payload); setIsLoading(false); + setIsCancelling(false); hasActiveSessionRef.current = false; if (!event.payload) { setError("Claude execution failed"); @@ -437,6 +440,40 @@ export const ClaudeCodeSession: React.FC = ({ setTimelineVersion((v) => v + 1); }; + const handleCancelExecution = async () => { + if (!isLoading || isCancelling) return; + + try { + setIsCancelling(true); + + // Cancel the Claude execution + await api.cancelClaudeExecution(); + + // Clean up listeners + unlistenRefs.current.forEach(unlisten => unlisten()); + unlistenRefs.current = []; + + // Add a system message indicating cancellation + const cancelMessage: ClaudeStreamMessage = { + type: "system", + subtype: "cancelled", + result: "Execution cancelled by user", + timestamp: new Date().toISOString() + }; + setMessages(prev => [...prev, cancelMessage]); + + // Reset states + setIsLoading(false); + hasActiveSessionRef.current = false; + setError(null); + } catch (err) { + console.error("Failed to cancel execution:", err); + setError("Failed to cancel execution"); + } finally { + setIsCancelling(false); + } + }; + const handleFork = (checkpointId: string) => { setForkCheckpointId(checkpointId); setForkSessionName(`Fork-${new Date().toISOString().slice(0, 10)}`); @@ -817,6 +854,42 @@ export const ClaudeCodeSession: React.FC = ({ {messagesList} )} + + {isLoading && enhancedMessages.length === 0 && ( +
+
+ + + {session ? "Loading session history..." : "Initializing Claude Code..."} + +
+
+ )} + + + {enhancedMessages.map((message, index) => ( + + + + + + ))} + + + {/* Show loading indicator when processing, even if there are messages */} + {isLoading && enhancedMessages.length > 0 && ( +
+ + + {isCancelling ? "Cancelling..." : "Processing..."} + +
+ )} {/* Floating Prompt Input - Always visible */} @@ -824,6 +897,7 @@ export const ClaudeCodeSession: React.FC = ({ = ({ )} - {/* Preview Prompt Dialog */} - setShowPreviewPrompt(false)} - /> - {/* Fork Dialog */} @@ -912,4 +978,4 @@ export const ClaudeCodeSession: React.FC = ({ )} ); -}; \ No newline at end of file +}; diff --git a/src/components/FloatingPromptInput.tsx b/src/components/FloatingPromptInput.tsx index b803d3b..eeedaf6 100644 --- a/src/components/FloatingPromptInput.tsx +++ b/src/components/FloatingPromptInput.tsx @@ -1,12 +1,13 @@ import React, { useState, useRef, useEffect } from "react"; import { motion, AnimatePresence } from "framer-motion"; -import { - Send, - Maximize2, +import { + Send, + Maximize2, Minimize2, ChevronUp, Sparkles, - Zap + Zap, + Square } from "lucide-react"; import { cn } from "@/lib/utils"; import { Button } from "@/components/ui/button"; @@ -42,6 +43,10 @@ interface FloatingPromptInputProps { * Optional className for styling */ className?: string; + /** + * Callback when cancel is clicked (only during loading) + */ + onCancel?: () => void; } export interface FloatingPromptInputRef { @@ -81,14 +86,18 @@ const MODELS: Model[] = [ * isLoading={false} * /> */ -export const FloatingPromptInput = React.forwardRef(({ - onSend, - isLoading = false, - disabled = false, - defaultModel = "sonnet", - projectPath, - className, -}, ref) => { +const FloatingPromptInputInner = ( + { + onSend, + isLoading = false, + disabled = false, + defaultModel = "sonnet", + projectPath, + className, + onCancel, + }: FloatingPromptInputProps, + ref: React.Ref, +) => { const [prompt, setPrompt] = useState(""); const [selectedModel, setSelectedModel] = useState<"sonnet" | "opus">(defaultModel); const [isExpanded, setIsExpanded] = useState(false); @@ -98,11 +107,11 @@ export const FloatingPromptInput = React.forwardRef([]); const [dragActive, setDragActive] = useState(false); - + const textareaRef = useRef(null); const expandedTextareaRef = useRef(null); const unlistenDragDropRef = useRef<(() => void) | null>(null); - + // Expose a method to add images programmatically React.useImperativeHandle( ref, @@ -113,17 +122,17 @@ export const FloatingPromptInput = React.forwardRef { const target = isExpanded ? expandedTextareaRef.current : textareaRef.current; target?.focus(); target?.setSelectionRange(newPrompt.length, newPrompt.length); }, 0); - + return newPrompt; }); } @@ -144,7 +153,7 @@ export const FloatingPromptInput = React.forwardRef m[0])); const pathsSet = new Set(); // Use Set to ensure uniqueness - + for (const match of matches) { const path = match[1]; console.log('[extractImagePaths] Processing path:', path); @@ -155,7 +164,7 @@ export const FloatingPromptInput = React.forwardRef `@${p}`).join(' '); const newPrompt = currentPrompt + (currentPrompt.endsWith(' ') || currentPrompt === '' ? '' : ' ') + mentionsToAdd + ' '; - + setTimeout(() => { const target = isExpanded ? expandedTextareaRef.current : textareaRef.current; target?.focus(); @@ -260,7 +269,7 @@ export const FloatingPromptInput = React.forwardRef) => { const newValue = e.target.value; const newCursorPosition = e.target.selectionStart || 0; - + // Check if @ was just typed if (projectPath?.trim() && newValue.length > prompt.length && newValue[newCursorPosition - 1] === '@') { console.log('[FloatingPromptInput] @ detected, projectPath:', projectPath); @@ -268,7 +277,7 @@ export const FloatingPromptInput = React.forwardRef= cursorPosition) { // Find the @ position before cursor @@ -283,7 +292,7 @@ export const FloatingPromptInput = React.forwardRef { textarea.focus(); @@ -321,7 +330,7 @@ export const FloatingPromptInput = React.forwardRef { setShowFilePicker(false); setFilePickerQuery(""); @@ -338,7 +347,7 @@ export const FloatingPromptInput = React.forwardRef - + {/* Image previews in expanded mode */} {embeddedImages.length > 0 && ( )} - +