mirror of
https://github.com/zadam/trilium.git
synced 2025-11-15 17:55:52 +01:00
set up more reasonable context window and dimension sizes
This commit is contained in:
@@ -1,22 +1,212 @@
|
||||
import type { EmbeddingProvider, EmbeddingConfig, NoteEmbeddingContext } from './embeddings_interface.js';
|
||||
import log from "../../log.js";
|
||||
import { LLM_CONSTANTS } from "../../../routes/api/llm.js";
|
||||
import options from "../../options.js";
|
||||
|
||||
/**
|
||||
* Base class that implements common functionality for embedding providers
|
||||
*/
|
||||
export abstract class BaseEmbeddingProvider implements EmbeddingProvider {
|
||||
abstract name: string;
|
||||
name: string = "base";
|
||||
protected config: EmbeddingConfig;
|
||||
protected apiKey?: string;
|
||||
protected baseUrl: string;
|
||||
protected modelInfoCache = new Map<string, any>();
|
||||
|
||||
constructor(config: EmbeddingConfig) {
|
||||
this.config = config;
|
||||
this.apiKey = config.apiKey;
|
||||
this.baseUrl = config.baseUrl || "";
|
||||
}
|
||||
|
||||
getConfig(): EmbeddingConfig {
|
||||
return this.config;
|
||||
return { ...this.config };
|
||||
}
|
||||
|
||||
getDimension(): number {
|
||||
return this.config.dimension;
|
||||
}
|
||||
|
||||
async initialize(): Promise<void> {
|
||||
// Default implementation does nothing
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embeddings for a single text
|
||||
*/
|
||||
abstract generateEmbeddings(text: string): Promise<Float32Array>;
|
||||
abstract generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]>;
|
||||
|
||||
/**
|
||||
* Get the appropriate batch size for this provider
|
||||
* Override in provider implementations if needed
|
||||
*/
|
||||
protected async getBatchSize(): Promise<number> {
|
||||
// Try to get the user-configured batch size
|
||||
let configuredBatchSize: number | null = null;
|
||||
|
||||
try {
|
||||
const batchSizeStr = await options.getOption('embeddingBatchSize');
|
||||
if (batchSizeStr) {
|
||||
configuredBatchSize = parseInt(batchSizeStr, 10);
|
||||
}
|
||||
} catch (error) {
|
||||
log.error(`Error getting batch size from options: ${error}`);
|
||||
}
|
||||
|
||||
// If user has configured a specific batch size, use that
|
||||
if (configuredBatchSize && !isNaN(configuredBatchSize) && configuredBatchSize > 0) {
|
||||
return configuredBatchSize;
|
||||
}
|
||||
|
||||
// Otherwise use the provider-specific default from constants
|
||||
return this.config.batchSize ||
|
||||
LLM_CONSTANTS.BATCH_SIZE[this.name.toUpperCase() as keyof typeof LLM_CONSTANTS.BATCH_SIZE] ||
|
||||
LLM_CONSTANTS.BATCH_SIZE.DEFAULT;
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a batch of texts with adaptive handling
|
||||
* This method will try to process the batch and reduce batch size if encountering errors
|
||||
*/
|
||||
protected async processWithAdaptiveBatch<T>(
|
||||
items: T[],
|
||||
processFn: (batch: T[]) => Promise<any[]>,
|
||||
isBatchSizeError: (error: any) => boolean
|
||||
): Promise<any[]> {
|
||||
const results: any[] = [];
|
||||
const failures: { index: number, error: string }[] = [];
|
||||
let currentBatchSize = await this.getBatchSize();
|
||||
let lastError: Error | null = null;
|
||||
|
||||
// Process items in batches
|
||||
for (let i = 0; i < items.length;) {
|
||||
const batch = items.slice(i, i + currentBatchSize);
|
||||
|
||||
try {
|
||||
// Process the current batch
|
||||
const batchResults = await processFn(batch);
|
||||
results.push(...batchResults);
|
||||
i += batch.length;
|
||||
}
|
||||
catch (error: any) {
|
||||
lastError = error;
|
||||
const errorMessage = error.message || 'Unknown error';
|
||||
|
||||
// Check if this is a batch size related error
|
||||
if (isBatchSizeError(error) && currentBatchSize > 1) {
|
||||
// Reduce batch size and retry
|
||||
const newBatchSize = Math.max(1, Math.floor(currentBatchSize / 2));
|
||||
console.warn(`Batch size error detected, reducing batch size from ${currentBatchSize} to ${newBatchSize}: ${errorMessage}`);
|
||||
currentBatchSize = newBatchSize;
|
||||
}
|
||||
else if (currentBatchSize === 1) {
|
||||
// If we're already at batch size 1, we can't reduce further, so log the error and skip this item
|
||||
log.error(`Error processing item at index ${i} with batch size 1: ${errorMessage}`);
|
||||
failures.push({ index: i, error: errorMessage });
|
||||
i++; // Move to the next item
|
||||
}
|
||||
else {
|
||||
// For other errors, retry with a smaller batch size as a precaution
|
||||
const newBatchSize = Math.max(1, Math.floor(currentBatchSize / 2));
|
||||
console.warn(`Error processing batch, reducing batch size from ${currentBatchSize} to ${newBatchSize} as a precaution: ${errorMessage}`);
|
||||
currentBatchSize = newBatchSize;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If all items failed and we have a last error, throw it
|
||||
if (results.length === 0 && failures.length > 0 && lastError) {
|
||||
throw lastError;
|
||||
}
|
||||
|
||||
// If some items failed but others succeeded, log the summary
|
||||
if (failures.length > 0) {
|
||||
console.warn(`Processed ${results.length} items successfully, but ${failures.length} items failed`);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect if an error is related to batch size limits
|
||||
* Override in provider-specific implementations
|
||||
*/
|
||||
protected isBatchSizeError(error: any): boolean {
|
||||
const errorMessage = error?.message || '';
|
||||
const batchSizeErrorPatterns = [
|
||||
'batch size', 'too many items', 'too many inputs',
|
||||
'input too large', 'payload too large', 'context length',
|
||||
'token limit', 'rate limit', 'request too large'
|
||||
];
|
||||
|
||||
return batchSizeErrorPatterns.some(pattern =>
|
||||
errorMessage.toLowerCase().includes(pattern.toLowerCase())
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embeddings for multiple texts
|
||||
* Default implementation processes texts one by one
|
||||
*/
|
||||
async generateBatchEmbeddings(texts: string[]): Promise<Float32Array[]> {
|
||||
if (texts.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
return await this.processWithAdaptiveBatch(
|
||||
texts,
|
||||
async (batch) => {
|
||||
const batchResults = await Promise.all(
|
||||
batch.map(text => this.generateEmbeddings(text))
|
||||
);
|
||||
return batchResults;
|
||||
},
|
||||
this.isBatchSizeError
|
||||
);
|
||||
}
|
||||
catch (error: any) {
|
||||
const errorMessage = error.message || "Unknown error";
|
||||
log.error(`Batch embedding error for provider ${this.name}: ${errorMessage}`);
|
||||
throw new Error(`${this.name} batch embedding error: ${errorMessage}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embeddings for a note with its context
|
||||
*/
|
||||
async generateNoteEmbeddings(context: NoteEmbeddingContext): Promise<Float32Array> {
|
||||
const text = [context.title || "", context.content || ""].filter(Boolean).join(" ");
|
||||
return this.generateEmbeddings(text);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embeddings for multiple notes with their contexts
|
||||
*/
|
||||
async generateBatchNoteEmbeddings(contexts: NoteEmbeddingContext[]): Promise<Float32Array[]> {
|
||||
if (contexts.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
return await this.processWithAdaptiveBatch(
|
||||
contexts,
|
||||
async (batch) => {
|
||||
const batchResults = await Promise.all(
|
||||
batch.map(context => this.generateNoteEmbeddings(context))
|
||||
);
|
||||
return batchResults;
|
||||
},
|
||||
this.isBatchSizeError
|
||||
);
|
||||
}
|
||||
catch (error: any) {
|
||||
const errorMessage = error.message || "Unknown error";
|
||||
log.error(`Batch note embedding error for provider ${this.name}: ${errorMessage}`);
|
||||
throw new Error(`${this.name} batch note embedding error: ${errorMessage}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleans and normalizes text for embeddings by removing excessive whitespace
|
||||
@@ -157,20 +347,4 @@ export abstract class BaseEmbeddingProvider implements EmbeddingProvider {
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Default implementation that converts note context to text and generates embeddings
|
||||
*/
|
||||
async generateNoteEmbeddings(context: NoteEmbeddingContext): Promise<Float32Array> {
|
||||
const text = this.generateNoteContextText(context);
|
||||
return this.generateEmbeddings(text);
|
||||
}
|
||||
|
||||
/**
|
||||
* Default implementation that processes notes in batch
|
||||
*/
|
||||
async generateBatchNoteEmbeddings(contexts: NoteEmbeddingContext[]): Promise<Float32Array[]> {
|
||||
const texts = contexts.map(ctx => this.generateNoteContextText(ctx));
|
||||
return this.generateBatchEmbeddings(texts);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user