diff --git a/codex-rs/codex-api/src/sse/chat.rs b/codex-rs/codex-api/src/sse/chat.rs index 5e48c57bd..21adfa571 100644 --- a/codex-rs/codex-api/src/sse/chat.rs +++ b/codex-rs/codex-api/src/sse/chat.rs @@ -10,6 +10,7 @@ use eventsource_stream::Eventsource; use futures::Stream; use futures::StreamExt; use std::collections::HashMap; +use std::collections::HashSet; use std::time::Duration; use tokio::sync::mpsc; use tokio::time::Instant; @@ -41,12 +42,17 @@ pub async fn process_chat_sse( #[derive(Default, Debug)] struct ToolCallState { + id: Option, name: Option, arguments: String, } - let mut tool_calls: HashMap = HashMap::new(); - let mut tool_call_order: Vec = Vec::new(); + let mut tool_calls: HashMap = HashMap::new(); + let mut tool_call_order: Vec = Vec::new(); + let mut tool_call_order_seen: HashSet = HashSet::new(); + let mut tool_call_index_by_id: HashMap = HashMap::new(); + let mut next_tool_call_index = 0usize; + let mut last_tool_call_index: Option = None; let mut assistant_item: Option = None; let mut reasoning_item: Option = None; let mut completed_sent = false; @@ -149,15 +155,40 @@ pub async fn process_chat_sse( if let Some(tool_call_values) = delta.get("tool_calls").and_then(|c| c.as_array()) { for tool_call in tool_call_values { - let id = tool_call - .get("id") - .and_then(|i| i.as_str()) - .map(str::to_string) - .unwrap_or_else(|| format!("tool-call-{}", tool_call_order.len())); + let mut index = tool_call + .get("index") + .and_then(serde_json::Value::as_u64) + .map(|i| i as usize); - let call_state = tool_calls.entry(id.clone()).or_default(); - if !tool_call_order.contains(&id) { - tool_call_order.push(id.clone()); + let mut call_id_for_lookup = None; + if let Some(call_id) = tool_call.get("id").and_then(|i| i.as_str()) { + call_id_for_lookup = Some(call_id.to_string()); + if let Some(existing) = tool_call_index_by_id.get(call_id) { + index = Some(*existing); + } + } + + if index.is_none() && call_id_for_lookup.is_none() { + index = last_tool_call_index; + } + + let index = index.unwrap_or_else(|| { + while tool_calls.contains_key(&next_tool_call_index) { + next_tool_call_index += 1; + } + let idx = next_tool_call_index; + next_tool_call_index += 1; + idx + }); + + let call_state = tool_calls.entry(index).or_default(); + if tool_call_order_seen.insert(index) { + tool_call_order.push(index); + } + + if let Some(id) = tool_call.get("id").and_then(|i| i.as_str()) { + call_state.id.get_or_insert_with(|| id.to_string()); + tool_call_index_by_id.entry(id.to_string()).or_insert(index); } if let Some(func) = tool_call.get("function") { @@ -171,6 +202,8 @@ pub async fn process_chat_sse( call_state.arguments.push_str(arguments); } } + + last_tool_call_index = Some(index); } } } @@ -224,13 +257,25 @@ pub async fn process_chat_sse( .await; } - for call_id in tool_call_order.drain(..) { - let state = tool_calls.remove(&call_id).unwrap_or_default(); + for index in tool_call_order.drain(..) { + let Some(state) = tool_calls.remove(&index) else { + continue; + }; + tool_call_order_seen.remove(&index); + let ToolCallState { + id, + name, + arguments, + } = state; + let Some(name) = name else { + debug!("Skipping tool call at index {index} because name is missing"); + continue; + }; let item = ResponseItem::FunctionCall { id: None, - name: state.name.unwrap_or_default(), - arguments: state.arguments, - call_id: call_id.clone(), + name, + arguments, + call_id: id.unwrap_or_else(|| format!("tool-call-{index}")), }; let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; } @@ -335,6 +380,59 @@ mod tests { out } + #[tokio::test] + async fn concatenates_tool_call_arguments_across_deltas() { + let delta_name = json!({ + "choices": [{ + "delta": { + "tool_calls": [{ + "id": "call_a", + "index": 0, + "function": { "name": "do_a" } + }] + } + }] + }); + + let delta_args_1 = json!({ + "choices": [{ + "delta": { + "tool_calls": [{ + "index": 0, + "function": { "arguments": "{ \"foo\":" } + }] + } + }] + }); + + let delta_args_2 = json!({ + "choices": [{ + "delta": { + "tool_calls": [{ + "index": 0, + "function": { "arguments": "1}" } + }] + } + }] + }); + + let finish = json!({ + "choices": [{ + "finish_reason": "tool_calls" + }] + }); + + let body = build_body(&[delta_name, delta_args_1, delta_args_2, finish]); + let events = collect_events(&body).await; + assert_matches!( + &events[..], + [ + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. }), + ResponseEvent::Completed { .. } + ] if call_id == "call_a" && name == "do_a" && arguments == "{ \"foo\":1}" + ); + } + #[tokio::test] async fn emits_multiple_tool_calls() { let delta_a = json!({ @@ -367,50 +465,74 @@ mod tests { let body = build_body(&[delta_a, delta_b, finish]); let events = collect_events(&body).await; - assert_eq!(events.len(), 3); - assert_matches!( - &events[0], - ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. }) - if call_id == "call_a" && name == "do_a" && arguments == "{\"foo\":1}" + &events[..], + [ + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id: call_a, name: name_a, arguments: args_a, .. }), + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id: call_b, name: name_b, arguments: args_b, .. }), + ResponseEvent::Completed { .. } + ] if call_a == "call_a" && name_a == "do_a" && args_a == "{\"foo\":1}" && call_b == "call_b" && name_b == "do_b" && args_b == "{\"bar\":2}" ); - assert_matches!( - &events[1], - ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id, name, arguments, .. }) - if call_id == "call_b" && name == "do_b" && arguments == "{\"bar\":2}" - ); - assert_matches!(events[2], ResponseEvent::Completed { .. }); } #[tokio::test] - async fn concatenates_tool_call_arguments_across_deltas() { - let delta_name = json!({ + async fn emits_tool_calls_for_multiple_choices() { + let payload = json!({ + "choices": [ + { + "delta": { + "tool_calls": [{ + "id": "call_a", + "index": 0, + "function": { "name": "do_a", "arguments": "{}" } + }] + }, + "finish_reason": "tool_calls" + }, + { + "delta": { + "tool_calls": [{ + "id": "call_b", + "index": 0, + "function": { "name": "do_b", "arguments": "{}" } + }] + }, + "finish_reason": "tool_calls" + } + ] + }); + + let body = build_body(&[payload]); + let events = collect_events(&body).await; + assert_matches!( + &events[..], + [ + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id: call_a, name: name_a, arguments: args_a, .. }), + ResponseEvent::OutputItemDone(ResponseItem::FunctionCall { call_id: call_b, name: name_b, arguments: args_b, .. }), + ResponseEvent::Completed { .. } + ] if call_a == "call_a" && name_a == "do_a" && args_a == "{}" && call_b == "call_b" && name_b == "do_b" && args_b == "{}" + ); + } + + #[tokio::test] + async fn merges_tool_calls_by_index_when_id_missing_on_subsequent_deltas() { + let delta_with_id = json!({ "choices": [{ "delta": { "tool_calls": [{ + "index": 0, "id": "call_a", - "function": { "name": "do_a" } + "function": { "name": "do_a", "arguments": "{ \"foo\":" } }] } }] }); - let delta_args_1 = json!({ + let delta_without_id = json!({ "choices": [{ "delta": { "tool_calls": [{ - "id": "call_a", - "function": { "arguments": "{ \"foo\":" } - }] - } - }] - }); - - let delta_args_2 = json!({ - "choices": [{ - "delta": { - "tool_calls": [{ - "id": "call_a", + "index": 0, "function": { "arguments": "1}" } }] } @@ -423,7 +545,7 @@ mod tests { }] }); - let body = build_body(&[delta_name, delta_args_1, delta_args_2, finish]); + let body = build_body(&[delta_with_id, delta_without_id, finish]); let events = collect_events(&body).await; assert_matches!( &events[..],