try a context approach

This commit is contained in:
perf3ct
2025-03-10 03:34:48 +00:00
parent adaac46fbf
commit cf0e9242a0
4 changed files with 803 additions and 370 deletions

View File

@@ -9,6 +9,8 @@ import providerManager from "../../services/llm/embeddings/providers.js";
import type { Message, ChatCompletionOptions } from "../../services/llm/ai_interface.js";
// Import this way to prevent immediate instantiation
import * as aiServiceManagerModule from "../../services/llm/ai_service_manager.js";
import triliumContextService from "../../services/llm/trilium_context_service.js";
import sql from "../../services/sql.js";
// Define basic interfaces
interface ChatMessage {
@@ -290,132 +292,126 @@ async function deleteSession(req: Request, res: Response) {
}
/**
* Find relevant notes using vector embeddings
* Find relevant notes based on search query
*/
async function findRelevantNotes(query: string, contextNoteId: string | null = null, limit = 5): Promise<NoteSource[]> {
async function findRelevantNotes(content: string, contextNoteId: string | null = null, limit = 5): Promise<NoteSource[]> {
try {
// Only proceed if database is initialized
// If database is not initialized, we can't do this
if (!isDatabaseInitialized()) {
log.info('Database not initialized, skipping vector search');
return [{
noteId: "root",
title: "Database not initialized yet",
content: "Please wait for database initialization to complete."
}];
return [];
}
// Get the default embedding provider
let providerId;
try {
// @ts-ignore - embeddingsDefaultProvider exists but might not be in the TypeScript definitions
providerId = await options.getOption('embeddingsDefaultProvider') || 'openai';
} catch (error) {
log.info('Could not get default embedding provider, using mock data');
return [{
noteId: "root",
title: "Embeddings not configured",
content: "Embedding provider not available"
}];
// Check if embeddings are available
const enabledProviders = await providerManager.getEnabledEmbeddingProviders();
if (enabledProviders.length === 0) {
log.info("No embedding providers available, can't find relevant notes");
return [];
}
const provider = providerManager.getEmbeddingProvider(providerId);
if (!provider) {
log.info(`Embedding provider ${providerId} not found, using mock data`);
return [{
noteId: "root",
title: "Embeddings not available",
content: "No embedding provider available"
}];
// If content is too short, don't bother
if (content.length < 3) {
return [];
}
// Generate embedding for the query
const embedding = await provider.generateEmbeddings(query);
// Get the embedding for the query
const provider = enabledProviders[0];
const embedding = await provider.generateEmbeddings(content);
// Find similar notes
const modelId = 'default'; // Use default model for the provider
const similarNotes = await vectorStore.findSimilarNotes(
embedding, providerId, modelId, limit, 0.6 // Lower threshold to find more results
);
// If a context note was provided, check if we should include its children
let results;
if (contextNoteId) {
const contextNote = becca.getNote(contextNoteId);
if (contextNote) {
const childNotes = contextNote.getChildNotes();
if (childNotes.length > 0) {
// Add relevant children that weren't already included
const childIds = new Set(childNotes.map(note => note.noteId));
const existingIds = new Set(similarNotes.map(note => note.noteId));
// For branch context, get notes specifically from that branch
// Find children that aren't already in the similar notes
const missingChildIds = Array.from(childIds).filter(id => !existingIds.has(id));
// TODO: This is a simplified implementation - we need to
// properly get all notes in the subtree starting from contextNoteId
// Add up to 3 children that weren't already included
for (const noteId of missingChildIds.slice(0, 3)) {
similarNotes.push({
// For now, just get direct children of the context note
const contextNote = becca.notes[contextNoteId];
if (!contextNote) {
return [];
}
const childBranches = await sql.getRows(`
SELECT branches.* FROM branches
WHERE branches.parentNoteId = ?
AND branches.isDeleted = 0
`, [contextNoteId]);
const childNoteIds = childBranches.map((branch: any) => branch.noteId);
// Include the context note itself
childNoteIds.push(contextNoteId);
// Find similar notes in this context
results = [];
for (const noteId of childNoteIds) {
const noteEmbedding = await vectorStore.getEmbeddingForNote(
noteId,
provider.name,
provider.getConfig().model
);
if (noteEmbedding) {
const similarity = vectorStore.cosineSimilarity(
embedding,
noteEmbedding.embedding
);
if (similarity > 0.65) {
results.push({
noteId,
similarity: 0.75 // Fixed similarity score for context children
similarity
});
}
}
}
// Sort by similarity
results.sort((a, b) => b.similarity - a.similarity);
results = results.slice(0, limit);
} else {
// General search across all notes
results = await vectorStore.findSimilarNotes(
embedding,
provider.name,
provider.getConfig().model,
limit
);
}
// Get note content for context
return await Promise.all(similarNotes.map(async ({ noteId, similarity }) => {
const note = becca.getNote(noteId);
if (!note) {
return {
noteId,
title: "Unknown Note",
similarity
};
// Format the results
const sources: NoteSource[] = [];
for (const result of results) {
const note = becca.notes[result.noteId];
if (!note) continue;
let noteContent: string | undefined = undefined;
if (note.type === 'text') {
const content = note.getContent();
// Handle both string and Buffer types
noteContent = typeof content === 'string' ? content :
content instanceof Buffer ? content.toString('utf8') : undefined;
}
// Get note content
let content = '';
try {
// @ts-ignore - Content can be string or Buffer
const noteContent = await note.getContent();
content = typeof noteContent === 'string' ? noteContent : noteContent.toString('utf8');
// Truncate content if it's too long (for performance)
if (content.length > 2000) {
content = content.substring(0, 2000) + "...";
}
} catch (e) {
log.error(`Error getting content for note ${noteId}: ${e}`);
}
// Get a branch ID for navigation
let branchId;
try {
const branches = note.getBranches();
if (branches.length > 0) {
branchId = branches[0].branchId;
}
} catch (e) {
log.error(`Error getting branch for note ${noteId}: ${e}`);
}
return {
noteId,
sources.push({
noteId: result.noteId,
title: note.title,
content,
similarity,
branchId
};
}));
} catch (error) {
log.error(`Error finding relevant notes: ${error}`);
// Return empty array on error
content: noteContent,
similarity: result.similarity,
branchId: note.getBranches()[0]?.branchId
});
}
return sources;
} catch (error: any) {
log.error(`Error finding relevant notes: ${error.message}`);
return [];
}
}
/**
* Build a context string from relevant notes
* Build context from notes
*/
function buildContextFromNotes(sources: NoteSource[], query: string): string {
console.log("Building context from notes with query:", query);
@@ -449,265 +445,237 @@ Now, based on the above notes, please answer: ${query}`;
}
/**
* Send a message to an LLM chat session and get a response
* Send a message to the AI
*/
async function sendMessage(req: Request, res: Response) {
try {
const { sessionId, content, temperature, maxTokens, provider, model } = req.body;
console.log("Received message request:", {
sessionId,
contentLength: content ? content.length : 0,
contentPreview: content ? content.substring(0, 50) + (content.length > 50 ? '...' : '') : 'undefined',
temperature,
maxTokens,
provider,
model
});
if (!sessionId) {
throw new Error('Session ID is required');
}
// Extract the content from the request body
const { content, sessionId, useAdvancedContext = false } = req.body || {};
// Validate the content
if (!content || typeof content !== 'string' || content.trim().length === 0) {
throw new Error('Content cannot be empty');
}
// Check if streaming is requested
const wantsStream = (req.headers as any)['accept']?.includes('text/event-stream');
// Get or create the session
let session: ChatSession;
// If client wants streaming, set up SSE response
if (wantsStream) {
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
// Get chat session
let session = sessions.get(sessionId);
if (!session) {
const newSession = await createSession(req, res);
if (!newSession) {
throw new Error('Failed to create session');
}
// Add required properties to match ChatSession interface
session = {
...newSession,
messages: [],
lastActive: new Date(),
metadata: {}
};
sessions.set(sessionId, session);
if (sessionId && sessions.has(sessionId)) {
session = sessions.get(sessionId)!;
session.lastActive = new Date();
} else {
const result = await createSession(req, res);
if (!result?.id) {
throw new Error('Failed to create a new session');
}
session = sessions.get(result.id)!;
}
// Add user message to session
const userMessage: ChatMessage = {
role: 'user',
content: content,
timestamp: new Date()
// Check if AI services are available
if (!safelyUseAIManager()) {
throw new Error('AI services are not available');
}
// Get the AI service manager
const aiServiceManager = aiServiceManagerModule.default.getInstance();
// Get the default service - just use the first available one
const availableProviders = aiServiceManager.getAvailableProviders();
let service = null;
if (availableProviders.length > 0) {
// Use the first available provider
const providerName = availableProviders[0];
// We know the manager has a 'services' property from our code inspection,
// but TypeScript doesn't know that from the interface.
// This is a workaround to access it
service = (aiServiceManager as any).services[providerName];
}
if (!service) {
throw new Error('No AI service is available');
}
// Create user message
const userMessage: Message = {
role: 'user',
content
};
// Add message to session
session.messages.push({
role: 'user',
content,
timestamp: new Date()
});
// Log a preview of the message
log.info(`Processing LLM message: "${content.substring(0, 50)}${content.length > 50 ? '...' : ''}"`);
// Information to return to the client
let aiResponse = '';
let sourceNotes: NoteSource[] = [];
// If Advanced Context is enabled, we use the improved method
if (useAdvancedContext) {
// Use the Trilium-specific approach
const contextNoteId = session.noteContext || null;
const results = await triliumContextService.processQuery(content, service, contextNoteId);
// Get the generated context
const context = results.context;
sourceNotes = results.notes;
// Add system message with the context
const contextMessage: Message = {
role: 'system',
content: context
};
console.log("Created user message:", {
role: userMessage.role,
contentLength: userMessage.content?.length || 0,
contentPreview: userMessage.content?.substring(0, 50) + (userMessage.content?.length > 50 ? '...' : '') || 'undefined'
});
session.messages.push(userMessage);
// Get context for query
const sources = await findRelevantNotes(content, session.noteContext || null);
// Format messages for AI with proper type casting
// Format all messages for the AI
const aiMessages: Message[] = [
{ role: 'system', content: 'You are a helpful assistant for Trilium Notes. When providing answers, use only the context provided in the notes. If the information is not in the notes, say so.' },
{ role: 'user', content: buildContextFromNotes(sources, content) }
contextMessage,
...session.messages.slice(-10).map(msg => ({
role: msg.role,
content: msg.content
}))
];
// Ensure we're not sending empty content
console.log("Final message content length:", aiMessages[1].content.length);
console.log("Final message content preview:", aiMessages[1].content.substring(0, 100));
try {
// Send initial SSE message with session info
const sourcesForResponse = sources.map(({ noteId, title, similarity, branchId }) => ({
noteId,
title,
similarity: similarity ? Math.round(similarity * 100) / 100 : undefined,
branchId
}));
res.write(`data: ${JSON.stringify({
type: 'init',
session: {
id: sessionId,
messages: session.messages.slice(0, -1), // Don't include the new message yet
sources: sourcesForResponse
}
})}\n\n`);
// Get AI response with streaming enabled
const aiResponse = await aiServiceManagerModule.default.generateChatCompletion(aiMessages, {
temperature,
maxTokens,
model: provider ? `${provider}:${model}` : model,
stream: true
});
if (aiResponse.stream) {
// Create an empty assistant message
const assistantMessage: ChatMessage = {
role: 'assistant',
content: '',
timestamp: new Date()
};
session.messages.push(assistantMessage);
// Stream the response chunks
await aiResponse.stream(async (chunk) => {
if (chunk.text) {
// Update the message content
assistantMessage.content += chunk.text;
// Send chunk to client
res.write(`data: ${JSON.stringify({
type: 'chunk',
text: chunk.text,
done: chunk.done
})}\n\n`);
}
if (chunk.done) {
// Send final message with complete response
res.write(`data: ${JSON.stringify({
type: 'done',
session: {
id: sessionId,
messages: session.messages,
sources: sourcesForResponse
}
})}\n\n`);
res.end();
}
});
return; // Early return for streaming
} else {
// Fallback for non-streaming response
const assistantMessage: ChatMessage = {
role: 'assistant',
content: aiResponse.text,
timestamp: new Date()
};
session.messages.push(assistantMessage);
// Send complete response
res.write(`data: ${JSON.stringify({
type: 'done',
session: {
id: sessionId,
messages: session.messages,
sources: sourcesForResponse
}
})}\n\n`);
res.end();
return;
}
} catch (error: any) {
// Send error in streaming format
res.write(`data: ${JSON.stringify({
type: 'error',
error: `AI service error: ${error.message}`
})}\n\n`);
res.end();
return;
}
}
// Non-streaming API continues with normal JSON response...
// Get chat session
let session = sessions.get(sessionId);
if (!session) {
const newSession = await createSession(req, res);
if (!newSession) {
throw new Error('Failed to create session');
}
// Add required properties to match ChatSession interface
session = {
...newSession,
messages: [],
lastActive: new Date(),
metadata: {}
// Configure chat options from session metadata
const chatOptions: ChatCompletionOptions = {
temperature: session.metadata.temperature || 0.7,
maxTokens: session.metadata.maxTokens,
model: session.metadata.model
// 'provider' property has been removed as it's not in the ChatCompletionOptions type
};
sessions.set(sessionId, session);
// Get streaming response if requested
const acceptHeader = req.get('Accept');
if (acceptHeader && acceptHeader.includes('text/event-stream')) {
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
let messageContent = '';
// Stream the response
await service.sendChatCompletion(
aiMessages,
chatOptions,
(chunk: string) => {
messageContent += chunk;
res.write(`data: ${JSON.stringify({ content: chunk })}\n\n`);
}
);
// Close the stream
res.write('data: [DONE]\n\n');
res.end();
// Store the full response
aiResponse = messageContent;
} else {
// Non-streaming approach
aiResponse = await service.sendChatCompletion(aiMessages, chatOptions);
}
} else {
// Original approach - find relevant notes through direct embedding comparison
const relevantNotes = await findRelevantNotes(
content,
session.noteContext || null,
5
);
sourceNotes = relevantNotes;
// Build context from relevant notes
const context = buildContextFromNotes(relevantNotes, content);
// Add system message with the context
const contextMessage: Message = {
role: 'system',
content: context
};
// Format all messages for the AI
const aiMessages: Message[] = [
contextMessage,
...session.messages.slice(-10).map(msg => ({
role: msg.role,
content: msg.content
}))
];
// Configure chat options from session metadata
const chatOptions: ChatCompletionOptions = {
temperature: session.metadata.temperature || 0.7,
maxTokens: session.metadata.maxTokens,
model: session.metadata.model
// 'provider' property has been removed as it's not in the ChatCompletionOptions type
};
// Get streaming response if requested
const acceptHeader = req.get('Accept');
if (acceptHeader && acceptHeader.includes('text/event-stream')) {
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
let messageContent = '';
// Stream the response
await service.sendChatCompletion(
aiMessages,
chatOptions,
(chunk: string) => {
messageContent += chunk;
res.write(`data: ${JSON.stringify({ content: chunk })}\n\n`);
}
);
// Close the stream
res.write('data: [DONE]\n\n');
res.end();
// Store the full response
aiResponse = messageContent;
} else {
// Non-streaming approach
aiResponse = await service.sendChatCompletion(aiMessages, chatOptions);
}
}
// Add user message to session
const userMessage: ChatMessage = {
role: 'user',
content: content,
timestamp: new Date()
};
console.log("Created user message:", {
role: userMessage.role,
contentLength: userMessage.content?.length || 0,
contentPreview: userMessage.content?.substring(0, 50) + (userMessage.content?.length > 50 ? '...' : '') || 'undefined'
});
session.messages.push(userMessage);
// Get context for query
const sources = await findRelevantNotes(content, session.noteContext || null);
// Format messages for AI with proper type casting
const aiMessages: Message[] = [
{ role: 'system', content: 'You are a helpful assistant for Trilium Notes. When providing answers, use only the context provided in the notes. If the information is not in the notes, say so.' },
{ role: 'user', content: buildContextFromNotes(sources, content) }
];
// Ensure we're not sending empty content
console.log("Final message content length:", aiMessages[1].content.length);
console.log("Final message content preview:", aiMessages[1].content.substring(0, 100));
try {
// Get AI response using the safe accessor methods
const aiResponse = await aiServiceManagerModule.default.generateChatCompletion(aiMessages, {
temperature,
maxTokens,
model: provider ? `${provider}:${model}` : model,
stream: false
// Only store the assistant's message if we're not streaming (otherwise we already did)
const acceptHeader = req.get('Accept');
if (!acceptHeader || !acceptHeader.includes('text/event-stream')) {
// Store the assistant's response in the session
session.messages.push({
role: 'assistant',
content: aiResponse,
timestamp: new Date()
});
// Add assistant message to session
const assistantMessage: ChatMessage = {
role: 'assistant',
content: aiResponse.text,
timestamp: new Date()
};
session.messages.push(assistantMessage);
// Format sources for the response (without content to reduce payload size)
const sourcesForResponse = sources.map(({ noteId, title, similarity, branchId }) => ({
noteId,
title,
similarity: similarity ? Math.round(similarity * 100) / 100 : undefined,
branchId
}));
// Return the response
return {
id: sessionId,
messages: session.messages,
sources: sourcesForResponse,
provider: aiResponse.provider,
model: aiResponse.model
content: aiResponse,
sources: sourceNotes.map(note => ({
noteId: note.noteId,
title: note.title,
similarity: note.similarity,
branchId: note.branchId
}))
};
} catch (error: any) {
log.error(`AI service error: ${error.message}`);
throw new Error(`AI service error: ${error.message}`);
} else {
// For streaming responses, we've already sent the data
// But we still need to add the message to the session
session.messages.push({
role: 'assistant',
content: aiResponse,
timestamp: new Date()
});
}
} catch (error: any) {
log.error(`Error sending message: ${error.message}`);
throw error;
log.error(`Error sending message to LLM: ${error.message}`);
throw new Error(`Failed to send message: ${error.message}`);
}
}