mirror of
https://github.com/zadam/trilium.git
synced 2025-11-01 19:05:59 +01:00
302 lines
12 KiB
TypeScript
302 lines
12 KiB
TypeScript
import sql from "../../sql.js";
|
|
import { randomString } from "../../../services/utils.js";
|
|
import dateUtils from "../../../services/date_utils.js";
|
|
import log from "../../log.js";
|
|
import { embeddingToBuffer, bufferToEmbedding, cosineSimilarity, enhancedCosineSimilarity, selectOptimalEmbedding, adaptEmbeddingDimensions } from "./vector_utils.js";
|
|
import type { EmbeddingResult } from "./types.js";
|
|
import entityChangesService from "../../../services/entity_changes.js";
|
|
import type { EntityChange } from "../../../services/entity_changes_interface.js";
|
|
|
|
/**
|
|
* Creates or updates an embedding for a note
|
|
*/
|
|
export async function storeNoteEmbedding(
|
|
noteId: string,
|
|
providerId: string,
|
|
modelId: string,
|
|
embedding: Float32Array
|
|
): Promise<string> {
|
|
const dimension = embedding.length;
|
|
const embeddingBlob = embeddingToBuffer(embedding);
|
|
const now = dateUtils.localNowDateTime();
|
|
const utcNow = dateUtils.utcNowDateTime();
|
|
|
|
// Check if an embedding already exists for this note and provider/model
|
|
const existingEmbed = await getEmbeddingForNote(noteId, providerId, modelId);
|
|
let embedId;
|
|
|
|
if (existingEmbed) {
|
|
// Update existing embedding
|
|
embedId = existingEmbed.embedId;
|
|
await sql.execute(`
|
|
UPDATE note_embeddings
|
|
SET embedding = ?, dimension = ?, version = version + 1,
|
|
dateModified = ?, utcDateModified = ?
|
|
WHERE embedId = ?`,
|
|
[embeddingBlob, dimension, now, utcNow, embedId]
|
|
);
|
|
} else {
|
|
// Create new embedding
|
|
embedId = randomString(16);
|
|
await sql.execute(`
|
|
INSERT INTO note_embeddings
|
|
(embedId, noteId, providerId, modelId, dimension, embedding,
|
|
dateCreated, utcDateCreated, dateModified, utcDateModified)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
[embedId, noteId, providerId, modelId, dimension, embeddingBlob,
|
|
now, utcNow, now, utcNow]
|
|
);
|
|
}
|
|
|
|
// Create entity change record for syncing
|
|
interface EmbeddingRow {
|
|
embedId: string;
|
|
noteId: string;
|
|
providerId: string;
|
|
modelId: string;
|
|
dimension: number;
|
|
version: number;
|
|
dateCreated: string;
|
|
utcDateCreated: string;
|
|
dateModified: string;
|
|
utcDateModified: string;
|
|
}
|
|
|
|
const row = await sql.getRow<EmbeddingRow>(`
|
|
SELECT embedId, noteId, providerId, modelId, dimension, version,
|
|
dateCreated, utcDateCreated, dateModified, utcDateModified
|
|
FROM note_embeddings
|
|
WHERE embedId = ?`,
|
|
[embedId]
|
|
);
|
|
|
|
if (row) {
|
|
// Skip the actual embedding data for the hash since it's large
|
|
const ec: EntityChange = {
|
|
entityName: "note_embeddings",
|
|
entityId: embedId,
|
|
hash: `${row.noteId}|${row.providerId}|${row.modelId}|${row.dimension}|${row.version}|${row.utcDateModified}`,
|
|
utcDateChanged: row.utcDateModified,
|
|
isSynced: true,
|
|
isErased: false
|
|
};
|
|
|
|
entityChangesService.putEntityChange(ec);
|
|
}
|
|
|
|
return embedId;
|
|
}
|
|
|
|
/**
|
|
* Retrieves embedding for a specific note
|
|
*/
|
|
export async function getEmbeddingForNote(noteId: string, providerId: string, modelId: string): Promise<EmbeddingResult | null> {
|
|
const row = await sql.getRow(`
|
|
SELECT embedId, noteId, providerId, modelId, dimension, embedding, version,
|
|
dateCreated, utcDateCreated, dateModified, utcDateModified
|
|
FROM note_embeddings
|
|
WHERE noteId = ? AND providerId = ? AND modelId = ?`,
|
|
[noteId, providerId, modelId]
|
|
);
|
|
|
|
if (!row) {
|
|
return null;
|
|
}
|
|
|
|
// Need to cast row to any as it doesn't have type information
|
|
const rowData = row as any;
|
|
|
|
return {
|
|
...rowData,
|
|
embedding: bufferToEmbedding(rowData.embedding, rowData.dimension)
|
|
};
|
|
}
|
|
|
|
/**
|
|
* Finds similar notes based on vector similarity
|
|
*/
|
|
export async function findSimilarNotes(
|
|
embedding: Float32Array,
|
|
providerId: string,
|
|
modelId: string,
|
|
limit = 10,
|
|
threshold?: number, // Made optional to use constants
|
|
useFallback = true // Whether to try other providers if no embeddings found
|
|
): Promise<{noteId: string, similarity: number}[]> {
|
|
// Import constants dynamically to avoid circular dependencies
|
|
const llmModule = await import('../../../routes/api/llm.js');
|
|
// Use a default threshold of 0.65 if not provided
|
|
const actualThreshold = threshold || 0.65;
|
|
|
|
try {
|
|
log.info(`Finding similar notes with provider: ${providerId}, model: ${modelId}, dimension: ${embedding.length}, threshold: ${actualThreshold}`);
|
|
|
|
// First try to find embeddings for the exact provider and model
|
|
const embeddings = await sql.getRows(`
|
|
SELECT ne.embedId, ne.noteId, ne.providerId, ne.modelId, ne.dimension, ne.embedding,
|
|
n.isDeleted, n.title, n.type, n.mime
|
|
FROM note_embeddings ne
|
|
JOIN notes n ON ne.noteId = n.noteId
|
|
WHERE ne.providerId = ? AND ne.modelId = ? AND n.isDeleted = 0
|
|
`, [providerId, modelId]);
|
|
|
|
if (embeddings && embeddings.length > 0) {
|
|
log.info(`Found ${embeddings.length} embeddings for provider ${providerId}, model ${modelId}`);
|
|
return await processEmbeddings(embedding, embeddings, actualThreshold, limit);
|
|
}
|
|
|
|
// If no embeddings found and fallback is allowed, try other providers
|
|
if (useFallback) {
|
|
log.info(`No embeddings found for ${providerId}/${modelId}, trying fallback providers`);
|
|
|
|
// Define the type for embedding metadata
|
|
interface EmbeddingMetadata {
|
|
providerId: string;
|
|
modelId: string;
|
|
count: number;
|
|
dimension: number;
|
|
}
|
|
|
|
// Get all available embedding metadata
|
|
const availableEmbeddings = await sql.getRows(`
|
|
SELECT DISTINCT providerId, modelId, COUNT(*) as count, dimension
|
|
FROM note_embeddings
|
|
GROUP BY providerId, modelId
|
|
ORDER BY dimension DESC, count DESC
|
|
`) as EmbeddingMetadata[];
|
|
|
|
if (availableEmbeddings.length > 0) {
|
|
log.info(`Available embeddings: ${JSON.stringify(availableEmbeddings.map(e => ({
|
|
providerId: e.providerId,
|
|
modelId: e.modelId,
|
|
count: e.count,
|
|
dimension: e.dimension
|
|
})))}`);
|
|
|
|
// Import the vector utils
|
|
const { selectOptimalEmbedding } = await import('./vector_utils.js');
|
|
|
|
// Get user dimension strategy preference
|
|
const options = (await import('../../options.js')).default;
|
|
const dimensionStrategy = await options.getOption('embeddingDimensionStrategy') || 'native';
|
|
log.info(`Using embedding dimension strategy: ${dimensionStrategy}`);
|
|
|
|
// Find the best alternative based on highest dimension for 'native' strategy
|
|
if (dimensionStrategy === 'native') {
|
|
const bestAlternative = selectOptimalEmbedding(availableEmbeddings);
|
|
|
|
if (bestAlternative) {
|
|
log.info(`Using highest-dimension fallback: ${bestAlternative.providerId}/${bestAlternative.modelId} (${bestAlternative.dimension}D)`);
|
|
|
|
// Get embeddings for this provider/model
|
|
const alternativeEmbeddings = await sql.getRows(`
|
|
SELECT ne.embedId, ne.noteId, ne.providerId, ne.modelId, ne.dimension, ne.embedding,
|
|
n.isDeleted, n.title, n.type, n.mime
|
|
FROM note_embeddings ne
|
|
JOIN notes n ON ne.noteId = n.noteId
|
|
WHERE ne.providerId = ? AND ne.modelId = ? AND n.isDeleted = 0
|
|
`, [bestAlternative.providerId, bestAlternative.modelId]);
|
|
|
|
if (alternativeEmbeddings && alternativeEmbeddings.length > 0) {
|
|
return await processEmbeddings(embedding, alternativeEmbeddings, actualThreshold, limit);
|
|
}
|
|
}
|
|
} else {
|
|
// Use dedicated embedding provider precedence from options for other strategies
|
|
let preferredProviders: string[] = [];
|
|
const embeddingPrecedence = await options.getOption('embeddingProviderPrecedence');
|
|
|
|
if (embeddingPrecedence) {
|
|
// For "comma,separated,values"
|
|
if (embeddingPrecedence.includes(',')) {
|
|
preferredProviders = embeddingPrecedence.split(',').map(p => p.trim());
|
|
}
|
|
// For JSON array ["value1", "value2"]
|
|
else if (embeddingPrecedence.startsWith('[') && embeddingPrecedence.endsWith(']')) {
|
|
try {
|
|
preferredProviders = JSON.parse(embeddingPrecedence);
|
|
} catch (e) {
|
|
log.error(`Error parsing embedding precedence: ${e}`);
|
|
preferredProviders = [embeddingPrecedence]; // Fallback to using as single value
|
|
}
|
|
}
|
|
// For a single value
|
|
else {
|
|
preferredProviders = [embeddingPrecedence];
|
|
}
|
|
}
|
|
|
|
log.info(`Using provider precedence: ${preferredProviders.join(', ')}`);
|
|
|
|
// Try providers in precedence order
|
|
for (const provider of preferredProviders) {
|
|
const providerEmbeddings = availableEmbeddings.filter(e => e.providerId === provider);
|
|
|
|
if (providerEmbeddings.length > 0) {
|
|
// Choose the model with the most embeddings
|
|
const bestModel = providerEmbeddings.sort((a, b) => b.count - a.count)[0];
|
|
log.info(`Found fallback provider: ${provider}, model: ${bestModel.modelId}, dimension: ${bestModel.dimension}`);
|
|
|
|
// The 'regenerate' strategy would go here if needed
|
|
// We're no longer supporting the 'adapt' strategy
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
log.info('No suitable fallback embeddings found, returning empty results');
|
|
}
|
|
|
|
return [];
|
|
} catch (error) {
|
|
log.error(`Error finding similar notes: ${error}`);
|
|
return [];
|
|
}
|
|
}
|
|
|
|
// Helper function to process embeddings and calculate similarities
|
|
async function processEmbeddings(queryEmbedding: Float32Array, embeddings: any[], threshold: number, limit: number) {
|
|
const { enhancedCosineSimilarity, bufferToEmbedding } = await import('./vector_utils.js');
|
|
const similarities = [];
|
|
|
|
for (const e of embeddings) {
|
|
const embVector = bufferToEmbedding(e.embedding, e.dimension);
|
|
const similarity = enhancedCosineSimilarity(queryEmbedding, embVector);
|
|
|
|
if (similarity >= threshold) {
|
|
similarities.push({
|
|
noteId: e.noteId,
|
|
similarity: similarity
|
|
});
|
|
}
|
|
}
|
|
|
|
return similarities
|
|
.sort((a, b) => b.similarity - a.similarity)
|
|
.slice(0, limit);
|
|
}
|
|
|
|
/**
|
|
* Delete embeddings for a note
|
|
*
|
|
* @param noteId - The ID of the note
|
|
* @param providerId - Optional provider ID to delete embeddings only for a specific provider
|
|
* @param modelId - Optional model ID to delete embeddings only for a specific model
|
|
*/
|
|
export async function deleteNoteEmbeddings(noteId: string, providerId?: string, modelId?: string) {
|
|
let query = "DELETE FROM note_embeddings WHERE noteId = ?";
|
|
const params: any[] = [noteId];
|
|
|
|
if (providerId) {
|
|
query += " AND providerId = ?";
|
|
params.push(providerId);
|
|
|
|
if (modelId) {
|
|
query += " AND modelId = ?";
|
|
params.push(modelId);
|
|
}
|
|
}
|
|
|
|
await sql.execute(query, params);
|
|
}
|