|
22 | 22 | from dataclasses import dataclass |
23 | 23 | from typing import Dict, List, Optional |
24 | 24 |
|
25 | | -import httpx |
| 25 | +import aiohttp |
26 | 26 | import requests |
27 | 27 | from kubernetes import client, config, watch |
28 | 28 |
|
@@ -308,22 +308,29 @@ def get_endpoint_info(self) -> List[EndpointInfo]: |
308 | 308 | model_info=self._get_model_info(model), |
309 | 309 | ) |
310 | 310 | endpoint_infos.append(endpoint_info) |
| 311 | + return endpoint_infos |
| 312 | + |
| 313 | + async def initialize_client_sessions(self) -> None: |
| 314 | + """ |
| 315 | + Initialize aiohttp ClientSession objects for prefill and decode endpoints. |
| 316 | + This must be called from an async context during app startup. |
| 317 | + """ |
311 | 318 | if ( |
312 | 319 | self.prefill_model_labels is not None |
313 | 320 | and self.decode_model_labels is not None |
314 | 321 | ): |
| 322 | + endpoint_infos = self.get_endpoint_info() |
315 | 323 | for endpoint_info in endpoint_infos: |
316 | 324 | if endpoint_info.model_label in self.prefill_model_labels: |
317 | | - self.app.state.prefill_client = httpx.AsyncClient( |
| 325 | + self.app.state.prefill_client = aiohttp.ClientSession( |
318 | 326 | base_url=endpoint_info.url, |
319 | | - timeout=None, |
| 327 | + timeout=aiohttp.ClientTimeout(total=None), |
320 | 328 | ) |
321 | 329 | elif endpoint_info.model_label in self.decode_model_labels: |
322 | | - self.app.state.decode_client = httpx.AsyncClient( |
| 330 | + self.app.state.decode_client = aiohttp.ClientSession( |
323 | 331 | base_url=endpoint_info.url, |
324 | | - timeout=None, |
| 332 | + timeout=aiohttp.ClientTimeout(total=None), |
325 | 333 | ) |
326 | | - return endpoint_infos |
327 | 334 |
|
328 | 335 |
|
329 | 336 | class K8sPodIPServiceDiscovery(ServiceDiscovery): |
@@ -629,20 +636,7 @@ def _add_engine( |
629 | 636 | namespace=self.namespace, |
630 | 637 | model_info=model_info, |
631 | 638 | ) |
632 | | - if ( |
633 | | - self.prefill_model_labels is not None |
634 | | - and self.decode_model_labels is not None |
635 | | - ): |
636 | | - if model_label in self.prefill_model_labels: |
637 | | - self.app.state.prefill_client = httpx.AsyncClient( |
638 | | - base_url=f"http://{engine_ip}:{self.port}", |
639 | | - timeout=None, |
640 | | - ) |
641 | | - elif model_label in self.decode_model_labels: |
642 | | - self.app.state.decode_client = httpx.AsyncClient( |
643 | | - base_url=f"http://{engine_ip}:{self.port}", |
644 | | - timeout=None, |
645 | | - ) |
| 639 | + |
646 | 640 | # Store model information in the endpoint info |
647 | 641 | self.available_engines[engine_name].model_info = model_info |
648 | 642 |
|
@@ -720,6 +714,28 @@ def close(self): |
720 | 714 | self.k8s_watcher.stop() |
721 | 715 | self.watcher_thread.join() |
722 | 716 |
|
| 717 | + async def initialize_client_sessions(self) -> None: |
| 718 | + """ |
| 719 | + Initialize aiohttp ClientSession objects for prefill and decode endpoints. |
| 720 | + This must be called from an async context during app startup. |
| 721 | + """ |
| 722 | + if ( |
| 723 | + self.prefill_model_labels is not None |
| 724 | + and self.decode_model_labels is not None |
| 725 | + ): |
| 726 | + endpoint_infos = self.get_endpoint_info() |
| 727 | + for endpoint_info in endpoint_infos: |
| 728 | + if endpoint_info.model_label in self.prefill_model_labels: |
| 729 | + self.app.state.prefill_client = aiohttp.ClientSession( |
| 730 | + base_url=endpoint_info.url, |
| 731 | + timeout=aiohttp.ClientTimeout(total=None), |
| 732 | + ) |
| 733 | + elif endpoint_info.model_label in self.decode_model_labels: |
| 734 | + self.app.state.decode_client = aiohttp.ClientSession( |
| 735 | + base_url=endpoint_info.url, |
| 736 | + timeout=aiohttp.ClientTimeout(total=None), |
| 737 | + ) |
| 738 | + |
723 | 739 |
|
724 | 740 | class K8sServiceNameServiceDiscovery(ServiceDiscovery): |
725 | 741 | def __init__( |
@@ -1024,20 +1040,7 @@ def _add_engine(self, engine_name: str, model_names: List[str], model_label: str |
1024 | 1040 | namespace=self.namespace, |
1025 | 1041 | model_info=model_info, |
1026 | 1042 | ) |
1027 | | - if ( |
1028 | | - self.prefill_model_labels is not None |
1029 | | - and self.decode_model_labels is not None |
1030 | | - ): |
1031 | | - if model_label in self.prefill_model_labels: |
1032 | | - self.app.state.prefill_client = httpx.AsyncClient( |
1033 | | - base_url=f"http://{engine_name}:{self.port}", |
1034 | | - timeout=None, |
1035 | | - ) |
1036 | | - elif model_label in self.decode_model_labels: |
1037 | | - self.app.state.decode_client = httpx.AsyncClient( |
1038 | | - base_url=f"http://{engine_name}:{self.port}", |
1039 | | - timeout=None, |
1040 | | - ) |
| 1043 | + |
1041 | 1044 | # Store model information in the endpoint info |
1042 | 1045 | self.available_engines[engine_name].model_info = model_info |
1043 | 1046 |
|
@@ -1114,6 +1117,28 @@ def close(self): |
1114 | 1117 | self.k8s_watcher.stop() |
1115 | 1118 | self.watcher_thread.join() |
1116 | 1119 |
|
| 1120 | + async def initialize_client_sessions(self) -> None: |
| 1121 | + """ |
| 1122 | + Initialize aiohttp ClientSession objects for prefill and decode endpoints. |
| 1123 | + This must be called from an async context during app startup. |
| 1124 | + """ |
| 1125 | + if ( |
| 1126 | + self.prefill_model_labels is not None |
| 1127 | + and self.decode_model_labels is not None |
| 1128 | + ): |
| 1129 | + endpoint_infos = self.get_endpoint_info() |
| 1130 | + for endpoint_info in endpoint_infos: |
| 1131 | + if endpoint_info.model_label in self.prefill_model_labels: |
| 1132 | + self.app.state.prefill_client = aiohttp.ClientSession( |
| 1133 | + base_url=endpoint_info.url, |
| 1134 | + timeout=aiohttp.ClientTimeout(total=None), |
| 1135 | + ) |
| 1136 | + elif endpoint_info.model_label in self.decode_model_labels: |
| 1137 | + self.app.state.decode_client = aiohttp.ClientSession( |
| 1138 | + base_url=endpoint_info.url, |
| 1139 | + timeout=aiohttp.ClientTimeout(total=None), |
| 1140 | + ) |
| 1141 | + |
1117 | 1142 |
|
1118 | 1143 | def _create_service_discovery( |
1119 | 1144 | service_discovery_type: ServiceDiscoveryType, *args, **kwargs |
|
0 commit comments