Skip to content

Commit a4a784b

Browse files
committed
Update default agent model and enhance ChatKit event handling
- Changed the default agent model from `ollama/gpt-oss:120b-cloud` to `openai/gpt-4.1` in configuration files and migration scripts. - Improved the `_process_streaming_impl` method in `main.py` to handle multiple thread items, ensuring proper loading of pending tool calls. - Enhanced item ID management in `ChatKitDataStore` to prevent duplicate entries and ensure consistent ID generation for events. - Updated TypeScript functions to align with Python implementations, ensuring consistent handling of function calls and responses. These changes enhance the integration with OpenAI's services and improve the overall functionality and reliability of the ChatKit system.
1 parent 4b609c4 commit a4a784b

File tree

12 files changed

+588
-114
lines changed

12 files changed

+588
-114
lines changed

.github/workflows/ci-cd.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ env:
2020
VITE_SUPABASE_URL: http://127.0.0.1:54321
2121
VITE_SUPABASE_ANON_KEY: sb_publishable_ACJWlzQHlZjBrEguHvfOxg_3BJgxAaH
2222
# Supabase edge function secrets
23-
DEFAULT_AGENT_MODEL: ollama/gpt-oss:120b-cloud
23+
DEFAULT_AGENT_MODEL: openai/gpt-4.1
2424
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
2525
OLLAMA_API_KEY: ${{ secrets.OLLAMA_API_KEY }}
2626
HF_TOKEN: ${{ secrets.HF_TOKEN }}

src/api/main.py

Lines changed: 197 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
from openai import AsyncOpenAI
1313
from chatkit.agents import simple_to_agent_input, stream_agent_response, AgentContext, ClientToolCall
1414
from chatkit.server import StreamingResult, ChatKitServer
15-
from chatkit.types import ThreadMetadata, UserMessageItem, ThreadStreamEvent, UserMessageTextContent, ClientToolCallItem
15+
from chatkit.types import ThreadMetadata, UserMessageItem, ThreadStreamEvent, UserMessageTextContent, ClientToolCallItem, ThreadsAddClientToolOutputReq, StreamingReq, ThreadItemDoneEvent, ThreadItemAddedEvent, AssistantMessageItem
16+
from datetime import datetime
1617
from chatkit.store import Store, AttachmentStore
18+
from chatkit.server import DEFAULT_PAGE_SIZE
1719
from supabase import create_client, Client
1820
from .stores import ChatKitDataStore, ChatKitAttachmentStore, TContext
1921

@@ -383,6 +385,61 @@ def __init__(
383385
):
384386
super().__init__(data_store, attachment_store)
385387

