2121import google .auth
2222
2323from asimov .asimov_base import AsimovBase
24+ from asimov .utils .token_counter import approx_tokens_from_serialized_messages
2425from asimov .graph import NonRetryableException
2526
2627logger = logging .getLogger (__name__ )
@@ -160,7 +161,7 @@ async def tool_chain(
160161 Awaitable [tuple [str , List [Tuple [Callable , Dict [str , Any ]]], Hashable ]],
161162 ]
162163 ] = None ,
163- fifo_context : bool = False ,
164+ fifo_ratio : Optional [ float ] = None ,
164165 ):
165166 mode = None
166167 if mode_swap_callback :
@@ -208,6 +209,46 @@ async def tool_chain(
208209
209210 last_mode_cached_message [mode ] = len (serialized_messages ) - 1
210211
212+ tokens = approx_tokens_from_serialized_messages (serialized_messages )
213+
214+ if fifo_ratio and (tokens / 200000 ) > fifo_ratio :
215+ logger .info (
216+ f"ContextLengthExceeded ({ e } ), tossing early messages and retrying"
217+ )
218+ # If we hit context length, remove a handful of assistant,user message pairs from the middle
219+ # A handful so that we can hopefully get at least a couple cache hits with this
220+ # truncated history before having to drop messages again.
221+
222+ # We want the earliest thing we remove to be an assistant message (requesting the next tool call), which have odd indices
223+ start_remove = int (len (serialized_messages ) / 3 )
224+ if start_remove % 2 != 1 :
225+ start_remove += 1
226+ # And the last thing we remove should be a user message (with tool response), which have even indices
227+ end_remove = int (len (serialized_messages ) * 2 / 3 )
228+ if end_remove % 2 != 0 :
229+ end_remove -= 1
230+ logger .debug (
231+ f"Removing messages { start_remove } through { end_remove } from serialized messages"
232+ )
233+ end_remove += 1 # inclusive
234+ serialized_messages = (
235+ serialized_messages [:start_remove ]
236+ + serialized_messages [end_remove :]
237+ )
238+ for mode in last_mode_cached_message .keys ():
239+ # Delete markers if they are in the removed range
240+ if (
241+ start_remove
242+ <= last_mode_cached_message [mode ]
243+ < end_remove
244+ ):
245+ del last_mode_cached_message [mode ]
246+ # And adjust indices of anything that got "slid" back
247+ elif last_mode_cached_message [mode ] >= end_remove :
248+ last_mode_cached_message [mode ] -= (
249+ end_remove - start_remove
250+ )
251+
211252 for retry in range (1 , 5 ):
212253 try :
213254 resp = await self ._tool_chain_stream (
@@ -229,49 +270,10 @@ async def tool_chain(
229270 await self ._trace (serialized_messages , resp )
230271 break
231272 except ContextLengthExceeded as e :
232- if fifo_context :
233- logger .info (
234- f"ContextLengthExceeded ({ e } ), tossing early messages and retrying"
235- )
236- # If we hit context length, remove a handful of assistant,user message pairs from the middle
237- # A handful so that we can hopefully get at least a couple cache hits with this
238- # truncated history before having to drop messages again.
239-
240- # We want the earliest thing we remove to be an assistant message (requesting the next tool call), which have odd indices
241- start_remove = int (len (serialized_messages ) / 3 )
242- if start_remove % 2 != 1 :
243- start_remove += 1
244- # And the last thing we remove should be a user message (with tool response), which have even indices
245- end_remove = int (len (serialized_messages ) * 2 / 3 )
246- if end_remove % 2 != 0 :
247- end_remove -= 1
248- logger .debug (
249- f"Removing messages { start_remove } through { end_remove } from serialized messages"
250- )
251- end_remove += 1 # inclusive
252- serialized_messages = (
253- serialized_messages [:start_remove ]
254- + serialized_messages [end_remove :]
255- )
256- for mode in last_mode_cached_message .keys ():
257- # Delete markers if they are in the removed range
258- if (
259- start_remove
260- <= last_mode_cached_message [mode ]
261- < end_remove
262- ):
263- del last_mode_cached_message [mode ]
264- # And adjust indices of anything that got "slid" back
265- elif last_mode_cached_message [mode ] >= end_remove :
266- last_mode_cached_message [mode ] -= (
267- end_remove - start_remove
268- )
269- continue
270- else :
271- logger .info (
272- "Non-retryable exception hit (context length), bailing"
273- )
274- return serialized_messages
273+ logger .info (
274+ "Non-retryable exception hit (context length), bailing"
275+ )
276+ return serialized_messages
275277 except NonRetryableException as e :
276278 logger .info (f"Non-retryable exception hit ({ e } ), bailing" )
277279 raise
0 commit comments