mirror of
https://github.com/zadam/trilium.git
synced 2025-11-01 19:05:59 +01:00
545 lines
21 KiB
TypeScript
545 lines
21 KiB
TypeScript
import sql from '../../sql.js'
|
|
import { randomString } from "../../../services/utils.js";
|
|
import dateUtils from "../../../services/date_utils.js";
|
|
import log from "../../log.js";
|
|
import { embeddingToBuffer, bufferToEmbedding, cosineSimilarity, enhancedCosineSimilarity, selectOptimalEmbedding, adaptEmbeddingDimensions } from "./vector_utils.js";
|
|
import type { EmbeddingResult } from "./types.js";
|
|
import entityChangesService from "../../../services/entity_changes.js";
|
|
import type { EntityChange } from "../../../services/entity_changes_interface.js";
|
|
import { EMBEDDING_CONSTANTS } from "../constants/embedding_constants.js";
|
|
import { SEARCH_CONSTANTS } from '../constants/search_constants.js';
|
|
import type { NoteEmbeddingContext } from "./embeddings_interface.js";
|
|
import becca from "../../../becca/becca.js";
|
|
import { isNoteExcludedFromAIById } from "../utils/ai_exclusion_utils.js";
|
|
import { getSelectedEmbeddingProvider } from '../config/configuration_helpers.js';
|
|
|
|
interface Similarity {
|
|
noteId: string;
|
|
similarity: number;
|
|
contentType: string;
|
|
bonuses?: Record<string, number>; // Optional for debugging
|
|
}
|
|
|
|
/**
|
|
* Creates or updates an embedding for a note
|
|
*/
|
|
export async function storeNoteEmbedding(
|
|
noteId: string,
|
|
providerId: string,
|
|
modelId: string,
|
|
embedding: Float32Array
|
|
): Promise<string> {
|
|
const dimension = embedding.length;
|
|
const embeddingBlob = embeddingToBuffer(embedding);
|
|
const now = dateUtils.localNowDateTime();
|
|
const utcNow = dateUtils.utcNowDateTime();
|
|
|
|
// Check if an embedding already exists for this note and provider/model
|
|
const existingEmbed = await getEmbeddingForNote(noteId, providerId, modelId);
|
|
let embedId;
|
|
|
|
if (existingEmbed) {
|
|
// Update existing embedding
|
|
embedId = existingEmbed.embedId;
|
|
await sql.execute(`
|
|
UPDATE note_embeddings
|
|
SET embedding = ?, dimension = ?, version = version + 1,
|
|
dateModified = ?, utcDateModified = ?
|
|
WHERE embedId = ?`,
|
|
[embeddingBlob, dimension, now, utcNow, embedId]
|
|
);
|
|
} else {
|
|
// Create new embedding
|
|
embedId = randomString(16);
|
|
await sql.execute(`
|
|
INSERT INTO note_embeddings
|
|
(embedId, noteId, providerId, modelId, dimension, embedding,
|
|
dateCreated, utcDateCreated, dateModified, utcDateModified)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
[embedId, noteId, providerId, modelId, dimension, embeddingBlob,
|
|
now, utcNow, now, utcNow]
|
|
);
|
|
}
|
|
|
|
// Create entity change record for syncing
|
|
interface EmbeddingRow {
|
|
embedId: string;
|
|
noteId: string;
|
|
providerId: string;
|
|
modelId: string;
|
|
dimension: number;
|
|
version: number;
|
|
dateCreated: string;
|
|
utcDateCreated: string;
|
|
dateModified: string;
|
|
utcDateModified: string;
|
|
}
|
|
|
|
const row = await sql.getRow<EmbeddingRow>(`
|
|
SELECT embedId, noteId, providerId, modelId, dimension, version,
|
|
dateCreated, utcDateCreated, dateModified, utcDateModified
|
|
FROM note_embeddings
|
|
WHERE embedId = ?`,
|
|
[embedId]
|
|
);
|
|
|
|
if (row) {
|
|
// Skip the actual embedding data for the hash since it's large
|
|
const ec: EntityChange = {
|
|
entityName: "note_embeddings",
|
|
entityId: embedId,
|
|
hash: `${row.noteId}|${row.providerId}|${row.modelId}|${row.dimension}|${row.version}|${row.utcDateModified}`,
|
|
utcDateChanged: row.utcDateModified,
|
|
isSynced: true,
|
|
isErased: false
|
|
};
|
|
|
|
entityChangesService.putEntityChange(ec);
|
|
}
|
|
|
|
return embedId;
|
|
}
|
|
|
|
/**
|
|
* Retrieves embedding for a specific note
|
|
*/
|
|
export async function getEmbeddingForNote(noteId: string, providerId: string, modelId: string): Promise<EmbeddingResult | null> {
|
|
const row = await sql.getRow(`
|
|
SELECT embedId, noteId, providerId, modelId, dimension, embedding, version,
|
|
dateCreated, utcDateCreated, dateModified, utcDateModified
|
|
FROM note_embeddings
|
|
WHERE noteId = ? AND providerId = ? AND modelId = ?`,
|
|
[noteId, providerId, modelId]
|
|
);
|
|
|
|
if (!row) {
|
|
return null;
|
|
}
|
|
|
|
// Need to cast row to any as it doesn't have type information
|
|
const rowData = row as any;
|
|
|
|
return {
|
|
...rowData,
|
|
embedding: bufferToEmbedding(rowData.embedding, rowData.dimension)
|
|
};
|
|
}
|
|
|
|
// Create an interface that represents the embedding row from the database
|
|
interface EmbeddingRow {
|
|
embedId: string;
|
|
noteId: string;
|
|
providerId: string;
|
|
modelId: string;
|
|
dimension: number;
|
|
embedding: Buffer;
|
|
title?: string;
|
|
type?: string;
|
|
mime?: string;
|
|
isDeleted?: number;
|
|
}
|
|
|
|
// Interface for enhanced embedding with query model information
|
|
interface EnhancedEmbeddingRow extends EmbeddingRow {
|
|
queryProviderId: string;
|
|
queryModelId: string;
|
|
}
|
|
|
|
/**
|
|
* Finds similar notes based on vector similarity
|
|
*/
|
|
export async function findSimilarNotes(
|
|
embedding: Float32Array,
|
|
providerId: string,
|
|
modelId: string,
|
|
limit = SEARCH_CONSTANTS.VECTOR_SEARCH.DEFAULT_MAX_RESULTS,
|
|
threshold?: number, // Made optional to use constants
|
|
useFallback = true // Whether to try other providers if no embeddings found
|
|
): Promise<{noteId: string, similarity: number, contentType?: string}[]> {
|
|
// Import constants dynamically to avoid circular dependencies
|
|
const llmModule = await import('../../../routes/api/llm.js');
|
|
// Use default threshold if not provided
|
|
const actualThreshold = threshold || SEARCH_CONSTANTS.VECTOR_SEARCH.EXACT_MATCH_THRESHOLD;
|
|
|
|
try {
|
|
log.info(`Finding similar notes with provider: ${providerId}, model: ${modelId}, dimension: ${embedding.length}, threshold: ${actualThreshold}`);
|
|
|
|
// First try to find embeddings for the exact provider and model
|
|
const embeddings = await sql.getRows(`
|
|
SELECT ne.embedId, ne.noteId, ne.providerId, ne.modelId, ne.dimension, ne.embedding,
|
|
n.isDeleted, n.title, n.type, n.mime
|
|
FROM note_embeddings ne
|
|
JOIN notes n ON ne.noteId = n.noteId
|
|
WHERE ne.providerId = ? AND ne.modelId = ? AND n.isDeleted = 0
|
|
`, [providerId, modelId]) as EmbeddingRow[];
|
|
|
|
if (embeddings && embeddings.length > 0) {
|
|
log.info(`Found ${embeddings.length} embeddings for provider ${providerId}, model ${modelId}`);
|
|
|
|
// Add query model information to each embedding for cross-model comparison
|
|
const enhancedEmbeddings: EnhancedEmbeddingRow[] = embeddings.map(e => {
|
|
return {
|
|
embedId: e.embedId,
|
|
noteId: e.noteId,
|
|
providerId: e.providerId,
|
|
modelId: e.modelId,
|
|
dimension: e.dimension,
|
|
embedding: e.embedding,
|
|
title: e.title,
|
|
type: e.type,
|
|
mime: e.mime,
|
|
isDeleted: e.isDeleted,
|
|
queryProviderId: providerId,
|
|
queryModelId: modelId
|
|
};
|
|
});
|
|
|
|
return await processEmbeddings(embedding, enhancedEmbeddings, actualThreshold, limit);
|
|
}
|
|
|
|
// If no embeddings found and fallback is allowed, try other providers
|
|
if (useFallback) {
|
|
log.info(`No embeddings found for ${providerId}/${modelId}, trying fallback providers`);
|
|
|
|
// Define the type for embedding metadata
|
|
interface EmbeddingMetadata {
|
|
providerId: string;
|
|
modelId: string;
|
|
count: number;
|
|
dimension: number;
|
|
}
|
|
|
|
// Get all available embedding metadata
|
|
const availableEmbeddings = await sql.getRows(`
|
|
SELECT DISTINCT providerId, modelId, COUNT(*) as count, dimension
|
|
FROM note_embeddings
|
|
GROUP BY providerId, modelId
|
|
ORDER BY dimension DESC, count DESC
|
|
`) as EmbeddingMetadata[];
|
|
|
|
if (availableEmbeddings.length > 0) {
|
|
log.info(`Available embeddings: ${JSON.stringify(availableEmbeddings.map(e => ({
|
|
providerId: e.providerId,
|
|
modelId: e.modelId,
|
|
count: e.count,
|
|
dimension: e.dimension
|
|
})))}`);
|
|
|
|
// Import the vector utils
|
|
const { selectOptimalEmbedding } = await import('./vector_utils.js');
|
|
|
|
// Get user dimension strategy preference
|
|
const options = (await import('../../options.js')).default;
|
|
const dimensionStrategy = await options.getOption('embeddingDimensionStrategy') || 'native';
|
|
log.info(`Using embedding dimension strategy: ${dimensionStrategy}`);
|
|
|
|
// Find the best alternative based on highest dimension for 'native' strategy
|
|
if (dimensionStrategy === 'native') {
|
|
const bestAlternative = selectOptimalEmbedding(availableEmbeddings);
|
|
|
|
if (bestAlternative) {
|
|
log.info(`Using highest-dimension fallback: ${bestAlternative.providerId}/${bestAlternative.modelId} (${bestAlternative.dimension}D)`);
|
|
|
|
// Get embeddings for this provider/model
|
|
const alternativeEmbeddings = await sql.getRows(`
|
|
SELECT ne.embedId, ne.noteId, ne.providerId, ne.modelId, ne.dimension, ne.embedding,
|
|
n.isDeleted, n.title, n.type, n.mime
|
|
FROM note_embeddings ne
|
|
JOIN notes n ON ne.noteId = n.noteId
|
|
WHERE ne.providerId = ? AND ne.modelId = ? AND n.isDeleted = 0
|
|
`, [bestAlternative.providerId, bestAlternative.modelId]) as EmbeddingRow[];
|
|
|
|
if (alternativeEmbeddings && alternativeEmbeddings.length > 0) {
|
|
// Add query model information to each embedding for cross-model comparison
|
|
const enhancedEmbeddings: EnhancedEmbeddingRow[] = alternativeEmbeddings.map(e => {
|
|
return {
|
|
embedId: e.embedId,
|
|
noteId: e.noteId,
|
|
providerId: e.providerId,
|
|
modelId: e.modelId,
|
|
dimension: e.dimension,
|
|
embedding: e.embedding,
|
|
title: e.title,
|
|
type: e.type,
|
|
mime: e.mime,
|
|
isDeleted: e.isDeleted,
|
|
queryProviderId: providerId,
|
|
queryModelId: modelId
|
|
};
|
|
});
|
|
|
|
return await processEmbeddings(embedding, enhancedEmbeddings, actualThreshold, limit);
|
|
}
|
|
}
|
|
} else {
|
|
// Try providers using the new configuration system
|
|
if (useFallback) {
|
|
log.info('No embeddings found for specified provider, trying fallback providers...');
|
|
|
|
// Use the new configuration system - no string parsing!
|
|
const selectedProvider = await getSelectedEmbeddingProvider();
|
|
const preferredProviders = selectedProvider ? [selectedProvider] : [];
|
|
|
|
log.info(`Using selected provider: ${selectedProvider || 'none'}`);
|
|
|
|
// Try providers in precedence order
|
|
for (const provider of preferredProviders) {
|
|
const providerEmbeddings = availableEmbeddings.filter(e => e.providerId === provider);
|
|
|
|
if (providerEmbeddings.length > 0) {
|
|
// Choose the model with the most embeddings
|
|
const bestModel = providerEmbeddings.sort((a, b) => b.count - a.count)[0];
|
|
log.info(`Found fallback provider: ${provider}, model: ${bestModel.modelId}, dimension: ${bestModel.dimension}`);
|
|
|
|
// The 'regenerate' strategy would go here if needed
|
|
// We're no longer supporting the 'adapt' strategy
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
log.info('No suitable fallback embeddings found, returning empty results');
|
|
}
|
|
|
|
return [];
|
|
} catch (error) {
|
|
log.error(`Error finding similar notes: ${error}`);
|
|
return [];
|
|
}
|
|
}
|
|
|
|
// Helper function to process embeddings and calculate similarities
|
|
async function processEmbeddings(queryEmbedding: Float32Array, embeddings: any[], threshold: number, limit: number) {
|
|
const {
|
|
enhancedCosineSimilarity,
|
|
bufferToEmbedding,
|
|
ContentType,
|
|
PerformanceProfile,
|
|
detectContentType,
|
|
vectorDebugConfig
|
|
} = await import('./vector_utils.js');
|
|
|
|
// Store original debug settings but keep debug disabled
|
|
const originalDebugEnabled = vectorDebugConfig.enabled;
|
|
const originalLogLevel = vectorDebugConfig.logLevel;
|
|
|
|
// Keep debug disabled for normal operation
|
|
vectorDebugConfig.enabled = false;
|
|
vectorDebugConfig.recordStats = false;
|
|
|
|
const options = (await import('../../options.js')).default;
|
|
|
|
// Define weighting factors with defaults that can be overridden by settings
|
|
interface SimilarityWeights {
|
|
exactTitleMatch: number;
|
|
titleContainsQuery: number;
|
|
partialTitleMatch: number;
|
|
// Add more weights as needed - examples:
|
|
sameType?: number;
|
|
attributeMatch?: number;
|
|
recentlyCreated?: number;
|
|
recentlyModified?: number;
|
|
}
|
|
|
|
// Default weights that match our previous hardcoded values
|
|
const defaultWeights: SimilarityWeights = {
|
|
exactTitleMatch: 0.3,
|
|
titleContainsQuery: 0.2,
|
|
partialTitleMatch: 0.1,
|
|
sameType: 0.05,
|
|
attributeMatch: 0.05,
|
|
recentlyCreated: 0.05,
|
|
recentlyModified: 0.05
|
|
};
|
|
|
|
// Get weights from options if they exist
|
|
const weights: SimilarityWeights = { ...defaultWeights };
|
|
try {
|
|
const customWeightsJSON = EMBEDDING_CONSTANTS;
|
|
if (customWeightsJSON) {
|
|
try {
|
|
const customWeights = EMBEDDING_CONSTANTS;
|
|
// Override defaults with any custom weights
|
|
Object.assign(weights, customWeights);
|
|
log.info(`Using custom similarity weights: ${JSON.stringify(weights)}`);
|
|
} catch (e) {
|
|
log.error(`Error parsing custom similarity weights: ${e}`);
|
|
}
|
|
}
|
|
} catch (e) {
|
|
// Use defaults if no custom weights
|
|
}
|
|
|
|
/**
|
|
* Calculate similarity bonuses based on various factors
|
|
*/
|
|
function calculateSimilarityBonuses(
|
|
embedding: any,
|
|
note: any,
|
|
queryText: string,
|
|
weights: SimilarityWeights
|
|
): { bonuses: Record<string, number>, totalBonus: number } {
|
|
const bonuses: Record<string, number> = {};
|
|
|
|
// Skip if we don't have query text
|
|
if (!queryText || !note.title) {
|
|
return { bonuses, totalBonus: 0 };
|
|
}
|
|
|
|
const titleLower = note.title.toLowerCase();
|
|
const queryLower = queryText.toLowerCase();
|
|
|
|
// 1. Exact title match
|
|
if (titleLower === queryLower) {
|
|
bonuses.exactTitleMatch = weights.exactTitleMatch;
|
|
}
|
|
// 2. Title contains the entire query
|
|
else if (titleLower.includes(queryLower)) {
|
|
bonuses.titleContainsQuery = weights.titleContainsQuery;
|
|
}
|
|
// 3. Partial term matching
|
|
else {
|
|
// Split query into terms and check if title contains them
|
|
const queryTerms = queryLower.split(/\s+/).filter((term: string) => term.length > 2);
|
|
let matchCount = 0;
|
|
|
|
for (const term of queryTerms) {
|
|
if (titleLower.includes(term)) {
|
|
matchCount++;
|
|
}
|
|
}
|
|
|
|
if (matchCount > 0 && queryTerms.length > 0) {
|
|
// Calculate proportion of matching terms and apply a scaled bonus
|
|
const matchProportion = matchCount / queryTerms.length;
|
|
bonuses.partialTitleMatch = weights.partialTitleMatch * matchProportion;
|
|
}
|
|
}
|
|
|
|
// 4. Add more factors as needed here
|
|
// Example: Same note type bonus
|
|
// if (note.type && weights.sameType) {
|
|
// // Note: This would need to be compared with the query context to be meaningful
|
|
// // For now, this is a placeholder for demonstration
|
|
// bonuses.sameType = weights.sameType;
|
|
// }
|
|
|
|
// Calculate total bonus
|
|
const totalBonus = Object.values(bonuses).reduce((sum, bonus) => sum + bonus, 0);
|
|
|
|
return { bonuses, totalBonus };
|
|
}
|
|
|
|
const similarities: Similarity[] = [];
|
|
|
|
try {
|
|
// Try to extract the original query text if it was added to the metadata
|
|
// This will help us determine title matches
|
|
const queryText = queryEmbedding.hasOwnProperty('originalQuery')
|
|
? (queryEmbedding as any).originalQuery
|
|
: '';
|
|
|
|
for (const e of embeddings) {
|
|
// Check if this note is excluded from AI features
|
|
if (isNoteExcludedFromAIById(e.noteId)) {
|
|
continue; // Skip this note if it has the AI exclusion label
|
|
}
|
|
|
|
const embVector = bufferToEmbedding(e.embedding, e.dimension);
|
|
|
|
// Detect content type from mime type if available
|
|
let contentType = ContentType.GENERAL_TEXT;
|
|
if (e.mime) {
|
|
contentType = detectContentType(e.mime);
|
|
// Debug logging removed to avoid console spam
|
|
}
|
|
|
|
// Select performance profile based on embedding size and use case
|
|
// For most similarity searches, BALANCED is a good default
|
|
const performanceProfile = PerformanceProfile.BALANCED;
|
|
|
|
// Determine if this is cross-model comparison
|
|
const isCrossModel = e.providerId !== e.queryProviderId || e.modelId !== e.queryModelId;
|
|
|
|
// Calculate similarity with content-aware parameters
|
|
let similarity = enhancedCosineSimilarity(
|
|
queryEmbedding,
|
|
embVector,
|
|
true, // normalize vectors to ensure consistent comparison
|
|
e.queryModelId, // source model ID
|
|
e.modelId, // target model ID
|
|
contentType, // content-specific padding strategy
|
|
performanceProfile
|
|
);
|
|
|
|
// Calculate and apply similarity bonuses
|
|
const { bonuses, totalBonus } = calculateSimilarityBonuses(
|
|
queryEmbedding,
|
|
e,
|
|
queryText,
|
|
weights
|
|
);
|
|
|
|
if (totalBonus > 0) {
|
|
similarity += totalBonus;
|
|
|
|
// Log significant bonuses for debugging
|
|
const significantBonuses = Object.entries(bonuses)
|
|
.filter(([_, value]) => value >= 0.05)
|
|
.map(([key, value]) => `${key}: +${value.toFixed(2)}`)
|
|
.join(', ');
|
|
|
|
if (significantBonuses) {
|
|
log.info(`Added bonuses for note "${e.title}" (${e.noteId}): ${significantBonuses}`);
|
|
}
|
|
|
|
// Cap similarity at 1.0 to maintain expected range
|
|
similarity = Math.min(similarity, 1.0);
|
|
}
|
|
|
|
if (similarity >= threshold) {
|
|
similarities.push({
|
|
noteId: e.noteId,
|
|
similarity: similarity,
|
|
contentType: contentType.toString(),
|
|
// Optionally include bonuses for debugging/analysis
|
|
// bonuses: bonuses
|
|
});
|
|
}
|
|
}
|
|
|
|
return similarities
|
|
.sort((a, b) => b.similarity - a.similarity)
|
|
.slice(0, limit);
|
|
} finally {
|
|
// Restore original debug settings
|
|
vectorDebugConfig.enabled = originalDebugEnabled;
|
|
vectorDebugConfig.logLevel = originalLogLevel;
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Delete embeddings for a note
|
|
*
|
|
* @param noteId - The ID of the note
|
|
* @param providerId - Optional provider ID to delete embeddings only for a specific provider
|
|
* @param modelId - Optional model ID to delete embeddings only for a specific model
|
|
*/
|
|
export async function deleteNoteEmbeddings(noteId: string, providerId?: string, modelId?: string) {
|
|
let query = "DELETE FROM note_embeddings WHERE noteId = ?";
|
|
const params: any[] = [noteId];
|
|
|
|
if (providerId) {
|
|
query += " AND providerId = ?";
|
|
params.push(providerId);
|
|
|
|
if (modelId) {
|
|
query += " AND modelId = ?";
|
|
params.push(modelId);
|
|
}
|
|
}
|
|
|
|
await sql.execute(query, params);
|
|
}
|