Improve token usage estimate for images (#12419)
Fixes #11845. Adjust context/token estimation for inline image `data:*;base64,...` URLs so we do not count the raw base64 payload as model-visible text. What changed: - keep the existing JSON-length estimator as the baseline - detect only inline base64 `data:` image URLs in message and function-call output content items - subtract only the base64 payload bytes (preserving data URL prefix + JSON overhead) - add a fixed per-image estimate of 340 bytes (~85 tokens at the repo’s 4-bytes/token heuristic) This avoids large overestimates from MCP image tool outputs while leaving normal image URLs (`https://`, `file://`, non-base64 `data:` URLs) unchanged. Tests: - message image data URL estimate regression - function-call output image data URL estimate regression - non-base64 image URLs unchanged - non-base64 `data:` URLs unchanged - `data:application/octet-stream;base64,...` adjusted - multiple inline images apply multiple fixed costs - text-only items unchanged
This commit is contained in:
parent
b17148f13a
commit
3586fcb802
2 changed files with 286 additions and 3 deletions
|
|
@ -419,6 +419,12 @@ fn estimate_item_token_count(item: &ResponseItem) -> i64 {
|
|||
approx_tokens_from_byte_count_i64(model_visible_bytes)
|
||||
}
|
||||
|
||||
/// Approximate model-visible byte cost for one image input.
|
||||
///
|
||||
/// The estimator later converts bytes to tokens using a 4-bytes/token heuristic,
|
||||
/// so 340 bytes is approximately 85 tokens.
|
||||
const IMAGE_BYTES_ESTIMATE: i64 = 340;
|
||||
|
||||
pub(crate) fn estimate_response_item_model_visible_bytes(item: &ResponseItem) -> i64 {
|
||||
match item {
|
||||
ResponseItem::GhostSnapshot { .. } => 0,
|
||||
|
|
@ -429,12 +435,97 @@ pub(crate) fn estimate_response_item_model_visible_bytes(item: &ResponseItem) ->
|
|||
| ResponseItem::Compaction {
|
||||
encrypted_content: content,
|
||||
} => i64::try_from(estimate_reasoning_length(content.len())).unwrap_or(i64::MAX),
|
||||
item => serde_json::to_string(item)
|
||||
.map(|serialized| i64::try_from(serialized.len()).unwrap_or(i64::MAX))
|
||||
.unwrap_or_default(),
|
||||
item => {
|
||||
let raw = serde_json::to_string(item)
|
||||
.map(|serialized| i64::try_from(serialized.len()).unwrap_or(i64::MAX))
|
||||
.unwrap_or_default();
|
||||
let (payload_bytes, image_count) = image_data_url_estimate_adjustment(item);
|
||||
if payload_bytes == 0 || image_count == 0 {
|
||||
raw
|
||||
} else {
|
||||
// Replace raw base64 payload bytes with a fixed per-image cost.
|
||||
// We intentionally preserve the data URL prefix and JSON wrapper
|
||||
// bytes already included in `raw`.
|
||||
raw.saturating_sub(payload_bytes)
|
||||
.saturating_add(image_count.saturating_mul(IMAGE_BYTES_ESTIMATE))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the base64 payload byte length for inline image data URLs that are
|
||||
/// eligible for token-estimation discounting.
|
||||
///
|
||||
/// We only discount payloads for `data:image/...;base64,...` URLs (case
|
||||
/// insensitive markers) and leave everything else at raw serialized size.
|
||||
fn base64_data_url_payload_len(url: &str) -> Option<usize> {
|
||||
if !url
|
||||
.get(.."data:".len())
|
||||
.is_some_and(|prefix| prefix.eq_ignore_ascii_case("data:"))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
let comma_index = url.find(',')?;
|
||||
let metadata = &url[..comma_index];
|
||||
let payload = &url[comma_index + 1..];
|
||||
// Parse the media type and parameters without decoding. This keeps the
|
||||
// estimator cheap while ensuring we only apply the fixed-cost image
|
||||
// heuristic to image-typed base64 data URLs.
|
||||
let metadata_without_scheme = &metadata["data:".len()..];
|
||||
let mut metadata_parts = metadata_without_scheme.split(';');
|
||||
let mime_type = metadata_parts.next().unwrap_or_default();
|
||||
let has_base64_marker = metadata_parts.any(|part| part.eq_ignore_ascii_case("base64"));
|
||||
if !mime_type
|
||||
.get(.."image/".len())
|
||||
.is_some_and(|prefix| prefix.eq_ignore_ascii_case("image/"))
|
||||
{
|
||||
return None;
|
||||
}
|
||||
if !has_base64_marker {
|
||||
return None;
|
||||
}
|
||||
Some(payload.len())
|
||||
}
|
||||
|
||||
/// Scans one response item for discount-eligible inline image data URLs and
|
||||
/// returns:
|
||||
/// - total base64 payload bytes to subtract from raw serialized size
|
||||
/// - count of qualifying images to replace with `IMAGE_BYTES_ESTIMATE`
|
||||
fn image_data_url_estimate_adjustment(item: &ResponseItem) -> (i64, i64) {
|
||||
let mut payload_bytes = 0i64;
|
||||
let mut image_count = 0i64;
|
||||
|
||||
let mut accumulate = |image_url: &str| {
|
||||
if let Some(payload_len) = base64_data_url_payload_len(image_url) {
|
||||
payload_bytes =
|
||||
payload_bytes.saturating_add(i64::try_from(payload_len).unwrap_or(i64::MAX));
|
||||
image_count = image_count.saturating_add(1);
|
||||
}
|
||||
};
|
||||
|
||||
match item {
|
||||
ResponseItem::Message { content, .. } => {
|
||||
for content_item in content {
|
||||
if let ContentItem::InputImage { image_url } = content_item {
|
||||
accumulate(image_url);
|
||||
}
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { output, .. } => {
|
||||
if let FunctionCallOutputBody::ContentItems(items) = &output.body {
|
||||
for content_item in items {
|
||||
if let FunctionCallOutputContentItem::InputImage { image_url } = content_item {
|
||||
accumulate(image_url);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
(payload_bytes, image_count)
|
||||
}
|
||||
|
||||
fn is_model_generated_item(item: &ResponseItem) -> bool {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } => role == "assistant",
|
||||
|
|
|
|||
|
|
@ -1240,3 +1240,195 @@ fn normalize_mixed_inserts_and_removals_panics_in_debug() {
|
|||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history(&default_input_modalities());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn image_data_url_payload_does_not_dominate_message_estimate() {
|
||||
let payload = "A".repeat(100_000);
|
||||
let image_url = format!("data:image/png;base64,{payload}");
|
||||
let image_item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
ContentItem::InputText {
|
||||
text: "Here is the screenshot".to_string(),
|
||||
},
|
||||
ContentItem::InputImage { image_url },
|
||||
],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
};
|
||||
let text_only_item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "Here is the screenshot".to_string(),
|
||||
}],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
};
|
||||
|
||||
let raw_len = serde_json::to_string(&image_item).unwrap().len() as i64;
|
||||
let estimated = estimate_response_item_model_visible_bytes(&image_item);
|
||||
let expected = raw_len - payload.len() as i64 + IMAGE_BYTES_ESTIMATE;
|
||||
let text_only_estimated = estimate_response_item_model_visible_bytes(&text_only_item);
|
||||
|
||||
assert_eq!(estimated, expected);
|
||||
assert!(estimated < raw_len);
|
||||
assert!(estimated > text_only_estimated);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn image_data_url_payload_does_not_dominate_function_call_output_estimate() {
|
||||
let payload = "B".repeat(50_000);
|
||||
let image_url = format!("data:image/png;base64,{payload}");
|
||||
let item = ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-abc".to_string(),
|
||||
output: FunctionCallOutputPayload::from_content_items(vec![
|
||||
FunctionCallOutputContentItem::InputText {
|
||||
text: "Screenshot captured".to_string(),
|
||||
},
|
||||
FunctionCallOutputContentItem::InputImage { image_url },
|
||||
]),
|
||||
};
|
||||
|
||||
let raw_len = serde_json::to_string(&item).unwrap().len() as i64;
|
||||
let estimated = estimate_response_item_model_visible_bytes(&item);
|
||||
let expected = raw_len - payload.len() as i64 + IMAGE_BYTES_ESTIMATE;
|
||||
|
||||
assert_eq!(estimated, expected);
|
||||
assert!(estimated < raw_len);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_base64_image_urls_are_unchanged() {
|
||||
let message_item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputImage {
|
||||
image_url: "https://example.com/foo.png".to_string(),
|
||||
}],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
};
|
||||
let function_output_item = ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-1".to_string(),
|
||||
output: FunctionCallOutputPayload::from_content_items(vec![
|
||||
FunctionCallOutputContentItem::InputImage {
|
||||
image_url: "file:///tmp/foo.png".to_string(),
|
||||
},
|
||||
]),
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
estimate_response_item_model_visible_bytes(&message_item),
|
||||
serde_json::to_string(&message_item).unwrap().len() as i64
|
||||
);
|
||||
assert_eq!(
|
||||
estimate_response_item_model_visible_bytes(&function_output_item),
|
||||
serde_json::to_string(&function_output_item).unwrap().len() as i64
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn data_url_without_base64_marker_is_unchanged() {
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputImage {
|
||||
image_url: "data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg'/>".to_string(),
|
||||
}],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
estimate_response_item_model_visible_bytes(&item),
|
||||
serde_json::to_string(&item).unwrap().len() as i64
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_image_base64_data_url_is_unchanged() {
|
||||
let payload = "C".repeat(4_096);
|
||||
let image_url = format!("data:application/octet-stream;base64,{payload}");
|
||||
let item = ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-octet".to_string(),
|
||||
output: FunctionCallOutputPayload::from_content_items(vec![
|
||||
FunctionCallOutputContentItem::InputImage { image_url },
|
||||
]),
|
||||
};
|
||||
|
||||
let raw_len = serde_json::to_string(&item).unwrap().len() as i64;
|
||||
let estimated = estimate_response_item_model_visible_bytes(&item);
|
||||
|
||||
assert_eq!(estimated, raw_len);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mixed_case_data_url_markers_are_adjusted() {
|
||||
let payload = "F".repeat(1_024);
|
||||
let image_url = format!("DATA:image/png;BASE64,{payload}");
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputImage { image_url }],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
};
|
||||
|
||||
let raw_len = serde_json::to_string(&item).unwrap().len() as i64;
|
||||
let estimated = estimate_response_item_model_visible_bytes(&item);
|
||||
let expected = raw_len - payload.len() as i64 + IMAGE_BYTES_ESTIMATE;
|
||||
|
||||
assert_eq!(estimated, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_inline_images_apply_multiple_fixed_costs() {
|
||||
let payload_one = "D".repeat(100);
|
||||
let payload_two = "E".repeat(200);
|
||||
let image_url_one = format!("data:image/png;base64,{payload_one}");
|
||||
let image_url_two = format!("data:image/jpeg;base64,{payload_two}");
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
ContentItem::InputText {
|
||||
text: "images".to_string(),
|
||||
},
|
||||
ContentItem::InputImage {
|
||||
image_url: image_url_one,
|
||||
},
|
||||
ContentItem::InputImage {
|
||||
image_url: image_url_two,
|
||||
},
|
||||
],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
};
|
||||
|
||||
let raw_len = serde_json::to_string(&item).unwrap().len() as i64;
|
||||
let payload_sum = (payload_one.len() + payload_two.len()) as i64;
|
||||
let estimated = estimate_response_item_model_visible_bytes(&item);
|
||||
let expected = raw_len - payload_sum + (2 * IMAGE_BYTES_ESTIMATE);
|
||||
|
||||
assert_eq!(estimated, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn text_only_items_unchanged() {
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "Hello world, this is a response.".to_string(),
|
||||
}],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
};
|
||||
|
||||
let estimated = estimate_response_item_model_visible_bytes(&item);
|
||||
let raw_len = serde_json::to_string(&item).unwrap().len() as i64;
|
||||
|
||||
assert_eq!(estimated, raw_len);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue