mirror of
https://github.com/zadam/trilium.git
synced 2026-05-07 10:05:40 +02:00
feat(llm): add stop generation button (#9341)
This commit is contained in:
@@ -27,7 +27,8 @@ export interface StreamCallbacks {
|
||||
export async function streamChatCompletion(
|
||||
messages: LlmMessage[],
|
||||
config: LlmChatConfig,
|
||||
callbacks: StreamCallbacks
|
||||
callbacks: StreamCallbacks,
|
||||
abortSignal?: AbortSignal
|
||||
): Promise<void> {
|
||||
const headers = await server.getHeaders();
|
||||
|
||||
@@ -37,7 +38,8 @@ export async function streamChatCompletion(
|
||||
...headers,
|
||||
"Content-Type": "application/json"
|
||||
} as HeadersInit,
|
||||
body: JSON.stringify({ messages, config })
|
||||
body: JSON.stringify({ messages, config }),
|
||||
signal: abortSignal
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
|
||||
@@ -1677,7 +1677,8 @@
|
||||
"note_context_enabled": "Click to disable note context: {{title}}",
|
||||
"note_context_disabled": "Click to include current note in context",
|
||||
"no_provider_message": "No AI provider configured. Add one to start chatting.",
|
||||
"add_provider": "Add AI Provider"
|
||||
"add_provider": "Add AI Provider",
|
||||
"stop": "Stop"
|
||||
},
|
||||
"sidebar_chat": {
|
||||
"title": "AI Chat",
|
||||
|
||||
@@ -48,6 +48,10 @@
|
||||
opacity: 0.4;
|
||||
}
|
||||
|
||||
.llm-chat-stop-btn {
|
||||
color: var(--danger-color, #dc3545);
|
||||
}
|
||||
|
||||
/* Model selector */
|
||||
.llm-chat-model-selector {
|
||||
display: flex;
|
||||
|
||||
@@ -228,11 +228,11 @@ export default function ChatInputBar({
|
||||
)}
|
||||
</div>
|
||||
<ActionButton
|
||||
icon={chat.isStreaming ? "bx bx-loader-alt bx-spin" : "bx bx-send"}
|
||||
text={chat.isStreaming ? t("llm_chat.sending") : t("llm_chat.send")}
|
||||
onClick={handleSubmit}
|
||||
disabled={chat.isStreaming || !chat.input.trim()}
|
||||
className="llm-chat-send-btn"
|
||||
icon={chat.isStreaming ? "bx bx-stop" : "bx bx-send"}
|
||||
text={chat.isStreaming ? t("llm_chat.stop") : t("llm_chat.send")}
|
||||
onClick={chat.isStreaming ? chat.stopStreaming : handleSubmit}
|
||||
disabled={!chat.isStreaming && !chat.input.trim()}
|
||||
className={`llm-chat-send-btn ${chat.isStreaming ? "llm-chat-stop-btn" : ""}`}
|
||||
/>
|
||||
</div>
|
||||
</form>
|
||||
|
||||
@@ -62,6 +62,8 @@ export interface UseLlmChatReturn {
|
||||
clearMessages: () => void;
|
||||
/** Refresh the provider/models list */
|
||||
refreshModels: () => void;
|
||||
/** Stop the current generation */
|
||||
stopStreaming: () => void;
|
||||
}
|
||||
|
||||
export function useLlmChat(
|
||||
@@ -89,6 +91,7 @@ export function useLlmChat(
|
||||
const [isCheckingProvider, setIsCheckingProvider] = useState<boolean>(true);
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
// Refs to get fresh values in getContent (avoids stale closures)
|
||||
const messagesRef = useRef(messages);
|
||||
@@ -251,6 +254,56 @@ export function useLlmChat(
|
||||
streamOptions.enableExtendedThinking = enableExtendedThinking;
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
abortControllerRef.current = abortController;
|
||||
|
||||
/** Shared cleanup: finalize collected content and reset streaming state. */
|
||||
function finalizeStream() {
|
||||
// Mark any in-progress tool calls as stopped so they don't show infinite spinners
|
||||
for (const [i, block] of contentBlocks.entries()) {
|
||||
if (block.type === "tool_call" && !block.toolCall.result) {
|
||||
contentBlocks[i] = {
|
||||
type: "tool_call",
|
||||
toolCall: { ...block.toolCall, result: "[Stopped]", isError: true }
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const finalNewMessages: StoredMessage[] = [];
|
||||
|
||||
if (thinkingContent) {
|
||||
finalNewMessages.push({
|
||||
id: randomString(),
|
||||
role: "assistant",
|
||||
content: thinkingContent,
|
||||
createdAt: new Date().toISOString(),
|
||||
type: "thinking"
|
||||
});
|
||||
}
|
||||
|
||||
if (contentBlocks.length > 0) {
|
||||
finalNewMessages.push({
|
||||
id: randomString(),
|
||||
role: "assistant",
|
||||
content: contentBlocks,
|
||||
createdAt: new Date().toISOString(),
|
||||
citations: citations.length > 0 ? citations : undefined,
|
||||
usage
|
||||
});
|
||||
}
|
||||
|
||||
if (finalNewMessages.length > 0) {
|
||||
setMessages([...newMessages, ...finalNewMessages]);
|
||||
}
|
||||
|
||||
setStreamingContent("");
|
||||
setStreamingBlocks([]);
|
||||
setStreamingThinking("");
|
||||
setPendingCitations([]);
|
||||
setIsStreaming(false);
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
|
||||
await streamChatCompletion(
|
||||
apiMessages,
|
||||
streamOptions,
|
||||
@@ -320,42 +373,19 @@ export function useLlmChat(
|
||||
setIsStreaming(false);
|
||||
},
|
||||
onDone: () => {
|
||||
const finalNewMessages: StoredMessage[] = [];
|
||||
|
||||
if (thinkingContent) {
|
||||
finalNewMessages.push({
|
||||
id: randomString(),
|
||||
role: "assistant",
|
||||
content: thinkingContent,
|
||||
createdAt: new Date().toISOString(),
|
||||
type: "thinking"
|
||||
});
|
||||
}
|
||||
|
||||
if (contentBlocks.length > 0) {
|
||||
finalNewMessages.push({
|
||||
id: randomString(),
|
||||
role: "assistant",
|
||||
content: contentBlocks,
|
||||
createdAt: new Date().toISOString(),
|
||||
citations: citations.length > 0 ? citations : undefined,
|
||||
usage
|
||||
});
|
||||
}
|
||||
|
||||
if (finalNewMessages.length > 0) {
|
||||
const allMessages = [...newMessages, ...finalNewMessages];
|
||||
setMessages(allMessages);
|
||||
}
|
||||
|
||||
setStreamingContent("");
|
||||
setStreamingBlocks([]);
|
||||
setStreamingThinking("");
|
||||
setPendingCitations([]);
|
||||
setIsStreaming(false);
|
||||
finalizeStream();
|
||||
}
|
||||
},
|
||||
abortController.signal
|
||||
).catch((e) => {
|
||||
// AbortError is expected when user stops generation
|
||||
if (e instanceof DOMException && e.name === "AbortError") {
|
||||
finalizeStream();
|
||||
} else {
|
||||
// Re-throw other errors so they are not swallowed
|
||||
throw e;
|
||||
}
|
||||
);
|
||||
});
|
||||
}, [input, isStreaming, messages, selectedModel, enableWebSearch, enableNoteTools, enableExtendedThinking, contextNoteId, supportsExtendedThinking, setMessages]);
|
||||
|
||||
const handleKeyDown = useCallback((e: KeyboardEvent) => {
|
||||
@@ -365,6 +395,13 @@ export function useLlmChat(
|
||||
}
|
||||
}, [handleSubmit]);
|
||||
|
||||
/** Stop the current generation by aborting the SSE connection. */
|
||||
const stopStreaming = useCallback(() => {
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
}
|
||||
}, []);
|
||||
|
||||
return {
|
||||
// State
|
||||
messages,
|
||||
@@ -402,6 +439,7 @@ export function useLlmChat(
|
||||
loadFromContent,
|
||||
getContent,
|
||||
clearMessages,
|
||||
refreshModels
|
||||
refreshModels,
|
||||
stopStreaming
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user