diff --git a/apps/client/src/services/llm_chat.ts b/apps/client/src/services/llm_chat.ts index fa0a0279d3..cd6ab3e63f 100644 --- a/apps/client/src/services/llm_chat.ts +++ b/apps/client/src/services/llm_chat.ts @@ -27,7 +27,8 @@ export interface StreamCallbacks { export async function streamChatCompletion( messages: LlmMessage[], config: LlmChatConfig, - callbacks: StreamCallbacks + callbacks: StreamCallbacks, + abortSignal?: AbortSignal ): Promise { const headers = await server.getHeaders(); @@ -37,7 +38,8 @@ export async function streamChatCompletion( ...headers, "Content-Type": "application/json" } as HeadersInit, - body: JSON.stringify({ messages, config }) + body: JSON.stringify({ messages, config }), + signal: abortSignal }); if (!response.ok) { diff --git a/apps/client/src/translations/en/translation.json b/apps/client/src/translations/en/translation.json index 6931954e65..142798e442 100644 --- a/apps/client/src/translations/en/translation.json +++ b/apps/client/src/translations/en/translation.json @@ -1677,7 +1677,8 @@ "note_context_enabled": "Click to disable note context: {{title}}", "note_context_disabled": "Click to include current note in context", "no_provider_message": "No AI provider configured. Add one to start chatting.", - "add_provider": "Add AI Provider" + "add_provider": "Add AI Provider", + "stop": "Stop" }, "sidebar_chat": { "title": "AI Chat", diff --git a/apps/client/src/widgets/type_widgets/llm_chat/ChatInputBar.css b/apps/client/src/widgets/type_widgets/llm_chat/ChatInputBar.css index 4599e6a511..07adaacf3c 100644 --- a/apps/client/src/widgets/type_widgets/llm_chat/ChatInputBar.css +++ b/apps/client/src/widgets/type_widgets/llm_chat/ChatInputBar.css @@ -48,6 +48,10 @@ opacity: 0.4; } +.llm-chat-stop-btn { + color: var(--danger-color, #dc3545); +} + /* Model selector */ .llm-chat-model-selector { display: flex; diff --git a/apps/client/src/widgets/type_widgets/llm_chat/ChatInputBar.tsx b/apps/client/src/widgets/type_widgets/llm_chat/ChatInputBar.tsx index 6491a595b0..b4515d2bb4 100644 --- a/apps/client/src/widgets/type_widgets/llm_chat/ChatInputBar.tsx +++ b/apps/client/src/widgets/type_widgets/llm_chat/ChatInputBar.tsx @@ -228,11 +228,11 @@ export default function ChatInputBar({ )} diff --git a/apps/client/src/widgets/type_widgets/llm_chat/useLlmChat.ts b/apps/client/src/widgets/type_widgets/llm_chat/useLlmChat.ts index 63cbf4bbf4..f52fb6cc61 100644 --- a/apps/client/src/widgets/type_widgets/llm_chat/useLlmChat.ts +++ b/apps/client/src/widgets/type_widgets/llm_chat/useLlmChat.ts @@ -62,6 +62,8 @@ export interface UseLlmChatReturn { clearMessages: () => void; /** Refresh the provider/models list */ refreshModels: () => void; + /** Stop the current generation */ + stopStreaming: () => void; } export function useLlmChat( @@ -89,6 +91,7 @@ export function useLlmChat( const [isCheckingProvider, setIsCheckingProvider] = useState(true); const messagesEndRef = useRef(null); const textareaRef = useRef(null); + const abortControllerRef = useRef(null); // Refs to get fresh values in getContent (avoids stale closures) const messagesRef = useRef(messages); @@ -251,6 +254,56 @@ export function useLlmChat( streamOptions.enableExtendedThinking = enableExtendedThinking; } + const abortController = new AbortController(); + abortControllerRef.current = abortController; + + /** Shared cleanup: finalize collected content and reset streaming state. */ + function finalizeStream() { + // Mark any in-progress tool calls as stopped so they don't show infinite spinners + for (const [i, block] of contentBlocks.entries()) { + if (block.type === "tool_call" && !block.toolCall.result) { + contentBlocks[i] = { + type: "tool_call", + toolCall: { ...block.toolCall, result: "[Stopped]", isError: true } + }; + } + } + + const finalNewMessages: StoredMessage[] = []; + + if (thinkingContent) { + finalNewMessages.push({ + id: randomString(), + role: "assistant", + content: thinkingContent, + createdAt: new Date().toISOString(), + type: "thinking" + }); + } + + if (contentBlocks.length > 0) { + finalNewMessages.push({ + id: randomString(), + role: "assistant", + content: contentBlocks, + createdAt: new Date().toISOString(), + citations: citations.length > 0 ? citations : undefined, + usage + }); + } + + if (finalNewMessages.length > 0) { + setMessages([...newMessages, ...finalNewMessages]); + } + + setStreamingContent(""); + setStreamingBlocks([]); + setStreamingThinking(""); + setPendingCitations([]); + setIsStreaming(false); + abortControllerRef.current = null; + } + await streamChatCompletion( apiMessages, streamOptions, @@ -320,42 +373,19 @@ export function useLlmChat( setIsStreaming(false); }, onDone: () => { - const finalNewMessages: StoredMessage[] = []; - - if (thinkingContent) { - finalNewMessages.push({ - id: randomString(), - role: "assistant", - content: thinkingContent, - createdAt: new Date().toISOString(), - type: "thinking" - }); - } - - if (contentBlocks.length > 0) { - finalNewMessages.push({ - id: randomString(), - role: "assistant", - content: contentBlocks, - createdAt: new Date().toISOString(), - citations: citations.length > 0 ? citations : undefined, - usage - }); - } - - if (finalNewMessages.length > 0) { - const allMessages = [...newMessages, ...finalNewMessages]; - setMessages(allMessages); - } - - setStreamingContent(""); - setStreamingBlocks([]); - setStreamingThinking(""); - setPendingCitations([]); - setIsStreaming(false); + finalizeStream(); } + }, + abortController.signal + ).catch((e) => { + // AbortError is expected when user stops generation + if (e instanceof DOMException && e.name === "AbortError") { + finalizeStream(); + } else { + // Re-throw other errors so they are not swallowed + throw e; } - ); + }); }, [input, isStreaming, messages, selectedModel, enableWebSearch, enableNoteTools, enableExtendedThinking, contextNoteId, supportsExtendedThinking, setMessages]); const handleKeyDown = useCallback((e: KeyboardEvent) => { @@ -365,6 +395,13 @@ export function useLlmChat( } }, [handleSubmit]); + /** Stop the current generation by aborting the SSE connection. */ + const stopStreaming = useCallback(() => { + if (abortControllerRef.current) { + abortControllerRef.current.abort(); + } + }, []); + return { // State messages, @@ -402,6 +439,7 @@ export function useLlmChat( loadFromContent, getContent, clearMessages, - refreshModels + refreshModels, + stopStreaming }; }