Skip to content

Commit 8803e55

Browse files
committed
auto context reduction
1 parent 17d7ab5 commit 8803e55

File tree

1 file changed

+45
-6
lines changed

1 file changed

+45
-6
lines changed

asimov/services/inference_clients.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,16 @@ class InferenceException(Exception):
3030
"""
3131
A generic exception for inference errors.
3232
Should be safe to retry.
33-
ValueError is raised if the request is un-retryable (e.g. parameters are malformed, or the request is too large).
33+
ValueError is raised if the request is un-retryable (e.g. parameters are malformed).
3434
"""
3535

3636
pass
3737

3838

39+
class ContextLengthExceeded(InferenceException, ValueError):
40+
pass
41+
42+
3943
class RetriesExceeded(InferenceException):
4044
"""
4145
Raised when the maximum number of retries is exceeded.
@@ -154,6 +158,7 @@ async def tool_chain(
154158
Awaitable[tuple[str, List[Tuple[Callable, Dict[str, Any]]], Hashable]],
155159
]
156160
] = None,
161+
fifo_context: bool = False,
157162
):
158163
mode = None
159164
if mode_swap_callback:
@@ -221,12 +226,43 @@ async def tool_chain(
221226

222227
await self._trace(serialized_messages, resp)
223228
break
229+
except ContextLengthExceeded as e:
230+
if fifo_context:
231+
logger.info(f"ContextLengthExceeded ({e}), tossing early messages and retrying")
232+
# If we hit context length, remove a handful of assistant,user message pairs from the middle
233+
# A handful so that we can hopefully get at least a couple cache hits with this
234+
# truncated history before having to drop messages again.
235+
236+
# We want the earliest thing we remove to be an assistant message (requesting the next tool call), which have odd indices
237+
start_remove = int(len(serialized_messages) / 3)
238+
if start_remove % 2 != 1:
239+
start_remove += 1
240+
# And the last thing we remove should be a user message (with tool response), which have even indices
241+
end_remove = int(len(serialized_messages) * 2 / 3)
242+
if end_remove - start_remove % 2 != 0:
243+
end_remove -= 1
244+
logger.debug(
245+
f"Removing messages {start_remove} to {end_remove} from serialized messages"
246+
)
247+
end_remove += 1 # inclusive
248+
serialized_messages = serialized_messages[:start_remove] + serialized_messages[end_remove:]
249+
for mode in last_mode_cached_message.keys():
250+
# Delete markers if they are in the removed range
251+
if start_remove <= last_mode_cached_message[mode] < end_remove:
252+
del last_mode_cached_message[mode]
253+
# And adjust indices of anything that got "slid" back
254+
elif last_mode_cached_message[mode] >= end_remove:
255+
last_mode_cached_message[mode] -= end_remove - start_remove
256+
continue
257+
else:
258+
logger.info("Non-retryable exception hit (context length), bailing")
259+
return serialized_messages
260+
except NonRetryableException as e:
261+
logger.info(f"Non-retryable exception hit ({e}), bailing")
262+
raise
224263
except ValueError as e:
225264
logger.info(f"ValueError hit ({e}), bailing")
226265
return serialized_messages
227-
except NonRetryableException:
228-
logger.info("Non-retryable exception hit, bailing")
229-
raise
230266
except InferenceException as e:
231267
logger.info("inference exception %s", e)
232268
await asyncio.sleep(3**retry)
@@ -629,6 +665,7 @@ async def get_generation(
629665
)
630666

631667
if response.status_code == 400:
668+
# TODO: ContextLengthExceeded
632669
raise ValueError(await response.aread())
633670
elif response.status_code != 200:
634671
raise InferenceException(await response.aread())
@@ -1354,7 +1391,7 @@ async def connect_and_listen(
13541391
request["messages"],
13551392
[{"text": out}],
13561393
)
1357-
1394+
13581395
@property
13591396
def include_cache_control(self):
13601397
return "anthropic" in self.model
@@ -1715,7 +1752,9 @@ async def _tool_chain_stream(
17151752
if "id" in data:
17161753
id = data["id"]
17171754
if data.get("error"):
1718-
if "invalid_request_error" in str(data['error']):
1755+
if "context" in str(data['error']):
1756+
raise ContextLengthExceeded(str(data['error']))
1757+
elif "invalid_request_error" in str(data['error']):
17191758
raise NonRetryableException(str(data['error']))
17201759
raise InferenceException(
17211760
data["error"]["message"] + f" ({data['error']})"

0 commit comments

Comments
 (0)