Skip to content

Commit efde2a7

Browse files
committed
Change system messages to optional
1 parent 82f5568 commit efde2a7

File tree

1 file changed

+41
-28
lines changed

1 file changed

+41
-28
lines changed

asimov/services/inference_clients.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,6 @@ class AnthropicMessage:
4848
content: List[AnthropicMessageContent]
4949

5050

51-
class ModelFamily(Enum):
52-
Anthropic = "Anthropic"
53-
54-
5551
class InferenceClient(ABC):
5652
@abstractmethod
5753
async def connect_and_listen(
@@ -67,10 +63,10 @@ async def get_generation(
6763

6864

6965
class BedrockInferenceClient(InferenceClient):
70-
def __init__(self, model: str):
66+
def __init__(self, model: str, region_name="us-east-1"):
7167
self.model = model
68+
self.region_name = region_name
7269
self.session = aioboto3.Session()
73-
self.model_family = ModelFamily.Anthropic
7470
self.anthropic_version = "bedrock-2023-05-31"
7571

7672
async def get_generation(
@@ -99,7 +95,7 @@ async def get_generation(
9995

10096
async with self.session.client(
10197
service_name="bedrock-runtime",
102-
region_name="us-east-1",
98+
region_name=self.region_name,
10399
) as client:
104100
response = await client.invoke_model(
105101
body=json.dumps(body.__dict__),
@@ -179,16 +175,22 @@ def __init__(
179175
async def get_generation(
180176
self, messages: List[ChatMessage], max_tokens=4096, top_p=0.5, temperature=0.5
181177
):
178+
system = None
179+
if messages[0]["role"] == "system":
180+
system = {
181+
"system": [
182+
{"type": "text", "text": messages[0]["content"]}
183+
| (
184+
{"cache_control": {"type": "ephemeral"}}
185+
if messages[0].get("cache_marker")
186+
else {}
187+
)
188+
]
189+
}
190+
messages = messages[1:]
191+
182192
request = {
183193
"model": self.model,
184-
"system": [
185-
{"type": "text", "text": messages[0]["content"]}
186-
| (
187-
{"cache_control": {"type": "ephemeral"}}
188-
if messages[0].get("cache_marker")
189-
else {}
190-
)
191-
],
192194
"top_p": top_p,
193195
"temperature": temperature,
194196
"max_tokens": max_tokens,
@@ -199,10 +201,14 @@ async def get_generation(
199201
if msg.get("cache_marker")
200202
else {}
201203
)
202-
for msg in messages[1:]
204+
for msg in messages
203205
],
204206
"stream": False,
205207
}
208+
209+
if system:
210+
request.update(system)
211+
206212
async with httpx.AsyncClient() as client:
207213
response = await client.post(
208214
self.api_url,
@@ -223,16 +229,23 @@ async def get_generation(
223229
async def connect_and_listen(
224230
self, messages: List[ChatMessage], max_tokens=4096, top_p=0.5, temperature=0.5
225231
):
232+
233+
system = None
234+
if messages[0]["role"] == "system":
235+
system = {
236+
"system": [
237+
{"type": "text", "text": messages[0]["content"]}
238+
| (
239+
{"cache_control": {"type": "ephemeral"}}
240+
if messages[0].get("cache_marker")
241+
else {}
242+
)
243+
]
244+
}
245+
messages = messages[1:]
246+
226247
request = {
227248
"model": self.model,
228-
"system": [
229-
{"type": "text", "text": messages[0]["content"]}
230-
| (
231-
{"cache_control": {"type": "ephemeral"}}
232-
if messages[0].get("cache_marker")
233-
else {}
234-
)
235-
],
236249
"top_p": top_p,
237250
"temperature": temperature,
238251
"max_tokens": max_tokens,
@@ -248,11 +261,14 @@ async def connect_and_listen(
248261
)
249262
],
250263
}
251-
for msg in messages[1:]
264+
for msg in messages
252265
],
253266
"stream": True,
254267
}
255268

269+
if system:
270+
request.update(system)
271+
256272
async with httpx.AsyncClient() as client:
257273
async with client.stream(
258274
"POST",
@@ -272,9 +288,6 @@ async def connect_and_listen(
272288
if response.status_code != 200:
273289
message_logs = [{"role": msg["role"]} for msg in messages[1:]]
274290

275-
print(line)
276-
pprint(message_logs)
277-
278291
if line.startswith("data: "):
279292
data = json.loads(line[6:])
280293
chunk_type = data["type"]

0 commit comments

Comments
 (0)