@@ -30,12 +30,16 @@ class InferenceException(Exception):
3030 """
3131 A generic exception for inference errors.
3232 Should be safe to retry.
33- ValueError is raised if the request is un-retryable (e.g. parameters are malformed, or the request is too large ).
33+ ValueError is raised if the request is un-retryable (e.g. parameters are malformed).
3434 """
3535
3636 pass
3737
3838
39+ class ContextLengthExceeded (InferenceException , ValueError ):
40+ pass
41+
42+
3943class RetriesExceeded (InferenceException ):
4044 """
4145 Raised when the maximum number of retries is exceeded.
@@ -154,6 +158,7 @@ async def tool_chain(
154158 Awaitable [tuple [str , List [Tuple [Callable , Dict [str , Any ]]], Hashable ]],
155159 ]
156160 ] = None ,
161+ fifo_context : bool = False ,
157162 ):
158163 mode = None
159164 if mode_swap_callback :
@@ -221,12 +226,43 @@ async def tool_chain(
221226
222227 await self ._trace (serialized_messages , resp )
223228 break
229+ except ContextLengthExceeded as e :
230+ if fifo_context :
231+ logger .info (f"ContextLengthExceeded ({ e } ), tossing early messages and retrying" )
232+ # If we hit context length, remove a handful of assistant,user message pairs from the middle
233+ # A handful so that we can hopefully get at least a couple cache hits with this
234+ # truncated history before having to drop messages again.
235+
236+ # We want the earliest thing we remove to be an assistant message (requesting the next tool call), which have odd indices
237+ start_remove = int (len (serialized_messages ) / 3 )
238+ if start_remove % 2 != 1 :
239+ start_remove += 1
240+ # And the last thing we remove should be a user message (with tool response), which have even indices
241+ end_remove = int (len (serialized_messages ) * 2 / 3 )
242+ if end_remove - start_remove % 2 != 0 :
243+ end_remove -= 1
244+ logger .debug (
245+ f"Removing messages { start_remove } to { end_remove } from serialized messages"
246+ )
247+ end_remove += 1 # inclusive
248+ serialized_messages = serialized_messages [:start_remove ] + serialized_messages [end_remove :]
249+ for mode in last_mode_cached_message .keys ():
250+ # Delete markers if they are in the removed range
251+ if start_remove <= last_mode_cached_message [mode ] < end_remove :
252+ del last_mode_cached_message [mode ]
253+ # And adjust indices of anything that got "slid" back
254+ elif last_mode_cached_message [mode ] >= end_remove :
255+ last_mode_cached_message [mode ] -= end_remove - start_remove
256+ continue
257+ else :
258+ logger .info ("Non-retryable exception hit (context length), bailing" )
259+ return serialized_messages
260+ except NonRetryableException as e :
261+ logger .info (f"Non-retryable exception hit ({ e } ), bailing" )
262+ raise
224263 except ValueError as e :
225264 logger .info (f"ValueError hit ({ e } ), bailing" )
226265 return serialized_messages
227- except NonRetryableException :
228- logger .info ("Non-retryable exception hit, bailing" )
229- raise
230266 except InferenceException as e :
231267 logger .info ("inference exception %s" , e )
232268 await asyncio .sleep (3 ** retry )
@@ -629,6 +665,7 @@ async def get_generation(
629665 )
630666
631667 if response .status_code == 400 :
668+ # TODO: ContextLengthExceeded
632669 raise ValueError (await response .aread ())
633670 elif response .status_code != 200 :
634671 raise InferenceException (await response .aread ())
@@ -1354,7 +1391,7 @@ async def connect_and_listen(
13541391 request ["messages" ],
13551392 [{"text" : out }],
13561393 )
1357-
1394+
13581395 @property
13591396 def include_cache_control (self ):
13601397 return "anthropic" in self .model
@@ -1715,7 +1752,9 @@ async def _tool_chain_stream(
17151752 if "id" in data :
17161753 id = data ["id" ]
17171754 if data .get ("error" ):
1718- if "invalid_request_error" in str (data ['error' ]):
1755+ if "context" in str (data ['error' ]):
1756+ raise ContextLengthExceeded (str (data ['error' ]))
1757+ elif "invalid_request_error" in str (data ['error' ]):
17191758 raise NonRetryableException (str (data ['error' ]))
17201759 raise InferenceException (
17211760 data ["error" ]["message" ] + f" ({ data ['error' ]} )"
0 commit comments