Translate websocket errors (#10937)
When getting errors over a websocket connection, translate the error into our regular API error format
This commit is contained in:
parent
cfce286459
commit
b2d3843109
2 changed files with 344 additions and 0 deletions
|
|
@ -13,7 +13,12 @@ use codex_client::TransportError;
|
|||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderName;
|
||||
use http::HeaderValue;
|
||||
use http::StatusCode;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::map::Map as JsonMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
|
@ -252,6 +257,83 @@ fn map_ws_error(err: WsError, url: &Url) -> ApiError {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WrappedWebsocketErrorEvent {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
#[serde(alias = "status_code")]
|
||||
status: Option<u16>,
|
||||
#[serde(default)]
|
||||
error: Option<Value>,
|
||||
#[serde(default)]
|
||||
headers: Option<JsonMap<String, Value>>,
|
||||
}
|
||||
|
||||
fn parse_wrapped_websocket_error_event(payload: &str) -> Option<WrappedWebsocketErrorEvent> {
|
||||
let event: WrappedWebsocketErrorEvent = serde_json::from_str(payload).ok()?;
|
||||
if event.kind != "error" {
|
||||
return None;
|
||||
}
|
||||
Some(event)
|
||||
}
|
||||
|
||||
fn map_wrapped_websocket_error_event(event: WrappedWebsocketErrorEvent) -> Option<ApiError> {
|
||||
let WrappedWebsocketErrorEvent {
|
||||
status,
|
||||
error,
|
||||
headers,
|
||||
..
|
||||
} = event;
|
||||
|
||||
let status = StatusCode::from_u16(status?).ok()?;
|
||||
if status.is_success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let body = error.map(|error| {
|
||||
serde_json::to_string_pretty(&serde_json::json!({
|
||||
"error": error
|
||||
}))
|
||||
.unwrap_or_else(|_| {
|
||||
serde_json::json!({
|
||||
"error": error
|
||||
})
|
||||
.to_string()
|
||||
})
|
||||
});
|
||||
|
||||
Some(ApiError::Transport(TransportError::Http {
|
||||
status,
|
||||
url: None,
|
||||
headers: headers.map(json_headers_to_http_headers),
|
||||
body,
|
||||
}))
|
||||
}
|
||||
|
||||
fn json_headers_to_http_headers(headers: JsonMap<String, Value>) -> HeaderMap {
|
||||
let mut mapped = HeaderMap::new();
|
||||
for (name, value) in headers {
|
||||
let Ok(header_name) = HeaderName::from_bytes(name.as_bytes()) else {
|
||||
continue;
|
||||
};
|
||||
let Some(header_value) = json_header_value(value) else {
|
||||
continue;
|
||||
};
|
||||
mapped.insert(header_name, header_value);
|
||||
}
|
||||
mapped
|
||||
}
|
||||
|
||||
fn json_header_value(value: Value) -> Option<HeaderValue> {
|
||||
let value = match value {
|
||||
Value::String(value) => value,
|
||||
Value::Number(value) => value.to_string(),
|
||||
Value::Bool(value) => value.to_string(),
|
||||
_ => return None,
|
||||
};
|
||||
HeaderValue::from_str(&value).ok()
|
||||
}
|
||||
|
||||
async fn run_websocket_response_stream(
|
||||
ws_stream: &mut WsStream,
|
||||
tx_event: mpsc::Sender<std::result::Result<ResponseEvent, ApiError>>,
|
||||
|
|
@ -306,6 +388,12 @@ async fn run_websocket_response_stream(
|
|||
match message {
|
||||
Message::Text(text) => {
|
||||
trace!("websocket event: {text}");
|
||||
if let Some(wrapped_error) = parse_wrapped_websocket_error_event(&text)
|
||||
&& let Some(error) = map_wrapped_websocket_error_event(wrapped_error)
|
||||
{
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
let event = match serde_json::from_str::<ResponsesStreamEvent>(&text) {
|
||||
Ok(event) => event,
|
||||
Err(err) => {
|
||||
|
|
@ -357,10 +445,124 @@ async fn run_websocket_response_stream(
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn websocket_config_enables_permessage_deflate() {
|
||||
let config = websocket_config();
|
||||
assert!(config.extensions.permessage_deflate.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_wrapped_websocket_error_event_maps_to_transport_http() {
|
||||
let payload = json!({
|
||||
"type": "error",
|
||||
"status": 429,
|
||||
"error": {
|
||||
"type": "usage_limit_reached",
|
||||
"message": "The usage limit has been reached",
|
||||
"plan_type": "pro",
|
||||
"resets_at": 1738888888
|
||||
},
|
||||
"headers": {
|
||||
"x-codex-primary-used-percent": "100.0",
|
||||
"x-codex-primary-window-minutes": 15
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let wrapped_error = parse_wrapped_websocket_error_event(&payload)
|
||||
.expect("expected websocket error payload to be parsed");
|
||||
let api_error = map_wrapped_websocket_error_event(wrapped_error)
|
||||
.expect("expected websocket error payload to map to ApiError");
|
||||
|
||||
let ApiError::Transport(TransportError::Http {
|
||||
status,
|
||||
headers,
|
||||
body,
|
||||
..
|
||||
}) = api_error
|
||||
else {
|
||||
panic!("expected ApiError::Transport(Http)");
|
||||
};
|
||||
|
||||
assert_eq!(status, StatusCode::TOO_MANY_REQUESTS);
|
||||
let headers = headers.expect("expected headers");
|
||||
assert_eq!(
|
||||
headers
|
||||
.get("x-codex-primary-used-percent")
|
||||
.and_then(|value| value.to_str().ok()),
|
||||
Some("100.0")
|
||||
);
|
||||
assert_eq!(
|
||||
headers
|
||||
.get("x-codex-primary-window-minutes")
|
||||
.and_then(|value| value.to_str().ok()),
|
||||
Some("15")
|
||||
);
|
||||
let body = body.expect("expected body");
|
||||
assert!(body.contains("usage_limit_reached"));
|
||||
assert!(body.contains("The usage limit has been reached"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_wrapped_websocket_error_event_ignores_non_error_payloads() {
|
||||
let payload = json!({
|
||||
"type": "response.created",
|
||||
"response": {
|
||||
"id": "resp-1"
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let wrapped_error = parse_wrapped_websocket_error_event(&payload);
|
||||
assert!(wrapped_error.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_wrapped_websocket_error_event_with_status_maps_invalid_request() {
|
||||
let payload = json!({
|
||||
"type": "error",
|
||||
"status": 400,
|
||||
"error": {
|
||||
"type": "invalid_request_error",
|
||||
"message": "Model does not support image inputs"
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let wrapped_error = parse_wrapped_websocket_error_event(&payload)
|
||||
.expect("expected websocket error payload to be parsed");
|
||||
let api_error = map_wrapped_websocket_error_event(wrapped_error)
|
||||
.expect("expected websocket error payload to map to ApiError");
|
||||
let ApiError::Transport(TransportError::Http { status, body, .. }) = api_error else {
|
||||
panic!("expected ApiError::Transport(Http)");
|
||||
};
|
||||
assert_eq!(status, StatusCode::BAD_REQUEST);
|
||||
let body = body.expect("expected body");
|
||||
assert!(body.contains("invalid_request_error"));
|
||||
assert!(body.contains("Model does not support image inputs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_wrapped_websocket_error_event_without_status_is_not_mapped() {
|
||||
let payload = json!({
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "usage_limit_reached",
|
||||
"message": "The usage limit has been reached"
|
||||
},
|
||||
"headers": {
|
||||
"x-codex-primary-used-percent": "100.0",
|
||||
"x-codex-primary-window-minutes": 15
|
||||
}
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let wrapped_error = parse_wrapped_websocket_error_event(&payload)
|
||||
.expect("expected websocket error payload to be parsed");
|
||||
let api_error = map_wrapped_websocket_error_event(wrapped_error);
|
||||
assert!(api_error.is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
142
codex-rs/core/tests/suite/client_websockets.rs
Normal file → Executable file
142
codex-rs/core/tests/suite/client_websockets.rs
Normal file → Executable file
|
|
@ -12,6 +12,8 @@ use codex_core::WireApi;
|
|||
use codex_core::X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::models_manager::manager::ModelsManager;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_core::protocol::SessionSource;
|
||||
use codex_otel::OtelManager;
|
||||
use codex_otel::TelemetryAuthMode;
|
||||
|
|
@ -22,6 +24,7 @@ use codex_protocol::account::PlanType;
|
|||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::openai_models::ModelInfo;
|
||||
use codex_protocol::openai_models::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::responses::WebSocketConnectionConfig;
|
||||
use core_test_support::responses::WebSocketTestServer;
|
||||
|
|
@ -30,6 +33,8 @@ use core_test_support::responses::ev_response_created;
|
|||
use core_test_support::responses::start_websocket_server;
|
||||
use core_test_support::responses::start_websocket_server_with_headers;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use core_test_support::wait_for_event;
|
||||
use futures::FutureExt;
|
||||
use futures::StreamExt;
|
||||
use opentelemetry_sdk::metrics::InMemoryMetricExporter;
|
||||
|
|
@ -393,6 +398,143 @@ async fn responses_websocket_emits_rate_limit_events() {
|
|||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_usage_limit_error_emits_rate_limit_event() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let usage_limit_error = json!({
|
||||
"type": "error",
|
||||
"status": 429,
|
||||
"error": {
|
||||
"type": "usage_limit_reached",
|
||||
"message": "The usage limit has been reached",
|
||||
"plan_type": "pro",
|
||||
"resets_at": 1704067242,
|
||||
"resets_in_seconds": 1234
|
||||
},
|
||||
"headers": {
|
||||
"x-codex-primary-used-percent": "100.0",
|
||||
"x-codex-secondary-used-percent": "87.5",
|
||||
"x-codex-primary-over-secondary-limit-percent": "95.0",
|
||||
"x-codex-primary-window-minutes": "15",
|
||||
"x-codex-secondary-window-minutes": "60"
|
||||
}
|
||||
});
|
||||
|
||||
let server = start_websocket_server(vec![vec![vec![usage_limit_error]]]).await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model_provider.request_max_retries = Some(0);
|
||||
config.model_provider.stream_max_retries = Some(0);
|
||||
});
|
||||
let test = builder
|
||||
.build_with_websocket_server(&server)
|
||||
.await
|
||||
.expect("build websocket codex");
|
||||
|
||||
let submission_id = test
|
||||
.codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "hello".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await
|
||||
.expect("submission should succeed while emitting usage limit error events");
|
||||
|
||||
let token_event =
|
||||
wait_for_event(&test.codex, |msg| matches!(msg, EventMsg::TokenCount(_))).await;
|
||||
let EventMsg::TokenCount(event) = token_event else {
|
||||
unreachable!();
|
||||
};
|
||||
|
||||
let event_json = serde_json::to_value(&event).expect("serialize token count event");
|
||||
pretty_assertions::assert_eq!(
|
||||
event_json,
|
||||
json!({
|
||||
"info": null,
|
||||
"rate_limits": {
|
||||
"primary": {
|
||||
"used_percent": 100.0,
|
||||
"window_minutes": 15,
|
||||
"resets_at": null
|
||||
},
|
||||
"secondary": {
|
||||
"used_percent": 87.5,
|
||||
"window_minutes": 60,
|
||||
"resets_at": null
|
||||
},
|
||||
"credits": null,
|
||||
"plan_type": null
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
let error_event = wait_for_event(&test.codex, |msg| matches!(msg, EventMsg::Error(_))).await;
|
||||
let EventMsg::Error(error_event) = error_event else {
|
||||
unreachable!();
|
||||
};
|
||||
assert!(
|
||||
error_event.message.to_lowercase().contains("usage limit"),
|
||||
"unexpected error message for submission {submission_id}: {}",
|
||||
error_event.message
|
||||
);
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_invalid_request_error_with_status_is_forwarded() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let invalid_request_error = json!({
|
||||
"type": "error",
|
||||
"status": 400,
|
||||
"error": {
|
||||
"type": "invalid_request_error",
|
||||
"message": "Model 'castor-raikou-0205-ev3' does not support image inputs."
|
||||
}
|
||||
});
|
||||
|
||||
let server = start_websocket_server(vec![vec![vec![invalid_request_error]]]).await;
|
||||
let mut builder = test_codex().with_config(|config| {
|
||||
config.model_provider.request_max_retries = Some(0);
|
||||
config.model_provider.stream_max_retries = Some(0);
|
||||
});
|
||||
let test = builder
|
||||
.build_with_websocket_server(&server)
|
||||
.await
|
||||
.expect("build websocket codex");
|
||||
|
||||
let submission_id = test
|
||||
.codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "hello".into(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
final_output_json_schema: None,
|
||||
})
|
||||
.await
|
||||
.expect("submission should succeed while emitting invalid request events");
|
||||
|
||||
let error_event = wait_for_event(&test.codex, |msg| matches!(msg, EventMsg::Error(_))).await;
|
||||
let EventMsg::Error(error_event) = error_event else {
|
||||
unreachable!();
|
||||
};
|
||||
assert!(
|
||||
error_event
|
||||
.message
|
||||
.to_lowercase()
|
||||
.contains("does not support image inputs"),
|
||||
"unexpected error message for submission {submission_id}: {}",
|
||||
error_event.message
|
||||
);
|
||||
|
||||
server.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_websocket_appends_on_prefix() {
|
||||
skip_if_no_network!();
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue