Skip to content

Commit 6db54ff

Browse files
authored
test: cherrypick vllm e2e fixes into release 0.3.1 (#1664)
1 parent 45e727f commit 6db54ff

File tree

2 files changed

+90
-8
lines changed

2 files changed

+90
-8
lines changed

tests/serve/conftest.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,74 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
16+
import logging
17+
import os
18+
19+
import pytest
20+
21+
# List of models used in the serve tests
22+
SERVE_TEST_MODELS = [
23+
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
24+
"llava-hf/llava-1.5-7b-hf",
25+
]
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
@pytest.fixture(scope="session")
31+
def predownload_models():
32+
# Check for HF_TOKEN in environment
33+
hf_token = os.environ.get("HF_TOKEN")
34+
if hf_token:
35+
logger.info("HF_TOKEN found in environment")
36+
else:
37+
logger.warning(
38+
"HF_TOKEN not found in environment. "
39+
"Some models may fail to download or you may encounter rate limits. "
40+
"Get a token from https://huggingface.co/settings/tokens"
41+
)
42+
43+
try:
44+
from huggingface_hub import snapshot_download
45+
46+
for model_id in SERVE_TEST_MODELS:
47+
logger.info(f"Pre-downloading model: {model_id}")
48+
49+
try:
50+
# Download the full model snapshot (includes all files)
51+
# HuggingFace will handle caching automatically
52+
snapshot_download(
53+
repo_id=model_id,
54+
token=hf_token,
55+
)
56+
logger.info(f"Successfully pre-downloaded: {model_id}")
57+
58+
except Exception as e:
59+
logger.error(f"Failed to pre-download {model_id}: {e}")
60+
# Don't fail the fixture - let individual tests handle missing models
61+
62+
except ImportError:
63+
logger.warning(
64+
"huggingface_hub not installed. "
65+
"Models will be downloaded during test execution."
66+
)
67+
68+
yield
69+
70+
71+
# Automatically use the predownload fixture for all serve tests
72+
def pytest_collection_modifyitems(config, items):
73+
for item in items:
74+
# Skip items that don't have fixturenames (like MypyFileItem)
75+
if not hasattr(item, "fixturenames"):
76+
continue
77+
78+
# Only apply to tests in the serve directory
79+
if "serve" in str(item.path):
80+
# Check if the test already uses the fixture
81+
if "predownload_models" not in item.fixturenames:
82+
# Don't add if test explicitly marks to skip model download
83+
if not item.get_closest_marker("skip_model_download"):
84+
item.fixturenames = list(item.fixturenames)
85+
item.fixturenames.append("predownload_models")

tests/serve/test_dynamo_serve.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,14 @@
6767
],
6868
"max_tokens": 150, # Reduced from 500
6969
"temperature": 0.1,
70-
"seed": 0,
70+
# "seed": 0,
7171
},
7272
payload_completions={
7373
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
7474
"prompt": text_prompt,
7575
"max_tokens": 150,
7676
"temperature": 0.1,
77-
"seed": 0,
77+
# "seed": 0,
7878
},
7979
repeat_count=10,
8080
expected_log=[],
@@ -159,7 +159,7 @@
159159
"multimodal_agg": (
160160
DeploymentGraph(
161161
module="graphs.agg:Frontend",
162-
config="configs/agg.yaml",
162+
config="configs/agg-llava.yaml",
163163
directory="/workspace/examples/multimodal",
164164
endpoints=["v1/chat/completions"],
165165
response_handlers=[
@@ -257,12 +257,22 @@ def __init__(self, graph: DeploymentGraph, request, port=8000, timeout=900):
257257
if graph.config:
258258
command.extend(["-f", os.path.join(graph.directory, graph.config)])
259259

260-
command.extend(["--Frontend.port", str(port)])
261-
262-
health_check_urls = [(f"http://localhost:{port}/v1/models", self._check_model)]
263-
260+
# Handle multimodal deployments differently
264261
if "multimodal" in graph.directory:
262+
# Set DYNAMO_PORT environment variable for multimodal
263+
env = os.environ.copy()
264+
env["DYNAMO_PORT"] = str(port)
265265
health_check_urls = []
266+
# Don't add health check on port since multimodal uses DYNAMO_PORT
267+
health_check_ports = []
268+
else:
269+
# Regular LLM deployments
270+
command.extend(["--Frontend.port", str(port)])
271+
health_check_urls = [
272+
(f"http://localhost:{port}/v1/models", self._check_model)
273+
]
274+
health_check_ports = [port]
275+
env = None
266276

267277
self.port = port
268278

@@ -271,11 +281,12 @@ def __init__(self, graph: DeploymentGraph, request, port=8000, timeout=900):
271281
timeout=timeout,
272282
display_output=True,
273283
working_dir=graph.directory,
274-
health_check_ports=[port],
284+
health_check_ports=health_check_ports,
275285
health_check_urls=health_check_urls,
276286
delayed_start=graph.delayed_start,
277287
stragglers=["http"],
278288
log_dir=request.node.name,
289+
env=env, # Pass the environment variables
279290
)
280291

281292
def _check_model(self, response):

0 commit comments

Comments
 (0)