388+
async def _process_streaming_impl(
389+
self, request: StreamingReq, context: TContext
390+
) -> AsyncIterator[ThreadStreamEvent]:
391+
# Override to fix threads.add_client_tool_output handler
392+
# The library loads only 1 item, but we need to load more to find pending tool calls
393+
# if an assistant message was saved after the tool call
394+
395+
if isinstance(request, ThreadsAddClientToolOutputReq):
396+
thread = await self.store.load_thread(
397+
request.params.thread_id, context=context
398+
)
399+
# Load DEFAULT_PAGE_SIZE items instead of just 1 to find pending tool calls
400+
items = await self.store.load_thread_items(
401+
thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
402+
)
403+
logger.info(f"[_process_streaming_impl] Loaded {len(items.data)} items for thread {thread.id}")
404+
logger.info(f"[_process_streaming_impl] Item types: {[item.type if hasattr(item, 'type') else type(item).__name__ for item in items.data]}")
405+
tool_call = next(
406+
(
407+
item
408+
for item in items.data
409+
if isinstance(item, ClientToolCallItem)
410+
and item.status == "pending"
411+
),
412+
None,
413+
)
414+
if not tool_call:
415+
logger.error(f"[_process_streaming_impl] No pending ClientToolCallItem found in {len(items.data)} items")
416+
logger.error(f"[_process_streaming_impl] Items: {items.data}")
417+
raise ValueError(
418+
f"Last thread item in {thread.id} was not a ClientToolCallItem"
419+
)
420+
421+
tool_call.output = request.params.result
422+
tool_call.status = "completed"
423+
424+
await self.store.save_item(thread.id, tool_call, context=context)
425+
426+
# Safety against dangling pending tool calls if there are
427+
# multiple in a row, which should be impossible, and
428+
# integrations should ultimately filter out pending tool calls
429+
# when creating input response messages.
430+
await self._cleanup_pending_client_tool_call(thread, context)
431+
432+
async for event in self._process_events(
433+
thread,
434+
context,
435+
lambda: self.respond(thread, None, context),
436+
):
437+
yield event
438+
else:
439+
# For all other cases, use the parent's implementation
440+
async for event in super()._process_streaming_impl(request, context):
441+
yield event
442+
386443
async def respond(
387444
self,
388445
thread: ThreadMetadata,
@@ -494,9 +551,55 @@ def sanitize_item(item):
494551
return sanitized
495552

496553
# For regular messages, keep only role and content
554+
# But ensure content is properly formatted for the Agents SDK
555+
# For agent inputs, assistant messages should use input_text content (not output_text)
556+
role = item.get("role")
557+
content = item.get("content", [])
558+
559+
if isinstance(content, list):
560+
# Convert content items to the format Agents SDK expects
561+
formatted_content = []
562+
for c in content:
563+
if isinstance(c, dict):
564+
# Convert output_text to input_text for assistant messages (agent inputs use input_text)
565+
content_type = c.get("type")
566+
if content_type == "output_text":
567+
# Convert output_text to input_text for agent inputs
568+
formatted_content.append({
569+
"type": "input_text",
570+
"text": c.get("text", ""),
571+
})
572+
elif content_type == "input_text":
573+
# Already correct format
574+
formatted_content.append({
575+
"type": "input_text",
576+
"text": c.get("text", ""),
577+
})
578+
elif "text" in c:
579+
# Unknown type but has text, convert to input_text
580+
formatted_content.append({
581+
"type": "input_text",
582+
"text": c.get("text", ""),
583+
})
584+
else:
585+
# Unknown format, try to preserve it
586+
formatted_content.append(c)
587+
elif isinstance(c, str):
588+
# Plain string, wrap in input_text
589+
formatted_content.append({
590+
"type": "input_text",
591+
"text": c,
592+
})
593+
else:
594+
formatted_content.append(c)
595+
content = formatted_content
596+
elif isinstance(content, str):
597+
# Plain string, wrap in input_text array
598+
content = [{"type": "input_text", "text": content}]
599+
497600
return {
498-
"role": item.get("role"),
499-
"content": item.get("content"),
601+
"role": role,
602+
"content": content,
500603
}
501604
return item
502605

@@ -598,7 +701,97 @@ def sanitize_item(item):
598701
)
599702
logger.info(f"[python-respond] Runner.run_streamed returned, result type: {type(result)}")
600703

