Skip to content

Commit b929f8b

Browse files
committed
WIP for context management ratio
1 parent 16ac72a commit b929f8b

File tree

3 files changed

+99
-50
lines changed

3 files changed

+99
-50
lines changed

asimov/caches/cache.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,20 @@ async def apply_key_modifications(self, key: str) -> str:
3838

3939
@asynccontextmanager
4040
async def with_prefix(self, prefix: str):
41-
old_prefix = await self.get_prefix()
42-
self._prefix.set(prefix)
41+
token = self._prefix.set(prefix)
4342
try:
4443
yield self
4544
finally:
46-
self._prefix.set(old_prefix)
45+
self._prefix.reset(token)
4746

4847
@asynccontextmanager
4948
async def with_suffix(self, suffix: str):
50-
old_suffix = await self.get_suffix()
51-
self._suffix.set(suffix)
49+
token = self._suffix.set(suffix)
5250
try:
5351
yield self
5452
finally:
55-
self._suffix.set(old_suffix)
53+
self._suffix.reset(token)
54+
5655

5756
def __getitem__(self, key: str):
5857
return self.get(key)

asimov/services/inference_clients.py

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import google.auth
2222

2323
from asimov.asimov_base import AsimovBase
24+
from asimov.utils.token_counter import approx_tokens_from_serialized_messages
2425
from asimov.graph import NonRetryableException
2526

