mirror of
https://github.com/zadam/trilium.git
synced 2025-11-01 19:05:59 +01:00
rip out openai custom implementation in favor of sdk
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import options from '../../options.js';
|
||||
import { BaseAIService } from '../base_ai_service.js';
|
||||
import type { ChatCompletionOptions, ChatResponse, Message } from '../ai_interface.js';
|
||||
import { PROVIDER_CONSTANTS } from '../constants/provider_constants.js';
|
||||
import type { OpenAIOptions } from './provider_options.js';
|
||||
import { getOpenAIOptions } from './providers.js';
|
||||
import OpenAI from 'openai';
|
||||
|
||||
export class OpenAIService extends BaseAIService {
|
||||
private openai: OpenAI | null = null;
|
||||
|
||||
constructor() {
|
||||
super('OpenAI');
|
||||
}
|
||||
@@ -14,6 +15,16 @@ export class OpenAIService extends BaseAIService {
|
||||
return super.isAvailable() && !!options.getOption('openaiApiKey');
|
||||
}
|
||||
|
||||
private getClient(apiKey: string, baseUrl?: string): OpenAI {
|
||||
if (!this.openai) {
|
||||
this.openai = new OpenAI({
|
||||
apiKey,
|
||||
baseURL: baseUrl
|
||||
});
|
||||
}
|
||||
return this.openai;
|
||||
}
|
||||
|
||||
async generateChatCompletion(messages: Message[], opts: ChatCompletionOptions = {}): Promise<ChatResponse> {
|
||||
if (!this.isAvailable()) {
|
||||
throw new Error('OpenAI service is not available. Check API key and AI settings.');
|
||||
@@ -21,6 +32,9 @@ export class OpenAIService extends BaseAIService {
|
||||
|
||||
// Get provider-specific options from the central provider manager
|
||||
const providerOptions = getOpenAIOptions(opts);
|
||||
|
||||
// Initialize the OpenAI client
|
||||
const client = this.getClient(providerOptions.apiKey, providerOptions.baseUrl);
|
||||
|
||||
const systemPrompt = this.getSystemPrompt(providerOptions.systemPrompt || options.getOption('aiSystemPrompt'));
|
||||
|
||||
@@ -31,20 +45,10 @@ export class OpenAIService extends BaseAIService {
|
||||
: [{ role: 'system', content: systemPrompt }, ...messages];
|
||||
|
||||
try {
|
||||
// Fix endpoint construction - ensure we don't double up on /v1
|
||||
const normalizedBaseUrl = providerOptions.baseUrl.replace(/\/+$/, '');
|
||||
const endpoint = normalizedBaseUrl.includes('/v1')
|
||||
? `${normalizedBaseUrl}/chat/completions`
|
||||
: `${normalizedBaseUrl}/v1/chat/completions`;
|
||||
|
||||
// Create request body directly from provider options
|
||||
const requestBody: any = {
|
||||
// Create params object for the OpenAI SDK
|
||||
const params: OpenAI.Chat.ChatCompletionCreateParams = {
|
||||
model: providerOptions.model,
|
||||
messages: messagesWithSystem,
|
||||
};
|
||||
|
||||
// Extract API parameters from provider options
|
||||
const apiParams = {
|
||||
messages: messagesWithSystem as OpenAI.Chat.ChatCompletionMessageParam[],
|
||||
temperature: providerOptions.temperature,
|
||||
max_tokens: providerOptions.max_tokens,
|
||||
stream: providerOptions.stream,
|
||||
@@ -53,51 +57,138 @@ export class OpenAIService extends BaseAIService {
|
||||
presence_penalty: providerOptions.presence_penalty
|
||||
};
|
||||
|
||||
|
||||
|
||||
// Merge API parameters, filtering out undefined values
|
||||
Object.entries(apiParams).forEach(([key, value]) => {
|
||||
if (value !== undefined) {
|
||||
requestBody[key] = value;
|
||||
}
|
||||
});
|
||||
|
||||
// Add tools if enabled
|
||||
if (providerOptions.enableTools && providerOptions.tools && providerOptions.tools.length > 0) {
|
||||
requestBody.tools = providerOptions.tools;
|
||||
params.tools = providerOptions.tools as OpenAI.Chat.ChatCompletionTool[];
|
||||
}
|
||||
|
||||
if (providerOptions.tool_choice) {
|
||||
requestBody.tool_choice = providerOptions.tool_choice;
|
||||
params.tool_choice = providerOptions.tool_choice as OpenAI.Chat.ChatCompletionToolChoiceOption;
|
||||
}
|
||||
|
||||
const response = await fetch(endpoint, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': `Bearer ${providerOptions.apiKey}`
|
||||
},
|
||||
body: JSON.stringify(requestBody)
|
||||
});
|
||||
// If streaming is requested
|
||||
if (providerOptions.stream) {
|
||||
params.stream = true;
|
||||
|
||||
const stream = await client.chat.completions.create(params);
|
||||
let fullText = '';
|
||||
|
||||
// If a direct callback is provided, use it
|
||||
if (providerOptions.streamCallback) {
|
||||
// Process the stream with the callback
|
||||
try {
|
||||
// The stream is an AsyncIterable
|
||||
if (Symbol.asyncIterator in stream) {
|
||||
for await (const chunk of stream as AsyncIterable<OpenAI.Chat.ChatCompletionChunk>) {
|
||||
const content = chunk.choices[0]?.delta?.content || '';
|
||||
if (content) {
|
||||
fullText += content;
|
||||
await providerOptions.streamCallback(content, false, chunk);
|
||||
}
|
||||
|
||||
// If this is the last chunk
|
||||
if (chunk.choices[0]?.finish_reason) {
|
||||
await providerOptions.streamCallback('', true, chunk);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.error('Stream is not iterable, falling back to non-streaming response');
|
||||
|
||||
// If we get a non-streaming response somehow
|
||||
if ('choices' in stream) {
|
||||
const content = stream.choices[0]?.message?.content || '';
|
||||
fullText = content;
|
||||
if (providerOptions.streamCallback) {
|
||||
await providerOptions.streamCallback(content, true, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error processing stream:', error);
|
||||
throw error;
|
||||
}
|
||||
|
||||
return {
|
||||
text: fullText,
|
||||
model: params.model,
|
||||
provider: this.getName(),
|
||||
usage: {} // Usage stats aren't available with streaming
|
||||
};
|
||||
} else {
|
||||
// Use the more flexible stream interface
|
||||
return {
|
||||
text: '', // Initial empty text, will be filled by stream processing
|
||||
model: params.model,
|
||||
provider: this.getName(),
|
||||
usage: {}, // Usage stats aren't available with streaming
|
||||
stream: async (callback) => {
|
||||
let completeText = '';
|
||||
|
||||
try {
|
||||
// The stream is an AsyncIterable
|
||||
if (Symbol.asyncIterator in stream) {
|
||||
for await (const chunk of stream as AsyncIterable<OpenAI.Chat.ChatCompletionChunk>) {
|
||||
const content = chunk.choices[0]?.delta?.content || '';
|
||||
const isDone = !!chunk.choices[0]?.finish_reason;
|
||||
|
||||
if (content) {
|
||||
completeText += content;
|
||||
}
|
||||
|
||||
// Call the provided callback with the StreamChunk interface
|
||||
await callback({
|
||||
text: content,
|
||||
done: isDone
|
||||
});
|
||||
|
||||
if (isDone) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.warn('Stream is not iterable, falling back to non-streaming response');
|
||||
|
||||
// If we get a non-streaming response somehow
|
||||
if ('choices' in stream) {
|
||||
const content = stream.choices[0]?.message?.content || '';
|
||||
completeText = content;
|
||||
await callback({
|
||||
text: content,
|
||||
done: true
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error processing stream:', error);
|
||||
throw error;
|
||||
}
|
||||
|
||||
return completeText;
|
||||
}
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// Non-streaming response
|
||||
params.stream = false;
|
||||
|
||||
const completion = await client.chat.completions.create(params);
|
||||
|
||||
if (!('choices' in completion)) {
|
||||
throw new Error('Unexpected response format from OpenAI API');
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
const errorBody = await response.text();
|
||||
throw new Error(`OpenAI API error: ${response.status} ${response.statusText} - ${errorBody}`);
|
||||
return {
|
||||
text: completion.choices[0].message.content || '',
|
||||
model: completion.model,
|
||||
provider: this.getName(),
|
||||
usage: {
|
||||
promptTokens: completion.usage?.prompt_tokens,
|
||||
completionTokens: completion.usage?.completion_tokens,
|
||||
totalTokens: completion.usage?.total_tokens
|
||||
},
|
||||
tool_calls: completion.choices[0].message.tool_calls
|
||||
};
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
return {
|
||||
text: data.choices[0].message.content,
|
||||
model: data.model,
|
||||
provider: this.getName(),
|
||||
usage: {
|
||||
promptTokens: data.usage?.prompt_tokens,
|
||||
completionTokens: data.usage?.completion_tokens,
|
||||
totalTokens: data.usage?.total_tokens
|
||||
},
|
||||
tool_calls: data.choices[0].message.tool_calls
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('OpenAI service error:', error);
|
||||
throw error;
|
||||
|
||||
Reference in New Issue
Block a user