Skip to content

Commit 538e483

Browse files
authored
feat: return logprobs in OpenAIChatGenerator and OpenAIResponsesChatGenerator (#10035)
* Return logprobs * Add log probs in responses * Get logprobs from streaming * Fix linting * Update tests * Fix formatting * Update * Fix tests * Loosen up tests * updates * Update logprobs * linting * Fix test
1 parent 4db967d commit 538e483

File tree

6 files changed

+109
-50
lines changed

6 files changed

+109
-50
lines changed

haystack/components/generators/chat/openai.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pydantic import BaseModel
2323

2424
from haystack import component, default_from_dict, default_to_dict, logging
25-
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
25+
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message, _serialize_object
2626
from haystack.dataclasses import (
2727
AsyncStreamingCallbackT,
2828
ChatMessage,
@@ -563,16 +563,17 @@ def _convert_chat_completion_to_chat_message(
563563
_arguments=arguments_str,
564564
)
565565

566-
chat_message = ChatMessage.from_assistant(
567-
text=text,
568-
tool_calls=tool_calls,
569-
meta={
570-
"model": completion.model,
571-
"index": choice.index,
572-
"finish_reason": choice.finish_reason,
573-
"usage": _serialize_usage(completion.usage),
574-
},
575-
)
566+
logprobs = _serialize_object(choice.logprobs) if choice.logprobs else None
567+
meta = {
568+
"model": completion.model,
569+
"index": choice.index,
570+
"finish_reason": choice.finish_reason,
571+
"usage": _serialize_object(completion.usage),
572+
}
573+
if logprobs:
574+
meta["logprobs"] = logprobs
575+
576+
chat_message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
576577

577578
return chat_message
578579

@@ -610,7 +611,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
610611
meta={
611612
"model": chunk.model,
612613
"received_at": datetime.now().isoformat(),
613-
"usage": _serialize_usage(chunk.usage),
614+
"usage": _serialize_object(chunk.usage),
614615
},
615616
)
616617

@@ -643,7 +644,7 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
643644
"tool_calls": choice.delta.tool_calls,
644645
"finish_reason": choice.finish_reason,
645646
"received_at": datetime.now().isoformat(),
646-
"usage": _serialize_usage(chunk.usage),
647+
"usage": _serialize_object(chunk.usage),
647648
},
648649
)
649650
return chunk_message
@@ -658,6 +659,23 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
658659
# NOTE: We may need to revisit this if OpenAI allows planning/thinking content before tool calls like
659660
# Anthropic Claude
660661
resolved_index = 0
662+
663+
# Initialize meta dictionary
664+
meta = {
665+
"model": chunk.model,
666+
"index": choice.index,
667+
"tool_calls": choice.delta.tool_calls,
668+
"finish_reason": choice.finish_reason,
669+
"received_at": datetime.now().isoformat(),
670+
"usage": _serialize_object(chunk.usage),
671+
}
672+
673+
# check if logprobs are present
674+
# logprobs are returned only for text content
675+
logprobs = _serialize_object(choice.logprobs) if choice.logprobs else None
676+
if logprobs:
677+
meta["logprobs"] = logprobs
678+
661679
chunk_message = StreamingChunk(
662680
content=choice.delta.content or "",
663681
component_info=component_info,
@@ -666,27 +684,6 @@ def _convert_chat_completion_chunk_to_streaming_chunk(
666684
# and previous_chunks is length 1 then this is the start of text content.
667685
start=len(previous_chunks) == 1,
668686
finish_reason=finish_reason_mapping.get(choice.finish_reason) if choice.finish_reason else None,
669-
meta={
670-
"model": chunk.model,
671-
"index": choice.index,
672-
"tool_calls": choice.delta.tool_calls,
673-
"finish_reason": choice.finish_reason,
674-
"received_at": datetime.now().isoformat(),
675-
"usage": _serialize_usage(chunk.usage),
676-
},
687+
meta=meta,
677688
)
678689
return chunk_message
679-
680-
681-
def _serialize_usage(usage):
682-
"""Convert OpenAI usage object to serializable dict recursively"""
683-
if hasattr(usage, "model_dump"):
684-
return usage.model_dump()
685-
elif hasattr(usage, "__dict__"):
686-
return {k: _serialize_usage(v) for k, v in usage.__dict__.items() if not k.startswith("_")}
687-
elif isinstance(usage, dict):
688-
return {k: _serialize_usage(v) for k, v in usage.items()}
689-
elif isinstance(usage, list):
690-
return [_serialize_usage(item) for item in usage]
691-
else:
692-
return usage

haystack/components/generators/chat/openai_responses.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from pydantic import BaseModel
1414

1515
from haystack import component, default_from_dict, default_to_dict, logging
16+
from haystack.components.generators.utils import _serialize_object
1617
from haystack.dataclasses import (
1718
AsyncStreamingCallbackT,
1819
ChatMessage,
@@ -516,10 +517,17 @@ def _convert_response_to_chat_message(responses: Union[Response, ParsedResponse]
516517

517518
tool_calls = []
518519
reasoning = None
520+
logprobs: list[dict] = []
519521
for output in responses.output:
520522
if isinstance(output, ResponseOutputRefusal):
521523
logger.warning("OpenAI returned a refusal output: {output}", output=output)
522524
continue
525+
526+
if output.type == "message":
527+
for content in output.content:
528+
if hasattr(content, "logprobs") and content.logprobs is not None:
529+
logprobs.append(_serialize_object(content.logprobs))
530+
523531
if output.type == "reasoning":
524532
# openai doesn't return the reasoning tokens, but we can view summary if its enabled
525533
# https://platform.openai.com/docs/guides/reasoning#reasoning-summaries
@@ -547,11 +555,17 @@ def _convert_response_to_chat_message(responses: Union[Response, ParsedResponse]
547555
_name=output.name,
548556
_arguments=output.arguments,
549557
)
558+
arguments = {}
550559

551560
# we save the response as dict because it contains resp_id etc.
552561
meta = responses.to_dict()
562+
553563
# remove output from meta because it contains toolcalls, reasoning, text etc.
554564
meta.pop("output")
565+
566+
if logprobs:
567+
meta["logprobs"] = logprobs
568+
555569
chat_message = ChatMessage.from_assistant(
556570
text=responses.output_text if responses.output_text else None,
557571
reasoning=reasoning,
@@ -569,6 +583,7 @@ def _convert_response_chunk_to_streaming_chunk( # pylint: disable=too-many-retu
569583
Converts the streaming response chunk from the OpenAI Responses API to a StreamingChunk.
570584
571585
:param chunk: The chunk returned by the OpenAI Responses API.
586+
:param previous_chunks: A list of previously received StreamingChunks.
572587
:param component_info: An optional `ComponentInfo` object containing information about the component that
573588
generated the chunk, such as the component name and type.
574589
:returns:
@@ -676,18 +691,22 @@ def _convert_streaming_chunks_to_chat_message(chunks: list[StreamingChunk]) -> C
676691
677692
:returns: The ChatMessage.
678693
"""
694+
679695
# Get the full text by concatenating all text chunks
680696
text = "".join([chunk.content for chunk in chunks])
697+
logprobs = []
698+
for chunk in chunks:
699+
if chunk.meta.get("logprobs"):
700+
logprobs.append(chunk.meta.get("logprobs"))
681701

682702
# Gather reasoning information if present
683703
reasoning_id = None
684704
reasoning_text = ""
685705
for chunk in chunks:
686-
if not chunk.reasoning:
687-
continue
688-
reasoning_text += chunk.reasoning.reasoning_text
689-
if chunk.reasoning.extra.get("id"):
690-
reasoning_id = chunk.reasoning.extra.get("id")
706+
if chunk.reasoning:
707+
reasoning_text += chunk.reasoning.reasoning_text
708+
if chunk.reasoning.extra.get("id"):
709+
reasoning_id = chunk.reasoning.extra.get("id")
691710

692711
# Process tool calls if present in any chunk
693712
tool_call_data: dict[str, dict[str, Any]] = {} # Track tool calls by id
@@ -731,7 +750,10 @@ def _convert_streaming_chunks_to_chat_message(chunks: list[StreamingChunk]) -> C
731750
)
732751

733752
# We dump the entire final response into meta to be consistent with non-streaming response
734-
final_response = chunks[-1].meta.get("response")
753+
final_response = chunks[-1].meta.get("response") or {}
754+
final_response.pop("output", None)
755+
if logprobs:
756+
final_response["logprobs"] = logprobs
735757

736758
# Add reasoning content if both id and text are available
737759
reasoning = None

haystack/components/generators/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def _convert_streaming_chunks_to_chat_message(chunks: list[StreamingChunk]) -> C
8484
:returns: The ChatMessage.
8585
"""
8686
text = "".join([chunk.content for chunk in chunks])
87+
logprobs = []
88+
for chunk in chunks:
89+
if chunk.meta.get("logprobs"):
90+
logprobs.append(chunk.meta.get("logprobs"))
8791
tool_calls = []
8892

8993
# Process tool calls if present in any chunk
@@ -134,4 +138,21 @@ def _convert_streaming_chunks_to_chat_message(chunks: list[StreamingChunk]) -> C
134138
"usage": chunks[-1].meta.get("usage"), # last chunk has the final usage data if available
135139
}
136140

141+
if logprobs:
142+
meta["logprobs"] = logprobs
143+
137144
return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
145+
146+
147+
def _serialize_object(obj):
148+
"""Convert an object to a serializable dict recursively"""
149+
if hasattr(obj, "model_dump"):
150+
return obj.model_dump()
151+
elif hasattr(obj, "__dict__"):
152+
return {k: _serialize_object(v) for k, v in obj.__dict__.items() if not k.startswith("_")}
153+
elif isinstance(obj, dict):
154+
return {k: _serialize_object(v) for k, v in obj.items()}
155+
elif isinstance(obj, list):
156+
return [_serialize_object(item) for item in obj]
157+
else:
158+
return obj
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
If logprobs are enabled in the generation kwargs, return logprobs in ChatMessage.meta for `OpenAIChatGenerator` and `OpenAIResponsesChatGenerator`.

test/components/generators/chat/test_openai.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def test_to_dict_with_parameters(self, monkeypatch, calendar_event_model):
281281
"max_completion_tokens": 10,
282282
"some_test_param": "test-params",
283283
"response_format": calendar_event_model,
284+
"logprobs": True,
284285
},
285286
tools=[tool],
286287
tools_strict=True,
@@ -303,6 +304,7 @@ def test_to_dict_with_parameters(self, monkeypatch, calendar_event_model):
303304
"generation_kwargs": {
304305
"max_completion_tokens": 10,
305306
"some_test_param": "test-params",
307+
"logprobs": True,
306308
"response_format": {
307309
"type": "json_schema",
308310
"json_schema": {
@@ -804,14 +806,15 @@ def test_run_with_response_format_and_streaming_pydantic_model(self, calendar_ev
804806
@pytest.mark.integration
805807
def test_live_run(self):
806808
chat_messages = [ChatMessage.from_user("What's the capital of France")]
807-
component = OpenAIChatGenerator(generation_kwargs={"n": 1})
809+
component = OpenAIChatGenerator(generation_kwargs={"n": 1, "logprobs": True})
808810
results = component.run(chat_messages)
809811
assert len(results["replies"]) == 1
810812
message: ChatMessage = results["replies"][0]
811813
assert "Paris" in message.text
812814
assert "gpt-4o" in message.meta["model"]
813815
assert message.meta["finish_reason"] == "stop"
814816
assert message.meta["usage"]["prompt_tokens"] > 0
817+
assert message.meta["logprobs"] is not None
815818

816819
@pytest.mark.skipif(
817820
not os.environ.get("OPENAI_API_KEY", None),
@@ -987,7 +990,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
987990

988991
callback = Callback()
989992
component = OpenAIChatGenerator(
990-
streaming_callback=callback, generation_kwargs={"stream_options": {"include_usage": True}}
993+
streaming_callback=callback, generation_kwargs={"stream_options": {"include_usage": True}, "logprobs": True}
991994
)
992995
results = component.run([ChatMessage.from_user("What's the capital of France?")])
993996

@@ -1002,6 +1005,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
10021005
metadata = message.meta
10031006
assert "gpt-4o" in metadata["model"]
10041007
assert metadata["finish_reason"] == "stop"
1008+
assert metadata["logprobs"] is not None
10051009

10061010
# Usage information checks
10071011
assert isinstance(metadata.get("usage"), dict), "meta.usage not a dict"

test/components/generators/chat/test_openai_responses.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -539,15 +539,18 @@ def warm_up(self):
539539
@pytest.mark.integration
540540
def test_live_run(self):
541541
chat_messages = [ChatMessage.from_user("What's the capital of France")]
542-
component = OpenAIResponsesChatGenerator()
542+
component = OpenAIResponsesChatGenerator(
543+
model="gpt-4", generation_kwargs={"include": ["message.output_text.logprobs"]}
544+
)
543545
results = component.run(chat_messages)
544546
assert len(results["replies"]) == 1
545547
message: ChatMessage = results["replies"][0]
546548
assert "Paris" in message.text
547-
assert "gpt-5-mini" in message.meta["model"]
549+
assert "gpt-4" in message.meta["model"]
548550
assert message.meta["status"] == "completed"
549551
assert message.meta["usage"]["total_tokens"] > 0
550552
assert message.meta["id"] is not None
553+
assert message.meta["logprobs"] is not None
551554

552555
@pytest.mark.skipif(
553556
not os.environ.get("OPENAI_API_KEY", None),
@@ -692,7 +695,9 @@ def __call__(self, chunk: StreamingChunk) -> None:
692695
self.responses += chunk.content if chunk.content else ""
693696

694697
callback = Callback()
695-
component = OpenAIResponsesChatGenerator(streaming_callback=callback)
698+
component = OpenAIResponsesChatGenerator(
699+
model="gpt-4", streaming_callback=callback, generation_kwargs={"include": ["message.output_text.logprobs"]}
700+
)
696701
results = component.run([ChatMessage.from_user("What's the capital of France?")])
697702

698703
# Basic response checks
@@ -704,8 +709,8 @@ def __call__(self, chunk: StreamingChunk) -> None:
704709

705710
# Metadata checks
706711
metadata = message.meta
707-
assert "gpt-5-mini" in metadata["model"]
708-
712+
assert "gpt-4" in metadata["model"]
713+
assert metadata["logprobs"] is not None
709714
# Usage information checks
710715
assert isinstance(metadata.get("usage"), dict), "meta.usage not a dict"
711716
usage = metadata["usage"]
@@ -755,7 +760,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
755760
def test_live_run_with_tools_streaming(self, tools):
756761
chat_messages = [ChatMessage.from_user("What's the weather like in Paris and Berlin?")]
757762

758-
component = OpenAIResponsesChatGenerator(tools=tools, streaming_callback=print_streaming_chunk)
763+
component = OpenAIResponsesChatGenerator(model="gpt-5", tools=tools, streaming_callback=print_streaming_chunk)
759764
results = component.run(chat_messages)
760765
assert len(results["replies"]) == 1
761766
message = results["replies"][0]
@@ -764,12 +769,18 @@ def test_live_run_with_tools_streaming(self, tools):
764769
assert not message.text
765770
assert message.tool_calls
766771
tool_calls = message.tool_calls
767-
assert len(tool_calls) > 0
772+
assert len(tool_calls) == 2
768773

769774
for tool_call in tool_calls:
770775
assert isinstance(tool_call, ToolCall)
771776
assert tool_call.tool_name == "weather"
772777

778+
arguments = [tool_call.arguments for tool_call in tool_calls]
779+
# Extract city names (handle cases like "Berlin, Germany" -> "Berlin")
780+
city_values = [arg["city"].split(",")[0].strip().lower() for arg in arguments]
781+
assert "berlin" in city_values and "paris" in city_values
782+
assert len(city_values) == 2
783+
773784
@pytest.mark.skipif(
774785
not os.environ.get("OPENAI_API_KEY", None),
775786
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",

0 commit comments

Comments
 (0)