Skip to content

Commit ca38042

Browse files
committed
fmt
1 parent 813ffdf commit ca38042

File tree

1 file changed

+32
-16
lines changed

1 file changed

+32
-16
lines changed

asimov/services/inference_clients.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
tracer = opentelemetry.trace.get_tracer(__name__)
2828
opentelemetry.instrumentation.httpx.HTTPXClientInstrumentor().instrument()
2929

30+
3031
class InferenceException(Exception):
3132
"""
3233
A generic exception for inference errors.
@@ -229,7 +230,9 @@ async def tool_chain(
229230
break
230231
except ContextLengthExceeded as e:
231232
if fifo_context:
232-
logger.info(f"ContextLengthExceeded ({e}), tossing early messages and retrying")
233+
logger.info(
234+
f"ContextLengthExceeded ({e}), tossing early messages and retrying"
235+
)
233236
# If we hit context length, remove a handful of assistant,user message pairs from the middle
234237
# A handful so that we can hopefully get at least a couple cache hits with this
235238
# truncated history before having to drop messages again.
@@ -246,17 +249,28 @@ async def tool_chain(
246249
f"Removing messages {start_remove} through {end_remove} from serialized messages"
247250
)
248251
end_remove += 1 # inclusive
249-
serialized_messages = serialized_messages[:start_remove] + serialized_messages[end_remove:]
252+
serialized_messages = (
253+
serialized_messages[:start_remove]
254+
+ serialized_messages[end_remove:]
255+
)
250256
for mode in last_mode_cached_message.keys():
251257
# Delete markers if they are in the removed range
252-
if start_remove <= last_mode_cached_message[mode] < end_remove:
258+
if (
259+
start_remove
260+
<= last_mode_cached_message[mode]
261+
< end_remove
262+
):
253263
del last_mode_cached_message[mode]
254264
# And adjust indices of anything that got "slid" back
255265
elif last_mode_cached_message[mode] >= end_remove:
256-
last_mode_cached_message[mode] -= end_remove - start_remove
266+
last_mode_cached_message[mode] -= (
267+
end_remove - start_remove
268+
)
257269
continue
258270
else:
259-
logger.info("Non-retryable exception hit (context length), bailing")
271+
logger.info(
272+
"Non-retryable exception hit (context length), bailing"
273+
)
260274
return serialized_messages
261275
except NonRetryableException as e:
262276
logger.info(f"Non-retryable exception hit ({e}), bailing")
@@ -573,9 +587,7 @@ async def _tool_chain_stream(
573587
elif chunk_json["delta"]["type"] == "thinking_delta":
574588
current_block["thinking"] += chunk_json["delta"]["thinking"]
575589
elif chunk_json["delta"]["type"] == "signature_delta":
576-
current_block["signature"] = chunk_json["delta"][
577-
"signature"
578-
]
590+
current_block["signature"] = chunk_json["delta"]["signature"]
579591
elif chunk_type == "content_block_stop":
580592
current_content.append(current_block)
581593

@@ -1489,8 +1501,10 @@ def wrap_tool_schema(tool):
14891501
processed_messages = []
14901502

14911503
for message in serialized_messages:
1492-
if (not self.include_cache_control) and isinstance(message['content'], list):
1493-
for blk in message['content']:
1504+
if (not self.include_cache_control) and isinstance(
1505+
message["content"], list
1506+
):
1507+
for blk in message["content"]:
14941508
blk.pop("cache_control", None)
14951509

14961510
if (
@@ -1683,8 +1697,10 @@ async def _tool_chain_stream(
16831697

16841698
openrouter_messages = []
16851699
for message in serialized_messages:
1686-
if (not self.include_cache_control) and isinstance(message['content'], list):
1687-
for blk in message['content']:
1700+
if (not self.include_cache_control) and isinstance(
1701+
message["content"], list
1702+
):
1703+
for blk in message["content"]:
16881704
blk.pop("cache_control", None)
16891705

16901706
if (
@@ -1808,10 +1824,10 @@ async def _tool_chain_stream(
18081824
if "id" in data:
18091825
id = data["id"]
18101826
if data.get("error"):
1811-
if "context" in str(data['error']):
1812-
raise ContextLengthExceeded(str(data['error']))
1813-
elif "invalid_request_error" in str(data['error']):
1814-
raise NonRetryableException(str(data['error']))
1827+
if "context" in str(data["error"]):
1828+
raise ContextLengthExceeded(str(data["error"]))
1829+
elif "invalid_request_error" in str(data["error"]):
1830+
raise NonRetryableException(str(data["error"]))
18151831
raise InferenceException(
18161832
data["error"]["message"] + f" ({data['error']})"
18171833
)

0 commit comments

Comments
 (0)