Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 95 additions & 51 deletions lib/llm/src/preprocessor/prompt/template/oai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ fn may_be_fix_tool_schema(tools: serde_json::Value) -> Option<Value> {
}

fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
// If messages[content] is provided as a list containing ONLY text parts,
// concatenate them into a string to match chat template expectations.
// Mixed content types are left for chat templates to handle.
// Flatten content arrays into strings with placeholders for multimodal content.
// This mimics vLLM's preprocessing so templates receive simple strings.
// - Text-only arrays: concatenate text parts with newlines
// - Multimodal arrays: interleave text with <image>, <video>, <audio> placeholders

let Some(arr) = messages.as_array() else {
return Value::from_serialize(&messages);
Expand All @@ -87,32 +88,61 @@ fn may_be_fix_msg_content(messages: serde_json::Value) -> Value {
.map(|msg| {
match msg.get("content") {
Some(serde_json::Value::Array(content_array)) => {
let is_text_only_array = !content_array.is_empty()
&& content_array.iter().all(|part| {
part.get("type")
.and_then(|type_field| type_field.as_str())
.map(|type_str| type_str == "text")
.unwrap_or(false)
});

if is_text_only_array {
let mut modified_msg = msg.clone();
if let Some(msg_object) = modified_msg.as_object_mut() {
let text_parts: Vec<&str> = content_array
.iter()
.filter_map(|part| part.get("text")?.as_str())
.collect();
let concatenated_text = text_parts.join("\n");

msg_object.insert(
"content".to_string(),
serde_json::Value::String(concatenated_text),
);
if content_array.is_empty() {
return msg.clone();
}

// Check if this is a text-only array
let is_text_only = content_array.iter().all(|part| {
part.get("type")
.and_then(|t| t.as_str())
.map(|t| t == "text")
.unwrap_or(false)
});

let mut modified_msg = msg.clone();
if let Some(msg_object) = modified_msg.as_object_mut() {
let mut content_string = String::new();

for (idx, part) in content_array.iter().enumerate() {
if idx > 0 {
// Use newline for text-only, space for multimodal
content_string.push(if is_text_only { '\n' } else { ' ' });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: if you have two text portions back to back - would you have a newline in between those?

}

let part_type = part.get("type").and_then(|t| t.as_str()).unwrap_or("");

match part_type {
"text" => {
if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
content_string.push_str(text);
}
}
"image_url" => {
content_string.push_str("<image>");
}
"video_url" => {
content_string.push_str("<video>");
}
"audio_url" => {
content_string.push_str("<audio>");
}
_ => {
// Unknown type - skip or add placeholder
tracing::warn!(
"Unknown content type in message: {}",
part_type
);
}
}
}
modified_msg // Concatenated string content
} else {
msg.clone() // Mixed content or non-text only

msg_object.insert(
"content".to_string(),
serde_json::Value::String(content_string),
);
}
modified_msg
}
_ => msg.clone(), // String content or missing content - return unchanged
}
Expand Down Expand Up @@ -321,8 +351,19 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
ctx
};

// Select template based on whether tools are present
let tmpl: minijinja::Template<'_, '_> = if has_tools {
self.env.get_template("tool_use")?
// For tools, try tool_use first (custom template), fall back to default
match self.env.get_template("tool_use") {
Ok(t) => {
tracing::debug!("Using 'tool_use' template");
t
}
Err(_) => {
tracing::debug!("'tool_use' template not found, using 'default'");
self.env.get_template("default")?
}
}
} else {
self.env.get_template("default")?
};
Expand Down Expand Up @@ -571,7 +612,7 @@ mod tests {
);
}

/// Tests that content arrays with mixed types (text + non-text) remain as arrays.
/// Tests that content arrays with mixed types (text + non-text) are flattened with placeholders.
#[test]
fn test_may_be_fix_msg_content_mixed_types() {
let json_str = r#"{
Expand All @@ -591,16 +632,14 @@ mod tests {
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

// Verify: Mixed content types are preserved as array for template handling
assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 3);
assert_eq!(content_array[0]["type"], "text");
assert_eq!(content_array[1]["type"], "image_url");
assert_eq!(content_array[2]["type"], "text");
// Verify: Mixed content types are flattened into a single string with placeholders
assert_eq!(
messages[0]["content"],
serde_json::Value::String("Check this image: <image> What do you see?".to_string())
);
}

/// Tests that content arrays containing only non-text types remain as arrays.
/// Tests that content arrays containing only non-text types are flattened with placeholders.
#[test]
fn test_may_be_fix_msg_content_non_text_only() {
let json_str = r#"{
Expand All @@ -619,12 +658,11 @@ mod tests {
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

// Verify: Non-text content arrays are preserved for template handling
assert!(messages[0]["content"].is_array());
let content_array = messages[0]["content"].as_array().unwrap();
assert_eq!(content_array.len(), 2);
assert_eq!(content_array[0]["type"], "image_url");
assert_eq!(content_array[1]["type"], "image_url");
// Verify: Non-text content arrays are flattened with placeholders
assert_eq!(
messages[0]["content"],
serde_json::Value::String("<image> <image>".to_string())
);
}

#[test]
Expand Down Expand Up @@ -692,7 +730,7 @@ NORMAL MODE
assert!(result2.unwrap().contains("NORMAL MODE"));
}

/// Tests mixed content type scenarios.
/// Tests mixed content type scenarios are flattened with appropriate placeholders.
#[test]
fn test_may_be_fix_msg_content_multiple_content_types() {
// Scenario 1: Multiple different content types (text + image + audio)
Expand All @@ -715,11 +753,15 @@ NORMAL MODE
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

// Mixed types should preserve array structure
assert!(messages[0]["content"].is_array());
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 5);
// Mixed types should be flattened with placeholders
assert_eq!(
messages[0]["content"],
serde_json::Value::String(
"Listen to this: <audio> And look at: <image> What do you think?".to_string()
)
);

// Scenario 2: Unknown/future content types mixed with text
// Scenario 2: Video content type mixed with text
let json_str = r#"{
"model": "gpt-4o",
"messages": [
Expand All @@ -737,9 +779,11 @@ NORMAL MODE
let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap();
let messages = serde_json::to_value(request.messages()).unwrap();

// Unknown types mixed with text should preserve array
assert!(messages[0]["content"].is_array());
assert_eq!(messages[0]["content"].as_array().unwrap().len(), 3);
// Video types mixed with text should be flattened with placeholders
assert_eq!(
messages[0]["content"],
serde_json::Value::String("Check this: <video> Interesting?".to_string())
);
}

#[test]
Expand Down
Loading