601-
async for event in stream_agent_response(agent_context, result):
704+
# Wrap stream_agent_response to fix __fake_id__ in ThreadItemAddedEvent and ThreadItemDoneEvent items
705+
# CRITICAL: If items are saved with __fake_id__, they will overwrite each other due to PRIMARY KEY constraint
706+
# CRITICAL: Both thread.item.added and thread.item.done must have the SAME ID so the frontend recognizes them as the same item
707+
# This ensures ChatKit items have proper IDs (defense-in-depth - add_thread_item also fixes IDs)
708+
async def fix_chatkit_event_ids(events):
709+
event_count = 0
710+
# Track IDs we've generated for items, so thread.item.added and thread.item.done use the same ID
711+
item_id_map: dict[str, str] = {} # Maps original __fake_id__ to generated ID
712+
713+
async for event in events:
714+
event_count += 1
715+
event_type = event.type if hasattr(event, 'type') else type(event).__name__
716+
logger.info(f"[python-respond] Event #{event_count}: {event_type}")
717+
718+
# Fix __fake_id__ in ThreadItemAddedEvent items
719+
if isinstance(event, ThreadItemAddedEvent) and hasattr(event, 'item'):
720+
item = event.item
721+
original_id = item.id if hasattr(item, 'id') else 'N/A'
722+
item_type = item.type if hasattr(item, 'type') else type(item).__name__
723+
content_preview = ""
724+
content_length = 0
725+
if isinstance(item, AssistantMessageItem) and item.content:
726+
# Get first 50 chars of content for logging
727+
first_content = item.content[0] if item.content else None
728+
if first_content and hasattr(first_content, 'text'):
729+
content_length = len(first_content.text)
730+
content_preview = first_content.text[:50] + "..." if len(first_content.text) > 50 else first_content.text
731+
logger.info(f"[python-respond] ThreadItemAddedEvent: type={item_type}, id={original_id}, content_length={content_length}, content_preview={content_preview}")
732+
733+
if hasattr(item, 'id') and (item.id == '__fake_id__' or not item.id or item.id == 'N/A'):
734+
# Check if we've already generated an ID for this item (from a previous event)
735+
if original_id in item_id_map:
736+
item.id = item_id_map[original_id]
737+
logger.info(f"[python-respond] Reusing ID for ThreadItemAddedEvent: {original_id} -> {item.id}")
738+
else:
739+
logger.error(f"[python-respond] CRITICAL: Fixing __fake_id__ for {type(item).__name__} in ThreadItemAddedEvent (original_id={original_id})")
740+
thread_meta = ThreadMetadata(id=thread.id, created_at=datetime.now())
741+
if isinstance(item, ClientToolCallItem):
742+
item_type_for_id = "tool_call"
743+
elif isinstance(item, AssistantMessageItem):
744+
item_type_for_id = "message"
745+
elif isinstance(item, UserMessageItem):
746+
item_type_for_id = "message"
747+
else:
748+
item_type_for_id = "message"
749+
item.id = self.store.generate_item_id(item_type_for_id, thread_meta, context)
750+
item_id_map[original_id] = item.id
751+
logger.info(f"[python-respond] Fixed ID in ThreadItemAddedEvent: {original_id} -> {item.id}")
752+
else:
753+
logger.info(f"[python-respond] Item {type(item).__name__} already has valid ID: {original_id}")
754+
755+
# Fix __fake_id__ in ThreadItemDoneEvent items before they're saved
756+
if isinstance(event, ThreadItemDoneEvent) and hasattr(event, 'item'):
757+
item = event.item
758+
original_id = item.id if hasattr(item, 'id') else 'N/A'
759+
item_type = item.type if hasattr(item, 'type') else type(item).__name__
760+
content_preview = ""
761+
content_length = 0
762+
if isinstance(item, AssistantMessageItem) and item.content:
763+
# Get first 50 chars of content for logging
764+
first_content = item.content[0] if item.content else None
765+
if first_content and hasattr(first_content, 'text'):
766+
content_length = len(first_content.text)
767+
content_preview = first_content.text[:50] + "..." if len(first_content.text) > 50 else first_content.text
768+
logger.info(f"[python-respond] ThreadItemDoneEvent: type={item_type}, id={original_id}, content_length={content_length}, content_preview={content_preview}")
769+
770+
if hasattr(item, 'id') and (item.id == '__fake_id__' or not item.id or item.id == 'N/A'):
771+
# Check if we've already generated an ID for this item (from thread.item.added)
772+
if original_id in item_id_map:
773+
item.id = item_id_map[original_id]
774+
logger.info(f"[python-respond] Reusing ID for ThreadItemDoneEvent: {original_id} -> {item.id}")
775+
else:
776+
logger.error(f"[python-respond] CRITICAL: Fixing __fake_id__ for {type(item).__name__} in ThreadItemDoneEvent (original_id={original_id})")
777+
thread_meta = ThreadMetadata(id=thread.id, created_at=datetime.now())
778+
if isinstance(item, ClientToolCallItem):
779+
item_type_for_id = "tool_call"
780+
elif isinstance(item, AssistantMessageItem):
781+
item_type_for_id = "message"
782+
elif isinstance(item, UserMessageItem):
783+
item_type_for_id = "message"
784+
else:
785+
item_type_for_id = "message"
786+
item.id = self.store.generate_item_id(item_type_for_id, thread_meta, context)
787+
item_id_map[original_id] = item.id
788+
logger.info(f"[python-respond] Fixed ID in ThreadItemDoneEvent: {original_id} -> {item.id}")
789+
else:
790+
logger.info(f"[python-respond] Item {type(item).__name__} already has valid ID: {original_id}")
791+
yield event
792+
793+
# Stream events with fixed IDs
794+
async for event in fix_chatkit_event_ids(stream_agent_response(agent_context, result)):
602795
yield event
603796

