chore: initialize recovered claude workspace
This commit is contained in:
309
src/utils/hooks/AsyncHookRegistry.ts
Normal file
309
src/utils/hooks/AsyncHookRegistry.ts
Normal file
@@ -0,0 +1,309 @@
|
||||
import type {
|
||||
AsyncHookJSONOutput,
|
||||
HookEvent,
|
||||
SyncHookJSONOutput,
|
||||
} from 'src/entrypoints/agentSdkTypes.js'
|
||||
import { logForDebugging } from '../debug.js'
|
||||
import type { ShellCommand } from '../ShellCommand.js'
|
||||
import { invalidateSessionEnvCache } from '../sessionEnvironment.js'
|
||||
import { jsonParse, jsonStringify } from '../slowOperations.js'
|
||||
import { emitHookResponse, startHookProgressInterval } from './hookEvents.js'
|
||||
|
||||
export type PendingAsyncHook = {
|
||||
processId: string
|
||||
hookId: string
|
||||
hookName: string
|
||||
hookEvent: HookEvent | 'StatusLine' | 'FileSuggestion'
|
||||
toolName?: string
|
||||
pluginId?: string
|
||||
startTime: number
|
||||
timeout: number
|
||||
command: string
|
||||
responseAttachmentSent: boolean
|
||||
shellCommand?: ShellCommand
|
||||
stopProgressInterval: () => void
|
||||
}
|
||||
|
||||
// Global registry state
|
||||
const pendingHooks = new Map<string, PendingAsyncHook>()
|
||||
|
||||
export function registerPendingAsyncHook({
|
||||
processId,
|
||||
hookId,
|
||||
asyncResponse,
|
||||
hookName,
|
||||
hookEvent,
|
||||
command,
|
||||
shellCommand,
|
||||
toolName,
|
||||
pluginId,
|
||||
}: {
|
||||
processId: string
|
||||
hookId: string
|
||||
asyncResponse: AsyncHookJSONOutput
|
||||
hookName: string
|
||||
hookEvent: HookEvent | 'StatusLine' | 'FileSuggestion'
|
||||
command: string
|
||||
shellCommand: ShellCommand
|
||||
toolName?: string
|
||||
pluginId?: string
|
||||
}): void {
|
||||
const timeout = asyncResponse.asyncTimeout || 15000 // Default 15s
|
||||
logForDebugging(
|
||||
`Hooks: Registering async hook ${processId} (${hookName}) with timeout ${timeout}ms`,
|
||||
)
|
||||
const stopProgressInterval = startHookProgressInterval({
|
||||
hookId,
|
||||
hookName,
|
||||
hookEvent,
|
||||
getOutput: async () => {
|
||||
const taskOutput = pendingHooks.get(processId)?.shellCommand?.taskOutput
|
||||
if (!taskOutput) {
|
||||
return { stdout: '', stderr: '', output: '' }
|
||||
}
|
||||
const stdout = await taskOutput.getStdout()
|
||||
const stderr = taskOutput.getStderr()
|
||||
return { stdout, stderr, output: stdout + stderr }
|
||||
},
|
||||
})
|
||||
pendingHooks.set(processId, {
|
||||
processId,
|
||||
hookId,
|
||||
hookName,
|
||||
hookEvent,
|
||||
toolName,
|
||||
pluginId,
|
||||
command,
|
||||
startTime: Date.now(),
|
||||
timeout,
|
||||
responseAttachmentSent: false,
|
||||
shellCommand,
|
||||
stopProgressInterval,
|
||||
})
|
||||
}
|
||||
|
||||
export function getPendingAsyncHooks(): PendingAsyncHook[] {
|
||||
return Array.from(pendingHooks.values()).filter(
|
||||
hook => !hook.responseAttachmentSent,
|
||||
)
|
||||
}
|
||||
|
||||
async function finalizeHook(
|
||||
hook: PendingAsyncHook,
|
||||
exitCode: number,
|
||||
outcome: 'success' | 'error' | 'cancelled',
|
||||
): Promise<void> {
|
||||
hook.stopProgressInterval()
|
||||
const taskOutput = hook.shellCommand?.taskOutput
|
||||
const stdout = taskOutput ? await taskOutput.getStdout() : ''
|
||||
const stderr = taskOutput?.getStderr() ?? ''
|
||||
hook.shellCommand?.cleanup()
|
||||
emitHookResponse({
|
||||
hookId: hook.hookId,
|
||||
hookName: hook.hookName,
|
||||
hookEvent: hook.hookEvent,
|
||||
output: stdout + stderr,
|
||||
stdout,
|
||||
stderr,
|
||||
exitCode,
|
||||
outcome,
|
||||
})
|
||||
}
|
||||
|
||||
export async function checkForAsyncHookResponses(): Promise<
|
||||
Array<{
|
||||
processId: string
|
||||
response: SyncHookJSONOutput
|
||||
hookName: string
|
||||
hookEvent: HookEvent | 'StatusLine' | 'FileSuggestion'
|
||||
toolName?: string
|
||||
pluginId?: string
|
||||
stdout: string
|
||||
stderr: string
|
||||
exitCode?: number
|
||||
}>
|
||||
> {
|
||||
const responses: {
|
||||
processId: string
|
||||
response: SyncHookJSONOutput
|
||||
hookName: string
|
||||
hookEvent: HookEvent | 'StatusLine' | 'FileSuggestion'
|
||||
toolName?: string
|
||||
pluginId?: string
|
||||
stdout: string
|
||||
stderr: string
|
||||
exitCode?: number
|
||||
}[] = []
|
||||
|
||||
const pendingCount = pendingHooks.size
|
||||
logForDebugging(`Hooks: Found ${pendingCount} total hooks in registry`)
|
||||
|
||||
// Snapshot hooks before processing — we'll mutate the map after.
|
||||
const hooks = Array.from(pendingHooks.values())
|
||||
|
||||
const settled = await Promise.allSettled(
|
||||
hooks.map(async hook => {
|
||||
const stdout = (await hook.shellCommand?.taskOutput.getStdout()) ?? ''
|
||||
const stderr = hook.shellCommand?.taskOutput.getStderr() ?? ''
|
||||
logForDebugging(
|
||||
`Hooks: Checking hook ${hook.processId} (${hook.hookName}) - attachmentSent: ${hook.responseAttachmentSent}, stdout length: ${stdout.length}`,
|
||||
)
|
||||
|
||||
if (!hook.shellCommand) {
|
||||
logForDebugging(
|
||||
`Hooks: Hook ${hook.processId} has no shell command, removing from registry`,
|
||||
)
|
||||
hook.stopProgressInterval()
|
||||
return { type: 'remove' as const, processId: hook.processId }
|
||||
}
|
||||
|
||||
logForDebugging(`Hooks: Hook shell status ${hook.shellCommand.status}`)
|
||||
|
||||
if (hook.shellCommand.status === 'killed') {
|
||||
logForDebugging(
|
||||
`Hooks: Hook ${hook.processId} is ${hook.shellCommand.status}, removing from registry`,
|
||||
)
|
||||
hook.stopProgressInterval()
|
||||
hook.shellCommand.cleanup()
|
||||
return { type: 'remove' as const, processId: hook.processId }
|
||||
}
|
||||
|
||||
if (hook.shellCommand.status !== 'completed') {
|
||||
return { type: 'skip' as const }
|
||||
}
|
||||
|
||||
if (hook.responseAttachmentSent || !stdout.trim()) {
|
||||
logForDebugging(
|
||||
`Hooks: Skipping hook ${hook.processId} - already delivered/sent or no stdout`,
|
||||
)
|
||||
hook.stopProgressInterval()
|
||||
return { type: 'remove' as const, processId: hook.processId }
|
||||
}
|
||||
|
||||
const lines = stdout.split('\n')
|
||||
logForDebugging(
|
||||
`Hooks: Processing ${lines.length} lines of stdout for ${hook.processId}`,
|
||||
)
|
||||
|
||||
const execResult = await hook.shellCommand.result
|
||||
const exitCode = execResult.code
|
||||
|
||||
let response: SyncHookJSONOutput = {}
|
||||
for (const line of lines) {
|
||||
if (line.trim().startsWith('{')) {
|
||||
logForDebugging(
|
||||
`Hooks: Found JSON line: ${line.trim().substring(0, 100)}...`,
|
||||
)
|
||||
try {
|
||||
const parsed = jsonParse(line.trim())
|
||||
if (!('async' in parsed)) {
|
||||
logForDebugging(
|
||||
`Hooks: Found sync response from ${hook.processId}: ${jsonStringify(parsed)}`,
|
||||
)
|
||||
response = parsed
|
||||
break
|
||||
}
|
||||
} catch {
|
||||
logForDebugging(
|
||||
`Hooks: Failed to parse JSON from ${hook.processId}: ${line.trim()}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hook.responseAttachmentSent = true
|
||||
await finalizeHook(hook, exitCode, exitCode === 0 ? 'success' : 'error')
|
||||
|
||||
return {
|
||||
type: 'response' as const,
|
||||
processId: hook.processId,
|
||||
isSessionStart: hook.hookEvent === 'SessionStart',
|
||||
payload: {
|
||||
processId: hook.processId,
|
||||
response,
|
||||
hookName: hook.hookName,
|
||||
hookEvent: hook.hookEvent,
|
||||
toolName: hook.toolName,
|
||||
pluginId: hook.pluginId,
|
||||
stdout,
|
||||
stderr,
|
||||
exitCode,
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
// allSettled — isolate failures so one throwing callback doesn't orphan
|
||||
// already-applied side effects (responseAttachmentSent, finalizeHook) from others.
|
||||
let sessionStartCompleted = false
|
||||
for (const s of settled) {
|
||||
if (s.status !== 'fulfilled') {
|
||||
logForDebugging(
|
||||
`Hooks: checkForAsyncHookResponses callback rejected: ${s.reason}`,
|
||||
{ level: 'error' },
|
||||
)
|
||||
continue
|
||||
}
|
||||
const r = s.value
|
||||
if (r.type === 'remove') {
|
||||
pendingHooks.delete(r.processId)
|
||||
} else if (r.type === 'response') {
|
||||
responses.push(r.payload)
|
||||
pendingHooks.delete(r.processId)
|
||||
if (r.isSessionStart) sessionStartCompleted = true
|
||||
}
|
||||
}
|
||||
|
||||
if (sessionStartCompleted) {
|
||||
logForDebugging(
|
||||
`Invalidating session env cache after SessionStart hook completed`,
|
||||
)
|
||||
invalidateSessionEnvCache()
|
||||
}
|
||||
|
||||
logForDebugging(
|
||||
`Hooks: checkForNewResponses returning ${responses.length} responses`,
|
||||
)
|
||||
return responses
|
||||
}
|
||||
|
||||
export function removeDeliveredAsyncHooks(processIds: string[]): void {
|
||||
for (const processId of processIds) {
|
||||
const hook = pendingHooks.get(processId)
|
||||
if (hook && hook.responseAttachmentSent) {
|
||||
logForDebugging(`Hooks: Removing delivered hook ${processId}`)
|
||||
hook.stopProgressInterval()
|
||||
pendingHooks.delete(processId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export async function finalizePendingAsyncHooks(): Promise<void> {
|
||||
const hooks = Array.from(pendingHooks.values())
|
||||
await Promise.all(
|
||||
hooks.map(async hook => {
|
||||
if (hook.shellCommand?.status === 'completed') {
|
||||
const result = await hook.shellCommand.result
|
||||
await finalizeHook(
|
||||
hook,
|
||||
result.code,
|
||||
result.code === 0 ? 'success' : 'error',
|
||||
)
|
||||
} else {
|
||||
if (hook.shellCommand && hook.shellCommand.status !== 'killed') {
|
||||
hook.shellCommand.kill()
|
||||
}
|
||||
await finalizeHook(hook, 1, 'cancelled')
|
||||
}
|
||||
}),
|
||||
)
|
||||
pendingHooks.clear()
|
||||
}
|
||||
|
||||
// Test utility function to clear all hooks
|
||||
export function clearAllAsyncHooks(): void {
|
||||
for (const hook of pendingHooks.values()) {
|
||||
hook.stopProgressInterval()
|
||||
}
|
||||
pendingHooks.clear()
|
||||
}
|
||||
141
src/utils/hooks/apiQueryHookHelper.ts
Normal file
141
src/utils/hooks/apiQueryHookHelper.ts
Normal file
@@ -0,0 +1,141 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import type { QuerySource } from '../../constants/querySource.js'
|
||||
import { queryModelWithoutStreaming } from '../../services/api/claude.js'
|
||||
import type { Message } from '../../types/message.js'
|
||||
import { createAbortController } from '../../utils/abortController.js'
|
||||
import { logError } from '../../utils/log.js'
|
||||
import { toError } from '../errors.js'
|
||||
import { extractTextContent } from '../messages.js'
|
||||
import { asSystemPrompt } from '../systemPromptType.js'
|
||||
import type { REPLHookContext } from './postSamplingHooks.js'
|
||||
|
||||
export type ApiQueryHookContext = REPLHookContext & {
|
||||
queryMessageCount?: number
|
||||
}
|
||||
|
||||
export type ApiQueryHookConfig<TResult> = {
|
||||
name: QuerySource
|
||||
shouldRun: (context: ApiQueryHookContext) => Promise<boolean>
|
||||
|
||||
// Build the complete message list to send to the API
|
||||
buildMessages: (context: ApiQueryHookContext) => Message[]
|
||||
|
||||
// Optional: override system prompt (defaults to context.systemPrompt)
|
||||
systemPrompt?: string
|
||||
|
||||
// Optional: whether to use tools from context (defaults to true)
|
||||
// Set to false to pass empty tools array
|
||||
useTools?: boolean
|
||||
|
||||
parseResponse: (content: string, context: ApiQueryHookContext) => TResult
|
||||
logResult: (
|
||||
result: ApiQueryResult<TResult>,
|
||||
context: ApiQueryHookContext,
|
||||
) => void
|
||||
// Must be a function to ensure lazy loading (config is accessed before allowed)
|
||||
// Receives context so callers can inherit the main loop model if desired.
|
||||
getModel: (context: ApiQueryHookContext) => string
|
||||
}
|
||||
|
||||
export type ApiQueryResult<TResult> =
|
||||
| {
|
||||
type: 'success'
|
||||
queryName: string
|
||||
result: TResult
|
||||
messageId: string
|
||||
model: string
|
||||
uuid: string
|
||||
}
|
||||
| {
|
||||
type: 'error'
|
||||
queryName: string
|
||||
error: Error
|
||||
uuid: string
|
||||
}
|
||||
|
||||
export function createApiQueryHook<TResult>(
|
||||
config: ApiQueryHookConfig<TResult>,
|
||||
) {
|
||||
return async (context: ApiQueryHookContext): Promise<void> => {
|
||||
try {
|
||||
const shouldRun = await config.shouldRun(context)
|
||||
if (!shouldRun) {
|
||||
return
|
||||
}
|
||||
|
||||
const uuid = randomUUID()
|
||||
|
||||
// Build messages using the config's buildMessages function
|
||||
const messages = config.buildMessages(context)
|
||||
context.queryMessageCount = messages.length
|
||||
|
||||
// Use config's system prompt if provided, otherwise use context's
|
||||
const systemPrompt = config.systemPrompt
|
||||
? asSystemPrompt([config.systemPrompt])
|
||||
: context.systemPrompt
|
||||
|
||||
// Use config's tools preference (defaults to true = use context tools)
|
||||
const useTools = config.useTools ?? true
|
||||
const tools = useTools ? context.toolUseContext.options.tools : []
|
||||
|
||||
// Get model (lazy loaded)
|
||||
const model = config.getModel(context)
|
||||
|
||||
// Make API call
|
||||
const response = await queryModelWithoutStreaming({
|
||||
messages,
|
||||
systemPrompt,
|
||||
thinkingConfig: { type: 'disabled' as const },
|
||||
tools,
|
||||
signal: createAbortController().signal,
|
||||
options: {
|
||||
getToolPermissionContext: async () => {
|
||||
const appState = context.toolUseContext.getAppState()
|
||||
return appState.toolPermissionContext
|
||||
},
|
||||
model,
|
||||
toolChoice: undefined,
|
||||
isNonInteractiveSession:
|
||||
context.toolUseContext.options.isNonInteractiveSession,
|
||||
hasAppendSystemPrompt:
|
||||
!!context.toolUseContext.options.appendSystemPrompt,
|
||||
temperatureOverride: 0,
|
||||
agents: context.toolUseContext.options.agentDefinitions.activeAgents,
|
||||
querySource: config.name,
|
||||
mcpTools: [],
|
||||
agentId: context.toolUseContext.agentId,
|
||||
},
|
||||
})
|
||||
|
||||
// Parse response
|
||||
const content = extractTextContent(response.message.content).trim()
|
||||
|
||||
try {
|
||||
const result = config.parseResponse(content, context)
|
||||
config.logResult(
|
||||
{
|
||||
type: 'success',
|
||||
queryName: config.name,
|
||||
result,
|
||||
messageId: response.message.id,
|
||||
model,
|
||||
uuid,
|
||||
},
|
||||
context,
|
||||
)
|
||||
} catch (error) {
|
||||
config.logResult(
|
||||
{
|
||||
type: 'error',
|
||||
queryName: config.name,
|
||||
error: error as Error,
|
||||
uuid,
|
||||
},
|
||||
context,
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
logError(toError(error))
|
||||
}
|
||||
}
|
||||
}
|
||||
339
src/utils/hooks/execAgentHook.ts
Normal file
339
src/utils/hooks/execAgentHook.ts
Normal file
@@ -0,0 +1,339 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import type { HookEvent } from 'src/entrypoints/agentSdkTypes.js'
|
||||
import { query } from '../../query.js'
|
||||
import { logEvent } from '../../services/analytics/index.js'
|
||||
import type { AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS } from '../../services/analytics/metadata.js'
|
||||
import type { ToolUseContext } from '../../Tool.js'
|
||||
import { type Tool, toolMatchesName } from '../../Tool.js'
|
||||
import { SYNTHETIC_OUTPUT_TOOL_NAME } from '../../tools/SyntheticOutputTool/SyntheticOutputTool.js'
|
||||
import { ALL_AGENT_DISALLOWED_TOOLS } from '../../tools.js'
|
||||
import { asAgentId } from '../../types/ids.js'
|
||||
import type { Message } from '../../types/message.js'
|
||||
import { createAbortController } from '../abortController.js'
|
||||
import { createAttachmentMessage } from '../attachments.js'
|
||||
import { createCombinedAbortSignal } from '../combinedAbortSignal.js'
|
||||
import { logForDebugging } from '../debug.js'
|
||||
import { errorMessage } from '../errors.js'
|
||||
import type { HookResult } from '../hooks.js'
|
||||
import { createUserMessage, handleMessageFromStream } from '../messages.js'
|
||||
import { getSmallFastModel } from '../model/model.js'
|
||||
import { hasPermissionsToUseTool } from '../permissions/permissions.js'
|
||||
import { getAgentTranscriptPath, getTranscriptPath } from '../sessionStorage.js'
|
||||
import type { AgentHook } from '../settings/types.js'
|
||||
import { jsonStringify } from '../slowOperations.js'
|
||||
import { asSystemPrompt } from '../systemPromptType.js'
|
||||
import {
|
||||
addArgumentsToPrompt,
|
||||
createStructuredOutputTool,
|
||||
hookResponseSchema,
|
||||
registerStructuredOutputEnforcement,
|
||||
} from './hookHelpers.js'
|
||||
import { clearSessionHooks } from './sessionHooks.js'
|
||||
|
||||
/**
|
||||
* Execute an agent-based hook using a multi-turn LLM query
|
||||
*/
|
||||
export async function execAgentHook(
|
||||
hook: AgentHook,
|
||||
hookName: string,
|
||||
hookEvent: HookEvent,
|
||||
jsonInput: string,
|
||||
signal: AbortSignal,
|
||||
toolUseContext: ToolUseContext,
|
||||
toolUseID: string | undefined,
|
||||
// Kept for signature stability with the other exec*Hook functions.
|
||||
// Was used by hook.prompt(messages) before the .transform() was removed
|
||||
// (CC-79) — the only consumer of that was ExitPlanModeV2Tool's
|
||||
// programmatic construction, since refactored into VerifyPlanExecutionTool.
|
||||
_messages: Message[],
|
||||
agentName?: string,
|
||||
): Promise<HookResult> {
|
||||
const effectiveToolUseID = toolUseID || `hook-${randomUUID()}`
|
||||
|
||||
// Get transcript path from context
|
||||
const transcriptPath = toolUseContext.agentId
|
||||
? getAgentTranscriptPath(toolUseContext.agentId)
|
||||
: getTranscriptPath()
|
||||
const hookStartTime = Date.now()
|
||||
try {
|
||||
// Replace $ARGUMENTS with the JSON input
|
||||
const processedPrompt = addArgumentsToPrompt(hook.prompt, jsonInput)
|
||||
logForDebugging(
|
||||
`Hooks: Processing agent hook with prompt: ${processedPrompt}`,
|
||||
)
|
||||
|
||||
// Create user message directly - no need for processUserInput which would
|
||||
// trigger UserPromptSubmit hooks and cause infinite recursion
|
||||
const userMessage = createUserMessage({ content: processedPrompt })
|
||||
const agentMessages = [userMessage]
|
||||
|
||||
logForDebugging(
|
||||
`Hooks: Starting agent query with ${agentMessages.length} messages`,
|
||||
)
|
||||
|
||||
// Setup timeout and combine with parent signal
|
||||
const hookTimeoutMs = hook.timeout ? hook.timeout * 1000 : 60000
|
||||
const hookAbortController = createAbortController()
|
||||
|
||||
// Combine parent signal with timeout, and have it abort our controller
|
||||
const { signal: parentTimeoutSignal, cleanup: cleanupCombinedSignal } =
|
||||
createCombinedAbortSignal(signal, { timeoutMs: hookTimeoutMs })
|
||||
const onParentTimeout = () => hookAbortController.abort()
|
||||
parentTimeoutSignal.addEventListener('abort', onParentTimeout)
|
||||
|
||||
// Combined signal is just our controller's signal now
|
||||
const combinedSignal = hookAbortController.signal
|
||||
|
||||
try {
|
||||
// Create StructuredOutput tool with our schema
|
||||
const structuredOutputTool = createStructuredOutputTool()
|
||||
|
||||
// Filter out any existing StructuredOutput tool to avoid duplicates with different schemas
|
||||
// (e.g., when parent context has a StructuredOutput tool from --json-schema flag)
|
||||
const filteredTools = toolUseContext.options.tools.filter(
|
||||
tool => !toolMatchesName(tool, SYNTHETIC_OUTPUT_TOOL_NAME),
|
||||
)
|
||||
|
||||
// Use all available tools plus our structured output tool
|
||||
// Filter out disallowed agent tools to prevent stop hook agents from spawning subagents
|
||||
// or entering plan mode, and filter out duplicate StructuredOutput tools
|
||||
const tools: Tool[] = [
|
||||
...filteredTools.filter(
|
||||
tool => !ALL_AGENT_DISALLOWED_TOOLS.has(tool.name),
|
||||
),
|
||||
structuredOutputTool,
|
||||
]
|
||||
|
||||
const systemPrompt = asSystemPrompt([
|
||||
`You are verifying a stop condition in Claude Code. Your task is to verify that the agent completed the given plan. The conversation transcript is available at: ${transcriptPath}\nYou can read this file to analyze the conversation history if needed.
|
||||
|
||||
Use the available tools to inspect the codebase and verify the condition.
|
||||
Use as few steps as possible - be efficient and direct.
|
||||
|
||||
When done, return your result using the ${SYNTHETIC_OUTPUT_TOOL_NAME} tool with:
|
||||
- ok: true if the condition is met
|
||||
- ok: false with reason if the condition is not met`,
|
||||
])
|
||||
|
||||
const model = hook.model ?? getSmallFastModel()
|
||||
const MAX_AGENT_TURNS = 50
|
||||
|
||||
// Create unique agentId for this hook agent
|
||||
const hookAgentId = asAgentId(`hook-agent-${randomUUID()}`)
|
||||
|
||||
// Create a modified toolUseContext for the agent
|
||||
const agentToolUseContext: ToolUseContext = {
|
||||
...toolUseContext,
|
||||
agentId: hookAgentId,
|
||||
abortController: hookAbortController,
|
||||
options: {
|
||||
...toolUseContext.options,
|
||||
tools,
|
||||
mainLoopModel: model,
|
||||
isNonInteractiveSession: true,
|
||||
thinkingConfig: { type: 'disabled' as const },
|
||||
},
|
||||
setInProgressToolUseIDs: () => {},
|
||||
getAppState() {
|
||||
const appState = toolUseContext.getAppState()
|
||||
// Add session rule to allow reading transcript file
|
||||
const existingSessionRules =
|
||||
appState.toolPermissionContext.alwaysAllowRules.session ?? []
|
||||
return {
|
||||
...appState,
|
||||
toolPermissionContext: {
|
||||
...appState.toolPermissionContext,
|
||||
mode: 'dontAsk' as const,
|
||||
alwaysAllowRules: {
|
||||
...appState.toolPermissionContext.alwaysAllowRules,
|
||||
session: [...existingSessionRules, `Read(/${transcriptPath})`],
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
// Register a session-level stop hook to enforce structured output
|
||||
registerStructuredOutputEnforcement(
|
||||
toolUseContext.setAppState,
|
||||
hookAgentId,
|
||||
)
|
||||
|
||||
let structuredOutputResult: { ok: boolean; reason?: string } | null = null
|
||||
let turnCount = 0
|
||||
let hitMaxTurns = false
|
||||
|
||||
// Use query() for multi-turn execution
|
||||
for await (const message of query({
|
||||
messages: agentMessages,
|
||||
systemPrompt,
|
||||
userContext: {},
|
||||
systemContext: {},
|
||||
canUseTool: hasPermissionsToUseTool,
|
||||
toolUseContext: agentToolUseContext,
|
||||
querySource: 'hook_agent',
|
||||
})) {
|
||||
// Process stream events to update response length in the spinner
|
||||
handleMessageFromStream(
|
||||
message,
|
||||
() => {}, // onMessage - we handle messages below
|
||||
newContent =>
|
||||
toolUseContext.setResponseLength(
|
||||
length => length + newContent.length,
|
||||
),
|
||||
toolUseContext.setStreamMode ?? (() => {}),
|
||||
() => {}, // onStreamingToolUses - not needed for hooks
|
||||
)
|
||||
|
||||
// Skip streaming events for further processing
|
||||
if (
|
||||
message.type === 'stream_event' ||
|
||||
message.type === 'stream_request_start'
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Count assistant turns
|
||||
if (message.type === 'assistant') {
|
||||
turnCount++
|
||||
|
||||
// Check if we've hit the turn limit
|
||||
if (turnCount >= MAX_AGENT_TURNS) {
|
||||
hitMaxTurns = true
|
||||
logForDebugging(
|
||||
`Hooks: Agent turn ${turnCount} hit max turns, aborting`,
|
||||
)
|
||||
hookAbortController.abort()
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check for structured output in attachments
|
||||
if (
|
||||
message.type === 'attachment' &&
|
||||
message.attachment.type === 'structured_output'
|
||||
) {
|
||||
const parsed = hookResponseSchema().safeParse(message.attachment.data)
|
||||
if (parsed.success) {
|
||||
structuredOutputResult = parsed.data
|
||||
logForDebugging(
|
||||
`Hooks: Got structured output: ${jsonStringify(structuredOutputResult)}`,
|
||||
)
|
||||
// Got structured output, abort and exit
|
||||
hookAbortController.abort()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parentTimeoutSignal.removeEventListener('abort', onParentTimeout)
|
||||
cleanupCombinedSignal()
|
||||
|
||||
// Clean up the session hook we registered for this agent
|
||||
clearSessionHooks(toolUseContext.setAppState, hookAgentId)
|
||||
|
||||
// Check if we got a result
|
||||
if (!structuredOutputResult) {
|
||||
// If we hit max turns, just log and return cancelled (no UI message)
|
||||
if (hitMaxTurns) {
|
||||
logForDebugging(
|
||||
`Hooks: Agent hook did not complete within ${MAX_AGENT_TURNS} turns`,
|
||||
)
|
||||
logEvent('tengu_agent_stop_hook_max_turns', {
|
||||
durationMs: Date.now() - hookStartTime,
|
||||
turnCount,
|
||||
agentName:
|
||||
agentName as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return {
|
||||
hook,
|
||||
outcome: 'cancelled',
|
||||
}
|
||||
}
|
||||
|
||||
// For other cases (e.g., agent finished without calling structured output tool),
|
||||
// just log and return cancelled (don't show error to user)
|
||||
logForDebugging(`Hooks: Agent hook did not return structured output`)
|
||||
logEvent('tengu_agent_stop_hook_error', {
|
||||
durationMs: Date.now() - hookStartTime,
|
||||
turnCount,
|
||||
errorType: 1, // 1 = no structured output
|
||||
agentName:
|
||||
agentName as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return {
|
||||
hook,
|
||||
outcome: 'cancelled',
|
||||
}
|
||||
}
|
||||
|
||||
// Return result based on structured output
|
||||
if (!structuredOutputResult.ok) {
|
||||
logForDebugging(
|
||||
`Hooks: Agent hook condition was not met: ${structuredOutputResult.reason}`,
|
||||
)
|
||||
return {
|
||||
hook,
|
||||
outcome: 'blocking',
|
||||
blockingError: {
|
||||
blockingError: `Agent hook condition was not met: ${structuredOutputResult.reason}`,
|
||||
command: hook.prompt,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Condition was met
|
||||
logForDebugging(`Hooks: Agent hook condition was met`)
|
||||
logEvent('tengu_agent_stop_hook_success', {
|
||||
durationMs: Date.now() - hookStartTime,
|
||||
turnCount,
|
||||
agentName:
|
||||
agentName as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return {
|
||||
hook,
|
||||
outcome: 'success',
|
||||
message: createAttachmentMessage({
|
||||
type: 'hook_success',
|
||||
hookName,
|
||||
toolUseID: effectiveToolUseID,
|
||||
hookEvent,
|
||||
content: '',
|
||||
}),
|
||||
}
|
||||
} catch (error) {
|
||||
parentTimeoutSignal.removeEventListener('abort', onParentTimeout)
|
||||
cleanupCombinedSignal()
|
||||
|
||||
if (combinedSignal.aborted) {
|
||||
return {
|
||||
hook,
|
||||
outcome: 'cancelled',
|
||||
}
|
||||
}
|
||||
throw error
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMsg = errorMessage(error)
|
||||
logForDebugging(`Hooks: Agent hook error: ${errorMsg}`)
|
||||
logEvent('tengu_agent_stop_hook_error', {
|
||||
durationMs: Date.now() - hookStartTime,
|
||||
errorType: 2, // 2 = general error
|
||||
agentName:
|
||||
agentName as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
})
|
||||
return {
|
||||
hook,
|
||||
outcome: 'non_blocking_error',
|
||||
message: createAttachmentMessage({
|
||||
type: 'hook_non_blocking_error',
|
||||
hookName,
|
||||
toolUseID: effectiveToolUseID,
|
||||
hookEvent,
|
||||
stderr: `Error executing agent hook: ${errorMsg}`,
|
||||
stdout: '',
|
||||
exitCode: 1,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
242
src/utils/hooks/execHttpHook.ts
Normal file
242
src/utils/hooks/execHttpHook.ts
Normal file
@@ -0,0 +1,242 @@
|
||||
import axios from 'axios'
|
||||
import type { HookEvent } from 'src/entrypoints/agentSdkTypes.js'
|
||||
import { createCombinedAbortSignal } from '../combinedAbortSignal.js'
|
||||
import { logForDebugging } from '../debug.js'
|
||||
import { errorMessage } from '../errors.js'
|
||||
import { getProxyUrl, shouldBypassProxy } from '../proxy.js'
|
||||
// Import as namespace so spyOn works in tests (direct imports bypass spies)
|
||||
import * as settingsModule from '../settings/settings.js'
|
||||
import type { HttpHook } from '../settings/types.js'
|
||||
import { ssrfGuardedLookup } from './ssrfGuard.js'
|
||||
|
||||
const DEFAULT_HTTP_HOOK_TIMEOUT_MS = 10 * 60 * 1000 // 10 minutes (matches TOOL_HOOK_EXECUTION_TIMEOUT_MS)
|
||||
|
||||
/**
|
||||
* Get the sandbox proxy config for routing HTTP hook requests through the
|
||||
* sandbox network proxy when sandboxing is enabled.
|
||||
*
|
||||
* Uses dynamic import to avoid a static import cycle
|
||||
* (sandbox-adapter -> settings -> ... -> hooks -> execHttpHook).
|
||||
*/
|
||||
async function getSandboxProxyConfig(): Promise<
|
||||
{ host: string; port: number; protocol: string } | undefined
|
||||
> {
|
||||
const { SandboxManager } = await import('../sandbox/sandbox-adapter.js')
|
||||
|
||||
if (!SandboxManager.isSandboxingEnabled()) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// Wait for the sandbox network proxy to finish initializing. In REPL mode,
|
||||
// SandboxManager.initialize() is fire-and-forget so the proxy may not be
|
||||
// ready yet when the first hook fires.
|
||||
await SandboxManager.waitForNetworkInitialization()
|
||||
|
||||
const proxyPort = SandboxManager.getProxyPort()
|
||||
if (!proxyPort) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return { host: '127.0.0.1', port: proxyPort, protocol: 'http' }
|
||||
}
|
||||
|
||||
/**
|
||||
* Read HTTP hook allowlist restrictions from merged settings (all sources).
|
||||
* Follows the allowedMcpServers precedent: arrays concatenate across sources.
|
||||
* When allowManagedHooksOnly is set in managed settings, only admin-defined
|
||||
* hooks run anyway, so no separate lock-down boolean is needed here.
|
||||
*/
|
||||
function getHttpHookPolicy(): {
|
||||
allowedUrls: string[] | undefined
|
||||
allowedEnvVars: string[] | undefined
|
||||
} {
|
||||
const settings = settingsModule.getInitialSettings()
|
||||
return {
|
||||
allowedUrls: settings.allowedHttpHookUrls,
|
||||
allowedEnvVars: settings.httpHookAllowedEnvVars,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Match a URL against a pattern with * as a wildcard (any characters).
|
||||
* Same semantics as the MCP server allowlist patterns.
|
||||
*/
|
||||
function urlMatchesPattern(url: string, pattern: string): boolean {
|
||||
const escaped = pattern.replace(/[.+?^${}()|[\]\\]/g, '\\$&')
|
||||
const regexStr = escaped.replace(/\*/g, '.*')
|
||||
return new RegExp(`^${regexStr}$`).test(url)
|
||||
}
|
||||
|
||||
/**
|
||||
* Strip CR, LF, and NUL bytes from a header value to prevent HTTP header
|
||||
* injection (CRLF injection) via env var values or hook-configured header
|
||||
* templates. A malicious env var like "token\r\nX-Evil: 1" would otherwise
|
||||
* inject a second header into the request.
|
||||
*/
|
||||
function sanitizeHeaderValue(value: string): string {
|
||||
// eslint-disable-next-line no-control-regex
|
||||
return value.replace(/[\r\n\x00]/g, '')
|
||||
}
|
||||
|
||||
/**
|
||||
* Interpolate $VAR_NAME and ${VAR_NAME} patterns in a string using process.env,
|
||||
* but only for variable names present in the allowlist. References to variables
|
||||
* not in the allowlist are replaced with empty strings to prevent exfiltration
|
||||
* of secrets via project-configured HTTP hooks.
|
||||
*
|
||||
* The result is sanitized to strip CR/LF/NUL bytes to prevent header injection.
|
||||
*/
|
||||
function interpolateEnvVars(
|
||||
value: string,
|
||||
allowedEnvVars: ReadonlySet<string>,
|
||||
): string {
|
||||
const interpolated = value.replace(
|
||||
/\$\{([A-Z_][A-Z0-9_]*)\}|\$([A-Z_][A-Z0-9_]*)/g,
|
||||
(_, braced, unbraced) => {
|
||||
const varName = braced ?? unbraced
|
||||
if (!allowedEnvVars.has(varName)) {
|
||||
logForDebugging(
|
||||
`Hooks: env var $${varName} not in allowedEnvVars, skipping interpolation`,
|
||||
{ level: 'warn' },
|
||||
)
|
||||
return ''
|
||||
}
|
||||
return process.env[varName] ?? ''
|
||||
},
|
||||
)
|
||||
return sanitizeHeaderValue(interpolated)
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute an HTTP hook by POSTing the hook input JSON to the configured URL.
|
||||
* Returns the raw response for the caller to interpret.
|
||||
*
|
||||
* When sandboxing is enabled, requests are routed through the sandbox network
|
||||
* proxy which enforces the domain allowlist. The proxy returns HTTP 403 for
|
||||
* blocked domains.
|
||||
*
|
||||
* Header values support $VAR_NAME and ${VAR_NAME} env var interpolation so that
|
||||
* secrets (e.g. "Authorization: Bearer $MY_TOKEN") are not stored in settings.json.
|
||||
* Only env vars explicitly listed in the hook's `allowedEnvVars` array are resolved;
|
||||
* all other references are replaced with empty strings.
|
||||
*/
|
||||
export async function execHttpHook(
|
||||
hook: HttpHook,
|
||||
_hookEvent: HookEvent,
|
||||
jsonInput: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<{
|
||||
ok: boolean
|
||||
statusCode?: number
|
||||
body: string
|
||||
error?: string
|
||||
aborted?: boolean
|
||||
}> {
|
||||
// Enforce URL allowlist before any I/O. Follows allowedMcpServers semantics:
|
||||
// undefined → no restriction; [] → block all; non-empty → must match a pattern.
|
||||
const policy = getHttpHookPolicy()
|
||||
if (policy.allowedUrls !== undefined) {
|
||||
const matched = policy.allowedUrls.some(p => urlMatchesPattern(hook.url, p))
|
||||
if (!matched) {
|
||||
const msg = `HTTP hook blocked: ${hook.url} does not match any pattern in allowedHttpHookUrls`
|
||||
logForDebugging(msg, { level: 'warn' })
|
||||
return { ok: false, body: '', error: msg }
|
||||
}
|
||||
}
|
||||
|
||||
const timeoutMs = hook.timeout
|
||||
? hook.timeout * 1000
|
||||
: DEFAULT_HTTP_HOOK_TIMEOUT_MS
|
||||
|
||||
const { signal: combinedSignal, cleanup } = createCombinedAbortSignal(
|
||||
signal,
|
||||
{ timeoutMs },
|
||||
)
|
||||
|
||||
try {
|
||||
// Build headers with env var interpolation in values
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
if (hook.headers) {
|
||||
// Intersect hook's allowedEnvVars with policy allowlist when policy is set
|
||||
const hookVars = hook.allowedEnvVars ?? []
|
||||
const effectiveVars =
|
||||
policy.allowedEnvVars !== undefined
|
||||
? hookVars.filter(v => policy.allowedEnvVars!.includes(v))
|
||||
: hookVars
|
||||
const allowedEnvVars = new Set(effectiveVars)
|
||||
for (const [name, value] of Object.entries(hook.headers)) {
|
||||
headers[name] = interpolateEnvVars(value, allowedEnvVars)
|
||||
}
|
||||
}
|
||||
|
||||
// Route through sandbox network proxy when available. The proxy enforces
|
||||
// the domain allowlist and returns 403 for blocked domains.
|
||||
const sandboxProxy = await getSandboxProxyConfig()
|
||||
|
||||
// Detect env var proxy (HTTP_PROXY / HTTPS_PROXY, respecting NO_PROXY).
|
||||
// When set, configureGlobalAgents() has already installed a request
|
||||
// interceptor that sets httpsAgent to an HttpsProxyAgent — the proxy
|
||||
// handles DNS for the target. Skip the SSRF guard in that case, same
|
||||
// as we do for the sandbox proxy, so that we don't accidentally block
|
||||
// a corporate proxy sitting on a private IP (e.g. 10.0.0.1:3128).
|
||||
const envProxyActive =
|
||||
!sandboxProxy &&
|
||||
getProxyUrl() !== undefined &&
|
||||
!shouldBypassProxy(hook.url)
|
||||
|
||||
if (sandboxProxy) {
|
||||
logForDebugging(
|
||||
`Hooks: HTTP hook POST to ${hook.url} (via sandbox proxy :${sandboxProxy.port})`,
|
||||
)
|
||||
} else if (envProxyActive) {
|
||||
logForDebugging(
|
||||
`Hooks: HTTP hook POST to ${hook.url} (via env-var proxy)`,
|
||||
)
|
||||
} else {
|
||||
logForDebugging(`Hooks: HTTP hook POST to ${hook.url}`)
|
||||
}
|
||||
|
||||
const response = await axios.post<string>(hook.url, jsonInput, {
|
||||
headers,
|
||||
signal: combinedSignal,
|
||||
responseType: 'text',
|
||||
validateStatus: () => true,
|
||||
maxRedirects: 0,
|
||||
// Explicit false prevents axios's own env-var proxy detection; when an
|
||||
// env-var proxy is configured, the global axios interceptor installed
|
||||
// by configureGlobalAgents() handles it via httpsAgent instead.
|
||||
proxy: sandboxProxy ?? false,
|
||||
// SSRF guard: validate resolved IPs, block private/link-local ranges
|
||||
// (but allow loopback for local dev). Skipped when any proxy is in
|
||||
// use — the proxy performs DNS for the target, and applying the
|
||||
// guard would instead validate the proxy's own IP, breaking
|
||||
// connections to corporate proxies on private networks.
|
||||
lookup: sandboxProxy || envProxyActive ? undefined : ssrfGuardedLookup,
|
||||
})
|
||||
|
||||
cleanup()
|
||||
|
||||
const body = response.data ?? ''
|
||||
logForDebugging(
|
||||
`Hooks: HTTP hook response status ${response.status}, body length ${body.length}`,
|
||||
)
|
||||
|
||||
return {
|
||||
ok: response.status >= 200 && response.status < 300,
|
||||
statusCode: response.status,
|
||||
body,
|
||||
}
|
||||
} catch (error) {
|
||||
cleanup()
|
||||
|
||||
if (combinedSignal.aborted) {
|
||||
return { ok: false, body: '', aborted: true }
|
||||
}
|
||||
|
||||
const errorMsg = errorMessage(error)
|
||||
logForDebugging(`Hooks: HTTP hook error: ${errorMsg}`, { level: 'error' })
|
||||
return { ok: false, body: '', error: errorMsg }
|
||||
}
|
||||
}
|
||||
211
src/utils/hooks/execPromptHook.ts
Normal file
211
src/utils/hooks/execPromptHook.ts
Normal file
@@ -0,0 +1,211 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import type { HookEvent } from 'src/entrypoints/agentSdkTypes.js'
|
||||
import { queryModelWithoutStreaming } from '../../services/api/claude.js'
|
||||
import type { ToolUseContext } from '../../Tool.js'
|
||||
import type { Message } from '../../types/message.js'
|
||||
import { createAttachmentMessage } from '../attachments.js'
|
||||
import { createCombinedAbortSignal } from '../combinedAbortSignal.js'
|
||||
import { logForDebugging } from '../debug.js'
|
||||
import { errorMessage } from '../errors.js'
|
||||
import type { HookResult } from '../hooks.js'
|
||||
import { safeParseJSON } from '../json.js'
|
||||
import { createUserMessage, extractTextContent } from '../messages.js'
|
||||
import { getSmallFastModel } from '../model/model.js'
|
||||
import type { PromptHook } from '../settings/types.js'
|
||||
import { asSystemPrompt } from '../systemPromptType.js'
|
||||
import { addArgumentsToPrompt, hookResponseSchema } from './hookHelpers.js'
|
||||
|
||||
/**
|
||||
* Execute a prompt-based hook using an LLM
|
||||
*/
|
||||
export async function execPromptHook(
|
||||
hook: PromptHook,
|
||||
hookName: string,
|
||||
hookEvent: HookEvent,
|
||||
jsonInput: string,
|
||||
signal: AbortSignal,
|
||||
toolUseContext: ToolUseContext,
|
||||
messages?: Message[],
|
||||
toolUseID?: string,
|
||||
): Promise<HookResult> {
|
||||
// Use provided toolUseID or generate a new one
|
||||
const effectiveToolUseID = toolUseID || `hook-${randomUUID()}`
|
||||
try {
|
||||
// Replace $ARGUMENTS with the JSON input
|
||||
const processedPrompt = addArgumentsToPrompt(hook.prompt, jsonInput)
|
||||
logForDebugging(
|
||||
`Hooks: Processing prompt hook with prompt: ${processedPrompt}`,
|
||||
)
|
||||
|
||||
// Create user message directly - no need for processUserInput which would
|
||||
// trigger UserPromptSubmit hooks and cause infinite recursion
|
||||
const userMessage = createUserMessage({ content: processedPrompt })
|
||||
|
||||
// Prepend conversation history if provided
|
||||
const messagesToQuery =
|
||||
messages && messages.length > 0
|
||||
? [...messages, userMessage]
|
||||
: [userMessage]
|
||||
|
||||
logForDebugging(
|
||||
`Hooks: Querying model with ${messagesToQuery.length} messages`,
|
||||
)
|
||||
|
||||
// Query the model with Haiku
|
||||
const hookTimeoutMs = hook.timeout ? hook.timeout * 1000 : 30000
|
||||
|
||||
// Combined signal: aborts if either the hook signal or timeout triggers
|
||||
const { signal: combinedSignal, cleanup: cleanupSignal } =
|
||||
createCombinedAbortSignal(signal, { timeoutMs: hookTimeoutMs })
|
||||
|
||||
try {
|
||||
const response = await queryModelWithoutStreaming({
|
||||
messages: messagesToQuery,
|
||||
systemPrompt: asSystemPrompt([
|
||||
`You are evaluating a hook in Claude Code.
|
||||
|
||||
Your response must be a JSON object matching one of the following schemas:
|
||||
1. If the condition is met, return: {"ok": true}
|
||||
2. If the condition is not met, return: {"ok": false, "reason": "Reason for why it is not met"}`,
|
||||
]),
|
||||
thinkingConfig: { type: 'disabled' as const },
|
||||
tools: toolUseContext.options.tools,
|
||||
signal: combinedSignal,
|
||||
options: {
|
||||
async getToolPermissionContext() {
|
||||
const appState = toolUseContext.getAppState()
|
||||
return appState.toolPermissionContext
|
||||
},
|
||||
model: hook.model ?? getSmallFastModel(),
|
||||
toolChoice: undefined,
|
||||
isNonInteractiveSession: true,
|
||||
hasAppendSystemPrompt: false,
|
||||
agents: [],
|
||||
querySource: 'hook_prompt',
|
||||
mcpTools: [],
|
||||
agentId: toolUseContext.agentId,
|
||||
outputFormat: {
|
||||
type: 'json_schema',
|
||||
schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
ok: { type: 'boolean' },
|
||||
reason: { type: 'string' },
|
||||
},
|
||||
required: ['ok'],
|
||||
additionalProperties: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
cleanupSignal()
|
||||
|
||||
// Extract text content from response
|
||||
const content = extractTextContent(response.message.content)
|
||||
|
||||
// Update response length for spinner display
|
||||
toolUseContext.setResponseLength(length => length + content.length)
|
||||
|
||||
const fullResponse = content.trim()
|
||||
logForDebugging(`Hooks: Model response: ${fullResponse}`)
|
||||
|
||||
const json = safeParseJSON(fullResponse)
|
||||
if (!json) {
|
||||
logForDebugging(
|
||||
`Hooks: error parsing response as JSON: ${fullResponse}`,
|
||||
)
|
||||
return {
|
||||
hook,
|
||||
outcome: 'non_blocking_error',
|
||||
message: createAttachmentMessage({
|
||||
type: 'hook_non_blocking_error',
|
||||
hookName,
|
||||
toolUseID: effectiveToolUseID,
|
||||
hookEvent,
|
||||
stderr: 'JSON validation failed',
|
||||
stdout: fullResponse,
|
||||
exitCode: 1,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
const parsed = hookResponseSchema().safeParse(json)
|
||||
if (!parsed.success) {
|
||||
logForDebugging(
|
||||
`Hooks: model response does not conform to expected schema: ${parsed.error.message}`,
|
||||
)
|
||||
return {
|
||||
hook,
|
||||
outcome: 'non_blocking_error',
|
||||
message: createAttachmentMessage({
|
||||
type: 'hook_non_blocking_error',
|
||||
hookName,
|
||||
toolUseID: effectiveToolUseID,
|
||||
hookEvent,
|
||||
stderr: `Schema validation failed: ${parsed.error.message}`,
|
||||
stdout: fullResponse,
|
||||
exitCode: 1,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// Failed to meet condition
|
||||
if (!parsed.data.ok) {
|
||||
logForDebugging(
|
||||
`Hooks: Prompt hook condition was not met: ${parsed.data.reason}`,
|
||||
)
|
||||
return {
|
||||
hook,
|
||||
outcome: 'blocking',
|
||||
blockingError: {
|
||||
blockingError: `Prompt hook condition was not met: ${parsed.data.reason}`,
|
||||
command: hook.prompt,
|
||||
},
|
||||
preventContinuation: true,
|
||||
stopReason: parsed.data.reason,
|
||||
}
|
||||
}
|
||||
|
||||
// Condition was met
|
||||
logForDebugging(`Hooks: Prompt hook condition was met`)
|
||||
return {
|
||||
hook,
|
||||
outcome: 'success',
|
||||
message: createAttachmentMessage({
|
||||
type: 'hook_success',
|
||||
hookName,
|
||||
toolUseID: effectiveToolUseID,
|
||||
hookEvent,
|
||||
content: '',
|
||||
}),
|
||||
}
|
||||
} catch (error) {
|
||||
cleanupSignal()
|
||||
|
||||
if (combinedSignal.aborted) {
|
||||
return {
|
||||
hook,
|
||||
outcome: 'cancelled',
|
||||
}
|
||||
}
|
||||
throw error
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMsg = errorMessage(error)
|
||||
logForDebugging(`Hooks: Prompt hook error: ${errorMsg}`)
|
||||
return {
|
||||
hook,
|
||||
outcome: 'non_blocking_error',
|
||||
message: createAttachmentMessage({
|
||||
type: 'hook_non_blocking_error',
|
||||
hookName,
|
||||
toolUseID: effectiveToolUseID,
|
||||
hookEvent,
|
||||
stderr: `Error executing prompt hook: ${errorMsg}`,
|
||||
stdout: '',
|
||||
exitCode: 1,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
191
src/utils/hooks/fileChangedWatcher.ts
Normal file
191
src/utils/hooks/fileChangedWatcher.ts
Normal file
@@ -0,0 +1,191 @@
|
||||
import chokidar, { type FSWatcher } from 'chokidar'
|
||||
import { isAbsolute, join } from 'path'
|
||||
import { registerCleanup } from '../cleanupRegistry.js'
|
||||
import { logForDebugging } from '../debug.js'
|
||||
import { errorMessage } from '../errors.js'
|
||||
import {
|
||||
executeCwdChangedHooks,
|
||||
executeFileChangedHooks,
|
||||
type HookOutsideReplResult,
|
||||
} from '../hooks.js'
|
||||
import { clearCwdEnvFiles } from '../sessionEnvironment.js'
|
||||
import { getHooksConfigFromSnapshot } from './hooksConfigSnapshot.js'
|
||||
|
||||
let watcher: FSWatcher | null = null
|
||||
let currentCwd: string
|
||||
let dynamicWatchPaths: string[] = []
|
||||
let dynamicWatchPathsSorted: string[] = []
|
||||
let initialized = false
|
||||
let hasEnvHooks = false
|
||||
let notifyCallback: ((text: string, isError: boolean) => void) | null = null
|
||||
|
||||
export function setEnvHookNotifier(
|
||||
cb: ((text: string, isError: boolean) => void) | null,
|
||||
): void {
|
||||
notifyCallback = cb
|
||||
}
|
||||
|
||||
export function initializeFileChangedWatcher(cwd: string): void {
|
||||
if (initialized) return
|
||||
initialized = true
|
||||
currentCwd = cwd
|
||||
|
||||
const config = getHooksConfigFromSnapshot()
|
||||
hasEnvHooks =
|
||||
(config?.CwdChanged?.length ?? 0) > 0 ||
|
||||
(config?.FileChanged?.length ?? 0) > 0
|
||||
|
||||
if (hasEnvHooks) {
|
||||
registerCleanup(async () => dispose())
|
||||
}
|
||||
|
||||
const paths = resolveWatchPaths(config)
|
||||
if (paths.length === 0) return
|
||||
|
||||
startWatching(paths)
|
||||
}
|
||||
|
||||
function resolveWatchPaths(
|
||||
config?: ReturnType<typeof getHooksConfigFromSnapshot>,
|
||||
): string[] {
|
||||
const matchers = (config ?? getHooksConfigFromSnapshot())?.FileChanged ?? []
|
||||
|
||||
// Matcher field: filenames to watch in cwd, pipe-separated (e.g. ".envrc|.env")
|
||||
const staticPaths: string[] = []
|
||||
for (const m of matchers) {
|
||||
if (!m.matcher) continue
|
||||
for (const name of m.matcher.split('|').map(s => s.trim())) {
|
||||
if (!name) continue
|
||||
staticPaths.push(isAbsolute(name) ? name : join(currentCwd, name))
|
||||
}
|
||||
}
|
||||
|
||||
// Combine static matcher paths with dynamic paths from hook output
|
||||
return [...new Set([...staticPaths, ...dynamicWatchPaths])]
|
||||
}
|
||||
|
||||
function startWatching(paths: string[]): void {
|
||||
logForDebugging(`FileChanged: watching ${paths.length} paths`)
|
||||
watcher = chokidar.watch(paths, {
|
||||
persistent: true,
|
||||
ignoreInitial: true,
|
||||
awaitWriteFinish: { stabilityThreshold: 500, pollInterval: 200 },
|
||||
ignorePermissionErrors: true,
|
||||
})
|
||||
watcher.on('change', p => handleFileEvent(p, 'change'))
|
||||
watcher.on('add', p => handleFileEvent(p, 'add'))
|
||||
watcher.on('unlink', p => handleFileEvent(p, 'unlink'))
|
||||
}
|
||||
|
||||
function handleFileEvent(
|
||||
path: string,
|
||||
event: 'change' | 'add' | 'unlink',
|
||||
): void {
|
||||
logForDebugging(`FileChanged: ${event} ${path}`)
|
||||
void executeFileChangedHooks(path, event)
|
||||
.then(({ results, watchPaths, systemMessages }) => {
|
||||
if (watchPaths.length > 0) {
|
||||
updateWatchPaths(watchPaths)
|
||||
}
|
||||
for (const msg of systemMessages) {
|
||||
notifyCallback?.(msg, false)
|
||||
}
|
||||
for (const r of results) {
|
||||
if (!r.succeeded && r.output) {
|
||||
notifyCallback?.(r.output, true)
|
||||
}
|
||||
}
|
||||
})
|
||||
.catch(e => {
|
||||
const msg = errorMessage(e)
|
||||
logForDebugging(`FileChanged hook failed: ${msg}`, {
|
||||
level: 'error',
|
||||
})
|
||||
notifyCallback?.(msg, true)
|
||||
})
|
||||
}
|
||||
|
||||
export function updateWatchPaths(paths: string[]): void {
|
||||
if (!initialized) return
|
||||
const sorted = paths.slice().sort()
|
||||
if (
|
||||
sorted.length === dynamicWatchPathsSorted.length &&
|
||||
sorted.every((p, i) => p === dynamicWatchPathsSorted[i])
|
||||
) {
|
||||
return
|
||||
}
|
||||
dynamicWatchPaths = paths
|
||||
dynamicWatchPathsSorted = sorted
|
||||
restartWatching()
|
||||
}
|
||||
|
||||
function restartWatching(): void {
|
||||
if (watcher) {
|
||||
void watcher.close()
|
||||
watcher = null
|
||||
}
|
||||
const paths = resolveWatchPaths()
|
||||
if (paths.length > 0) {
|
||||
startWatching(paths)
|
||||
}
|
||||
}
|
||||
|
||||
export async function onCwdChangedForHooks(
|
||||
oldCwd: string,
|
||||
newCwd: string,
|
||||
): Promise<void> {
|
||||
if (oldCwd === newCwd) return
|
||||
|
||||
// Re-evaluate from the current snapshot so mid-session hook changes are picked up
|
||||
const config = getHooksConfigFromSnapshot()
|
||||
const currentHasEnvHooks =
|
||||
(config?.CwdChanged?.length ?? 0) > 0 ||
|
||||
(config?.FileChanged?.length ?? 0) > 0
|
||||
if (!currentHasEnvHooks) return
|
||||
currentCwd = newCwd
|
||||
|
||||
await clearCwdEnvFiles()
|
||||
const hookResult = await executeCwdChangedHooks(oldCwd, newCwd).catch(e => {
|
||||
const msg = errorMessage(e)
|
||||
logForDebugging(`CwdChanged hook failed: ${msg}`, {
|
||||
level: 'error',
|
||||
})
|
||||
notifyCallback?.(msg, true)
|
||||
return {
|
||||
results: [] as HookOutsideReplResult[],
|
||||
watchPaths: [] as string[],
|
||||
systemMessages: [] as string[],
|
||||
}
|
||||
})
|
||||
dynamicWatchPaths = hookResult.watchPaths
|
||||
dynamicWatchPathsSorted = hookResult.watchPaths.slice().sort()
|
||||
for (const msg of hookResult.systemMessages) {
|
||||
notifyCallback?.(msg, false)
|
||||
}
|
||||
for (const r of hookResult.results) {
|
||||
if (!r.succeeded && r.output) {
|
||||
notifyCallback?.(r.output, true)
|
||||
}
|
||||
}
|
||||
|
||||
// Re-resolve matcher paths against the new cwd
|
||||
if (initialized) {
|
||||
restartWatching()
|
||||
}
|
||||
}
|
||||
|
||||
function dispose(): void {
|
||||
if (watcher) {
|
||||
void watcher.close()
|
||||
watcher = null
|
||||
}
|
||||
dynamicWatchPaths = []
|
||||
dynamicWatchPathsSorted = []
|
||||
initialized = false
|
||||
hasEnvHooks = false
|
||||
notifyCallback = null
|
||||
}
|
||||
|
||||
export function resetFileChangedWatcherForTesting(): void {
|
||||
dispose()
|
||||
}
|
||||
192
src/utils/hooks/hookEvents.ts
Normal file
192
src/utils/hooks/hookEvents.ts
Normal file
@@ -0,0 +1,192 @@
|
||||
/**
|
||||
* Hook event system for broadcasting hook execution events.
|
||||
*
|
||||
* This module provides a generic event system that is separate from the
|
||||
* main message stream. Handlers can register to receive events and decide
|
||||
* what to do with them (e.g., convert to SDK messages, log, etc.).
|
||||
*/
|
||||
|
||||
import { HOOK_EVENTS } from 'src/entrypoints/sdk/coreTypes.js'
|
||||
|
||||
import { logForDebugging } from '../debug.js'
|
||||
|
||||
/**
|
||||
* Hook events that are always emitted regardless of the includeHookEvents
|
||||
* option. These are low-noise lifecycle events that were in the original
|
||||
* allowlist and are backwards-compatible.
|
||||
*/
|
||||
const ALWAYS_EMITTED_HOOK_EVENTS = ['SessionStart', 'Setup'] as const
|
||||
|
||||
const MAX_PENDING_EVENTS = 100
|
||||
|
||||
export type HookStartedEvent = {
|
||||
type: 'started'
|
||||
hookId: string
|
||||
hookName: string
|
||||
hookEvent: string
|
||||
}
|
||||
|
||||
export type HookProgressEvent = {
|
||||
type: 'progress'
|
||||
hookId: string
|
||||
hookName: string
|
||||
hookEvent: string
|
||||
stdout: string
|
||||
stderr: string
|
||||
output: string
|
||||
}
|
||||
|
||||
export type HookResponseEvent = {
|
||||
type: 'response'
|
||||
hookId: string
|
||||
hookName: string
|
||||
hookEvent: string
|
||||
output: string
|
||||
stdout: string
|
||||
stderr: string
|
||||
exitCode?: number
|
||||
outcome: 'success' | 'error' | 'cancelled'
|
||||
}
|
||||
|
||||
export type HookExecutionEvent =
|
||||
| HookStartedEvent
|
||||
| HookProgressEvent
|
||||
| HookResponseEvent
|
||||
export type HookEventHandler = (event: HookExecutionEvent) => void
|
||||
|
||||
const pendingEvents: HookExecutionEvent[] = []
|
||||
let eventHandler: HookEventHandler | null = null
|
||||
let allHookEventsEnabled = false
|
||||
|
||||
export function registerHookEventHandler(
|
||||
handler: HookEventHandler | null,
|
||||
): void {
|
||||
eventHandler = handler
|
||||
if (handler && pendingEvents.length > 0) {
|
||||
for (const event of pendingEvents.splice(0)) {
|
||||
handler(event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function emit(event: HookExecutionEvent): void {
|
||||
if (eventHandler) {
|
||||
eventHandler(event)
|
||||
} else {
|
||||
pendingEvents.push(event)
|
||||
if (pendingEvents.length > MAX_PENDING_EVENTS) {
|
||||
pendingEvents.shift()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function shouldEmit(hookEvent: string): boolean {
|
||||
if ((ALWAYS_EMITTED_HOOK_EVENTS as readonly string[]).includes(hookEvent)) {
|
||||
return true
|
||||
}
|
||||
return (
|
||||
allHookEventsEnabled &&
|
||||
(HOOK_EVENTS as readonly string[]).includes(hookEvent)
|
||||
)
|
||||
}
|
||||
|
||||
export function emitHookStarted(
|
||||
hookId: string,
|
||||
hookName: string,
|
||||
hookEvent: string,
|
||||
): void {
|
||||
if (!shouldEmit(hookEvent)) return
|
||||
|
||||
emit({
|
||||
type: 'started',
|
||||
hookId,
|
||||
hookName,
|
||||
hookEvent,
|
||||
})
|
||||
}
|
||||
|
||||
export function emitHookProgress(data: {
|
||||
hookId: string
|
||||
hookName: string
|
||||
hookEvent: string
|
||||
stdout: string
|
||||
stderr: string
|
||||
output: string
|
||||
}): void {
|
||||
if (!shouldEmit(data.hookEvent)) return
|
||||
|
||||
emit({
|
||||
type: 'progress',
|
||||
...data,
|
||||
})
|
||||
}
|
||||
|
||||
export function startHookProgressInterval(params: {
|
||||
hookId: string
|
||||
hookName: string
|
||||
hookEvent: string
|
||||
getOutput: () => Promise<{ stdout: string; stderr: string; output: string }>
|
||||
intervalMs?: number
|
||||
}): () => void {
|
||||
if (!shouldEmit(params.hookEvent)) return () => {}
|
||||
|
||||
let lastEmittedOutput = ''
|
||||
const interval = setInterval(() => {
|
||||
void params.getOutput().then(({ stdout, stderr, output }) => {
|
||||
if (output === lastEmittedOutput) return
|
||||
lastEmittedOutput = output
|
||||
emitHookProgress({
|
||||
hookId: params.hookId,
|
||||
hookName: params.hookName,
|
||||
hookEvent: params.hookEvent,
|
||||
stdout,
|
||||
stderr,
|
||||
output,
|
||||
})
|
||||
})
|
||||
}, params.intervalMs ?? 1000)
|
||||
interval.unref()
|
||||
|
||||
return () => clearInterval(interval)
|
||||
}
|
||||
|
||||
export function emitHookResponse(data: {
|
||||
hookId: string
|
||||
hookName: string
|
||||
hookEvent: string
|
||||
output: string
|
||||
stdout: string
|
||||
stderr: string
|
||||
exitCode?: number
|
||||
outcome: 'success' | 'error' | 'cancelled'
|
||||
}): void {
|
||||
// Always log full hook output to debug log for verbose mode debugging
|
||||
const outputToLog = data.stdout || data.stderr || data.output
|
||||
if (outputToLog) {
|
||||
logForDebugging(
|
||||
`Hook ${data.hookName} (${data.hookEvent}) ${data.outcome}:\n${outputToLog}`,
|
||||
)
|
||||
}
|
||||
|
||||
if (!shouldEmit(data.hookEvent)) return
|
||||
|
||||
emit({
|
||||
type: 'response',
|
||||
...data,
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable emission of all hook event types (beyond SessionStart and Setup).
|
||||
* Called when the SDK `includeHookEvents` option is set or when running
|
||||
* in CLAUDE_CODE_REMOTE mode.
|
||||
*/
|
||||
export function setAllHookEventsEnabled(enabled: boolean): void {
|
||||
allHookEventsEnabled = enabled
|
||||
}
|
||||
|
||||
export function clearHookEventState(): void {
|
||||
eventHandler = null
|
||||
pendingEvents.length = 0
|
||||
allHookEventsEnabled = false
|
||||
}
|
||||
83
src/utils/hooks/hookHelpers.ts
Normal file
83
src/utils/hooks/hookHelpers.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
import { z } from 'zod/v4'
|
||||
import type { Tool } from '../../Tool.js'
|
||||
import {
|
||||
SYNTHETIC_OUTPUT_TOOL_NAME,
|
||||
SyntheticOutputTool,
|
||||
} from '../../tools/SyntheticOutputTool/SyntheticOutputTool.js'
|
||||
import { substituteArguments } from '../argumentSubstitution.js'
|
||||
import { lazySchema } from '../lazySchema.js'
|
||||
import type { SetAppState } from '../messageQueueManager.js'
|
||||
import { hasSuccessfulToolCall } from '../messages.js'
|
||||
import { addFunctionHook } from './sessionHooks.js'
|
||||
|
||||
/**
|
||||
* Schema for hook responses (shared by prompt and agent hooks)
|
||||
*/
|
||||
export const hookResponseSchema = lazySchema(() =>
|
||||
z.object({
|
||||
ok: z.boolean().describe('Whether the condition was met'),
|
||||
reason: z
|
||||
.string()
|
||||
.describe('Reason, if the condition was not met')
|
||||
.optional(),
|
||||
}),
|
||||
)
|
||||
|
||||
/**
|
||||
* Add hook input JSON to prompt, either replacing $ARGUMENTS placeholder or appending.
|
||||
* Also supports indexed arguments like $ARGUMENTS[0], $ARGUMENTS[1], or shorthand $0, $1, etc.
|
||||
*/
|
||||
export function addArgumentsToPrompt(
|
||||
prompt: string,
|
||||
jsonInput: string,
|
||||
): string {
|
||||
return substituteArguments(prompt, jsonInput)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a StructuredOutput tool configured for hook responses.
|
||||
* Reusable by agent hooks and background verification.
|
||||
*/
|
||||
export function createStructuredOutputTool(): Tool {
|
||||
return {
|
||||
...SyntheticOutputTool,
|
||||
inputSchema: hookResponseSchema(),
|
||||
inputJSONSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
ok: {
|
||||
type: 'boolean',
|
||||
description: 'Whether the condition was met',
|
||||
},
|
||||
reason: {
|
||||
type: 'string',
|
||||
description: 'Reason, if the condition was not met',
|
||||
},
|
||||
},
|
||||
required: ['ok'],
|
||||
additionalProperties: false,
|
||||
},
|
||||
async prompt(): Promise<string> {
|
||||
return `Use this tool to return your verification result. You MUST call this tool exactly once at the end of your response.`
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a function hook that enforces structured output via SyntheticOutputTool.
|
||||
* Used by ask.tsx, execAgentHook.ts, and background verification.
|
||||
*/
|
||||
export function registerStructuredOutputEnforcement(
|
||||
setAppState: SetAppState,
|
||||
sessionId: string,
|
||||
): void {
|
||||
addFunctionHook(
|
||||
setAppState,
|
||||
sessionId,
|
||||
'Stop',
|
||||
'', // No matcher - applies to all stops
|
||||
messages => hasSuccessfulToolCall(messages, SYNTHETIC_OUTPUT_TOOL_NAME),
|
||||
`You MUST call the ${SYNTHETIC_OUTPUT_TOOL_NAME} tool to complete this request. Call this tool now.`,
|
||||
{ timeout: 5000 },
|
||||
)
|
||||
}
|
||||
400
src/utils/hooks/hooksConfigManager.ts
Normal file
400
src/utils/hooks/hooksConfigManager.ts
Normal file
@@ -0,0 +1,400 @@
|
||||
import memoize from 'lodash-es/memoize.js'
|
||||
import type { HookEvent } from 'src/entrypoints/agentSdkTypes.js'
|
||||
import { getRegisteredHooks } from '../../bootstrap/state.js'
|
||||
import type { AppState } from '../../state/AppState.js'
|
||||
import {
|
||||
getAllHooks,
|
||||
type IndividualHookConfig,
|
||||
sortMatchersByPriority,
|
||||
} from './hooksSettings.js'
|
||||
|
||||
export type MatcherMetadata = {
|
||||
fieldToMatch: string
|
||||
values: string[]
|
||||
}
|
||||
|
||||
export type HookEventMetadata = {
|
||||
summary: string
|
||||
description: string
|
||||
matcherMetadata?: MatcherMetadata
|
||||
}
|
||||
|
||||
// Hook event metadata configuration.
|
||||
// Resolver uses sorted-joined string key so that callers passing a fresh
|
||||
// toolNames array each render (e.g. HooksConfigMenu) hit the cache instead
|
||||
// of leaking a new entry per call.
|
||||
export const getHookEventMetadata = memoize(
|
||||
function (toolNames: string[]): Record<HookEvent, HookEventMetadata> {
|
||||
return {
|
||||
PreToolUse: {
|
||||
summary: 'Before tool execution',
|
||||
description:
|
||||
'Input to command is JSON of tool call arguments.\nExit code 0 - stdout/stderr not shown\nExit code 2 - show stderr to model and block tool call\nOther exit codes - show stderr to user only but continue with tool call',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'tool_name',
|
||||
values: toolNames,
|
||||
},
|
||||
},
|
||||
PostToolUse: {
|
||||
summary: 'After tool execution',
|
||||
description:
|
||||
'Input to command is JSON with fields "inputs" (tool call arguments) and "response" (tool call response).\nExit code 0 - stdout shown in transcript mode (ctrl+o)\nExit code 2 - show stderr to model immediately\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'tool_name',
|
||||
values: toolNames,
|
||||
},
|
||||
},
|
||||
PostToolUseFailure: {
|
||||
summary: 'After tool execution fails',
|
||||
description:
|
||||
'Input to command is JSON with tool_name, tool_input, tool_use_id, error, error_type, is_interrupt, and is_timeout.\nExit code 0 - stdout shown in transcript mode (ctrl+o)\nExit code 2 - show stderr to model immediately\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'tool_name',
|
||||
values: toolNames,
|
||||
},
|
||||
},
|
||||
PermissionDenied: {
|
||||
summary: 'After auto mode classifier denies a tool call',
|
||||
description:
|
||||
'Input to command is JSON with tool_name, tool_input, tool_use_id, and reason.\nReturn {"hookSpecificOutput":{"hookEventName":"PermissionDenied","retry":true}} to tell the model it may retry.\nExit code 0 - stdout shown in transcript mode (ctrl+o)\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'tool_name',
|
||||
values: toolNames,
|
||||
},
|
||||
},
|
||||
Notification: {
|
||||
summary: 'When notifications are sent',
|
||||
description:
|
||||
'Input to command is JSON with notification message and type.\nExit code 0 - stdout/stderr not shown\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'notification_type',
|
||||
values: [
|
||||
'permission_prompt',
|
||||
'idle_prompt',
|
||||
'auth_success',
|
||||
'elicitation_dialog',
|
||||
'elicitation_complete',
|
||||
'elicitation_response',
|
||||
],
|
||||
},
|
||||
},
|
||||
UserPromptSubmit: {
|
||||
summary: 'When the user submits a prompt',
|
||||
description:
|
||||
'Input to command is JSON with original user prompt text.\nExit code 0 - stdout shown to Claude\nExit code 2 - block processing, erase original prompt, and show stderr to user only\nOther exit codes - show stderr to user only',
|
||||
},
|
||||
SessionStart: {
|
||||
summary: 'When a new session is started',
|
||||
description:
|
||||
'Input to command is JSON with session start source.\nExit code 0 - stdout shown to Claude\nBlocking errors are ignored\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'source',
|
||||
values: ['startup', 'resume', 'clear', 'compact'],
|
||||
},
|
||||
},
|
||||
Stop: {
|
||||
summary: 'Right before Claude concludes its response',
|
||||
description:
|
||||
'Exit code 0 - stdout/stderr not shown\nExit code 2 - show stderr to model and continue conversation\nOther exit codes - show stderr to user only',
|
||||
},
|
||||
StopFailure: {
|
||||
summary: 'When the turn ends due to an API error',
|
||||
description:
|
||||
'Fires instead of Stop when an API error (rate limit, auth failure, etc.) ended the turn. Fire-and-forget — hook output and exit codes are ignored.',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'error',
|
||||
values: [
|
||||
'rate_limit',
|
||||
'authentication_failed',
|
||||
'billing_error',
|
||||
'invalid_request',
|
||||
'server_error',
|
||||
'max_output_tokens',
|
||||
'unknown',
|
||||
],
|
||||
},
|
||||
},
|
||||
SubagentStart: {
|
||||
summary: 'When a subagent (Agent tool call) is started',
|
||||
description:
|
||||
'Input to command is JSON with agent_id and agent_type.\nExit code 0 - stdout shown to subagent\nBlocking errors are ignored\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'agent_type',
|
||||
values: [], // Will be populated with available agent types
|
||||
},
|
||||
},
|
||||
SubagentStop: {
|
||||
summary:
|
||||
'Right before a subagent (Agent tool call) concludes its response',
|
||||
description:
|
||||
'Input to command is JSON with agent_id, agent_type, and agent_transcript_path.\nExit code 0 - stdout/stderr not shown\nExit code 2 - show stderr to subagent and continue having it run\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'agent_type',
|
||||
values: [], // Will be populated with available agent types
|
||||
},
|
||||
},
|
||||
PreCompact: {
|
||||
summary: 'Before conversation compaction',
|
||||
description:
|
||||
'Input to command is JSON with compaction details.\nExit code 0 - stdout appended as custom compact instructions\nExit code 2 - block compaction\nOther exit codes - show stderr to user only but continue with compaction',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'trigger',
|
||||
values: ['manual', 'auto'],
|
||||
},
|
||||
},
|
||||
PostCompact: {
|
||||
summary: 'After conversation compaction',
|
||||
description:
|
||||
'Input to command is JSON with compaction details and the summary.\nExit code 0 - stdout shown to user\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'trigger',
|
||||
values: ['manual', 'auto'],
|
||||
},
|
||||
},
|
||||
SessionEnd: {
|
||||
summary: 'When a session is ending',
|
||||
description:
|
||||
'Input to command is JSON with session end reason.\nExit code 0 - command completes successfully\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'reason',
|
||||
values: ['clear', 'logout', 'prompt_input_exit', 'other'],
|
||||
},
|
||||
},
|
||||
PermissionRequest: {
|
||||
summary: 'When a permission dialog is displayed',
|
||||
description:
|
||||
'Input to command is JSON with tool_name, tool_input, and tool_use_id.\nOutput JSON with hookSpecificOutput containing decision to allow or deny.\nExit code 0 - use hook decision if provided\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'tool_name',
|
||||
values: toolNames,
|
||||
},
|
||||
},
|
||||
Setup: {
|
||||
summary: 'Repo setup hooks for init and maintenance',
|
||||
description:
|
||||
'Input to command is JSON with trigger (init or maintenance).\nExit code 0 - stdout shown to Claude\nBlocking errors are ignored\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'trigger',
|
||||
values: ['init', 'maintenance'],
|
||||
},
|
||||
},
|
||||
TeammateIdle: {
|
||||
summary: 'When a teammate is about to go idle',
|
||||
description:
|
||||
'Input to command is JSON with teammate_name and team_name.\nExit code 0 - stdout/stderr not shown\nExit code 2 - show stderr to teammate and prevent idle (teammate continues working)\nOther exit codes - show stderr to user only',
|
||||
},
|
||||
TaskCreated: {
|
||||
summary: 'When a task is being created',
|
||||
description:
|
||||
'Input to command is JSON with task_id, task_subject, task_description, teammate_name, and team_name.\nExit code 0 - stdout/stderr not shown\nExit code 2 - show stderr to model and prevent task creation\nOther exit codes - show stderr to user only',
|
||||
},
|
||||
TaskCompleted: {
|
||||
summary: 'When a task is being marked as completed',
|
||||
description:
|
||||
'Input to command is JSON with task_id, task_subject, task_description, teammate_name, and team_name.\nExit code 0 - stdout/stderr not shown\nExit code 2 - show stderr to model and prevent task completion\nOther exit codes - show stderr to user only',
|
||||
},
|
||||
Elicitation: {
|
||||
summary: 'When an MCP server requests user input (elicitation)',
|
||||
description:
|
||||
'Input to command is JSON with mcp_server_name, message, and requested_schema.\nOutput JSON with hookSpecificOutput containing action (accept/decline/cancel) and optional content.\nExit code 0 - use hook response if provided\nExit code 2 - deny the elicitation\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'mcp_server_name',
|
||||
values: [],
|
||||
},
|
||||
},
|
||||
ElicitationResult: {
|
||||
summary: 'After a user responds to an MCP elicitation',
|
||||
description:
|
||||
'Input to command is JSON with mcp_server_name, action, content, mode, and elicitation_id.\nOutput JSON with hookSpecificOutput containing optional action and content to override the response.\nExit code 0 - use hook response if provided\nExit code 2 - block the response (action becomes decline)\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'mcp_server_name',
|
||||
values: [],
|
||||
},
|
||||
},
|
||||
ConfigChange: {
|
||||
summary: 'When configuration files change during a session',
|
||||
description:
|
||||
'Input to command is JSON with source (user_settings, project_settings, local_settings, policy_settings, skills) and file_path.\nExit code 0 - allow the change\nExit code 2 - block the change from being applied to the session\nOther exit codes - show stderr to user only',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'source',
|
||||
values: [
|
||||
'user_settings',
|
||||
'project_settings',
|
||||
'local_settings',
|
||||
'policy_settings',
|
||||
'skills',
|
||||
],
|
||||
},
|
||||
},
|
||||
InstructionsLoaded: {
|
||||
summary: 'When an instruction file (CLAUDE.md or rule) is loaded',
|
||||
description:
|
||||
'Input to command is JSON with file_path, memory_type (User, Project, Local, Managed), load_reason (session_start, nested_traversal, path_glob_match, include, compact), globs (optional — the paths: frontmatter patterns that matched), trigger_file_path (optional — the file Claude touched that caused the load), and parent_file_path (optional — the file that @-included this one).\nExit code 0 - command completes successfully\nOther exit codes - show stderr to user only\nThis hook is observability-only and does not support blocking.',
|
||||
matcherMetadata: {
|
||||
fieldToMatch: 'load_reason',
|
||||
values: [
|
||||
'session_start',
|
||||
'nested_traversal',
|
||||
'path_glob_match',
|
||||
'include',
|
||||
'compact',
|
||||
],
|
||||
},
|
||||
},
|
||||
WorktreeCreate: {
|
||||
summary: 'Create an isolated worktree for VCS-agnostic isolation',
|
||||
description:
|
||||
'Input to command is JSON with name (suggested worktree slug).\nStdout should contain the absolute path to the created worktree directory.\nExit code 0 - worktree created successfully\nOther exit codes - worktree creation failed',
|
||||
},
|
||||
WorktreeRemove: {
|
||||
summary: 'Remove a previously created worktree',
|
||||
description:
|
||||
'Input to command is JSON with worktree_path (absolute path to worktree).\nExit code 0 - worktree removed successfully\nOther exit codes - show stderr to user only',
|
||||
},
|
||||
CwdChanged: {
|
||||
summary: 'After the working directory changes',
|
||||
description:
|
||||
'Input to command is JSON with old_cwd and new_cwd.\nCLAUDE_ENV_FILE is set — write bash exports there to apply env to subsequent BashTool commands.\nHook output can include hookSpecificOutput.watchPaths (array of absolute paths) to register with the FileChanged watcher.\nExit code 0 - command completes successfully\nOther exit codes - show stderr to user only',
|
||||
},
|
||||
FileChanged: {
|
||||
summary: 'When a watched file changes',
|
||||
description:
|
||||
'Input to command is JSON with file_path and event (change, add, unlink).\nCLAUDE_ENV_FILE is set — write bash exports there to apply env to subsequent BashTool commands.\nThe matcher field specifies filenames to watch in the current directory (e.g. ".envrc|.env").\nHook output can include hookSpecificOutput.watchPaths (array of absolute paths) to dynamically update the watch list.\nExit code 0 - command completes successfully\nOther exit codes - show stderr to user only',
|
||||
},
|
||||
}
|
||||
},
|
||||
toolNames => toolNames.slice().sort().join(','),
|
||||
)
|
||||
|
||||
// Group hooks by event and matcher
|
||||
export function groupHooksByEventAndMatcher(
|
||||
appState: AppState,
|
||||
toolNames: string[],
|
||||
): Record<HookEvent, Record<string, IndividualHookConfig[]>> {
|
||||
const grouped: Record<HookEvent, Record<string, IndividualHookConfig[]>> = {
|
||||
PreToolUse: {},
|
||||
PostToolUse: {},
|
||||
PostToolUseFailure: {},
|
||||
PermissionDenied: {},
|
||||
Notification: {},
|
||||
UserPromptSubmit: {},
|
||||
SessionStart: {},
|
||||
SessionEnd: {},
|
||||
Stop: {},
|
||||
StopFailure: {},
|
||||
SubagentStart: {},
|
||||
SubagentStop: {},
|
||||
PreCompact: {},
|
||||
PostCompact: {},
|
||||
PermissionRequest: {},
|
||||
Setup: {},
|
||||
TeammateIdle: {},
|
||||
TaskCreated: {},
|
||||
TaskCompleted: {},
|
||||
Elicitation: {},
|
||||
ElicitationResult: {},
|
||||
ConfigChange: {},
|
||||
WorktreeCreate: {},
|
||||
WorktreeRemove: {},
|
||||
InstructionsLoaded: {},
|
||||
CwdChanged: {},
|
||||
FileChanged: {},
|
||||
}
|
||||
|
||||
const metadata = getHookEventMetadata(toolNames)
|
||||
|
||||
// Include hooks from settings files
|
||||
getAllHooks(appState).forEach(hook => {
|
||||
const eventGroup = grouped[hook.event]
|
||||
if (eventGroup) {
|
||||
// For events without matchers, use empty string as key
|
||||
const matcherKey =
|
||||
metadata[hook.event].matcherMetadata !== undefined
|
||||
? hook.matcher || ''
|
||||
: ''
|
||||
if (!eventGroup[matcherKey]) {
|
||||
eventGroup[matcherKey] = []
|
||||
}
|
||||
eventGroup[matcherKey].push(hook)
|
||||
}
|
||||
})
|
||||
|
||||
// Include registered hooks (e.g., plugin hooks)
|
||||
const registeredHooks = getRegisteredHooks()
|
||||
if (registeredHooks) {
|
||||
for (const [event, matchers] of Object.entries(registeredHooks)) {
|
||||
const hookEvent = event as HookEvent
|
||||
const eventGroup = grouped[hookEvent]
|
||||
if (!eventGroup) continue
|
||||
|
||||
for (const matcher of matchers) {
|
||||
const matcherKey = matcher.matcher || ''
|
||||
|
||||
// Only PluginHookMatcher has pluginRoot; HookCallbackMatcher (internal
|
||||
// callbacks like attributionHooks, sessionFileAccessHooks) does not.
|
||||
if ('pluginRoot' in matcher) {
|
||||
eventGroup[matcherKey] ??= []
|
||||
for (const hook of matcher.hooks) {
|
||||
eventGroup[matcherKey].push({
|
||||
event: hookEvent,
|
||||
config: hook,
|
||||
matcher: matcher.matcher,
|
||||
source: 'pluginHook',
|
||||
pluginName: matcher.pluginId,
|
||||
})
|
||||
}
|
||||
} else if (process.env.USER_TYPE === 'ant') {
|
||||
eventGroup[matcherKey] ??= []
|
||||
for (const _hook of matcher.hooks) {
|
||||
eventGroup[matcherKey].push({
|
||||
event: hookEvent,
|
||||
config: {
|
||||
type: 'command',
|
||||
command: '[ANT-ONLY] Built-in Hook',
|
||||
},
|
||||
matcher: matcher.matcher,
|
||||
source: 'builtinHook',
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return grouped
|
||||
}
|
||||
|
||||
// Get sorted matchers for a specific event
|
||||
export function getSortedMatchersForEvent(
|
||||
hooksByEventAndMatcher: Record<
|
||||
HookEvent,
|
||||
Record<string, IndividualHookConfig[]>
|
||||
>,
|
||||
event: HookEvent,
|
||||
): string[] {
|
||||
const matchers = Object.keys(hooksByEventAndMatcher[event] || {})
|
||||
return sortMatchersByPriority(matchers, hooksByEventAndMatcher, event)
|
||||
}
|
||||
|
||||
// Get hooks for a specific event and matcher
|
||||
export function getHooksForMatcher(
|
||||
hooksByEventAndMatcher: Record<
|
||||
HookEvent,
|
||||
Record<string, IndividualHookConfig[]>
|
||||
>,
|
||||
event: HookEvent,
|
||||
matcher: string | null,
|
||||
): IndividualHookConfig[] {
|
||||
// For events without matchers, hooks are stored with empty string as key
|
||||
// because the record keys must be strings.
|
||||
const matcherKey = matcher ?? ''
|
||||
return hooksByEventAndMatcher[event]?.[matcherKey] ?? []
|
||||
}
|
||||
|
||||
// Get metadata for a specific event's matcher
|
||||
export function getMatcherMetadata(
|
||||
event: HookEvent,
|
||||
toolNames: string[],
|
||||
): MatcherMetadata | undefined {
|
||||
return getHookEventMetadata(toolNames)[event].matcherMetadata
|
||||
}
|
||||
133
src/utils/hooks/hooksConfigSnapshot.ts
Normal file
133
src/utils/hooks/hooksConfigSnapshot.ts
Normal file
@@ -0,0 +1,133 @@
|
||||
import { resetSdkInitState } from '../../bootstrap/state.js'
|
||||
import { isRestrictedToPluginOnly } from '../settings/pluginOnlyPolicy.js'
|
||||
// Import as module object so spyOn works in tests (direct imports bypass spies)
|
||||
import * as settingsModule from '../settings/settings.js'
|
||||
import { resetSettingsCache } from '../settings/settingsCache.js'
|
||||
import type { HooksSettings } from '../settings/types.js'
|
||||
|
||||
let initialHooksConfig: HooksSettings | null = null
|
||||
|
||||
/**
|
||||
* Get hooks from allowed sources.
|
||||
* If allowManagedHooksOnly is set in policySettings, only managed hooks are returned.
|
||||
* If disableAllHooks is set in policySettings, no hooks are returned.
|
||||
* If disableAllHooks is set in non-managed settings, only managed hooks are returned
|
||||
* (non-managed settings cannot disable managed hooks).
|
||||
* Otherwise, returns merged hooks from all sources (backwards compatible).
|
||||
*/
|
||||
function getHooksFromAllowedSources(): HooksSettings {
|
||||
const policySettings = settingsModule.getSettingsForSource('policySettings')
|
||||
|
||||
// If managed settings disables all hooks, return empty
|
||||
if (policySettings?.disableAllHooks === true) {
|
||||
return {}
|
||||
}
|
||||
|
||||
// If allowManagedHooksOnly is set in managed settings, only use managed hooks
|
||||
if (policySettings?.allowManagedHooksOnly === true) {
|
||||
return policySettings.hooks ?? {}
|
||||
}
|
||||
|
||||
// strictPluginOnlyCustomization: block user/project/local settings hooks.
|
||||
// Plugin hooks (registered channel, hooks.ts:1391) are NOT affected —
|
||||
// they're assembled separately and the managedOnly skip there is keyed
|
||||
// on shouldAllowManagedHooksOnly(), not on this policy. Agent frontmatter
|
||||
// hooks are gated at REGISTRATION (runAgent.ts:~535) by agent source —
|
||||
// plugin/built-in/policySettings agents register normally, user-sourced
|
||||
// agents skip registration under ["hooks"]. A blanket execution-time
|
||||
// block here would over-kill plugin agents' hooks.
|
||||
if (isRestrictedToPluginOnly('hooks')) {
|
||||
return policySettings?.hooks ?? {}
|
||||
}
|
||||
|
||||
const mergedSettings = settingsModule.getSettings_DEPRECATED()
|
||||
|
||||
// If disableAllHooks is set in non-managed settings, only managed hooks still run
|
||||
// (non-managed settings cannot override managed hooks)
|
||||
if (mergedSettings.disableAllHooks === true) {
|
||||
return policySettings?.hooks ?? {}
|
||||
}
|
||||
|
||||
// Otherwise, use all hooks (merged from all sources) - backwards compatible
|
||||
return mergedSettings.hooks ?? {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if only managed hooks should run.
|
||||
* This is true when:
|
||||
* - policySettings has allowManagedHooksOnly: true, OR
|
||||
* - disableAllHooks is set in non-managed settings (non-managed settings
|
||||
* cannot disable managed hooks, so they effectively become managed-only)
|
||||
*/
|
||||
export function shouldAllowManagedHooksOnly(): boolean {
|
||||
const policySettings = settingsModule.getSettingsForSource('policySettings')
|
||||
if (policySettings?.allowManagedHooksOnly === true) {
|
||||
return true
|
||||
}
|
||||
// If disableAllHooks is set but NOT from managed settings,
|
||||
// treat as managed-only (non-managed hooks disabled, managed hooks still run)
|
||||
if (
|
||||
settingsModule.getSettings_DEPRECATED().disableAllHooks === true &&
|
||||
policySettings?.disableAllHooks !== true
|
||||
) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if all hooks (including managed) should be disabled.
|
||||
* This is only true when managed/policy settings has disableAllHooks: true.
|
||||
* When disableAllHooks is set in non-managed settings, managed hooks still run.
|
||||
*/
|
||||
export function shouldDisableAllHooksIncludingManaged(): boolean {
|
||||
return (
|
||||
settingsModule.getSettingsForSource('policySettings')?.disableAllHooks ===
|
||||
true
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture a snapshot of the current hooks configuration
|
||||
* This should be called once during application startup
|
||||
* Respects the allowManagedHooksOnly setting
|
||||
*/
|
||||
export function captureHooksConfigSnapshot(): void {
|
||||
initialHooksConfig = getHooksFromAllowedSources()
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the hooks configuration snapshot
|
||||
* This should be called when hooks are modified through the settings
|
||||
* Respects the allowManagedHooksOnly setting
|
||||
*/
|
||||
export function updateHooksConfigSnapshot(): void {
|
||||
// Reset the session cache to ensure we read fresh settings from disk.
|
||||
// Without this, the snapshot could use stale cached settings when the user
|
||||
// edits settings.json externally and then runs /hooks - the session cache
|
||||
// may not have been invalidated yet (e.g., if the file watcher's stability
|
||||
// threshold hasn't elapsed).
|
||||
resetSettingsCache()
|
||||
initialHooksConfig = getHooksFromAllowedSources()
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current hooks configuration from snapshot
|
||||
* Falls back to settings if no snapshot exists
|
||||
* @returns The hooks configuration
|
||||
*/
|
||||
export function getHooksConfigFromSnapshot(): HooksSettings | null {
|
||||
if (initialHooksConfig === null) {
|
||||
captureHooksConfigSnapshot()
|
||||
}
|
||||
return initialHooksConfig
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset the hooks configuration snapshot (useful for testing)
|
||||
* Also resets SDK init state to prevent test pollution
|
||||
*/
|
||||
export function resetHooksConfigSnapshot(): void {
|
||||
initialHooksConfig = null
|
||||
resetSdkInitState()
|
||||
}
|
||||
271
src/utils/hooks/hooksSettings.ts
Normal file
271
src/utils/hooks/hooksSettings.ts
Normal file
@@ -0,0 +1,271 @@
|
||||
import { resolve } from 'path'
|
||||
import type { HookEvent } from 'src/entrypoints/agentSdkTypes.js'
|
||||
import { getSessionId } from '../../bootstrap/state.js'
|
||||
import type { AppState } from '../../state/AppState.js'
|
||||
import type { EditableSettingSource } from '../settings/constants.js'
|
||||
import { SOURCES } from '../settings/constants.js'
|
||||
import {
|
||||
getSettingsFilePathForSource,
|
||||
getSettingsForSource,
|
||||
} from '../settings/settings.js'
|
||||
import type { HookCommand, HookMatcher } from '../settings/types.js'
|
||||
import { DEFAULT_HOOK_SHELL } from '../shell/shellProvider.js'
|
||||
import { getSessionHooks } from './sessionHooks.js'
|
||||
|
||||
export type HookSource =
|
||||
| EditableSettingSource
|
||||
| 'policySettings'
|
||||
| 'pluginHook'
|
||||
| 'sessionHook'
|
||||
| 'builtinHook'
|
||||
|
||||
export interface IndividualHookConfig {
|
||||
event: HookEvent
|
||||
config: HookCommand
|
||||
matcher?: string
|
||||
source: HookSource
|
||||
pluginName?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if two hooks are equal (comparing only command/prompt content, not timeout)
|
||||
*/
|
||||
export function isHookEqual(
|
||||
a: HookCommand | { type: 'function'; timeout?: number },
|
||||
b: HookCommand | { type: 'function'; timeout?: number },
|
||||
): boolean {
|
||||
if (a.type !== b.type) return false
|
||||
|
||||
// Use switch for exhaustive type checking
|
||||
// Note: We only compare command/prompt content, not timeout
|
||||
// `if` is part of identity: same command with different `if` conditions
|
||||
// are distinct hooks (e.g., setup.sh if=Bash(git *) vs if=Bash(npm *)).
|
||||
const sameIf = (x: { if?: string }, y: { if?: string }) =>
|
||||
(x.if ?? '') === (y.if ?? '')
|
||||
switch (a.type) {
|
||||
case 'command':
|
||||
// shell is part of identity: same command string with different
|
||||
// shells are distinct hooks. Default 'bash' so undefined === 'bash'.
|
||||
return (
|
||||
b.type === 'command' &&
|
||||
a.command === b.command &&
|
||||
(a.shell ?? DEFAULT_HOOK_SHELL) === (b.shell ?? DEFAULT_HOOK_SHELL) &&
|
||||
sameIf(a, b)
|
||||
)
|
||||
case 'prompt':
|
||||
return b.type === 'prompt' && a.prompt === b.prompt && sameIf(a, b)
|
||||
case 'agent':
|
||||
return b.type === 'agent' && a.prompt === b.prompt && sameIf(a, b)
|
||||
case 'http':
|
||||
return b.type === 'http' && a.url === b.url && sameIf(a, b)
|
||||
case 'function':
|
||||
// Function hooks can't be compared (no stable identifier)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/** Get the display text for a hook */
|
||||
export function getHookDisplayText(
|
||||
hook: HookCommand | { type: 'callback' | 'function'; statusMessage?: string },
|
||||
): string {
|
||||
// Return custom status message if provided
|
||||
if ('statusMessage' in hook && hook.statusMessage) {
|
||||
return hook.statusMessage
|
||||
}
|
||||
|
||||
switch (hook.type) {
|
||||
case 'command':
|
||||
return hook.command
|
||||
case 'prompt':
|
||||
return hook.prompt
|
||||
case 'agent':
|
||||
return hook.prompt
|
||||
case 'http':
|
||||
return hook.url
|
||||
case 'callback':
|
||||
return 'callback'
|
||||
case 'function':
|
||||
return 'function'
|
||||
}
|
||||
}
|
||||
|
||||
export function getAllHooks(appState: AppState): IndividualHookConfig[] {
|
||||
const hooks: IndividualHookConfig[] = []
|
||||
|
||||
// Check if restricted to managed hooks only
|
||||
const policySettings = getSettingsForSource('policySettings')
|
||||
const restrictedToManagedOnly = policySettings?.allowManagedHooksOnly === true
|
||||
|
||||
// If allowManagedHooksOnly is set, don't show any hooks in the UI
|
||||
// (user/project/local are blocked, and managed hooks are intentionally hidden)
|
||||
if (!restrictedToManagedOnly) {
|
||||
// Get hooks from all editable sources
|
||||
const sources = [
|
||||
'userSettings',
|
||||
'projectSettings',
|
||||
'localSettings',
|
||||
] as EditableSettingSource[]
|
||||
|
||||
// Track which settings files we've already processed to avoid duplicates
|
||||
// (e.g., when running from home directory, userSettings and projectSettings
|
||||
// both resolve to ~/.claude/settings.json)
|
||||
const seenFiles = new Set<string>()
|
||||
|
||||
for (const source of sources) {
|
||||
const filePath = getSettingsFilePathForSource(source)
|
||||
if (filePath) {
|
||||
const resolvedPath = resolve(filePath)
|
||||
if (seenFiles.has(resolvedPath)) {
|
||||
continue
|
||||
}
|
||||
seenFiles.add(resolvedPath)
|
||||
}
|
||||
|
||||
const sourceSettings = getSettingsForSource(source)
|
||||
if (!sourceSettings?.hooks) {
|
||||
continue
|
||||
}
|
||||
|
||||
for (const [event, matchers] of Object.entries(sourceSettings.hooks)) {
|
||||
for (const matcher of matchers as HookMatcher[]) {
|
||||
for (const hookCommand of matcher.hooks) {
|
||||
hooks.push({
|
||||
event: event as HookEvent,
|
||||
config: hookCommand,
|
||||
matcher: matcher.matcher,
|
||||
source,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get session hooks
|
||||
const sessionId = getSessionId()
|
||||
const sessionHooks = getSessionHooks(appState, sessionId)
|
||||
for (const [event, matchers] of sessionHooks.entries()) {
|
||||
for (const matcher of matchers) {
|
||||
for (const hookCommand of matcher.hooks) {
|
||||
hooks.push({
|
||||
event,
|
||||
config: hookCommand,
|
||||
matcher: matcher.matcher,
|
||||
source: 'sessionHook',
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return hooks
|
||||
}
|
||||
|
||||
export function getHooksForEvent(
|
||||
appState: AppState,
|
||||
event: HookEvent,
|
||||
): IndividualHookConfig[] {
|
||||
return getAllHooks(appState).filter(hook => hook.event === event)
|
||||
}
|
||||
|
||||
export function hookSourceDescriptionDisplayString(source: HookSource): string {
|
||||
switch (source) {
|
||||
case 'userSettings':
|
||||
return 'User settings (~/.claude/settings.json)'
|
||||
case 'projectSettings':
|
||||
return 'Project settings (.claude/settings.json)'
|
||||
case 'localSettings':
|
||||
return 'Local settings (.claude/settings.local.json)'
|
||||
case 'pluginHook':
|
||||
// TODO: Get the actual plugin hook file paths instead of using glob pattern
|
||||
// We should capture the specific plugin paths during hook registration and display them here
|
||||
// e.g., "Plugin hooks (~/.claude/plugins/repos/source/example-plugin/example-plugin/hooks/hooks.json)"
|
||||
return 'Plugin hooks (~/.claude/plugins/*/hooks/hooks.json)'
|
||||
case 'sessionHook':
|
||||
return 'Session hooks (in-memory, temporary)'
|
||||
case 'builtinHook':
|
||||
return 'Built-in hooks (registered internally by Claude Code)'
|
||||
default:
|
||||
return source as string
|
||||
}
|
||||
}
|
||||
|
||||
export function hookSourceHeaderDisplayString(source: HookSource): string {
|
||||
switch (source) {
|
||||
case 'userSettings':
|
||||
return 'User Settings'
|
||||
case 'projectSettings':
|
||||
return 'Project Settings'
|
||||
case 'localSettings':
|
||||
return 'Local Settings'
|
||||
case 'pluginHook':
|
||||
return 'Plugin Hooks'
|
||||
case 'sessionHook':
|
||||
return 'Session Hooks'
|
||||
case 'builtinHook':
|
||||
return 'Built-in Hooks'
|
||||
default:
|
||||
return source as string
|
||||
}
|
||||
}
|
||||
|
||||
export function hookSourceInlineDisplayString(source: HookSource): string {
|
||||
switch (source) {
|
||||
case 'userSettings':
|
||||
return 'User'
|
||||
case 'projectSettings':
|
||||
return 'Project'
|
||||
case 'localSettings':
|
||||
return 'Local'
|
||||
case 'pluginHook':
|
||||
return 'Plugin'
|
||||
case 'sessionHook':
|
||||
return 'Session'
|
||||
case 'builtinHook':
|
||||
return 'Built-in'
|
||||
default:
|
||||
return source as string
|
||||
}
|
||||
}
|
||||
|
||||
export function sortMatchersByPriority(
|
||||
matchers: string[],
|
||||
hooksByEventAndMatcher: Record<
|
||||
string,
|
||||
Record<string, IndividualHookConfig[]>
|
||||
>,
|
||||
selectedEvent: HookEvent,
|
||||
): string[] {
|
||||
// Create a priority map based on SOURCES order (lower index = higher priority)
|
||||
const sourcePriority = SOURCES.reduce(
|
||||
(acc, source, index) => {
|
||||
acc[source] = index
|
||||
return acc
|
||||
},
|
||||
{} as Record<EditableSettingSource, number>,
|
||||
)
|
||||
|
||||
return [...matchers].sort((a, b) => {
|
||||
const aHooks = hooksByEventAndMatcher[selectedEvent]?.[a] || []
|
||||
const bHooks = hooksByEventAndMatcher[selectedEvent]?.[b] || []
|
||||
|
||||
const aSources = Array.from(new Set(aHooks.map(h => h.source)))
|
||||
const bSources = Array.from(new Set(bHooks.map(h => h.source)))
|
||||
|
||||
// Sort by highest priority source first (lowest priority number)
|
||||
// Plugin hooks get lowest priority (highest number)
|
||||
const getSourcePriority = (source: HookSource) =>
|
||||
source === 'pluginHook' || source === 'builtinHook'
|
||||
? 999
|
||||
: sourcePriority[source as EditableSettingSource]
|
||||
|
||||
const aHighestPriority = Math.min(...aSources.map(getSourcePriority))
|
||||
const bHighestPriority = Math.min(...bSources.map(getSourcePriority))
|
||||
|
||||
if (aHighestPriority !== bHighestPriority) {
|
||||
return aHighestPriority - bHighestPriority
|
||||
}
|
||||
|
||||
// If same priority, sort by matcher name
|
||||
return a.localeCompare(b)
|
||||
})
|
||||
}
|
||||
70
src/utils/hooks/postSamplingHooks.ts
Normal file
70
src/utils/hooks/postSamplingHooks.ts
Normal file
@@ -0,0 +1,70 @@
|
||||
import type { QuerySource } from '../../constants/querySource.js'
|
||||
import type { ToolUseContext } from '../../Tool.js'
|
||||
import type { Message } from '../../types/message.js'
|
||||
import { toError } from '../errors.js'
|
||||
import { logError } from '../log.js'
|
||||
import type { SystemPrompt } from '../systemPromptType.js'
|
||||
|
||||
// Post-sampling hook - not exposed in settings.json config (yet), only used programmatically
|
||||
|
||||
// Generic context for REPL hooks (both post-sampling and stop hooks)
|
||||
export type REPLHookContext = {
|
||||
messages: Message[] // Full message history including assistant responses
|
||||
systemPrompt: SystemPrompt
|
||||
userContext: { [k: string]: string }
|
||||
systemContext: { [k: string]: string }
|
||||
toolUseContext: ToolUseContext
|
||||
querySource?: QuerySource
|
||||
}
|
||||
|
||||
export type PostSamplingHook = (
|
||||
context: REPLHookContext,
|
||||
) => Promise<void> | void
|
||||
|
||||
// Internal registry for post-sampling hooks
|
||||
const postSamplingHooks: PostSamplingHook[] = []
|
||||
|
||||
/**
|
||||
* Register a post-sampling hook that will be called after model sampling completes
|
||||
* This is an internal API not exposed through settings
|
||||
*/
|
||||
export function registerPostSamplingHook(hook: PostSamplingHook): void {
|
||||
postSamplingHooks.push(hook)
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all registered post-sampling hooks (for testing)
|
||||
*/
|
||||
export function clearPostSamplingHooks(): void {
|
||||
postSamplingHooks.length = 0
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute all registered post-sampling hooks
|
||||
*/
|
||||
export async function executePostSamplingHooks(
|
||||
messages: Message[],
|
||||
systemPrompt: SystemPrompt,
|
||||
userContext: { [k: string]: string },
|
||||
systemContext: { [k: string]: string },
|
||||
toolUseContext: ToolUseContext,
|
||||
querySource?: QuerySource,
|
||||
): Promise<void> {
|
||||
const context: REPLHookContext = {
|
||||
messages,
|
||||
systemPrompt,
|
||||
userContext,
|
||||
systemContext,
|
||||
toolUseContext,
|
||||
querySource,
|
||||
}
|
||||
|
||||
for (const hook of postSamplingHooks) {
|
||||
try {
|
||||
await hook(context)
|
||||
} catch (error) {
|
||||
// Log but don't fail on hook errors
|
||||
logError(toError(error))
|
||||
}
|
||||
}
|
||||
}
|
||||
67
src/utils/hooks/registerFrontmatterHooks.ts
Normal file
67
src/utils/hooks/registerFrontmatterHooks.ts
Normal file
@@ -0,0 +1,67 @@
|
||||
import { HOOK_EVENTS, type HookEvent } from 'src/entrypoints/agentSdkTypes.js'
|
||||
import type { AppState } from 'src/state/AppState.js'
|
||||
import { logForDebugging } from '../debug.js'
|
||||
import type { HooksSettings } from '../settings/types.js'
|
||||
import { addSessionHook } from './sessionHooks.js'
|
||||
|
||||
/**
|
||||
* Register hooks from frontmatter (agent or skill) into session-scoped hooks.
|
||||
* These hooks will be active for the duration of the session/agent and cleaned up
|
||||
* when the session/agent ends.
|
||||
*
|
||||
* @param setAppState Function to update app state
|
||||
* @param sessionId Session ID to scope the hooks (agent ID for agents, session ID for skills)
|
||||
* @param hooks The hooks settings from frontmatter
|
||||
* @param sourceName Human-readable source name for logging (e.g., "agent 'my-agent'")
|
||||
* @param isAgent If true, converts Stop hooks to SubagentStop (since subagents trigger SubagentStop, not Stop)
|
||||
*/
|
||||
export function registerFrontmatterHooks(
|
||||
setAppState: (updater: (prev: AppState) => AppState) => void,
|
||||
sessionId: string,
|
||||
hooks: HooksSettings,
|
||||
sourceName: string,
|
||||
isAgent: boolean = false,
|
||||
): void {
|
||||
if (!hooks || Object.keys(hooks).length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
let hookCount = 0
|
||||
|
||||
for (const event of HOOK_EVENTS) {
|
||||
const matchers = hooks[event]
|
||||
if (!matchers || matchers.length === 0) {
|
||||
continue
|
||||
}
|
||||
|
||||
// For agents, convert Stop hooks to SubagentStop since that's what fires when an agent completes
|
||||
// (executeStopHooks uses SubagentStop when called with an agentId)
|
||||
let targetEvent: HookEvent = event
|
||||
if (isAgent && event === 'Stop') {
|
||||
targetEvent = 'SubagentStop'
|
||||
logForDebugging(
|
||||
`Converting Stop hook to SubagentStop for ${sourceName} (subagents trigger SubagentStop)`,
|
||||
)
|
||||
}
|
||||
|
||||
for (const matcherConfig of matchers) {
|
||||
const matcher = matcherConfig.matcher ?? ''
|
||||
const hooksArray = matcherConfig.hooks
|
||||
|
||||
if (!hooksArray || hooksArray.length === 0) {
|
||||
continue
|
||||
}
|
||||
|
||||
for (const hook of hooksArray) {
|
||||
addSessionHook(setAppState, sessionId, targetEvent, matcher, hook)
|
||||
hookCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (hookCount > 0) {
|
||||
logForDebugging(
|
||||
`Registered ${hookCount} frontmatter hook(s) from ${sourceName} for session ${sessionId}`,
|
||||
)
|
||||
}
|
||||
}
|
||||
64
src/utils/hooks/registerSkillHooks.ts
Normal file
64
src/utils/hooks/registerSkillHooks.ts
Normal file
@@ -0,0 +1,64 @@
|
||||
import { HOOK_EVENTS } from 'src/entrypoints/agentSdkTypes.js'
|
||||
import type { AppState } from 'src/state/AppState.js'
|
||||
import { logForDebugging } from '../debug.js'
|
||||
import type { HooksSettings } from '../settings/types.js'
|
||||
import { addSessionHook, removeSessionHook } from './sessionHooks.js'
|
||||
|
||||
/**
|
||||
* Registers hooks from a skill's frontmatter as session hooks.
|
||||
*
|
||||
* Hooks are registered as session-scoped hooks that persist for the duration
|
||||
* of the session. If a hook has `once: true`, it will be automatically removed
|
||||
* after its first successful execution.
|
||||
*
|
||||
* @param setAppState - Function to update the app state
|
||||
* @param sessionId - The current session ID
|
||||
* @param hooks - The hooks settings from the skill's frontmatter
|
||||
* @param skillName - The name of the skill (for logging)
|
||||
* @param skillRoot - The base directory of the skill (for CLAUDE_PLUGIN_ROOT env var)
|
||||
*/
|
||||
export function registerSkillHooks(
|
||||
setAppState: (updater: (prev: AppState) => AppState) => void,
|
||||
sessionId: string,
|
||||
hooks: HooksSettings,
|
||||
skillName: string,
|
||||
skillRoot?: string,
|
||||
): void {
|
||||
let registeredCount = 0
|
||||
|
||||
for (const eventName of HOOK_EVENTS) {
|
||||
const matchers = hooks[eventName]
|
||||
if (!matchers) continue
|
||||
|
||||
for (const matcher of matchers) {
|
||||
for (const hook of matcher.hooks) {
|
||||
// For once: true hooks, use onHookSuccess callback to remove after execution
|
||||
const onHookSuccess = hook.once
|
||||
? () => {
|
||||
logForDebugging(
|
||||
`Removing one-shot hook for event ${eventName} in skill '${skillName}'`,
|
||||
)
|
||||
removeSessionHook(setAppState, sessionId, eventName, hook)
|
||||
}
|
||||
: undefined
|
||||
|
||||
addSessionHook(
|
||||
setAppState,
|
||||
sessionId,
|
||||
eventName,
|
||||
matcher.matcher || '',
|
||||
hook,
|
||||
onHookSuccess,
|
||||
skillRoot,
|
||||
)
|
||||
registeredCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (registeredCount > 0) {
|
||||
logForDebugging(
|
||||
`Registered ${registeredCount} hooks from skill '${skillName}'`,
|
||||
)
|
||||
}
|
||||
}
|
||||
447
src/utils/hooks/sessionHooks.ts
Normal file
447
src/utils/hooks/sessionHooks.ts
Normal file
@@ -0,0 +1,447 @@
|
||||
import { HOOK_EVENTS, type HookEvent } from 'src/entrypoints/agentSdkTypes.js'
|
||||
import type { AppState } from 'src/state/AppState.js'
|
||||
import type { Message } from 'src/types/message.js'
|
||||
import { logForDebugging } from '../debug.js'
|
||||
import type { AggregatedHookResult } from '../hooks.js'
|
||||
import type { HookCommand } from '../settings/types.js'
|
||||
import { isHookEqual } from './hooksSettings.js'
|
||||
|
||||
type OnHookSuccess = (
|
||||
hook: HookCommand | FunctionHook,
|
||||
result: AggregatedHookResult,
|
||||
) => void
|
||||
|
||||
/** Function hook callback - returns true if check passes, false to block */
|
||||
export type FunctionHookCallback = (
|
||||
messages: Message[],
|
||||
signal?: AbortSignal,
|
||||
) => boolean | Promise<boolean>
|
||||
|
||||
/**
|
||||
* Function hook type with callback embedded.
|
||||
* Session-scoped only, cannot be persisted to settings.json.
|
||||
*/
|
||||
export type FunctionHook = {
|
||||
type: 'function'
|
||||
id?: string // Optional unique ID for removal
|
||||
timeout?: number
|
||||
callback: FunctionHookCallback
|
||||
errorMessage: string
|
||||
statusMessage?: string
|
||||
}
|
||||
|
||||
type SessionHookMatcher = {
|
||||
matcher: string
|
||||
skillRoot?: string
|
||||
hooks: Array<{
|
||||
hook: HookCommand | FunctionHook
|
||||
onHookSuccess?: OnHookSuccess
|
||||
}>
|
||||
}
|
||||
|
||||
export type SessionStore = {
|
||||
hooks: {
|
||||
[event in HookEvent]?: SessionHookMatcher[]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Map (not Record) so .set/.delete don't change the container's identity.
|
||||
* Mutator functions mutate the Map and return prev unchanged, letting
|
||||
* store.ts's Object.is(next, prev) check short-circuit and skip listener
|
||||
* notification. Session hooks are ephemeral per-agent runtime callbacks,
|
||||
* never reactively read (only getAppState() snapshots in the query loop).
|
||||
* Same pattern as agentControllers on LocalWorkflowTaskState.
|
||||
*
|
||||
* This matters under high-concurrency workflows: parallel() with N
|
||||
* schema-mode agents fires N addFunctionHook calls in one synchronous
|
||||
* tick. With a Record + spread, each call cost O(N) to copy the growing
|
||||
* map (O(N²) total) plus fired all ~30 store listeners. With Map: .set()
|
||||
* is O(1), return prev means zero listener fires.
|
||||
*/
|
||||
export type SessionHooksState = Map<string, SessionStore>
|
||||
|
||||
/**
|
||||
* Add a command or prompt hook to the session.
|
||||
* Session hooks are temporary, in-memory only, and cleared when session ends.
|
||||
*/
|
||||
export function addSessionHook(
|
||||
setAppState: (updater: (prev: AppState) => AppState) => void,
|
||||
sessionId: string,
|
||||
event: HookEvent,
|
||||
matcher: string,
|
||||
hook: HookCommand,
|
||||
onHookSuccess?: OnHookSuccess,
|
||||
skillRoot?: string,
|
||||
): void {
|
||||
addHookToSession(
|
||||
setAppState,
|
||||
sessionId,
|
||||
event,
|
||||
matcher,
|
||||
hook,
|
||||
onHookSuccess,
|
||||
skillRoot,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a function hook to the session.
|
||||
* Function hooks execute TypeScript callbacks in-memory for validation.
|
||||
* @returns The hook ID (for removal)
|
||||
*/
|
||||
export function addFunctionHook(
|
||||
setAppState: (updater: (prev: AppState) => AppState) => void,
|
||||
sessionId: string,
|
||||
event: HookEvent,
|
||||
matcher: string,
|
||||
callback: FunctionHookCallback,
|
||||
errorMessage: string,
|
||||
options?: {
|
||||
timeout?: number
|
||||
id?: string
|
||||
},
|
||||
): string {
|
||||
const id = options?.id || `function-hook-${Date.now()}-${Math.random()}`
|
||||
const hook: FunctionHook = {
|
||||
type: 'function',
|
||||
id,
|
||||
timeout: options?.timeout || 5000,
|
||||
callback,
|
||||
errorMessage,
|
||||
}
|
||||
addHookToSession(setAppState, sessionId, event, matcher, hook)
|
||||
return id
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a function hook by ID from the session.
|
||||
*/
|
||||
export function removeFunctionHook(
|
||||
setAppState: (updater: (prev: AppState) => AppState) => void,
|
||||
sessionId: string,
|
||||
event: HookEvent,
|
||||
hookId: string,
|
||||
): void {
|
||||
setAppState(prev => {
|
||||
const store = prev.sessionHooks.get(sessionId)
|
||||
if (!store) {
|
||||
return prev
|
||||
}
|
||||
|
||||
const eventMatchers = store.hooks[event] || []
|
||||
|
||||
// Remove the hook with matching ID from all matchers
|
||||
const updatedMatchers = eventMatchers
|
||||
.map(matcher => {
|
||||
const updatedHooks = matcher.hooks.filter(h => {
|
||||
if (h.hook.type !== 'function') return true
|
||||
return h.hook.id !== hookId
|
||||
})
|
||||
|
||||
return updatedHooks.length > 0
|
||||
? { ...matcher, hooks: updatedHooks }
|
||||
: null
|
||||
})
|
||||
.filter((m): m is SessionHookMatcher => m !== null)
|
||||
|
||||
const newHooks =
|
||||
updatedMatchers.length > 0
|
||||
? { ...store.hooks, [event]: updatedMatchers }
|
||||
: Object.fromEntries(
|
||||
Object.entries(store.hooks).filter(([e]) => e !== event),
|
||||
)
|
||||
|
||||
prev.sessionHooks.set(sessionId, { hooks: newHooks })
|
||||
return prev
|
||||
})
|
||||
|
||||
logForDebugging(
|
||||
`Removed function hook ${hookId} for event ${event} in session ${sessionId}`,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal helper to add a hook to session state
|
||||
*/
|
||||
function addHookToSession(
|
||||
setAppState: (updater: (prev: AppState) => AppState) => void,
|
||||
sessionId: string,
|
||||
event: HookEvent,
|
||||
matcher: string,
|
||||
hook: HookCommand | FunctionHook,
|
||||
onHookSuccess?: OnHookSuccess,
|
||||
skillRoot?: string,
|
||||
): void {
|
||||
setAppState(prev => {
|
||||
const store = prev.sessionHooks.get(sessionId) ?? { hooks: {} }
|
||||
const eventMatchers = store.hooks[event] || []
|
||||
|
||||
// Find existing matcher or create new one
|
||||
const existingMatcherIndex = eventMatchers.findIndex(
|
||||
m => m.matcher === matcher && m.skillRoot === skillRoot,
|
||||
)
|
||||
|
||||
let updatedMatchers: SessionHookMatcher[]
|
||||
if (existingMatcherIndex >= 0) {
|
||||
// Add to existing matcher
|
||||
updatedMatchers = [...eventMatchers]
|
||||
const existingMatcher = updatedMatchers[existingMatcherIndex]!
|
||||
updatedMatchers[existingMatcherIndex] = {
|
||||
matcher: existingMatcher.matcher,
|
||||
skillRoot: existingMatcher.skillRoot,
|
||||
hooks: [...existingMatcher.hooks, { hook, onHookSuccess }],
|
||||
}
|
||||
} else {
|
||||
// Create new matcher
|
||||
updatedMatchers = [
|
||||
...eventMatchers,
|
||||
{
|
||||
matcher,
|
||||
skillRoot,
|
||||
hooks: [{ hook, onHookSuccess }],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
const newHooks = { ...store.hooks, [event]: updatedMatchers }
|
||||
|
||||
prev.sessionHooks.set(sessionId, { hooks: newHooks })
|
||||
return prev
|
||||
})
|
||||
|
||||
logForDebugging(
|
||||
`Added session hook for event ${event} in session ${sessionId}`,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a specific hook from the session
|
||||
* @param setAppState The function to update the app state
|
||||
* @param sessionId The session ID
|
||||
* @param event The hook event
|
||||
* @param hook The hook command to remove
|
||||
*/
|
||||
export function removeSessionHook(
|
||||
setAppState: (updater: (prev: AppState) => AppState) => void,
|
||||
sessionId: string,
|
||||
event: HookEvent,
|
||||
hook: HookCommand,
|
||||
): void {
|
||||
setAppState(prev => {
|
||||
const store = prev.sessionHooks.get(sessionId)
|
||||
if (!store) {
|
||||
return prev
|
||||
}
|
||||
|
||||
const eventMatchers = store.hooks[event] || []
|
||||
|
||||
// Remove the hook from all matchers
|
||||
const updatedMatchers = eventMatchers
|
||||
.map(matcher => {
|
||||
const updatedHooks = matcher.hooks.filter(
|
||||
h => !isHookEqual(h.hook, hook),
|
||||
)
|
||||
|
||||
return updatedHooks.length > 0
|
||||
? { ...matcher, hooks: updatedHooks }
|
||||
: null
|
||||
})
|
||||
.filter((m): m is SessionHookMatcher => m !== null)
|
||||
|
||||
const newHooks =
|
||||
updatedMatchers.length > 0
|
||||
? { ...store.hooks, [event]: updatedMatchers }
|
||||
: { ...store.hooks }
|
||||
|
||||
if (updatedMatchers.length === 0) {
|
||||
delete newHooks[event]
|
||||
}
|
||||
|
||||
prev.sessionHooks.set(sessionId, { ...store, hooks: newHooks })
|
||||
return prev
|
||||
})
|
||||
|
||||
logForDebugging(
|
||||
`Removed session hook for event ${event} in session ${sessionId}`,
|
||||
)
|
||||
}
|
||||
|
||||
// Extended hook matcher that includes optional skillRoot for skill-scoped hooks
|
||||
export type SessionDerivedHookMatcher = {
|
||||
matcher: string
|
||||
hooks: HookCommand[]
|
||||
skillRoot?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert session hook matchers to regular hook matchers
|
||||
* @param sessionMatchers The session hook matchers to convert
|
||||
* @returns Regular hook matchers (with optional skillRoot preserved)
|
||||
*/
|
||||
function convertToHookMatchers(
|
||||
sessionMatchers: SessionHookMatcher[],
|
||||
): SessionDerivedHookMatcher[] {
|
||||
return sessionMatchers.map(sm => ({
|
||||
matcher: sm.matcher,
|
||||
skillRoot: sm.skillRoot,
|
||||
// Filter out function hooks - they can't be persisted to HookMatcher format
|
||||
hooks: sm.hooks
|
||||
.map(h => h.hook)
|
||||
.filter((h): h is HookCommand => h.type !== 'function'),
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all session hooks for a specific event (excluding function hooks)
|
||||
* @param appState The app state
|
||||
* @param sessionId The session ID
|
||||
* @param event Optional event to filter by
|
||||
* @returns Hook matchers for the event, or all hooks if no event specified
|
||||
*/
|
||||
export function getSessionHooks(
|
||||
appState: AppState,
|
||||
sessionId: string,
|
||||
event?: HookEvent,
|
||||
): Map<HookEvent, SessionDerivedHookMatcher[]> {
|
||||
const store = appState.sessionHooks.get(sessionId)
|
||||
if (!store) {
|
||||
return new Map()
|
||||
}
|
||||
|
||||
const result = new Map<HookEvent, SessionDerivedHookMatcher[]>()
|
||||
|
||||
if (event) {
|
||||
const sessionMatchers = store.hooks[event]
|
||||
if (sessionMatchers) {
|
||||
result.set(event, convertToHookMatchers(sessionMatchers))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
for (const evt of HOOK_EVENTS) {
|
||||
const sessionMatchers = store.hooks[evt]
|
||||
if (sessionMatchers) {
|
||||
result.set(evt, convertToHookMatchers(sessionMatchers))
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type FunctionHookMatcher = {
|
||||
matcher: string
|
||||
hooks: FunctionHook[]
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all session function hooks for a specific event
|
||||
* Function hooks are kept separate because they can't be persisted to HookMatcher format.
|
||||
* @param appState The app state
|
||||
* @param sessionId The session ID
|
||||
* @param event Optional event to filter by
|
||||
* @returns Function hook matchers for the event
|
||||
*/
|
||||
export function getSessionFunctionHooks(
|
||||
appState: AppState,
|
||||
sessionId: string,
|
||||
event?: HookEvent,
|
||||
): Map<HookEvent, FunctionHookMatcher[]> {
|
||||
const store = appState.sessionHooks.get(sessionId)
|
||||
if (!store) {
|
||||
return new Map()
|
||||
}
|
||||
|
||||
const result = new Map<HookEvent, FunctionHookMatcher[]>()
|
||||
|
||||
const extractFunctionHooks = (
|
||||
sessionMatchers: SessionHookMatcher[],
|
||||
): FunctionHookMatcher[] => {
|
||||
return sessionMatchers
|
||||
.map(sm => ({
|
||||
matcher: sm.matcher,
|
||||
hooks: sm.hooks
|
||||
.map(h => h.hook)
|
||||
.filter((h): h is FunctionHook => h.type === 'function'),
|
||||
}))
|
||||
.filter(m => m.hooks.length > 0)
|
||||
}
|
||||
|
||||
if (event) {
|
||||
const sessionMatchers = store.hooks[event]
|
||||
if (sessionMatchers) {
|
||||
const functionMatchers = extractFunctionHooks(sessionMatchers)
|
||||
if (functionMatchers.length > 0) {
|
||||
result.set(event, functionMatchers)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
for (const evt of HOOK_EVENTS) {
|
||||
const sessionMatchers = store.hooks[evt]
|
||||
if (sessionMatchers) {
|
||||
const functionMatchers = extractFunctionHooks(sessionMatchers)
|
||||
if (functionMatchers.length > 0) {
|
||||
result.set(evt, functionMatchers)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the full hook entry (including callbacks) for a specific session hook
|
||||
*/
|
||||
export function getSessionHookCallback(
|
||||
appState: AppState,
|
||||
sessionId: string,
|
||||
event: HookEvent,
|
||||
matcher: string,
|
||||
hook: HookCommand | FunctionHook,
|
||||
):
|
||||
| {
|
||||
hook: HookCommand | FunctionHook
|
||||
onHookSuccess?: OnHookSuccess
|
||||
}
|
||||
| undefined {
|
||||
const store = appState.sessionHooks.get(sessionId)
|
||||
if (!store) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const eventMatchers = store.hooks[event]
|
||||
if (!eventMatchers) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// Find the hook in the matchers
|
||||
for (const matcherEntry of eventMatchers) {
|
||||
if (matcherEntry.matcher === matcher || matcher === '') {
|
||||
const hookEntry = matcherEntry.hooks.find(h => isHookEqual(h.hook, hook))
|
||||
if (hookEntry) {
|
||||
return hookEntry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all session hooks for a specific session
|
||||
* @param setAppState The function to update the app state
|
||||
* @param sessionId The session ID
|
||||
*/
|
||||
export function clearSessionHooks(
|
||||
setAppState: (updater: (prev: AppState) => AppState) => void,
|
||||
sessionId: string,
|
||||
): void {
|
||||
setAppState(prev => {
|
||||
prev.sessionHooks.delete(sessionId)
|
||||
return prev
|
||||
})
|
||||
|
||||
logForDebugging(`Cleared all session hooks for session ${sessionId}`)
|
||||
}
|
||||
267
src/utils/hooks/skillImprovement.ts
Normal file
267
src/utils/hooks/skillImprovement.ts
Normal file
@@ -0,0 +1,267 @@
|
||||
import { feature } from 'bun:bundle'
|
||||
import { getInvokedSkillsForAgent } from '../../bootstrap/state.js'
|
||||
import { getFeatureValue_CACHED_MAY_BE_STALE } from '../../services/analytics/growthbook.js'
|
||||
import {
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
type AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
logEvent,
|
||||
} from '../../services/analytics/index.js'
|
||||
import { queryModelWithoutStreaming } from '../../services/api/claude.js'
|
||||
import { getEmptyToolPermissionContext } from '../../Tool.js'
|
||||
import type { Message } from '../../types/message.js'
|
||||
import { createAbortController } from '../abortController.js'
|
||||
import { count } from '../array.js'
|
||||
import { getCwd } from '../cwd.js'
|
||||
import { toError } from '../errors.js'
|
||||
import { logError } from '../log.js'
|
||||
import {
|
||||
createUserMessage,
|
||||
extractTag,
|
||||
extractTextContent,
|
||||
} from '../messages.js'
|
||||
import { getSmallFastModel } from '../model/model.js'
|
||||
import { jsonParse } from '../slowOperations.js'
|
||||
import { asSystemPrompt } from '../systemPromptType.js'
|
||||
import {
|
||||
type ApiQueryHookConfig,
|
||||
createApiQueryHook,
|
||||
} from './apiQueryHookHelper.js'
|
||||
import { registerPostSamplingHook } from './postSamplingHooks.js'
|
||||
|
||||
const TURN_BATCH_SIZE = 5
|
||||
|
||||
export type SkillUpdate = {
|
||||
section: string
|
||||
change: string
|
||||
reason: string
|
||||
}
|
||||
|
||||
function formatRecentMessages(messages: Message[]): string {
|
||||
return messages
|
||||
.filter(m => m.type === 'user' || m.type === 'assistant')
|
||||
.map(m => {
|
||||
const role = m.type === 'user' ? 'User' : 'Assistant'
|
||||
const content = m.message.content
|
||||
if (typeof content === 'string')
|
||||
return `${role}: ${content.slice(0, 500)}`
|
||||
const text = content
|
||||
.filter(
|
||||
(b): b is Extract<typeof b, { type: 'text' }> => b.type === 'text',
|
||||
)
|
||||
.map(b => b.text)
|
||||
.join('\n')
|
||||
return `${role}: ${text.slice(0, 500)}`
|
||||
})
|
||||
.join('\n\n')
|
||||
}
|
||||
|
||||
function findProjectSkill() {
|
||||
const skills = getInvokedSkillsForAgent(null)
|
||||
for (const [, info] of skills) {
|
||||
if (info.skillPath.startsWith('projectSettings:')) {
|
||||
return info
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
function createSkillImprovementHook() {
|
||||
let lastAnalyzedCount = 0
|
||||
let lastAnalyzedIndex = 0
|
||||
|
||||
const config: ApiQueryHookConfig<SkillUpdate[]> = {
|
||||
name: 'skill_improvement',
|
||||
|
||||
async shouldRun(context) {
|
||||
if (context.querySource !== 'repl_main_thread') {
|
||||
return false
|
||||
}
|
||||
|
||||
if (!findProjectSkill()) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only run every TURN_BATCH_SIZE user messages
|
||||
const userCount = count(context.messages, m => m.type === 'user')
|
||||
if (userCount - lastAnalyzedCount < TURN_BATCH_SIZE) {
|
||||
return false
|
||||
}
|
||||
|
||||
lastAnalyzedCount = userCount
|
||||
return true
|
||||
},
|
||||
|
||||
buildMessages(context) {
|
||||
const projectSkill = findProjectSkill()!
|
||||
// Only analyze messages since the last check — the skill definition
|
||||
// provides enough context for the classifier to understand corrections
|
||||
const newMessages = context.messages.slice(lastAnalyzedIndex)
|
||||
lastAnalyzedIndex = context.messages.length
|
||||
|
||||
return [
|
||||
createUserMessage({
|
||||
content: `You are analyzing a conversation where a user is executing a skill (a repeatable process).
|
||||
Your job: identify if the user's recent messages contain preferences, requests, or corrections that should be permanently added to the skill definition for future runs.
|
||||
|
||||
<skill_definition>
|
||||
${projectSkill.content}
|
||||
</skill_definition>
|
||||
|
||||
<recent_messages>
|
||||
${formatRecentMessages(newMessages)}
|
||||
</recent_messages>
|
||||
|
||||
Look for:
|
||||
- Requests to add, change, or remove steps: "can you also ask me X", "please do Y too", "don't do Z"
|
||||
- Preferences about how steps should work: "ask me about energy levels", "note the time", "use a casual tone"
|
||||
- Corrections: "no, do X instead", "always use Y", "make sure to..."
|
||||
|
||||
Ignore:
|
||||
- Routine conversation that doesn't generalize (one-time answers, chitchat)
|
||||
- Things the skill already does
|
||||
|
||||
Output a JSON array inside <updates> tags. Each item: {"section": "which step/section to modify or 'new step'", "change": "what to add/modify", "reason": "which user message prompted this"}.
|
||||
Output <updates>[]</updates> if no updates are needed.`,
|
||||
}),
|
||||
]
|
||||
},
|
||||
|
||||
systemPrompt:
|
||||
'You detect user preferences and process improvements during skill execution. Flag anything the user asks for that should be remembered for next time.',
|
||||
|
||||
useTools: false,
|
||||
|
||||
parseResponse(content) {
|
||||
const updatesStr = extractTag(content, 'updates')
|
||||
if (!updatesStr) {
|
||||
return []
|
||||
}
|
||||
try {
|
||||
return jsonParse(updatesStr) as SkillUpdate[]
|
||||
} catch {
|
||||
return []
|
||||
}
|
||||
},
|
||||
|
||||
logResult(result, context) {
|
||||
if (result.type === 'success' && result.result.length > 0) {
|
||||
const projectSkill = findProjectSkill()
|
||||
const skillName = projectSkill?.skillName ?? 'unknown'
|
||||
|
||||
logEvent('tengu_skill_improvement_detected', {
|
||||
updateCount: result.result
|
||||
.length as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
uuid: result.uuid as AnalyticsMetadata_I_VERIFIED_THIS_IS_NOT_CODE_OR_FILEPATHS,
|
||||
// _PROTO_skill_name routes to the privileged skill_name BQ column.
|
||||
_PROTO_skill_name:
|
||||
skillName as AnalyticsMetadata_I_VERIFIED_THIS_IS_PII_TAGGED,
|
||||
})
|
||||
|
||||
context.toolUseContext.setAppState(prev => ({
|
||||
...prev,
|
||||
skillImprovement: {
|
||||
suggestion: { skillName, updates: result.result },
|
||||
},
|
||||
}))
|
||||
}
|
||||
},
|
||||
|
||||
getModel: getSmallFastModel,
|
||||
}
|
||||
|
||||
return createApiQueryHook(config)
|
||||
}
|
||||
|
||||
export function initSkillImprovement(): void {
|
||||
if (
|
||||
feature('SKILL_IMPROVEMENT') &&
|
||||
getFeatureValue_CACHED_MAY_BE_STALE('tengu_copper_panda', false)
|
||||
) {
|
||||
registerPostSamplingHook(createSkillImprovementHook())
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Apply skill improvements by calling a side-channel LLM to rewrite the skill file.
|
||||
* Fire-and-forget — does not block the main conversation.
|
||||
*/
|
||||
export async function applySkillImprovement(
|
||||
skillName: string,
|
||||
updates: SkillUpdate[],
|
||||
): Promise<void> {
|
||||
if (!skillName) return
|
||||
|
||||
const { join } = await import('path')
|
||||
const fs = await import('fs/promises')
|
||||
|
||||
// Skills live at .claude/skills/<name>/SKILL.md relative to CWD
|
||||
const filePath = join(getCwd(), '.claude', 'skills', skillName, 'SKILL.md')
|
||||
|
||||
let currentContent: string
|
||||
try {
|
||||
currentContent = await fs.readFile(filePath, 'utf-8')
|
||||
} catch {
|
||||
logError(
|
||||
new Error(`Failed to read skill file for improvement: ${filePath}`),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
const updateList = updates.map(u => `- ${u.section}: ${u.change}`).join('\n')
|
||||
|
||||
const response = await queryModelWithoutStreaming({
|
||||
messages: [
|
||||
createUserMessage({
|
||||
content: `You are editing a skill definition file. Apply the following improvements to the skill.
|
||||
|
||||
<current_skill_file>
|
||||
${currentContent}
|
||||
</current_skill_file>
|
||||
|
||||
<improvements>
|
||||
${updateList}
|
||||
</improvements>
|
||||
|
||||
Rules:
|
||||
- Integrate the improvements naturally into the existing structure
|
||||
- Preserve frontmatter (--- block) exactly as-is
|
||||
- Preserve the overall format and style
|
||||
- Do not remove existing content unless an improvement explicitly replaces it
|
||||
- Output the complete updated file inside <updated_file> tags`,
|
||||
}),
|
||||
],
|
||||
systemPrompt: asSystemPrompt([
|
||||
'You edit skill definition files to incorporate user preferences. Output only the updated file content.',
|
||||
]),
|
||||
thinkingConfig: { type: 'disabled' as const },
|
||||
tools: [],
|
||||
signal: createAbortController().signal,
|
||||
options: {
|
||||
getToolPermissionContext: async () => getEmptyToolPermissionContext(),
|
||||
model: getSmallFastModel(),
|
||||
toolChoice: undefined,
|
||||
isNonInteractiveSession: false,
|
||||
hasAppendSystemPrompt: false,
|
||||
temperatureOverride: 0,
|
||||
agents: [],
|
||||
querySource: 'skill_improvement_apply',
|
||||
mcpTools: [],
|
||||
},
|
||||
})
|
||||
|
||||
const responseText = extractTextContent(response.message.content).trim()
|
||||
|
||||
const updatedContent = extractTag(responseText, 'updated_file')
|
||||
if (!updatedContent) {
|
||||
logError(
|
||||
new Error('Skill improvement apply: no updated_file tag in response'),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
await fs.writeFile(filePath, updatedContent, 'utf-8')
|
||||
} catch (e) {
|
||||
logError(toError(e))
|
||||
}
|
||||
}
|
||||
294
src/utils/hooks/ssrfGuard.ts
Normal file
294
src/utils/hooks/ssrfGuard.ts
Normal file
@@ -0,0 +1,294 @@
|
||||
import type { AddressFamily, LookupAddress as AxiosLookupAddress } from 'axios'
|
||||
import { lookup as dnsLookup } from 'dns'
|
||||
import { isIP } from 'net'
|
||||
|
||||
/**
|
||||
* SSRF guard for HTTP hooks.
|
||||
*
|
||||
* Blocks private, link-local, and other non-routable address ranges to prevent
|
||||
* project-configured HTTP hooks from reaching cloud metadata endpoints
|
||||
* (169.254.169.254) or internal infrastructure.
|
||||
*
|
||||
* Loopback (127.0.0.0/8, ::1) is intentionally ALLOWED — local dev policy
|
||||
* servers are a primary HTTP hook use case.
|
||||
*
|
||||
* When a global proxy or the sandbox network proxy is in use, the guard is
|
||||
* effectively bypassed for the target host because the proxy performs DNS
|
||||
* resolution. The sandbox proxy enforces its own domain allowlist.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Returns true if the address is in a range that HTTP hooks should not reach.
|
||||
*
|
||||
* Blocked IPv4:
|
||||
* 0.0.0.0/8 "this" network
|
||||
* 10.0.0.0/8 private
|
||||
* 100.64.0.0/10 shared address space / CGNAT (some cloud metadata, e.g. Alibaba 100.100.100.200)
|
||||
* 169.254.0.0/16 link-local (cloud metadata)
|
||||
* 172.16.0.0/12 private
|
||||
* 192.168.0.0/16 private
|
||||
*
|
||||
* Blocked IPv6:
|
||||
* :: unspecified
|
||||
* fc00::/7 unique local
|
||||
* fe80::/10 link-local
|
||||
* ::ffff:<v4> mapped IPv4 in a blocked range
|
||||
*
|
||||
* Allowed (returns false):
|
||||
* 127.0.0.0/8 loopback (local dev hooks)
|
||||
* ::1 loopback
|
||||
* everything else
|
||||
*/
|
||||
export function isBlockedAddress(address: string): boolean {
|
||||
const v = isIP(address)
|
||||
if (v === 4) {
|
||||
return isBlockedV4(address)
|
||||
}
|
||||
if (v === 6) {
|
||||
return isBlockedV6(address)
|
||||
}
|
||||
// Not a valid IP literal — let the real DNS path handle it (this function
|
||||
// is only called on results from dns.lookup, which always returns valid IPs)
|
||||
return false
|
||||
}
|
||||
|
||||
function isBlockedV4(address: string): boolean {
|
||||
const parts = address.split('.').map(Number)
|
||||
const [a, b] = parts
|
||||
if (
|
||||
parts.length !== 4 ||
|
||||
a === undefined ||
|
||||
b === undefined ||
|
||||
parts.some(n => Number.isNaN(n))
|
||||
) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Loopback explicitly allowed
|
||||
if (a === 127) return false
|
||||
|
||||
// 0.0.0.0/8
|
||||
if (a === 0) return true
|
||||
// 10.0.0.0/8
|
||||
if (a === 10) return true
|
||||
// 169.254.0.0/16 — link-local, cloud metadata
|
||||
if (a === 169 && b === 254) return true
|
||||
// 172.16.0.0/12
|
||||
if (a === 172 && b >= 16 && b <= 31) return true
|
||||
// 100.64.0.0/10 — shared address space (RFC 6598, CGNAT). Some cloud
|
||||
// providers use this range for metadata endpoints (e.g. Alibaba Cloud at
|
||||
// 100.100.100.200).
|
||||
if (a === 100 && b >= 64 && b <= 127) return true
|
||||
// 192.168.0.0/16
|
||||
if (a === 192 && b === 168) return true
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
function isBlockedV6(address: string): boolean {
|
||||
const lower = address.toLowerCase()
|
||||
|
||||
// ::1 loopback explicitly allowed
|
||||
if (lower === '::1') return false
|
||||
|
||||
// :: unspecified
|
||||
if (lower === '::') return true
|
||||
|
||||
// IPv4-mapped IPv6 (0:0:0:0:0:ffff:X:Y in any representation — ::ffff:a.b.c.d,
|
||||
// ::ffff:XXXX:YYYY, expanded, or partially expanded). Extract the embedded
|
||||
// IPv4 address and delegate to the v4 check. Without this, hex-form mapped
|
||||
// addresses (e.g. ::ffff:a9fe:a9fe = 169.254.169.254) bypass the guard.
|
||||
const mappedV4 = extractMappedIPv4(lower)
|
||||
if (mappedV4 !== null) {
|
||||
return isBlockedV4(mappedV4)
|
||||
}
|
||||
|
||||
// fc00::/7 — unique local addresses (fc00:: through fdff::)
|
||||
if (lower.startsWith('fc') || lower.startsWith('fd')) {
|
||||
return true
|
||||
}
|
||||
|
||||
// fe80::/10 — link-local. The /10 means fe80 through febf, but the first
|
||||
// hextet is always fe80 in practice (RFC 4291 requires the next 54 bits
|
||||
// to be zero). Check both to be safe.
|
||||
const firstHextet = lower.split(':')[0]
|
||||
if (
|
||||
firstHextet &&
|
||||
firstHextet.length === 4 &&
|
||||
firstHextet >= 'fe80' &&
|
||||
firstHextet <= 'febf'
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Expand `::` and optional trailing dotted-decimal so an IPv6 address is
|
||||
* represented as exactly 8 hex groups. Returns null if expansion is not
|
||||
* well-formed (the caller has already validated with isIP, so this is
|
||||
* defensive).
|
||||
*/
|
||||
function expandIPv6Groups(addr: string): number[] | null {
|
||||
// Handle trailing dotted-decimal IPv4 (e.g. ::ffff:169.254.169.254).
|
||||
// Replace it with its two hex groups so the rest of the expansion is uniform.
|
||||
let tailHextets: number[] = []
|
||||
if (addr.includes('.')) {
|
||||
const lastColon = addr.lastIndexOf(':')
|
||||
const v4 = addr.slice(lastColon + 1)
|
||||
addr = addr.slice(0, lastColon)
|
||||
const octets = v4.split('.').map(Number)
|
||||
if (
|
||||
octets.length !== 4 ||
|
||||
octets.some(n => !Number.isInteger(n) || n < 0 || n > 255)
|
||||
) {
|
||||
return null
|
||||
}
|
||||
tailHextets = [
|
||||
(octets[0]! << 8) | octets[1]!,
|
||||
(octets[2]! << 8) | octets[3]!,
|
||||
]
|
||||
}
|
||||
|
||||
// Expand `::` (at most one) into the right number of zero groups.
|
||||
const dbl = addr.indexOf('::')
|
||||
let head: string[]
|
||||
let tail: string[]
|
||||
if (dbl === -1) {
|
||||
head = addr.split(':')
|
||||
tail = []
|
||||
} else {
|
||||
const headStr = addr.slice(0, dbl)
|
||||
const tailStr = addr.slice(dbl + 2)
|
||||
head = headStr === '' ? [] : headStr.split(':')
|
||||
tail = tailStr === '' ? [] : tailStr.split(':')
|
||||
}
|
||||
|
||||
const target = 8 - tailHextets.length
|
||||
const fill = target - head.length - tail.length
|
||||
if (fill < 0) return null
|
||||
|
||||
const hex = [...head, ...new Array<string>(fill).fill('0'), ...tail]
|
||||
const nums = hex.map(h => parseInt(h, 16))
|
||||
if (nums.some(n => Number.isNaN(n) || n < 0 || n > 0xffff)) {
|
||||
return null
|
||||
}
|
||||
nums.push(...tailHextets)
|
||||
return nums.length === 8 ? nums : null
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the embedded IPv4 address from an IPv4-mapped IPv6 address
|
||||
* (0:0:0:0:0:ffff:X:Y) in any valid representation — compressed, expanded,
|
||||
* hex groups, or trailing dotted-decimal. Returns null if the address is
|
||||
* not an IPv4-mapped IPv6 address.
|
||||
*/
|
||||
function extractMappedIPv4(addr: string): string | null {
|
||||
const g = expandIPv6Groups(addr)
|
||||
if (!g) return null
|
||||
// IPv4-mapped: first 80 bits zero, next 16 bits ffff, last 32 bits = IPv4
|
||||
if (
|
||||
g[0] === 0 &&
|
||||
g[1] === 0 &&
|
||||
g[2] === 0 &&
|
||||
g[3] === 0 &&
|
||||
g[4] === 0 &&
|
||||
g[5] === 0xffff
|
||||
) {
|
||||
const hi = g[6]!
|
||||
const lo = g[7]!
|
||||
return `${hi >> 8}.${hi & 0xff}.${lo >> 8}.${lo & 0xff}`
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* A dns.lookup-compatible function that resolves a hostname and rejects
|
||||
* addresses in blocked ranges. Used as the `lookup` option in axios request
|
||||
* config so that the validated IP is the one the socket connects to — no
|
||||
* rebinding window between validation and connection.
|
||||
*
|
||||
* IP literals in the hostname are validated directly without DNS.
|
||||
*
|
||||
* Signature matches axios's `lookup` config option (not Node's dns.lookup).
|
||||
*/
|
||||
export function ssrfGuardedLookup(
|
||||
hostname: string,
|
||||
options: object,
|
||||
callback: (
|
||||
err: Error | null,
|
||||
address: AxiosLookupAddress | AxiosLookupAddress[],
|
||||
family?: AddressFamily,
|
||||
) => void,
|
||||
): void {
|
||||
const wantsAll = 'all' in options && options.all === true
|
||||
|
||||
// If hostname is already an IP literal, validate it directly. dns.lookup
|
||||
// would short-circuit too, but checking here gives a clearer error and
|
||||
// avoids any platform-specific lookup behavior for literals.
|
||||
const ipVersion = isIP(hostname)
|
||||
if (ipVersion !== 0) {
|
||||
if (isBlockedAddress(hostname)) {
|
||||
callback(ssrfError(hostname, hostname), '')
|
||||
return
|
||||
}
|
||||
const family = ipVersion === 6 ? 6 : 4
|
||||
if (wantsAll) {
|
||||
callback(null, [{ address: hostname, family }])
|
||||
} else {
|
||||
callback(null, hostname, family)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
dnsLookup(hostname, { all: true }, (err, addresses) => {
|
||||
if (err) {
|
||||
callback(err, '')
|
||||
return
|
||||
}
|
||||
|
||||
for (const { address } of addresses) {
|
||||
if (isBlockedAddress(address)) {
|
||||
callback(ssrfError(hostname, address), '')
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
const first = addresses[0]
|
||||
if (!first) {
|
||||
callback(
|
||||
Object.assign(new Error(`ENOTFOUND ${hostname}`), {
|
||||
code: 'ENOTFOUND',
|
||||
hostname,
|
||||
}),
|
||||
'',
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
const family = first.family === 6 ? 6 : 4
|
||||
if (wantsAll) {
|
||||
callback(
|
||||
null,
|
||||
addresses.map(a => ({
|
||||
address: a.address,
|
||||
family: a.family === 6 ? 6 : 4,
|
||||
})),
|
||||
)
|
||||
} else {
|
||||
callback(null, first.address, family)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
function ssrfError(hostname: string, address: string): NodeJS.ErrnoException {
|
||||
const err = new Error(
|
||||
`HTTP hook blocked: ${hostname} resolves to ${address} (private/link-local address). Loopback (127.0.0.1, ::1) is allowed for local dev.`,
|
||||
)
|
||||
return Object.assign(err, {
|
||||
code: 'ERR_HTTP_HOOK_BLOCKED_ADDRESS',
|
||||
hostname,
|
||||
address,
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user