mirror of
https://github.com/zadam/trilium.git
synced 2026-05-07 01:45:36 +02:00
refactor(llm): integrate tools requiring context
This commit is contained in:
@@ -115,6 +115,7 @@ describe("mcp", () => {
|
||||
expect(toolNames).toContain("search_notes");
|
||||
expect(toolNames).toContain("read_note");
|
||||
expect(toolNames).toContain("create_note");
|
||||
expect(toolNames).not.toContain("get_current_note");
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -9,7 +9,8 @@ import type { LlmMessage } from "@triliumnext/commons";
|
||||
|
||||
import becca from "../../../becca/becca.js";
|
||||
import { getSkillsSummary } from "../skills/index.js";
|
||||
import { allToolRegistries, currentNoteTools } from "../tools/index.js";
|
||||
import { allToolRegistries } from "../tools/index.js";
|
||||
import type { ToolContext } from "../tools/tool_registry.js";
|
||||
import type { LlmProvider, LlmProviderConfig, ModelInfo, ModelPricing, StreamResult } from "../types.js";
|
||||
|
||||
const DEFAULT_MAX_TOKENS = 8096;
|
||||
@@ -128,13 +129,12 @@ export abstract class BaseProvider implements LlmProvider {
|
||||
this.addWebSearchTool(tools);
|
||||
}
|
||||
|
||||
if (config.contextNoteId) {
|
||||
Object.assign(tools, currentNoteTools(config.contextNoteId));
|
||||
}
|
||||
|
||||
if (config.enableNoteTools) {
|
||||
const context: ToolContext | undefined = config.contextNoteId
|
||||
? { contextNoteId: config.contextNoteId }
|
||||
: undefined;
|
||||
for (const registry of allToolRegistries) {
|
||||
Object.assign(tools, registry.toToolSet());
|
||||
Object.assign(tools, registry.toToolSet(context));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
* These reuse the same logic as ETAPI without any HTTP overhead.
|
||||
*/
|
||||
|
||||
export { noteTools, currentNoteTools } from "./note_tools.js";
|
||||
export { noteTools } from "./note_tools.js";
|
||||
export { attributeTools } from "./attribute_tools.js";
|
||||
export { hierarchyTools } from "./hierarchy_tools.js";
|
||||
export { skillTools } from "../skills/index.js";
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
* LLM tools for note operations (search, read, create, update, append).
|
||||
*/
|
||||
|
||||
import { tool } from "ai";
|
||||
import { z } from "zod";
|
||||
|
||||
import becca from "../../../becca/becca.js";
|
||||
@@ -11,7 +10,7 @@ import markdownImport from "../../import/markdown.js";
|
||||
import noteService from "../../notes.js";
|
||||
import SearchContext from "../../search/search_context.js";
|
||||
import searchService from "../../search/services/search.js";
|
||||
import { defineTools } from "./tool_registry.js";
|
||||
import { defineTools, type ToolContext } from "./tool_registry.js";
|
||||
|
||||
/**
|
||||
* Convert note content to a format suitable for LLM consumption.
|
||||
@@ -228,34 +227,27 @@ export const noteTools = defineTools({
|
||||
return { error: err instanceof Error ? err.message : "Failed to create note" };
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
get_current_note: {
|
||||
description: "Read the content of the note the user is currently viewing. Call this when the user asks about or refers to their current note.",
|
||||
inputSchema: z.object({}),
|
||||
needsContext: true as const,
|
||||
execute: async (_args: Record<string, never>, { contextNoteId }: ToolContext) => {
|
||||
const note = becca.getNote(contextNoteId);
|
||||
if (!note) {
|
||||
return { error: "Note not found" };
|
||||
}
|
||||
if (!note.isContentAvailable()) {
|
||||
return { error: "Note is protected" };
|
||||
}
|
||||
|
||||
return {
|
||||
noteId: note.noteId,
|
||||
title: note.getTitleOrProtected(),
|
||||
type: note.type,
|
||||
content: getNoteContentForLlm(note)
|
||||
};
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Read the content of the note the user is currently viewing.
|
||||
* Created dynamically so it captures the contextNoteId.
|
||||
*/
|
||||
export function currentNoteTools(contextNoteId: string) {
|
||||
return {
|
||||
get_current_note: tool({
|
||||
description: "Read the content of the note the user is currently viewing. Call this when the user asks about or refers to their current note.",
|
||||
inputSchema: z.object({}),
|
||||
execute: async () => {
|
||||
const note = becca.getNote(contextNoteId);
|
||||
if (!note) {
|
||||
return { error: "Note not found" };
|
||||
}
|
||||
if (!note.isContentAvailable()) {
|
||||
return { error: "Note is protected" };
|
||||
}
|
||||
|
||||
return {
|
||||
noteId: note.noteId,
|
||||
title: note.getTitleOrProtected(),
|
||||
type: note.type,
|
||||
content: getNoteContentForLlm(note)
|
||||
};
|
||||
}
|
||||
})
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/**
|
||||
* Lightweight wrapper around AI tool definitions that carries extra metadata
|
||||
* (e.g. `mutates`) while remaining compatible with the Vercel AI SDK ToolSet.
|
||||
* (e.g. `mutates`, `needsContext`) while remaining compatible with the Vercel
|
||||
* AI SDK ToolSet.
|
||||
*
|
||||
* Each tool module calls `defineTools({ ... })` to declare its tools.
|
||||
* Consumers can then:
|
||||
@@ -12,14 +13,32 @@ import { tool } from "ai";
|
||||
import type { z } from "zod";
|
||||
import type { ToolSet } from "ai";
|
||||
|
||||
export interface ToolDefinition {
|
||||
/** Context passed to tools that declare `needsContext: true`. */
|
||||
export interface ToolContext {
|
||||
contextNoteId: string;
|
||||
}
|
||||
|
||||
interface ToolDefinitionBase {
|
||||
description: string;
|
||||
inputSchema: z.ZodType;
|
||||
execute: (args: any) => Promise<unknown>;
|
||||
/** Whether this tool modifies data (needs CLS + transaction wrapping). */
|
||||
mutates?: boolean;
|
||||
}
|
||||
|
||||
/** A tool that does not require a note context. */
|
||||
export interface StaticToolDefinition extends ToolDefinitionBase {
|
||||
needsContext?: false;
|
||||
execute: (args: any) => Promise<unknown>;
|
||||
}
|
||||
|
||||
/** A tool that requires a note context (e.g. "current note"). */
|
||||
export interface ContextToolDefinition extends ToolDefinitionBase {
|
||||
needsContext: true;
|
||||
execute: (args: any, context: ToolContext) => Promise<unknown>;
|
||||
}
|
||||
|
||||
export type ToolDefinition = StaticToolDefinition | ContextToolDefinition;
|
||||
|
||||
/**
|
||||
* A named collection of tool definitions that can be iterated or converted
|
||||
* to an AI SDK ToolSet.
|
||||
@@ -32,15 +51,30 @@ export class ToolRegistry implements Iterable<[string, ToolDefinition]> {
|
||||
return Object.entries(this.tools)[Symbol.iterator]();
|
||||
}
|
||||
|
||||
/** Convert to an AI SDK ToolSet for use with the LLM chat providers. */
|
||||
toToolSet(): ToolSet {
|
||||
/**
|
||||
* Convert to an AI SDK ToolSet for use with the LLM chat providers.
|
||||
*
|
||||
* If `context` is provided, context-aware tools are included with the
|
||||
* context bound into their execute function. Otherwise they are skipped.
|
||||
*/
|
||||
toToolSet(context?: ToolContext): ToolSet {
|
||||
const set: ToolSet = {};
|
||||
for (const [name, def] of this) {
|
||||
set[name] = tool({
|
||||
description: def.description,
|
||||
inputSchema: def.inputSchema,
|
||||
execute: def.execute
|
||||
});
|
||||
if (def.needsContext) {
|
||||
if (!context) continue;
|
||||
const boundExecute = (args: any) => def.execute(args, context);
|
||||
set[name] = tool({
|
||||
description: def.description,
|
||||
inputSchema: def.inputSchema,
|
||||
execute: boundExecute
|
||||
});
|
||||
} else {
|
||||
set[name] = tool({
|
||||
description: def.description,
|
||||
inputSchema: def.inputSchema,
|
||||
execute: def.execute
|
||||
});
|
||||
}
|
||||
}
|
||||
return set;
|
||||
}
|
||||
@@ -52,7 +86,7 @@ export class ToolRegistry implements Iterable<[string, ToolDefinition]> {
|
||||
* ```ts
|
||||
* export const noteTools = defineTools({
|
||||
* search_notes: { description: "...", inputSchema: z.object({...}), execute: async (args) => {...} },
|
||||
* create_note: { description: "...", inputSchema: z.object({...}), execute: async (args) => {...}, mutates: true },
|
||||
* get_current_note: { description: "...", inputSchema: z.object({}), execute: async (args, ctx) => {...}, needsContext: true },
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
|
||||
@@ -42,6 +42,7 @@ export function createMcpServer(): McpServer {
|
||||
|
||||
for (const registry of allToolRegistries) {
|
||||
for (const [name, def] of registry) {
|
||||
if (def.needsContext) continue;
|
||||
registerTool(server, name, def);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user