604797
server = MyChatKitServer(data_store, attachment_store)

src/api/stores.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any
33
from datetime import datetime
44
import os
5+
import logging
56

67
from fastapi import HTTPException
78
from chatkit.store import Store, AttachmentStore
@@ -18,6 +19,8 @@
1819
)
1920
from openai import AsyncOpenAI
2021

22+
logger = logging.getLogger(__name__)
23+
2124

2225
class TContext(dict):
2326
"""Request-scoped context passed through ChatKit and Store.
@@ -91,6 +94,31 @@ async def add_thread_item(
9194
if not context.user_id:
9295
raise HTTPException(status_code=400, detail="Missing user_id")
9396

97+
item_id = item.id if hasattr(item, 'id') else 'N/A'
98+
logger.info(f"[add_thread_item] Adding item to thread {thread_id}: type={item.type if hasattr(item, 'type') else type(item).__name__}, id={item_id}")
99+
100+
# CRITICAL: Check if ID is invalid for any item type
101+
# If items are saved with __fake_id__, they will overwrite each other due to PRIMARY KEY constraint
102+
if item_id == '__fake_id__' or not item_id or item_id == 'N/A':
103+
logger.error(f"[add_thread_item] WARNING: Item has invalid ID: {item_id}, type={item.type if hasattr(item, 'type') else type(item).__name__}")
104+
# Generate a proper ID if missing
105+
# Create a minimal ThreadMetadata for generate_item_id
106+
thread_meta = ThreadMetadata(id=thread_id, created_at=datetime.now())
107+
# Determine item type for ID generation
108+
if isinstance(item, ClientToolCallItem):
109+
item_type_for_id = "tool_call"
110+
logger.info(f"[add_thread_item] ClientToolCallItem: status={item.status}, name={item.name}, call_id={item.call_id}")
111+
elif isinstance(item, AssistantMessageItem):
112+
item_type_for_id = "message"
113+
elif isinstance(item, UserMessageItem):
114+
item_type_for_id = "message"
115+
else:
116+
item_type_for_id = "message" # Default fallback
117+
item.id = self.generate_item_id(item_type_for_id, thread_meta, context)
118+
logger.info(f"[add_thread_item] Generated new ID for item: {item.id}")
119+
elif isinstance(item, ClientToolCallItem):
120+
logger.info(f"[add_thread_item] ClientToolCallItem: status={item.status}, name={item.name}, call_id={item.call_id}, id={item_id}")
121+
94122
import httpx
95123

96124
client = self._get_client(context)

supabase/config.toml.example

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ enabled = true
66
[edge_runtime.secrets]
77
ANTHROPIC_API_KEY = "your-anthropic-api-key"
88
# Default model for agents
9-
DEFAULT_AGENT_MODEL = "ollama/gpt-oss:120b-cloud"
9+
DEFAULT_AGENT_MODEL = "openai/gpt-4.1"
1010
# Add your Hugging Face token here for AI features
1111
HF_TOKEN = "your-huggingface-token"
1212
OLLAMA_API_KEY = "your-ollama-cloud-api-key"

supabase/functions/agent-chat-v2/chatkit/server.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,9 @@ export class ChatKitServer<TCtx = TContext> {
444444
const thread = await this.store.load_thread(req.params.thread_id, context);
445445
// Load recent items to find the pending client_tool_call
446446
// Match Python: items = await self.store.load_thread_items(thread.id, None, 1, "desc", context)
447-
const items = await this.store.load_thread_items(thread.id, null, 1, 'desc', context);
447+
// BUT: We need to load more items to find the pending tool call if an assistant message was saved after it
448+
// Load DEFAULT_PAGE_SIZE items to find the pending tool call
449+
const items = await this.store.load_thread_items(thread.id, null, DEFAULT_PAGE_SIZE, 'desc', context);
448450
// Match Python: tool_call = next((item for item in items.data if isinstance(item, ClientToolCallItem) and item.status == "pending"), None)
449451
const toolCall = items.data.find((item: ThreadItem) => {
450452
const typedItem = item as { type: string; status?: string };

0 commit comments

Comments
 (0)