|
12 | 12 | from openai import AsyncOpenAI |
13 | 13 | from chatkit.agents import simple_to_agent_input, stream_agent_response, AgentContext, ClientToolCall |
14 | 14 | from chatkit.server import StreamingResult, ChatKitServer |
15 | | -from chatkit.types import ThreadMetadata, UserMessageItem, ThreadStreamEvent, UserMessageTextContent, ClientToolCallItem |
| 15 | +from chatkit.types import ThreadMetadata, UserMessageItem, ThreadStreamEvent, UserMessageTextContent, ClientToolCallItem, ThreadsAddClientToolOutputReq, StreamingReq, ThreadItemDoneEvent, ThreadItemAddedEvent, AssistantMessageItem |
| 16 | +from datetime import datetime |
16 | 17 | from chatkit.store import Store, AttachmentStore |
| 18 | +from chatkit.server import DEFAULT_PAGE_SIZE |
17 | 19 | from supabase import create_client, Client |
18 | 20 | from .stores import ChatKitDataStore, ChatKitAttachmentStore, TContext |
19 | 21 |
|
@@ -383,6 +385,61 @@ def __init__( |
383 | 385 | ): |
384 | 386 | super().__init__(data_store, attachment_store) |
385 | 387 |
|
| 388 | + async def _process_streaming_impl( |
| 389 | + self, request: StreamingReq, context: TContext |
| 390 | + ) -> AsyncIterator[ThreadStreamEvent]: |
| 391 | + # Override to fix threads.add_client_tool_output handler |
| 392 | + # The library loads only 1 item, but we need to load more to find pending tool calls |
| 393 | + # if an assistant message was saved after the tool call |
| 394 | + |
| 395 | + if isinstance(request, ThreadsAddClientToolOutputReq): |
| 396 | + thread = await self.store.load_thread( |
| 397 | + request.params.thread_id, context=context |
| 398 | + ) |
| 399 | + # Load DEFAULT_PAGE_SIZE items instead of just 1 to find pending tool calls |
| 400 | + items = await self.store.load_thread_items( |
| 401 | + thread.id, None, DEFAULT_PAGE_SIZE, "desc", context |
| 402 | + ) |
| 403 | + logger.info(f"[_process_streaming_impl] Loaded {len(items.data)} items for thread {thread.id}") |
| 404 | + logger.info(f"[_process_streaming_impl] Item types: {[item.type if hasattr(item, 'type') else type(item).__name__ for item in items.data]}") |
| 405 | + tool_call = next( |
| 406 | + ( |
| 407 | + item |
| 408 | + for item in items.data |
| 409 | + if isinstance(item, ClientToolCallItem) |
| 410 | + and item.status == "pending" |
| 411 | + ), |
| 412 | + None, |
| 413 | + ) |
| 414 | + if not tool_call: |
| 415 | + logger.error(f"[_process_streaming_impl] No pending ClientToolCallItem found in {len(items.data)} items") |
| 416 | + logger.error(f"[_process_streaming_impl] Items: {items.data}") |
| 417 | + raise ValueError( |
| 418 | + f"Last thread item in {thread.id} was not a ClientToolCallItem" |
| 419 | + ) |
| 420 | + |
| 421 | + tool_call.output = request.params.result |
| 422 | + tool_call.status = "completed" |
| 423 | + |
| 424 | + await self.store.save_item(thread.id, tool_call, context=context) |
| 425 | + |
| 426 | + # Safety against dangling pending tool calls if there are |
| 427 | + # multiple in a row, which should be impossible, and |
| 428 | + # integrations should ultimately filter out pending tool calls |
| 429 | + # when creating input response messages. |
| 430 | + await self._cleanup_pending_client_tool_call(thread, context) |
| 431 | + |
| 432 | + async for event in self._process_events( |
| 433 | + thread, |
| 434 | + context, |
| 435 | + lambda: self.respond(thread, None, context), |
| 436 | + ): |
| 437 | + yield event |
| 438 | + else: |
| 439 | + # For all other cases, use the parent's implementation |
| 440 | + async for event in super()._process_streaming_impl(request, context): |
| 441 | + yield event |
| 442 | + |
386 | 443 | async def respond( |
387 | 444 | self, |
388 | 445 | thread: ThreadMetadata, |
@@ -494,9 +551,55 @@ def sanitize_item(item): |
494 | 551 | return sanitized |
495 | 552 |
|
496 | 553 | # For regular messages, keep only role and content |
| 554 | + # But ensure content is properly formatted for the Agents SDK |
| 555 | + # For agent inputs, assistant messages should use input_text content (not output_text) |
| 556 | + role = item.get("role") |
| 557 | + content = item.get("content", []) |
| 558 | + |
| 559 | + if isinstance(content, list): |
| 560 | + # Convert content items to the format Agents SDK expects |
| 561 | + formatted_content = [] |
| 562 | + for c in content: |
| 563 | + if isinstance(c, dict): |
| 564 | + # Convert output_text to input_text for assistant messages (agent inputs use input_text) |
| 565 | + content_type = c.get("type") |
| 566 | + if content_type == "output_text": |
| 567 | + # Convert output_text to input_text for agent inputs |
| 568 | + formatted_content.append({ |
| 569 | + "type": "input_text", |
| 570 | + "text": c.get("text", ""), |
| 571 | + }) |
| 572 | + elif content_type == "input_text": |
| 573 | + # Already correct format |
| 574 | + formatted_content.append({ |
| 575 | + "type": "input_text", |
| 576 | + "text": c.get("text", ""), |
| 577 | + }) |
| 578 | + elif "text" in c: |
| 579 | + # Unknown type but has text, convert to input_text |
| 580 | + formatted_content.append({ |
| 581 | + "type": "input_text", |
| 582 | + "text": c.get("text", ""), |
| 583 | + }) |
| 584 | + else: |
| 585 | + # Unknown format, try to preserve it |
| 586 | + formatted_content.append(c) |
| 587 | + elif isinstance(c, str): |
| 588 | + # Plain string, wrap in input_text |
| 589 | + formatted_content.append({ |
| 590 | + "type": "input_text", |
| 591 | + "text": c, |
| 592 | + }) |
| 593 | + else: |
| 594 | + formatted_content.append(c) |
| 595 | + content = formatted_content |
| 596 | + elif isinstance(content, str): |
| 597 | + # Plain string, wrap in input_text array |
| 598 | + content = [{"type": "input_text", "text": content}] |
| 599 | + |
497 | 600 | return { |
498 | | - "role": item.get("role"), |
499 | | - "content": item.get("content"), |
| 601 | + "role": role, |
| 602 | + "content": content, |
500 | 603 | } |
501 | 604 | return item |
502 | 605 |
|
@@ -598,7 +701,97 @@ def sanitize_item(item): |
598 | 701 | ) |
599 | 702 | logger.info(f"[python-respond] Runner.run_streamed returned, result type: {type(result)}") |
600 | 703 |
|
601 | | - async for event in stream_agent_response(agent_context, result): |
| 704 | + # Wrap stream_agent_response to fix __fake_id__ in ThreadItemAddedEvent and ThreadItemDoneEvent items |
| 705 | + # CRITICAL: If items are saved with __fake_id__, they will overwrite each other due to PRIMARY KEY constraint |
| 706 | + # CRITICAL: Both thread.item.added and thread.item.done must have the SAME ID so the frontend recognizes them as the same item |
| 707 | + # This ensures ChatKit items have proper IDs (defense-in-depth - add_thread_item also fixes IDs) |
| 708 | + async def fix_chatkit_event_ids(events): |
| 709 | + event_count = 0 |
| 710 | + # Track IDs we've generated for items, so thread.item.added and thread.item.done use the same ID |
| 711 | + item_id_map: dict[str, str] = {} # Maps original __fake_id__ to generated ID |
| 712 | + |
| 713 | + async for event in events: |
| 714 | + event_count += 1 |
| 715 | + event_type = event.type if hasattr(event, 'type') else type(event).__name__ |
| 716 | + logger.info(f"[python-respond] Event #{event_count}: {event_type}") |
| 717 | + |
| 718 | + # Fix __fake_id__ in ThreadItemAddedEvent items |
| 719 | + if isinstance(event, ThreadItemAddedEvent) and hasattr(event, 'item'): |
| 720 | + item = event.item |
| 721 | + original_id = item.id if hasattr(item, 'id') else 'N/A' |
| 722 | + item_type = item.type if hasattr(item, 'type') else type(item).__name__ |
| 723 | + content_preview = "" |
| 724 | + content_length = 0 |
| 725 | + if isinstance(item, AssistantMessageItem) and item.content: |
| 726 | + # Get first 50 chars of content for logging |
| 727 | + first_content = item.content[0] if item.content else None |
| 728 | + if first_content and hasattr(first_content, 'text'): |
| 729 | + content_length = len(first_content.text) |
| 730 | + content_preview = first_content.text[:50] + "..." if len(first_content.text) > 50 else first_content.text |
| 731 | + logger.info(f"[python-respond] ThreadItemAddedEvent: type={item_type}, id={original_id}, content_length={content_length}, content_preview={content_preview}") |
| 732 | + |
| 733 | + if hasattr(item, 'id') and (item.id == '__fake_id__' or not item.id or item.id == 'N/A'): |
| 734 | + # Check if we've already generated an ID for this item (from a previous event) |
| 735 | + if original_id in item_id_map: |
| 736 | + item.id = item_id_map[original_id] |
| 737 | + logger.info(f"[python-respond] Reusing ID for ThreadItemAddedEvent: {original_id} -> {item.id}") |
| 738 | + else: |
| 739 | + logger.error(f"[python-respond] CRITICAL: Fixing __fake_id__ for {type(item).__name__} in ThreadItemAddedEvent (original_id={original_id})") |
| 740 | + thread_meta = ThreadMetadata(id=thread.id, created_at=datetime.now()) |
| 741 | + if isinstance(item, ClientToolCallItem): |
| 742 | + item_type_for_id = "tool_call" |
| 743 | + elif isinstance(item, AssistantMessageItem): |
| 744 | + item_type_for_id = "message" |
| 745 | + elif isinstance(item, UserMessageItem): |
| 746 | + item_type_for_id = "message" |
| 747 | + else: |
| 748 | + item_type_for_id = "message" |
| 749 | + item.id = self.store.generate_item_id(item_type_for_id, thread_meta, context) |
| 750 | + item_id_map[original_id] = item.id |
| 751 | + logger.info(f"[python-respond] Fixed ID in ThreadItemAddedEvent: {original_id} -> {item.id}") |
| 752 | + else: |
| 753 | + logger.info(f"[python-respond] Item {type(item).__name__} already has valid ID: {original_id}") |
| 754 | + |
| 755 | + # Fix __fake_id__ in ThreadItemDoneEvent items before they're saved |
| 756 | + if isinstance(event, ThreadItemDoneEvent) and hasattr(event, 'item'): |
| 757 | + item = event.item |
| 758 | + original_id = item.id if hasattr(item, 'id') else 'N/A' |
| 759 | + item_type = item.type if hasattr(item, 'type') else type(item).__name__ |
| 760 | + content_preview = "" |
| 761 | + content_length = 0 |
| 762 | + if isinstance(item, AssistantMessageItem) and item.content: |
| 763 | + # Get first 50 chars of content for logging |
| 764 | + first_content = item.content[0] if item.content else None |
| 765 | + if first_content and hasattr(first_content, 'text'): |
| 766 | + content_length = len(first_content.text) |
| 767 | + content_preview = first_content.text[:50] + "..." if len(first_content.text) > 50 else first_content.text |
| 768 | + logger.info(f"[python-respond] ThreadItemDoneEvent: type={item_type}, id={original_id}, content_length={content_length}, content_preview={content_preview}") |
| 769 | + |
| 770 | + if hasattr(item, 'id') and (item.id == '__fake_id__' or not item.id or item.id == 'N/A'): |
| 771 | + # Check if we've already generated an ID for this item (from thread.item.added) |
| 772 | + if original_id in item_id_map: |
| 773 | + item.id = item_id_map[original_id] |
| 774 | + logger.info(f"[python-respond] Reusing ID for ThreadItemDoneEvent: {original_id} -> {item.id}") |
| 775 | + else: |
| 776 | + logger.error(f"[python-respond] CRITICAL: Fixing __fake_id__ for {type(item).__name__} in ThreadItemDoneEvent (original_id={original_id})") |
| 777 | + thread_meta = ThreadMetadata(id=thread.id, created_at=datetime.now()) |
| 778 | + if isinstance(item, ClientToolCallItem): |
| 779 | + item_type_for_id = "tool_call" |
| 780 | + elif isinstance(item, AssistantMessageItem): |
| 781 | + item_type_for_id = "message" |
| 782 | + elif isinstance(item, UserMessageItem): |
| 783 | + item_type_for_id = "message" |
| 784 | + else: |
| 785 | + item_type_for_id = "message" |
| 786 | + item.id = self.store.generate_item_id(item_type_for_id, thread_meta, context) |
| 787 | + item_id_map[original_id] = item.id |
| 788 | + logger.info(f"[python-respond] Fixed ID in ThreadItemDoneEvent: {original_id} -> {item.id}") |
| 789 | + else: |
| 790 | + logger.info(f"[python-respond] Item {type(item).__name__} already has valid ID: {original_id}") |
| 791 | + yield event |
| 792 | + |
| 793 | + # Stream events with fixed IDs |
| 794 | + async for event in fix_chatkit_event_ids(stream_agent_response(agent_context, result)): |
602 | 795 | yield event |
603 | 796 |
|
604 | 797 | server = MyChatKitServer(data_store, attachment_store) |
|
0 commit comments