2727tracer = opentelemetry .trace .get_tracer (__name__ )
2828opentelemetry .instrumentation .httpx .HTTPXClientInstrumentor ().instrument ()
2929
30+
3031class InferenceException (Exception ):
3132 """
3233 A generic exception for inference errors.
@@ -229,7 +230,9 @@ async def tool_chain(
229230 break
230231 except ContextLengthExceeded as e :
231232 if fifo_context :
232- logger .info (f"ContextLengthExceeded ({ e } ), tossing early messages and retrying" )
233+ logger .info (
234+ f"ContextLengthExceeded ({ e } ), tossing early messages and retrying"
235+ )
233236 # If we hit context length, remove a handful of assistant,user message pairs from the middle
234237 # A handful so that we can hopefully get at least a couple cache hits with this
235238 # truncated history before having to drop messages again.
@@ -246,17 +249,28 @@ async def tool_chain(
246249 f"Removing messages { start_remove } through { end_remove } from serialized messages"
247250 )
248251 end_remove += 1 # inclusive
249- serialized_messages = serialized_messages [:start_remove ] + serialized_messages [end_remove :]
252+ serialized_messages = (
253+ serialized_messages [:start_remove ]
254+ + serialized_messages [end_remove :]
255+ )
250256 for mode in last_mode_cached_message .keys ():
251257 # Delete markers if they are in the removed range
252- if start_remove <= last_mode_cached_message [mode ] < end_remove :
258+ if (
259+ start_remove
260+ <= last_mode_cached_message [mode ]
261+ < end_remove
262+ ):
253263 del last_mode_cached_message [mode ]
254264 # And adjust indices of anything that got "slid" back
255265 elif last_mode_cached_message [mode ] >= end_remove :
256- last_mode_cached_message [mode ] -= end_remove - start_remove
266+ last_mode_cached_message [mode ] -= (
267+ end_remove - start_remove
268+ )
257269 continue
258270 else :
259- logger .info ("Non-retryable exception hit (context length), bailing" )
271+ logger .info (
272+ "Non-retryable exception hit (context length), bailing"
273+ )
260274 return serialized_messages
261275 except NonRetryableException as e :
262276 logger .info (f"Non-retryable exception hit ({ e } ), bailing" )
@@ -573,9 +587,7 @@ async def _tool_chain_stream(
573587 elif chunk_json ["delta" ]["type" ] == "thinking_delta" :
574588 current_block ["thinking" ] += chunk_json ["delta" ]["thinking" ]
575589 elif chunk_json ["delta" ]["type" ] == "signature_delta" :
576- current_block ["signature" ] = chunk_json ["delta" ][
577- "signature"
578- ]
590+ current_block ["signature" ] = chunk_json ["delta" ]["signature" ]
579591 elif chunk_type == "content_block_stop" :
580592 current_content .append (current_block )
581593
@@ -1489,8 +1501,10 @@ def wrap_tool_schema(tool):
14891501 processed_messages = []
14901502
14911503 for message in serialized_messages :
1492- if (not self .include_cache_control ) and isinstance (message ['content' ], list ):
1493- for blk in message ['content' ]:
1504+ if (not self .include_cache_control ) and isinstance (
1505+ message ["content" ], list
1506+ ):
1507+ for blk in message ["content" ]:
14941508 blk .pop ("cache_control" , None )
14951509
14961510 if (
@@ -1683,8 +1697,10 @@ async def _tool_chain_stream(
16831697
16841698 openrouter_messages = []
16851699 for message in serialized_messages :
1686- if (not self .include_cache_control ) and isinstance (message ['content' ], list ):
1687- for blk in message ['content' ]:
1700+ if (not self .include_cache_control ) and isinstance (
1701+ message ["content" ], list
1702+ ):
1703+ for blk in message ["content" ]:
16881704 blk .pop ("cache_control" , None )
16891705
16901706 if (
@@ -1808,10 +1824,10 @@ async def _tool_chain_stream(
18081824 if "id" in data :
18091825 id = data ["id" ]
18101826 if data .get ("error" ):
1811- if "context" in str (data [' error' ]):
1812- raise ContextLengthExceeded (str (data [' error' ]))
1813- elif "invalid_request_error" in str (data [' error' ]):
1814- raise NonRetryableException (str (data [' error' ]))
1827+ if "context" in str (data [" error" ]):
1828+ raise ContextLengthExceeded (str (data [" error" ]))
1829+ elif "invalid_request_error" in str (data [" error" ]):
1830+ raise NonRetryableException (str (data [" error" ]))
18151831 raise InferenceException (
18161832 data ["error" ]["message" ] + f" ({ data ['error' ]} )"
18171833 )
0 commit comments