@@ -48,10 +48,6 @@ class AnthropicMessage:
4848 content : List [AnthropicMessageContent ]
4949
5050
51- class ModelFamily (Enum ):
52- Anthropic = "Anthropic"
53-
54-
5551class InferenceClient (ABC ):
5652 @abstractmethod
5753 async def connect_and_listen (
@@ -67,10 +63,10 @@ async def get_generation(
6763
6864
6965class BedrockInferenceClient (InferenceClient ):
70- def __init__ (self , model : str ):
66+ def __init__ (self , model : str , region_name = "us-east-1" ):
7167 self .model = model
68+ self .region_name = region_name
7269 self .session = aioboto3 .Session ()
73- self .model_family = ModelFamily .Anthropic
7470 self .anthropic_version = "bedrock-2023-05-31"
7571
7672 async def get_generation (
@@ -99,7 +95,7 @@ async def get_generation(
9995
10096 async with self .session .client (
10197 service_name = "bedrock-runtime" ,
102- region_name = "us-east-1" ,
98+ region_name = self . region_name ,
10399 ) as client :
104100 response = await client .invoke_model (
105101 body = json .dumps (body .__dict__ ),
@@ -179,16 +175,22 @@ def __init__(
179175 async def get_generation (
180176 self , messages : List [ChatMessage ], max_tokens = 4096 , top_p = 0.5 , temperature = 0.5
181177 ):
178+ system = None
179+ if messages [0 ]["role" ] == "system" :
180+ system = {
181+ "system" : [
182+ {"type" : "text" , "text" : messages [0 ]["content" ]}
183+ | (
184+ {"cache_control" : {"type" : "ephemeral" }}
185+ if messages [0 ].get ("cache_marker" )
186+ else {}
187+ )
188+ ]
189+ }
190+ messages = messages [1 :]
191+
182192 request = {
183193 "model" : self .model ,
184- "system" : [
185- {"type" : "text" , "text" : messages [0 ]["content" ]}
186- | (
187- {"cache_control" : {"type" : "ephemeral" }}
188- if messages [0 ].get ("cache_marker" )
189- else {}
190- )
191- ],
192194 "top_p" : top_p ,
193195 "temperature" : temperature ,
194196 "max_tokens" : max_tokens ,
@@ -199,10 +201,14 @@ async def get_generation(
199201 if msg .get ("cache_marker" )
200202 else {}
201203 )
202- for msg in messages [ 1 :]
204+ for msg in messages
203205 ],
204206 "stream" : False ,
205207 }
208+
209+ if system :
210+ request .update (system )
211+
206212 async with httpx .AsyncClient () as client :
207213 response = await client .post (
208214 self .api_url ,
@@ -223,16 +229,23 @@ async def get_generation(
223229 async def connect_and_listen (
224230 self , messages : List [ChatMessage ], max_tokens = 4096 , top_p = 0.5 , temperature = 0.5
225231 ):
232+
233+ system = None
234+ if messages [0 ]["role" ] == "system" :
235+ system = {
236+ "system" : [
237+ {"type" : "text" , "text" : messages [0 ]["content" ]}
238+ | (
239+ {"cache_control" : {"type" : "ephemeral" }}
240+ if messages [0 ].get ("cache_marker" )
241+ else {}
242+ )
243+ ]
244+ }
245+ messages = messages [1 :]
246+
226247 request = {
227248 "model" : self .model ,
228- "system" : [
229- {"type" : "text" , "text" : messages [0 ]["content" ]}
230- | (
231- {"cache_control" : {"type" : "ephemeral" }}
232- if messages [0 ].get ("cache_marker" )
233- else {}
234- )
235- ],
236249 "top_p" : top_p ,
237250 "temperature" : temperature ,
238251 "max_tokens" : max_tokens ,
@@ -248,11 +261,14 @@ async def connect_and_listen(
248261 )
249262 ],
250263 }
251- for msg in messages [ 1 :]
264+ for msg in messages
252265 ],
253266 "stream" : True ,
254267 }
255268
269+ if system :
270+ request .update (system )
271+
256272 async with httpx .AsyncClient () as client :
257273 async with client .stream (
258274 "POST" ,
@@ -272,9 +288,6 @@ async def connect_and_listen(
272288 if response .status_code != 200 :
273289 message_logs = [{"role" : msg ["role" ]} for msg in messages [1 :]]
274290
275- print (line )
276- pprint (message_logs )
277-
278291 if line .startswith ("data: " ):
279292 data = json .loads (line [6 :])
280293 chunk_type = data ["type" ]
0 commit comments