Skip to content

Commit cc675fc

Browse files
sjrlAmnah199
authored andcommitted
fix: Fix _convert_streaming_chunks_to_chat_message (#9566)
* Fix conversion * Add reno * Add unit test
1 parent 054aacf commit cc675fc

File tree

3 files changed

+68
-8
lines changed

3 files changed

+68
-8
lines changed

haystack/components/generators/utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,24 +84,20 @@ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> C
8484
tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
8585
for chunk in chunks:
8686
if chunk.tool_calls:
87-
# We do this to make sure mypy is happy, but we enforce index is not None in the StreamingChunk dataclass if
88-
# tool_call is present
89-
assert chunk.index is not None
90-
9187
for tool_call in chunk.tool_calls:
9288
# We use the index of the tool_call to track the tool call across chunks since the ID is not always
9389
# provided
9490
if tool_call.index not in tool_call_data:
95-
tool_call_data[chunk.index] = {"id": "", "name": "", "arguments": ""}
91+
tool_call_data[tool_call.index] = {"id": "", "name": "", "arguments": ""}
9692

9793
# Save the ID if present
9894
if tool_call.id is not None:
99-
tool_call_data[chunk.index]["id"] = tool_call.id
95+
tool_call_data[tool_call.index]["id"] = tool_call.id
10096

10197
if tool_call.tool_name is not None:
102-
tool_call_data[chunk.index]["name"] += tool_call.tool_name
98+
tool_call_data[tool_call.index]["name"] += tool_call.tool_name
10399
if tool_call.arguments is not None:
104-
tool_call_data[chunk.index]["arguments"] += tool_call.arguments
100+
tool_call_data[tool_call.index]["arguments"] += tool_call.arguments
105101

106102
# Convert accumulated tool call data into ToolCall objects
107103
sorted_keys = sorted(tool_call_data.keys())
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Fix `_convert_streaming_chunks_to_chat_message` which is used to convert Haystack StreamingChunks into a Haystack ChatMessage. This fixes the scenario where one StreamingChunk contains two ToolCallDetlas in StreamingChunk.tool_calls. With this fix this correctly saves both ToolCallDeltas whereas before they were overwriting each other. This only occurs with some LLM providers like Mistral (and not OpenAI) due to how the provider returns tool calls.

test/components/generators/test_utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,63 @@ def test_convert_streaming_chunks_to_chat_message_tool_calls_in_any_chunk():
325325
},
326326
"prompt_tokens_details": {"audio_tokens": 0, "cached_tokens": 0},
327327
}
328+
329+
330+
def test_convert_streaming_chunk_to_chat_message_two_tool_calls_in_same_chunk():
331+
chunks = [
332+
StreamingChunk(
333+
content="",
334+
meta={
335+
"model": "mistral-small-latest",
336+
"index": 0,
337+
"tool_calls": None,
338+
"finish_reason": None,
339+
"usage": None,
340+
},
341+
component_info=ComponentInfo(
342+
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
343+
name=None,
344+
),
345+
),
346+
StreamingChunk(
347+
content="",
348+
meta={
349+
"model": "mistral-small-latest",
350+
"index": 0,
351+
"finish_reason": "tool_calls",
352+
"usage": {
353+
"completion_tokens": 35,
354+
"prompt_tokens": 77,
355+
"total_tokens": 112,
356+
"completion_tokens_details": None,
357+
"prompt_tokens_details": None,
358+
},
359+
},
360+
component_info=ComponentInfo(
361+
type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
362+
name=None,
363+
),
364+
index=0,
365+
tool_calls=[
366+
ToolCallDelta(index=0, tool_name="weather", arguments='{"city": "Paris"}', id="FL1FFlqUG"),
367+
ToolCallDelta(index=1, tool_name="weather", arguments='{"city": "Berlin"}', id="xSuhp66iB"),
368+
],
369+
start=True,
370+
finish_reason="tool_calls",
371+
),
372+
]
373+
374+
# Convert chunks to a chat message
375+
result = _convert_streaming_chunks_to_chat_message(chunks=chunks)
376+
377+
assert not result.texts
378+
assert not result.text
379+
380+
# Verify both tool calls were found and processed
381+
assert len(result.tool_calls) == 2
382+
assert result.tool_calls[0].id == "FL1FFlqUG"
383+
assert result.tool_calls[0].tool_name == "weather"
384+
assert result.tool_calls[0].arguments == {"city": "Paris"}
385+
assert result.tool_calls[1].id == "xSuhp66iB"
386+
assert result.tool_calls[1].tool_name == "weather"
387+
assert result.tool_calls[1].arguments == {"city": "Berlin"}

0 commit comments

Comments
 (0)