mirror of
https://github.com/zadam/trilium.git
synced 2025-11-03 20:06:08 +01:00
tool calling is close to working
getting closer to calling tools... we definitely need this closer to tool execution... agentic tool calling is...kind of working?
This commit is contained in:
@@ -7,6 +7,10 @@ import { MessagePreparationStage } from './stages/message_preparation_stage.js';
|
||||
import { ModelSelectionStage } from './stages/model_selection_stage.js';
|
||||
import { LLMCompletionStage } from './stages/llm_completion_stage.js';
|
||||
import { ResponseProcessingStage } from './stages/response_processing_stage.js';
|
||||
import { ToolCallingStage } from './stages/tool_calling_stage.js';
|
||||
import { VectorSearchStage } from './stages/vector_search_stage.js';
|
||||
import toolRegistry from '../tools/tool_registry.js';
|
||||
import toolInitializer from '../tools/tool_initializer.js';
|
||||
import log from '../../log.js';
|
||||
|
||||
/**
|
||||
@@ -22,6 +26,8 @@ export class ChatPipeline {
|
||||
modelSelection: ModelSelectionStage;
|
||||
llmCompletion: LLMCompletionStage;
|
||||
responseProcessing: ResponseProcessingStage;
|
||||
toolCalling: ToolCallingStage;
|
||||
vectorSearch: VectorSearchStage;
|
||||
};
|
||||
|
||||
config: ChatPipelineConfig;
|
||||
@@ -40,7 +46,9 @@ export class ChatPipeline {
|
||||
messagePreparation: new MessagePreparationStage(),
|
||||
modelSelection: new ModelSelectionStage(),
|
||||
llmCompletion: new LLMCompletionStage(),
|
||||
responseProcessing: new ResponseProcessingStage()
|
||||
responseProcessing: new ResponseProcessingStage(),
|
||||
toolCalling: new ToolCallingStage(),
|
||||
vectorSearch: new VectorSearchStage()
|
||||
};
|
||||
|
||||
// Set default configuration values
|
||||
@@ -87,6 +95,34 @@ export class ChatPipeline {
|
||||
contentLength += message.content.length;
|
||||
}
|
||||
|
||||
// Initialize tools if needed
|
||||
try {
|
||||
const toolCount = toolRegistry.getAllTools().length;
|
||||
|
||||
// If there are no tools registered, initialize them
|
||||
if (toolCount === 0) {
|
||||
log.info('No tools found in registry, initializing tools...');
|
||||
await toolInitializer.initializeTools();
|
||||
log.info(`Tools initialized, now have ${toolRegistry.getAllTools().length} tools`);
|
||||
} else {
|
||||
log.info(`Found ${toolCount} tools already registered`);
|
||||
}
|
||||
} catch (error: any) {
|
||||
log.error(`Error checking/initializing tools: ${error.message || String(error)}`);
|
||||
}
|
||||
|
||||
// First, select the appropriate model based on query complexity and content length
|
||||
const modelSelectionStartTime = Date.now();
|
||||
const modelSelection = await this.stages.modelSelection.execute({
|
||||
options: input.options,
|
||||
query: input.query,
|
||||
contentLength
|
||||
});
|
||||
this.updateStageMetrics('modelSelection', modelSelectionStartTime);
|
||||
|
||||
// Determine if we should use tools or semantic context
|
||||
const useTools = modelSelection.options.enableTools === true;
|
||||
|
||||
// Determine which pipeline flow to use
|
||||
let context: string | undefined;
|
||||
|
||||
@@ -102,27 +138,63 @@ export class ChatPipeline {
|
||||
});
|
||||
context = agentContext.context;
|
||||
this.updateStageMetrics('agentToolsContext', contextStartTime);
|
||||
} else {
|
||||
// Get semantic context for regular queries
|
||||
} else if (!useTools) {
|
||||
// Only get semantic context if tools are NOT enabled
|
||||
// When tools are enabled, we'll let the LLM request context via tools instead
|
||||
log.info('Getting semantic context for note using pipeline stages');
|
||||
|
||||
// First use the vector search stage to find relevant notes
|
||||
const vectorSearchStartTime = Date.now();
|
||||
log.info(`Executing vector search stage for query: "${input.query?.substring(0, 50)}..."`);
|
||||
|
||||
const vectorSearchResult = await this.stages.vectorSearch.execute({
|
||||
query: input.query || '',
|
||||
noteId: input.noteId,
|
||||
options: {
|
||||
maxResults: 10,
|
||||
useEnhancedQueries: true,
|
||||
threshold: 0.6
|
||||
}
|
||||
});
|
||||
|
||||
this.updateStageMetrics('vectorSearch', vectorSearchStartTime);
|
||||
|
||||
log.info(`Vector search found ${vectorSearchResult.searchResults.length} relevant notes`);
|
||||
|
||||
// Then pass to the semantic context stage to build the formatted context
|
||||
const semanticContext = await this.stages.semanticContextExtraction.execute({
|
||||
noteId: input.noteId,
|
||||
query: input.query,
|
||||
messages: input.messages
|
||||
});
|
||||
|
||||
context = semanticContext.context;
|
||||
this.updateStageMetrics('semanticContextExtraction', contextStartTime);
|
||||
} else {
|
||||
log.info('Tools are enabled - using minimal direct context to avoid race conditions');
|
||||
// Get context from current note directly without semantic search
|
||||
if (input.noteId) {
|
||||
try {
|
||||
const contextExtractor = new (await import('../../llm/context/index.js')).ContextExtractor();
|
||||
// Just get the direct content of the current note
|
||||
context = await contextExtractor.extractContext(input.noteId, {
|
||||
includeContent: true,
|
||||
includeParents: true,
|
||||
includeChildren: true,
|
||||
includeLinks: true,
|
||||
includeSimilar: false // Skip semantic search to avoid race conditions
|
||||
});
|
||||
log.info(`Direct context extracted (${context.length} chars) without semantic search`);
|
||||
} catch (error: any) {
|
||||
log.error(`Error extracting direct context: ${error.message}`);
|
||||
context = ""; // Fallback to empty context if extraction fails
|
||||
}
|
||||
} else {
|
||||
context = ""; // No note ID, so no context
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Select the appropriate model based on query complexity and content length
|
||||
const modelSelectionStartTime = Date.now();
|
||||
const modelSelection = await this.stages.modelSelection.execute({
|
||||
options: input.options,
|
||||
query: input.query,
|
||||
contentLength
|
||||
});
|
||||
this.updateStageMetrics('modelSelection', modelSelectionStartTime);
|
||||
|
||||
// Prepare messages with context and system prompt
|
||||
const messagePreparationStartTime = Date.now();
|
||||
const preparedMessages = await this.stages.messagePreparation.execute({
|
||||
@@ -167,17 +239,106 @@ export class ChatPipeline {
|
||||
});
|
||||
}
|
||||
|
||||
// For non-streaming responses, process the full response
|
||||
// Process any tool calls in the response
|
||||
let currentMessages = preparedMessages.messages;
|
||||
let currentResponse = completion.response;
|
||||
let needsFollowUp = false;
|
||||
let toolCallIterations = 0;
|
||||
const maxToolCallIterations = this.config.maxToolCallIterations;
|
||||
|
||||
// Check if tools were enabled in the options
|
||||
const toolsEnabled = modelSelection.options.enableTools !== false;
|
||||
|
||||
log.info(`========== TOOL CALL PROCESSING ==========`);
|
||||
log.info(`Tools enabled: ${toolsEnabled}`);
|
||||
log.info(`Tool calls in response: ${currentResponse.tool_calls ? currentResponse.tool_calls.length : 0}`);
|
||||
log.info(`Current response format: ${typeof currentResponse}`);
|
||||
log.info(`Response keys: ${Object.keys(currentResponse).join(', ')}`);
|
||||
|
||||
// Detailed tool call inspection
|
||||
if (currentResponse.tool_calls) {
|
||||
currentResponse.tool_calls.forEach((tool, idx) => {
|
||||
log.info(`Tool call ${idx+1}: ${JSON.stringify(tool)}`);
|
||||
});
|
||||
}
|
||||
|
||||
// Process tool calls if present and tools are enabled
|
||||
if (toolsEnabled && currentResponse.tool_calls && currentResponse.tool_calls.length > 0) {
|
||||
log.info(`Response contains ${currentResponse.tool_calls.length} tool calls, processing...`);
|
||||
|
||||
// Start tool calling loop
|
||||
log.info(`Starting tool calling loop with max ${maxToolCallIterations} iterations`);
|
||||
|
||||
do {
|
||||
log.info(`Tool calling iteration ${toolCallIterations + 1}`);
|
||||
|
||||
// Execute tool calling stage
|
||||
const toolCallingStartTime = Date.now();
|
||||
const toolCallingResult = await this.stages.toolCalling.execute({
|
||||
response: currentResponse,
|
||||
messages: currentMessages,
|
||||
options: modelSelection.options
|
||||
});
|
||||
this.updateStageMetrics('toolCalling', toolCallingStartTime);
|
||||
|
||||
// Update state for next iteration
|
||||
currentMessages = toolCallingResult.messages;
|
||||
needsFollowUp = toolCallingResult.needsFollowUp;
|
||||
|
||||
// Make another call to the LLM if needed
|
||||
if (needsFollowUp) {
|
||||
log.info(`Tool execution completed, making follow-up LLM call (iteration ${toolCallIterations + 1})...`);
|
||||
|
||||
// Generate a new LLM response with the updated messages
|
||||
const followUpStartTime = Date.now();
|
||||
log.info(`Sending follow-up request to LLM with ${currentMessages.length} messages (including tool results)`);
|
||||
|
||||
const followUpCompletion = await this.stages.llmCompletion.execute({
|
||||
messages: currentMessages,
|
||||
options: modelSelection.options
|
||||
});
|
||||
this.updateStageMetrics('llmCompletion', followUpStartTime);
|
||||
|
||||
// Update current response for next iteration
|
||||
currentResponse = followUpCompletion.response;
|
||||
|
||||
// Check for more tool calls
|
||||
const hasMoreToolCalls = !!(currentResponse.tool_calls && currentResponse.tool_calls.length > 0);
|
||||
|
||||
if (hasMoreToolCalls) {
|
||||
log.info(`Follow-up response contains ${currentResponse.tool_calls?.length || 0} more tool calls`);
|
||||
} else {
|
||||
log.info(`Follow-up response contains no more tool calls - completing tool loop`);
|
||||
}
|
||||
|
||||
// Continue loop if there are more tool calls
|
||||
needsFollowUp = hasMoreToolCalls;
|
||||
}
|
||||
|
||||
// Increment iteration counter
|
||||
toolCallIterations++;
|
||||
|
||||
} while (needsFollowUp && toolCallIterations < maxToolCallIterations);
|
||||
|
||||
// If we hit max iterations but still have tool calls, log a warning
|
||||
if (toolCallIterations >= maxToolCallIterations && needsFollowUp) {
|
||||
log.error(`Reached maximum tool call iterations (${maxToolCallIterations}), stopping`);
|
||||
}
|
||||
|
||||
log.info(`Completed ${toolCallIterations} tool call iterations`);
|
||||
}
|
||||
|
||||
// For non-streaming responses, process the final response
|
||||
const processStartTime = Date.now();
|
||||
const processed = await this.stages.responseProcessing.execute({
|
||||
response: completion.response,
|
||||
response: currentResponse,
|
||||
options: input.options
|
||||
});
|
||||
this.updateStageMetrics('responseProcessing', processStartTime);
|
||||
|
||||
// Combine response with processed text, using accumulated text if streamed
|
||||
const finalResponse: ChatResponse = {
|
||||
...completion.response,
|
||||
...currentResponse,
|
||||
text: accumulatedText || processed.text
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user