2627
logger = logging.getLogger(__name__)
@@ -160,7 +161,7 @@ async def tool_chain(
160161
Awaitable[tuple[str, List[Tuple[Callable, Dict[str, Any]]], Hashable]],
161162
]
162163
] = None,
163-
fifo_context: bool = False,
164+
fifo_ratio: Optional[float] = None,
164165
):
165166
mode = None
166167
if mode_swap_callback:
@@ -208,6 +209,46 @@ async def tool_chain(
208209

209210
last_mode_cached_message[mode] = len(serialized_messages) - 1
210211

212+
tokens = approx_tokens_from_serialized_messages(serialized_messages)
213+
214+
if fifo_ratio and (tokens / 200000) > fifo_ratio:
215+
logger.info(
216+
f"ContextLengthExceeded ({e}), tossing early messages and retrying"
217+
)
218+
# If we hit context length, remove a handful of assistant,user message pairs from the middle
219+
# A handful so that we can hopefully get at least a couple cache hits with this
220+
# truncated history before having to drop messages again.
221+
222+
# We want the earliest thing we remove to be an assistant message (requesting the next tool call), which have odd indices
223+
start_remove = int(len(serialized_messages) / 3)
224+
if start_remove % 2 != 1:
225+
start_remove += 1
226+
# And the last thing we remove should be a user message (with tool response), which have even indices
227+
end_remove = int(len(serialized_messages) * 2 / 3)
228+
if end_remove % 2 != 0:
229+
end_remove -= 1
230+
logger.debug(
231+
f"Removing messages {start_remove} through {end_remove} from serialized messages"
232+
)
233+
end_remove += 1 # inclusive
234+
serialized_messages = (
235+
serialized_messages[:start_remove]
236+
+ serialized_messages[end_remove:]
237+
)
238+
for mode in last_mode_cached_message.keys():
239+
# Delete markers if they are in the removed range
240+
if (
241+
start_remove
242+
<= last_mode_cached_message[mode]
243+
< end_remove
244+
):
245+
del last_mode_cached_message[mode]
246+
# And adjust indices of anything that got "slid" back
247+
elif last_mode_cached_message[mode] >= end_remove:
248+
last_mode_cached_message[mode] -= (
249+
end_remove - start_remove
250+
)
251+
211252
for retry in range(1, 5):
212253
try:
213254
resp = await self._tool_chain_stream(
@@ -229,49 +270,10 @@ async def tool_chain(
229270
await self._trace(serialized_messages, resp)
230271
break
231272
except ContextLengthExceeded as e:
232-
if fifo_context:
233-
logger.info(
234-
f"ContextLengthExceeded ({e}), tossing early messages and retrying"
235-
)
236-
# If we hit context length, remove a handful of assistant,user message pairs from the middle
237-
# A handful so that we can hopefully get at least a couple cache hits with this
238-
# truncated history before having to drop messages again.
239-
240-
# We want the earliest thing we remove to be an assistant message (requesting the next tool call), which have odd indices
241-
start_remove = int(len(serialized_messages) / 3)
242-
if start_remove % 2 != 1:
243-
start_remove += 1
244-
# And the last thing we remove should be a user message (with tool response), which have even indices
245-
end_remove = int(len(serialized_messages) * 2 / 3)
246-
if end_remove % 2 != 0:
247-
end_remove -= 1
248-
logger.debug(
249-
f"Removing messages {start_remove} through {end_remove} from serialized messages"
250-
)
251-
end_remove += 1 # inclusive
252-
serialized_messages = (
253-
serialized_messages[:start_remove]
254-
+ serialized_messages[end_remove:]
255-
)
256-
for mode in last_mode_cached_message.keys():
257-
# Delete markers if they are in the removed range
258-
if (
259-
start_remove
260-
<= last_mode_cached_message[mode]
261-
< end_remove
262-
):
263-
del last_mode_cached_message[mode]
264-
# And adjust indices of anything that got "slid" back
265-
elif last_mode_cached_message[mode] >= end_remove:
266-
last_mode_cached_message[mode] -= (
267-
end_remove - start_remove
268-
)
269-
continue
270-
else:
271-
logger.info(
272-
"Non-retryable exception hit (context length), bailing"
273-
)
274-
return serialized_messages
273+
logger.info(
274+
"Non-retryable exception hit (context length), bailing"
275+
)
276+
return serialized_messages
275277
except NonRetryableException as e:
276278
logger.info(f"Non-retryable exception hit ({e}), bailing")
277279
raise

asimov/utils/token_counter.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# utils/approx_token_count.py
2+
from typing import Any, Dict, List
3+
import math
4+
5+
AVG_CHARS_PER_TOKEN = 4 # heuristic—you can tweak if your data skews long/short
6+
TOKENS_PER_MSG = 4 # ChatML fixed overhead (role, separators, etc.)
7+
TOKENS_PER_NAME = -1 # spec quirk: “name” field shaves one token
8+
END_OF_REQ_TOKENS = 2 # every request implicitly ends with: <assistant|ANSWER>
9+
10+
def approx_tokens_from_serialized_messages(
11+
serialized_messages: List[Dict[str, Any]],
12+
avg_chars_per_token: int = AVG_CHARS_PER_TOKEN,
13+
) -> int:
14+
"""
15+
Fast, model-agnostic token estimate for a ChatML message array.
16+
17+
Parameters
18+
----------
19+
serialized_messages : list[dict]
20+
Your [{role, content:[{type,text}]}] structure.
21+
avg_chars_per_token : int, optional
22+
How many characters you assume map to one token (default 4).
23+
24+
Returns
25+
-------
26+
int
27+
Estimated prompt token count.
28+
"""
29+
total_tokens = 0
30+
31+
for msg in serialized_messages:
32+
total_tokens += TOKENS_PER_MSG
33+
34+
# role string itself
35+
total_tokens += math.ceil(len(msg["role"]) / avg_chars_per_token)
36+
37+
if "name" in msg:
38+
total_tokens += TOKENS_PER_NAME
39+
40+
for part in msg["content"]:
41+
if part["type"] == "text":
42+
total_tokens += math.ceil(len(part["text"]) / avg_chars_per_token)
43+
else:
44+
# non-text parts: fall back to raw length heuristic
45+
total_tokens += math.ceil(len(str(part)) / avg_chars_per_token)
46+
47+
total_tokens += END_OF_REQ_TOKENS
48+
return max(total_tokens, 0)

0 commit comments

Comments
 (0)