From 038705483bfd12d05f090ea1fff15493b8ca31da Mon Sep 17 00:00:00 2001 From: Elian Doran Date: Wed, 1 Apr 2026 12:34:14 +0300 Subject: [PATCH] refactor(llm): integrate tools requiring context --- apps/server/spec/etapi/mcp.spec.ts | 1 + .../services/llm/providers/base_provider.ts | 12 ++-- apps/server/src/services/llm/tools/index.ts | 2 +- .../src/services/llm/tools/note_tools.ts | 54 ++++++++---------- .../src/services/llm/tools/tool_registry.ts | 56 +++++++++++++++---- apps/server/src/services/mcp/mcp_server.ts | 1 + 6 files changed, 77 insertions(+), 49 deletions(-) diff --git a/apps/server/spec/etapi/mcp.spec.ts b/apps/server/spec/etapi/mcp.spec.ts index 06c6ba7c1f..2663934b54 100644 --- a/apps/server/spec/etapi/mcp.spec.ts +++ b/apps/server/spec/etapi/mcp.spec.ts @@ -115,6 +115,7 @@ describe("mcp", () => { expect(toolNames).toContain("search_notes"); expect(toolNames).toContain("read_note"); expect(toolNames).toContain("create_note"); + expect(toolNames).not.toContain("get_current_note"); }); }); diff --git a/apps/server/src/services/llm/providers/base_provider.ts b/apps/server/src/services/llm/providers/base_provider.ts index ce8a5bd9b4..1b34ccada4 100644 --- a/apps/server/src/services/llm/providers/base_provider.ts +++ b/apps/server/src/services/llm/providers/base_provider.ts @@ -9,7 +9,8 @@ import type { LlmMessage } from "@triliumnext/commons"; import becca from "../../../becca/becca.js"; import { getSkillsSummary } from "../skills/index.js"; -import { allToolRegistries, currentNoteTools } from "../tools/index.js"; +import { allToolRegistries } from "../tools/index.js"; +import type { ToolContext } from "../tools/tool_registry.js"; import type { LlmProvider, LlmProviderConfig, ModelInfo, ModelPricing, StreamResult } from "../types.js"; const DEFAULT_MAX_TOKENS = 8096; @@ -128,13 +129,12 @@ export abstract class BaseProvider implements LlmProvider { this.addWebSearchTool(tools); } - if (config.contextNoteId) { - Object.assign(tools, currentNoteTools(config.contextNoteId)); - } - if (config.enableNoteTools) { + const context: ToolContext | undefined = config.contextNoteId + ? { contextNoteId: config.contextNoteId } + : undefined; for (const registry of allToolRegistries) { - Object.assign(tools, registry.toToolSet()); + Object.assign(tools, registry.toToolSet(context)); } } diff --git a/apps/server/src/services/llm/tools/index.ts b/apps/server/src/services/llm/tools/index.ts index e295a7c841..b4b10b4884 100644 --- a/apps/server/src/services/llm/tools/index.ts +++ b/apps/server/src/services/llm/tools/index.ts @@ -3,7 +3,7 @@ * These reuse the same logic as ETAPI without any HTTP overhead. */ -export { noteTools, currentNoteTools } from "./note_tools.js"; +export { noteTools } from "./note_tools.js"; export { attributeTools } from "./attribute_tools.js"; export { hierarchyTools } from "./hierarchy_tools.js"; export { skillTools } from "../skills/index.js"; diff --git a/apps/server/src/services/llm/tools/note_tools.ts b/apps/server/src/services/llm/tools/note_tools.ts index 9f3131f499..df5634f30b 100644 --- a/apps/server/src/services/llm/tools/note_tools.ts +++ b/apps/server/src/services/llm/tools/note_tools.ts @@ -2,7 +2,6 @@ * LLM tools for note operations (search, read, create, update, append). */ -import { tool } from "ai"; import { z } from "zod"; import becca from "../../../becca/becca.js"; @@ -11,7 +10,7 @@ import markdownImport from "../../import/markdown.js"; import noteService from "../../notes.js"; import SearchContext from "../../search/search_context.js"; import searchService from "../../search/services/search.js"; -import { defineTools } from "./tool_registry.js"; +import { defineTools, type ToolContext } from "./tool_registry.js"; /** * Convert note content to a format suitable for LLM consumption. @@ -228,34 +227,27 @@ export const noteTools = defineTools({ return { error: err instanceof Error ? err.message : "Failed to create note" }; } } + }, + + get_current_note: { + description: "Read the content of the note the user is currently viewing. Call this when the user asks about or refers to their current note.", + inputSchema: z.object({}), + needsContext: true as const, + execute: async (_args: Record, { contextNoteId }: ToolContext) => { + const note = becca.getNote(contextNoteId); + if (!note) { + return { error: "Note not found" }; + } + if (!note.isContentAvailable()) { + return { error: "Note is protected" }; + } + + return { + noteId: note.noteId, + title: note.getTitleOrProtected(), + type: note.type, + content: getNoteContentForLlm(note) + }; + } } }); - -/** - * Read the content of the note the user is currently viewing. - * Created dynamically so it captures the contextNoteId. - */ -export function currentNoteTools(contextNoteId: string) { - return { - get_current_note: tool({ - description: "Read the content of the note the user is currently viewing. Call this when the user asks about or refers to their current note.", - inputSchema: z.object({}), - execute: async () => { - const note = becca.getNote(contextNoteId); - if (!note) { - return { error: "Note not found" }; - } - if (!note.isContentAvailable()) { - return { error: "Note is protected" }; - } - - return { - noteId: note.noteId, - title: note.getTitleOrProtected(), - type: note.type, - content: getNoteContentForLlm(note) - }; - } - }) - }; -} diff --git a/apps/server/src/services/llm/tools/tool_registry.ts b/apps/server/src/services/llm/tools/tool_registry.ts index d39aa7c11d..f31d943393 100644 --- a/apps/server/src/services/llm/tools/tool_registry.ts +++ b/apps/server/src/services/llm/tools/tool_registry.ts @@ -1,6 +1,7 @@ /** * Lightweight wrapper around AI tool definitions that carries extra metadata - * (e.g. `mutates`) while remaining compatible with the Vercel AI SDK ToolSet. + * (e.g. `mutates`, `needsContext`) while remaining compatible with the Vercel + * AI SDK ToolSet. * * Each tool module calls `defineTools({ ... })` to declare its tools. * Consumers can then: @@ -12,14 +13,32 @@ import { tool } from "ai"; import type { z } from "zod"; import type { ToolSet } from "ai"; -export interface ToolDefinition { +/** Context passed to tools that declare `needsContext: true`. */ +export interface ToolContext { + contextNoteId: string; +} + +interface ToolDefinitionBase { description: string; inputSchema: z.ZodType; - execute: (args: any) => Promise; /** Whether this tool modifies data (needs CLS + transaction wrapping). */ mutates?: boolean; } +/** A tool that does not require a note context. */ +export interface StaticToolDefinition extends ToolDefinitionBase { + needsContext?: false; + execute: (args: any) => Promise; +} + +/** A tool that requires a note context (e.g. "current note"). */ +export interface ContextToolDefinition extends ToolDefinitionBase { + needsContext: true; + execute: (args: any, context: ToolContext) => Promise; +} + +export type ToolDefinition = StaticToolDefinition | ContextToolDefinition; + /** * A named collection of tool definitions that can be iterated or converted * to an AI SDK ToolSet. @@ -32,15 +51,30 @@ export class ToolRegistry implements Iterable<[string, ToolDefinition]> { return Object.entries(this.tools)[Symbol.iterator](); } - /** Convert to an AI SDK ToolSet for use with the LLM chat providers. */ - toToolSet(): ToolSet { + /** + * Convert to an AI SDK ToolSet for use with the LLM chat providers. + * + * If `context` is provided, context-aware tools are included with the + * context bound into their execute function. Otherwise they are skipped. + */ + toToolSet(context?: ToolContext): ToolSet { const set: ToolSet = {}; for (const [name, def] of this) { - set[name] = tool({ - description: def.description, - inputSchema: def.inputSchema, - execute: def.execute - }); + if (def.needsContext) { + if (!context) continue; + const boundExecute = (args: any) => def.execute(args, context); + set[name] = tool({ + description: def.description, + inputSchema: def.inputSchema, + execute: boundExecute + }); + } else { + set[name] = tool({ + description: def.description, + inputSchema: def.inputSchema, + execute: def.execute + }); + } } return set; } @@ -52,7 +86,7 @@ export class ToolRegistry implements Iterable<[string, ToolDefinition]> { * ```ts * export const noteTools = defineTools({ * search_notes: { description: "...", inputSchema: z.object({...}), execute: async (args) => {...} }, - * create_note: { description: "...", inputSchema: z.object({...}), execute: async (args) => {...}, mutates: true }, + * get_current_note: { description: "...", inputSchema: z.object({}), execute: async (args, ctx) => {...}, needsContext: true }, * }); * ``` */ diff --git a/apps/server/src/services/mcp/mcp_server.ts b/apps/server/src/services/mcp/mcp_server.ts index e74af59080..7e1624a5df 100644 --- a/apps/server/src/services/mcp/mcp_server.ts +++ b/apps/server/src/services/mcp/mcp_server.ts @@ -42,6 +42,7 @@ export function createMcpServer(): McpServer { for (const registry of allToolRegistries) { for (const [name, def] of registry) { + if (def.needsContext) continue; registerTool(server, name, def); } }