Skip to content

Commit 17d7ab5

Browse files
committed
fix up bedrock client
1 parent c723b3b commit 17d7ab5

File tree

1 file changed

+43
-17
lines changed

1 file changed

+43
-17
lines changed

asimov/services/inference_clients.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from asimov.asimov_base import AsimovBase
2323
from asimov.graph import NonRetryableException
2424

25+
logger = logging.getLogger(__name__)
2526
tracer = opentelemetry.trace.get_tracer(__name__)
2627
opentelemetry.instrumentation.httpx.HTTPXClientInstrumentor().instrument()
2728

28-
2929
class InferenceException(Exception):
3030
"""
3131
A generic exception for inference errors.
@@ -106,6 +106,7 @@ async def _trace(self, request, response):
106106
)
107107

108108
if self.trace_cb:
109+
logger.debug(f"Request {self._trace_id} cost {self._cost}")
109110
await self.trace_cb(self._trace_id, request, response, self._cost)
110111
self._cost = InferenceCost()
111112
self._trace_id += 1
@@ -221,30 +222,26 @@ async def tool_chain(
221222
await self._trace(serialized_messages, resp)
222223
break
223224
except ValueError as e:
224-
print(f"ValueError hit ({e}), bailing")
225+
logger.info(f"ValueError hit ({e}), bailing")
225226
return serialized_messages
226227
except NonRetryableException:
227-
print("Non-retryable exception hit, bailing")
228+
logger.info("Non-retryable exception hit, bailing")
228229
raise
229230
except InferenceException as e:
230-
print("inference exception", e)
231+
logger.info("inference exception %s", e)
231232
await asyncio.sleep(3**retry)
232233
if retry > 3:
233234
# Modify messages to try and cache bust in case we have a poison message or similar
234235
serialized_messages[0]["content"][0][
235236
"text"
236237
] += "\n\nTFJeD9K6smAnr6sUcllj"
237238
continue
238-
except Exception as e:
239-
print("generic inference exception", e)
240-
import traceback
241-
242-
traceback.print_exc()
243-
239+
except Exception:
240+
logger.warning("generic inference exception", exc_info=True)
244241
await asyncio.sleep(3**retry)
245242
continue
246243
else:
247-
print("Retries exceeded, bailing!")
244+
logger.info("Retries exceeded, bailing!")
248245
raise RetriesExceeded()
249246

250247
serialized_messages.append(
@@ -459,7 +456,13 @@ async def _tool_chain_stream(
459456
"tool_choice": {"type": tool_choice},
460457
}
461458
if system:
462-
request["system"] = system
459+
request["system"] = [
460+
{
461+
"type": "text",
462+
"text": system,
463+
"cache_control": {"type": "ephemeral"},
464+
}
465+
]
463466

464467
try:
465468
response = await client.invoke_model_with_response_stream(
@@ -487,6 +490,12 @@ async def _tool_chain_stream(
487490
self._cost.input_tokens += chunk_json["message"]["usage"][
488491
"input_tokens"
489492
]
493+
self._cost.cache_read_input_tokens += chunk_json["message"][
494+
"usage"
495+
].get("cache_read_input_tokens", 0)
496+
self._cost.cache_write_input_tokens += chunk_json["message"][
497+
"usage"
498+
].get("cache_creation_input_tokens", 0)
490499
elif chunk_type == "content_block_start":
491500
block_type = chunk_json["content_block"]["type"]
492501
if block_type == "text":
@@ -495,13 +504,20 @@ async def _tool_chain_stream(
495504
"text": "",
496505
}
497506
elif block_type == "tool_use":
498-
current_block["tool_use"] = {
507+
current_block = {
499508
"type": "tool_use",
500509
"id": chunk_json["content_block"]["id"],
501510
"name": chunk_json["content_block"]["name"],
502511
"input": {},
503512
}
504513
current_json = ""
514+
elif block_type == "thinking":
515+
current_block = {
516+
"type": "thinking",
517+
"thinking": chunk_json["content_block"]["thinking"],
518+
}
519+
elif block_type == "redacted_thinking":
520+
current_block = chunk_json["content_block"]
505521
elif chunk_type == "content_block_delta":
506522
if chunk_json["delta"]["type"] == "text_delta":
507523
current_block["text"] += chunk_json["delta"]["text"]
@@ -517,6 +533,12 @@ async def _tool_chain_stream(
517533
await middleware(current_block)
518534
except ValueError:
519535
pass
536+
elif chunk_json["delta"]["type"] == "thinking_delta":
537+
current_block["thinking"] += chunk_json["delta"]["thinking"]
538+
elif chunk_json["delta"]["type"] == "signature_delta":
539+
current_block["signature"] = chunk_json["delta"][
540+
"signature"
541+
]
520542
elif chunk_type == "content_block_stop":
521543
current_content.append(current_block)
522544

@@ -527,6 +549,12 @@ async def _tool_chain_stream(
527549
if chunk_json["delta"].get("stop_reason") == "tool_use":
528550
self._cost.output_tokens += chunk_json["usage"]["output_tokens"]
529551
break
552+
elif chunk_type == "error":
553+
if chunk_json["error"]["type"] == "invalid_request_error":
554+
raise ValueError(chunk_json["error"]["type"])
555+
raise InferenceException(chunk_json["error"]["type"])
556+
elif chunk_type == "ping":
557+
pass
530558

531559
return current_content
532560

@@ -862,8 +890,7 @@ async def _tool_chain_stream(
862890
elif chunk_type == "ping":
863891
pass
864892
else:
865-
print("Unknown message type from Anthropic stream.")
866-
print(chunk_json)
893+
logger.warning("Unknown message type from Anthropic stream.")
867894

868895
return current_content
869896

@@ -1251,7 +1278,7 @@ async def get_generation(
12511278
)
12521279
self._cost.output_tokens += body["usage"]["completion_tokens"]
12531280
except KeyError:
1254-
logging.warning(f"Malformed usage? {repr(body)}")
1281+
logger.warning(f"Malformed usage? {repr(body)}")
12551282

12561283
await self._trace(
12571284
request["messages"],
@@ -1525,7 +1552,6 @@ async def _populate_cost(self, id: str):
15251552
continue
15261553

15271554
if response.status_code != 200:
1528-
print(response.status_code, await response.aread())
15291555
await asyncio.sleep(0.5)
15301556
continue
15311557

0 commit comments

Comments
 (0)