2222from asimov .asimov_base import AsimovBase
2323from asimov .graph import NonRetryableException
2424
25+ logger = logging .getLogger (__name__ )
2526tracer = opentelemetry .trace .get_tracer (__name__ )
2627opentelemetry .instrumentation .httpx .HTTPXClientInstrumentor ().instrument ()
2728
28-
2929class InferenceException (Exception ):
3030 """
3131 A generic exception for inference errors.
@@ -106,6 +106,7 @@ async def _trace(self, request, response):
106106 )
107107
108108 if self .trace_cb :
109+ logger .debug (f"Request { self ._trace_id } cost { self ._cost } " )
109110 await self .trace_cb (self ._trace_id , request , response , self ._cost )
110111 self ._cost = InferenceCost ()
111112 self ._trace_id += 1
@@ -221,30 +222,26 @@ async def tool_chain(
221222 await self ._trace (serialized_messages , resp )
222223 break
223224 except ValueError as e :
224- print (f"ValueError hit ({ e } ), bailing" )
225+ logger . info (f"ValueError hit ({ e } ), bailing" )
225226 return serialized_messages
226227 except NonRetryableException :
227- print ("Non-retryable exception hit, bailing" )
228+ logger . info ("Non-retryable exception hit, bailing" )
228229 raise
229230 except InferenceException as e :
230- print ("inference exception" , e )
231+ logger . info ("inference exception %s " , e )
231232 await asyncio .sleep (3 ** retry )
232233 if retry > 3 :
233234 # Modify messages to try and cache bust in case we have a poison message or similar
234235 serialized_messages [0 ]["content" ][0 ][
235236 "text"
236237 ] += "\n \n TFJeD9K6smAnr6sUcllj"
237238 continue
238- except Exception as e :
239- print ("generic inference exception" , e )
240- import traceback
241-
242- traceback .print_exc ()
243-
239+ except Exception :
240+ logger .warning ("generic inference exception" , exc_info = True )
244241 await asyncio .sleep (3 ** retry )
245242 continue
246243 else :
247- print ("Retries exceeded, bailing!" )
244+ logger . info ("Retries exceeded, bailing!" )
248245 raise RetriesExceeded ()
249246
250247 serialized_messages .append (
@@ -459,7 +456,13 @@ async def _tool_chain_stream(
459456 "tool_choice" : {"type" : tool_choice },
460457 }
461458 if system :
462- request ["system" ] = system
459+ request ["system" ] = [
460+ {
461+ "type" : "text" ,
462+ "text" : system ,
463+ "cache_control" : {"type" : "ephemeral" },
464+ }
465+ ]
463466
464467 try :
465468 response = await client .invoke_model_with_response_stream (
@@ -487,6 +490,12 @@ async def _tool_chain_stream(
487490 self ._cost .input_tokens += chunk_json ["message" ]["usage" ][
488491 "input_tokens"
489492 ]
493+ self ._cost .cache_read_input_tokens += chunk_json ["message" ][
494+ "usage"
495+ ].get ("cache_read_input_tokens" , 0 )
496+ self ._cost .cache_write_input_tokens += chunk_json ["message" ][
497+ "usage"
498+ ].get ("cache_creation_input_tokens" , 0 )
490499 elif chunk_type == "content_block_start" :
491500 block_type = chunk_json ["content_block" ]["type" ]
492501 if block_type == "text" :
@@ -495,13 +504,20 @@ async def _tool_chain_stream(
495504 "text" : "" ,
496505 }
497506 elif block_type == "tool_use" :
498- current_block [ "tool_use" ] = {
507+ current_block = {
499508 "type" : "tool_use" ,
500509 "id" : chunk_json ["content_block" ]["id" ],
501510 "name" : chunk_json ["content_block" ]["name" ],
502511 "input" : {},
503512 }
504513 current_json = ""
514+ elif block_type == "thinking" :
515+ current_block = {
516+ "type" : "thinking" ,
517+ "thinking" : chunk_json ["content_block" ]["thinking" ],
518+ }
519+ elif block_type == "redacted_thinking" :
520+ current_block = chunk_json ["content_block" ]
505521 elif chunk_type == "content_block_delta" :
506522 if chunk_json ["delta" ]["type" ] == "text_delta" :
507523 current_block ["text" ] += chunk_json ["delta" ]["text" ]
@@ -517,6 +533,12 @@ async def _tool_chain_stream(
517533 await middleware (current_block )
518534 except ValueError :
519535 pass
536+ elif chunk_json ["delta" ]["type" ] == "thinking_delta" :
537+ current_block ["thinking" ] += chunk_json ["delta" ]["thinking" ]
538+ elif chunk_json ["delta" ]["type" ] == "signature_delta" :
539+ current_block ["signature" ] = chunk_json ["delta" ][
540+ "signature"
541+ ]
520542 elif chunk_type == "content_block_stop" :
521543 current_content .append (current_block )
522544
@@ -527,6 +549,12 @@ async def _tool_chain_stream(
527549 if chunk_json ["delta" ].get ("stop_reason" ) == "tool_use" :
528550 self ._cost .output_tokens += chunk_json ["usage" ]["output_tokens" ]
529551 break
552+ elif chunk_type == "error" :
553+ if chunk_json ["error" ]["type" ] == "invalid_request_error" :
554+ raise ValueError (chunk_json ["error" ]["type" ])
555+ raise InferenceException (chunk_json ["error" ]["type" ])
556+ elif chunk_type == "ping" :
557+ pass
530558
531559 return current_content
532560
@@ -862,8 +890,7 @@ async def _tool_chain_stream(
862890 elif chunk_type == "ping" :
863891 pass
864892 else :
865- print ("Unknown message type from Anthropic stream." )
866- print (chunk_json )
893+ logger .warning ("Unknown message type from Anthropic stream." )
867894
868895 return current_content
869896
@@ -1251,7 +1278,7 @@ async def get_generation(
12511278 )
12521279 self ._cost .output_tokens += body ["usage" ]["completion_tokens" ]
12531280 except KeyError :
1254- logging .warning (f"Malformed usage? { repr (body )} " )
1281+ logger .warning (f"Malformed usage? { repr (body )} " )
12551282
12561283 await self ._trace (
12571284 request ["messages" ],
@@ -1525,7 +1552,6 @@ async def _populate_cost(self, id: str):
15251552 continue
15261553
15271554 if response .status_code != 200 :
1528- print (response .status_code , await response .aread ())
15291555 await asyncio .sleep (0.5 )
15301556 continue
15311557
0 commit comments