增加默认模型映射

This commit is contained in:
2025-10-11 01:24:18 +08:00
parent 27bc42d872
commit fed1e63c34
2 changed files with 120 additions and 8 deletions

View File

@@ -233,7 +233,7 @@ pub fn init_database(app: &AppHandle) -> SqliteResult<Connection> {
icon TEXT NOT NULL,
system_prompt TEXT NOT NULL,
default_task TEXT,
model TEXT NOT NULL DEFAULT 'sonnet',
model TEXT NOT NULL DEFAULT 'claude-sonnet-4-20250514',
enable_file_read BOOLEAN NOT NULL DEFAULT 1,
enable_file_write BOOLEAN NOT NULL DEFAULT 1,
enable_network BOOLEAN NOT NULL DEFAULT 0,
@@ -247,7 +247,7 @@ pub fn init_database(app: &AppHandle) -> SqliteResult<Connection> {
// Add columns to existing table if they don't exist
let _ = conn.execute("ALTER TABLE agents ADD COLUMN default_task TEXT", []);
let _ = conn.execute(
"ALTER TABLE agents ADD COLUMN model TEXT DEFAULT 'sonnet'",
"ALTER TABLE agents ADD COLUMN model TEXT DEFAULT 'claude-sonnet-4-20250514'",
[],
);
let _ = conn.execute("ALTER TABLE agents ADD COLUMN hooks TEXT", []);
@@ -344,6 +344,38 @@ pub fn init_database(app: &AppHandle) -> SqliteResult<Connection> {
[],
)?;
// Create model mappings table for configurable model aliases
conn.execute(
"CREATE TABLE IF NOT EXISTS model_mappings (
alias TEXT PRIMARY KEY,
model_name TEXT NOT NULL,
updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
)",
[],
)?;
// Initialize default model mappings if empty
let count: i64 = conn.query_row(
"SELECT COUNT(*) FROM model_mappings",
[],
|row| row.get(0),
).unwrap_or(0);
if count == 0 {
conn.execute(
"INSERT INTO model_mappings (alias, model_name) VALUES ('sonnet', 'claude-sonnet-4-20250514')",
[],
)?;
conn.execute(
"INSERT INTO model_mappings (alias, model_name) VALUES ('opus', 'claude-opus-4-1-20250805')",
[],
)?;
conn.execute(
"INSERT INTO model_mappings (alias, model_name) VALUES ('haiku', 'claude-haiku-4-20250410')",
[],
)?;
}
Ok(conn)
}
@@ -397,7 +429,7 @@ pub async fn create_agent(
hooks: Option<String>,
) -> Result<Agent, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
let model = model.unwrap_or_else(|| "sonnet".to_string());
let model = model.unwrap_or_else(|| "claude-sonnet-4-20250514".to_string());
let enable_file_read = enable_file_read.unwrap_or(true);
let enable_file_write = enable_file_write.unwrap_or(true);
let enable_network = enable_network.unwrap_or(false);
@@ -453,7 +485,7 @@ pub async fn update_agent(
hooks: Option<String>,
) -> Result<Agent, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
let model = model.unwrap_or_else(|| "sonnet".to_string());
let model = model.unwrap_or_else(|| "claude-sonnet-4-20250514".to_string());
// Build dynamic query based on provided parameters
let mut query =
@@ -549,7 +581,7 @@ pub async fn get_agent(db: State<'_, AgentDb>, id: i64) -> Result<Agent, String>
icon: row.get(2)?,
system_prompt: row.get(3)?,
default_task: row.get(4)?,
model: row.get::<_, String>(5).unwrap_or_else(|_| "sonnet".to_string()),
model: row.get::<_, String>(5).unwrap_or_else(|_| "claude-sonnet-4-20250514".to_string()),
enable_file_read: row.get::<_, bool>(6).unwrap_or(true),
enable_file_write: row.get::<_, bool>(7).unwrap_or(true),
enable_network: row.get::<_, bool>(8).unwrap_or(false),
@@ -695,6 +727,15 @@ pub async fn execute_agent(
let agent = get_agent(db.clone(), agent_id).await?;
let execution_model = model.unwrap_or(agent.model.clone());
// Resolve model alias to actual model name using mappings
let resolved_model = get_model_by_alias(&db, &execution_model)
.unwrap_or_else(|_| {
warn!("Model alias '{}' not found, using as-is", execution_model);
execution_model.clone()
});
info!("Resolved model: {} -> {}", execution_model, resolved_model);
// Create .claude/settings.json with agent hooks if it doesn't exist
if let Some(hooks_json) = &agent.hooks {
let claude_dir = std::path::Path::new(&project_path).join(".claude");
@@ -759,7 +800,7 @@ pub async fn execute_agent(
"--system-prompt".to_string(),
agent.system_prompt.clone(),
"--model".to_string(),
execution_model.clone(),
resolved_model.clone(), // Use resolved model name
"--output-format".to_string(),
"stream-json".to_string(),
"--verbose".to_string(),
@@ -768,9 +809,9 @@ pub async fn execute_agent(
// Execute based on whether we should use sidecar or system binary
if should_use_sidecar(&claude_path) {
spawn_agent_sidecar(app, run_id, agent_id, agent.name.clone(), args, project_path, task, execution_model, db, registry).await
spawn_agent_sidecar(app, run_id, agent_id, agent.name.clone(), args, project_path, task, resolved_model, db, registry).await
} else {
spawn_agent_system(app, run_id, agent_id, agent.name.clone(), claude_path, args, project_path, task, execution_model, db, registry).await
spawn_agent_system(app, run_id, agent_id, agent.name.clone(), claude_path, args, project_path, task, resolved_model, db, registry).await
}
}
@@ -2176,3 +2217,71 @@ pub async fn load_agent_session_history(
Err(format!("Session file not found: {}", session_id))
}
}
/// Model mapping structure
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ModelMapping {
pub alias: String,
pub model_name: String,
pub updated_at: String,
}
/// Get all model mappings
#[tauri::command]
pub async fn get_model_mappings(db: State<'_, AgentDb>) -> Result<Vec<ModelMapping>, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
let mut stmt = conn
.prepare("SELECT alias, model_name, updated_at FROM model_mappings ORDER BY alias")
.map_err(|e| e.to_string())?;
let mappings = stmt
.query_map([], |row| {
Ok(ModelMapping {
alias: row.get(0)?,
model_name: row.get(1)?,
updated_at: row.get(2)?,
})
})
.map_err(|e| e.to_string())?
.collect::<Result<Vec<_>, _>>()
.map_err(|e| e.to_string())?;
Ok(mappings)
}
/// Update a model mapping
#[tauri::command]
pub async fn update_model_mapping(
db: State<'_, AgentDb>,
alias: String,
model_name: String,
) -> Result<(), String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
conn.execute(
"INSERT OR REPLACE INTO model_mappings (alias, model_name, updated_at) VALUES (?1, ?2, CURRENT_TIMESTAMP)",
params![alias, model_name],
)
.map_err(|e| e.to_string())?;
Ok(())
}
/// Get model name by alias (with fallback)
fn get_model_by_alias(db: &AgentDb, alias: &str) -> Result<String, String> {
let conn = db.0.lock().map_err(|e| e.to_string())?;
// If alias looks like a full model name (contains 'claude-'), return it directly
if alias.starts_with("claude-") {
return Ok(alias.to_string());
}
// Otherwise, look up the mapping
conn.query_row(
"SELECT model_name FROM model_mappings WHERE alias = ?1",
params![alias],
|row| row.get(0),
)
.map_err(|_| format!("Model alias '{}' not found in mappings", alias))
}

View File

@@ -18,6 +18,7 @@ use commands::agents::{
import_agent_from_file, import_agent_from_github, init_database, kill_agent_session,
list_agent_runs, list_agent_runs_with_metrics, list_agents, list_claude_installations,
list_running_sessions, load_agent_session_history, set_claude_binary_path, stream_session_output, update_agent, AgentDb,
get_model_mappings, update_model_mapping,
};
use commands::claude::{
cancel_claude_execution, check_auto_checkpoint, check_claude_version, cleanup_old_checkpoints,
@@ -316,6 +317,8 @@ fn main() {
fetch_github_agents,
fetch_github_agent_content,
import_agent_from_github,
get_model_mappings,
update_model_mapping,
// Usage & Analytics
get_usage_stats,