Prefer websockets when providers support them (#13592)

Remove all flags and model settings.

---------

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
pakrym-oai 2026-03-17 19:46:44 -07:00 committed by GitHub
parent d950543e65
commit 770616414a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 348 additions and 303 deletions

View file

@ -1943,6 +1943,7 @@ impl CodexMessageProcessor {
config_overrides,
typesafe_overrides,
&cloud_requirements,
&listener_task_context.codex_home,
)
.await
{
@ -3392,6 +3393,7 @@ impl CodexMessageProcessor {
typesafe_overrides,
history_cwd,
&cloud_requirements,
&self.config.codex_home,
)
.await
{
@ -3918,6 +3920,7 @@ impl CodexMessageProcessor {
typesafe_overrides,
history_cwd,
&cloud_requirements,
&self.config.codex_home,
)
.await
{
@ -7016,6 +7019,7 @@ impl CodexMessageProcessor {
},
Some(command_cwd.clone()),
&cloud_requirements,
&config.codex_home,
)
.await;
let setup_result = match derived_config {
@ -7610,6 +7614,7 @@ async fn derive_config_from_params(
request_overrides: Option<HashMap<String, serde_json::Value>>,
typesafe_overrides: ConfigOverrides,
cloud_requirements: &CloudRequirementsLoader,
codex_home: &Path,
) -> std::io::Result<Config> {
let merged_cli_overrides = cli_overrides
.iter()
@ -7623,6 +7628,7 @@ async fn derive_config_from_params(
.collect::<Vec<_>>();
codex_core::config::ConfigBuilder::default()
.codex_home(codex_home.to_path_buf())
.cli_overrides(merged_cli_overrides)
.harness_overrides(typesafe_overrides)
.cloud_requirements(cloud_requirements.clone())
@ -7636,6 +7642,7 @@ async fn derive_config_for_cwd(
typesafe_overrides: ConfigOverrides,
cwd: Option<PathBuf>,
cloud_requirements: &CloudRequirementsLoader,
codex_home: &Path,
) -> std::io::Result<Config> {
let merged_cli_overrides = cli_overrides
.iter()
@ -7649,6 +7656,7 @@ async fn derive_config_for_cwd(
.collect::<Vec<_>>();
codex_core::config::ConfigBuilder::default()
.codex_home(codex_home.to_path_buf())
.cli_overrides(merged_cli_overrides)
.harness_overrides(typesafe_overrides)
.fallback_cwd(cwd)

View file

@ -34,21 +34,23 @@ pub fn write_mock_responses_config_toml(
Some(true) => "requires_openai_auth = true\n".to_string(),
Some(false) | None => String::new(),
};
let provider_block = if model_provider_id == "openai" {
String::new()
let provider_name = if matches!(requires_openai_auth, Some(true)) {
"OpenAI"
} else {
format!(
r#"
[model_providers.mock_provider]
name = "Mock provider for test"
"Mock provider for test"
};
let provider_block = format!(
r#"
[model_providers.{model_provider_id}]
name = "{provider_name}"
base_url = "{server_uri}/v1"
wire_api = "responses"
request_max_retries = 0
stream_max_retries = 0
supports_websockets = false
{requires_line}
"#
)
};
);
let openai_base_url_line = if model_provider_id == "openai" {
format!("openai_base_url = \"{server_uri}/v1\"\n")
} else {

View file

@ -45,7 +45,6 @@ fn preset_to_info(preset: &ModelPreset, priority: i32) -> ModelInfo {
effective_context_window_percent: 95,
experimental_supported_tools: Vec::new(),
input_modalities: default_input_modalities(),
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
}

View file

@ -149,7 +149,7 @@ async fn auto_compaction_remote_emits_started_and_completed_items() -> Result<()
&BTreeMap::default(),
REMOTE_AUTO_COMPACT_LIMIT,
Some(true),
"openai",
"mock_provider",
COMPACT_PROMPT,
)?;
write_chatgpt_auth(

View file

@ -93,7 +93,6 @@ async fn models_client_hits_models_endpoint() {
effective_context_window_percent: 95,
experimental_supported_tools: Vec::new(),
input_modalities: default_input_modalities(),
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
}],

View file

@ -1,7 +1,6 @@
{
"models": [
{
"prefer_websockets": false,
"support_verbosity": true,
"default_verbosity": "low",
"apply_patch_tool_type": "freeform",
@ -75,7 +74,6 @@
"supports_reasoning_summaries": true
},
{
"prefer_websockets": false,
"support_verbosity": true,
"default_verbosity": "low",
"apply_patch_tool_type": "freeform",
@ -146,7 +144,6 @@
"supports_reasoning_summaries": true
},
{
"prefer_websockets": false,
"support_verbosity": false,
"default_verbosity": null,
"apply_patch_tool_type": "freeform",
@ -221,7 +218,6 @@
"supports_reasoning_summaries": true
},
{
"prefer_websockets": false,
"support_verbosity": false,
"default_verbosity": null,
"apply_patch_tool_type": "freeform",
@ -289,7 +285,6 @@
"supports_reasoning_summaries": true
},
{
"prefer_websockets": false,
"support_verbosity": false,
"default_verbosity": null,
"apply_patch_tool_type": "freeform",
@ -353,7 +348,6 @@
"supports_reasoning_summaries": true
},
{
"prefer_websockets": false,
"support_verbosity": true,
"default_verbosity": "low",
"apply_patch_tool_type": "freeform",
@ -421,7 +415,6 @@
"supports_reasoning_summaries": true
},
{
"prefer_websockets": false,
"support_verbosity": true,
"default_verbosity": "low",
"apply_patch_tool_type": "freeform",
@ -485,7 +478,6 @@
"supports_reasoning_summaries": true
},
{
"prefer_websockets": false,
"support_verbosity": false,
"default_verbosity": null,
"apply_patch_tool_type": "freeform",
@ -549,7 +541,6 @@
"supports_reasoning_summaries": true
},
{
"prefer_websockets": false,
"support_verbosity": true,
"default_verbosity": null,
"apply_patch_tool_type": null,
@ -617,7 +608,6 @@
"supports_reasoning_summaries": true
},
{
"prefer_websockets": false,
"support_verbosity": true,
"default_verbosity": null,
"apply_patch_tool_type": "freeform",
@ -677,7 +667,6 @@
"supports_reasoning_summaries": true
},
{
"prefer_websockets": false,
"support_verbosity": true,
"default_verbosity": null,
"apply_patch_tool_type": "freeform",
@ -737,7 +726,6 @@
"supports_reasoning_summaries": true
},
{
"prefer_websockets": false,
"support_verbosity": false,
"default_verbosity": null,
"apply_patch_tool_type": "freeform",
@ -797,7 +785,6 @@
"supports_reasoning_summaries": true
},
{
"prefer_websockets": false,
"support_verbosity": false,
"default_verbosity": null,
"apply_patch_tool_type": "freeform",

View file

@ -2,7 +2,7 @@
//!
//! `ModelClient` is intended to live for the lifetime of a Codex session and holds the stable
//! configuration and state needed to talk to a provider (auth, provider selection, conversation id,
//! and feature-gated request behavior).
//! and transport fallback state).
//!
//! Per-turn settings (model selection, reasoning controls, telemetry context, and turn metadata)
//! are passed explicitly to streaming and unary methods so that the turn lifetime is visible at the
@ -94,7 +94,6 @@ use crate::auth::RefreshTokenError;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::client_common::ResponseStream;
use crate::config::Config;
use crate::default_client::build_reqwest_client;
use crate::error::CodexErr;
use crate::error::Result;
@ -122,14 +121,6 @@ const MEMORIES_SUMMARIZE_ENDPOINT: &str = "/memories/trace_summarize";
#[cfg(test)]
pub(crate) const WEBSOCKET_CONNECT_TIMEOUT: Duration =
Duration::from_millis(crate::model_provider_info::DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS);
pub fn ws_version_from_features(config: &Config) -> bool {
config
.features
.enabled(crate::features::Feature::ResponsesWebsockets)
|| config
.features
.enabled(crate::features::Feature::ResponsesWebsocketsV2)
}
/// Session-scoped state shared by all [`ModelClient`] clones.
///
@ -143,7 +134,6 @@ struct ModelClientState {
auth_env_telemetry: AuthEnvTelemetry,
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
responses_websockets_enabled_by_feature: bool,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
@ -175,8 +165,7 @@ impl RequestRouteTelemetry {
/// A session-scoped client for model-provider API calls.
///
/// This holds configuration and state that should be shared across turns within a Codex session
/// (auth, provider selection, conversation id, feature-gated request behavior, and transport
/// fallback state).
/// (auth, provider selection, conversation id, and transport fallback state).
///
/// WebSocket fallback is session-scoped: once a turn activates the HTTP fallback, subsequent turns
/// will also use HTTP for the remainder of the session.
@ -265,7 +254,6 @@ impl ModelClient {
provider: ModelProviderInfo,
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
responses_websockets_enabled_by_feature: bool,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
@ -282,7 +270,6 @@ impl ModelClient {
auth_env_telemetry,
session_source,
model_verbosity,
responses_websockets_enabled_by_feature,
enable_request_compression,
include_timing_metrics,
beta_features_header,
@ -324,9 +311,9 @@ impl ModelClient {
pub(crate) fn force_http_fallback(
&self,
session_telemetry: &SessionTelemetry,
model_info: &ModelInfo,
_model_info: &ModelInfo,
) -> bool {
let websocket_enabled = self.responses_websocket_enabled(model_info);
let websocket_enabled = self.responses_websocket_enabled();
let activated =
websocket_enabled && !self.state.disable_websockets.swap(true, Ordering::Relaxed);
if activated {
@ -517,19 +504,16 @@ impl ModelClient {
/// Returns whether the Responses-over-WebSocket transport is active for this session.
///
/// This combines provider capability and feature gating; both must be true for websocket paths
/// to be eligible.
///
/// If websockets are only enabled via model preference (no explicit feature flag), prefer the
/// current v2 behavior.
pub fn responses_websocket_enabled(&self, model_info: &ModelInfo) -> bool {
/// WebSocket use is controlled by provider capability and session-scoped fallback state.
pub fn responses_websocket_enabled(&self) -> bool {
if !self.state.provider.supports_websockets
|| self.state.disable_websockets.load(Ordering::Relaxed)
|| (*CODEX_RS_SSE_FIXTURE).is_some()
{
return false;
}
self.state.responses_websockets_enabled_by_feature || model_info.prefer_websockets
true
}
/// Returns auth + provider configuration resolved from the current session auth state.
@ -868,9 +852,9 @@ impl ModelClientSession {
pub async fn preconnect_websocket(
&mut self,
session_telemetry: &SessionTelemetry,
model_info: &ModelInfo,
_model_info: &ModelInfo,
) -> std::result::Result<(), ApiError> {
if !self.client.responses_websocket_enabled(model_info) {
if !self.client.responses_websocket_enabled() {
return Ok(());
}
if self.websocket_session.connection.is_some() {
@ -1248,7 +1232,7 @@ impl ModelClientSession {
service_tier: Option<ServiceTier>,
turn_metadata_header: Option<&str>,
) -> Result<()> {
if !self.client.responses_websocket_enabled(model_info) {
if !self.client.responses_websocket_enabled() {
return Ok(());
}
if self.websocket_session.last_request.is_some() {
@ -1292,8 +1276,8 @@ impl ModelClientSession {
///
/// The caller is responsible for passing per-turn settings explicitly (model selection,
/// reasoning settings, telemetry context, and turn metadata). This method will prefer the
/// Responses WebSocket transport when enabled and healthy, and will fall back to the HTTP
/// Responses API transport otherwise.
/// Responses WebSocket transport when the provider supports it and it remains healthy, and will
/// fall back to the HTTP Responses API transport otherwise.
pub async fn stream(
&mut self,
prompt: &Prompt,
@ -1307,7 +1291,7 @@ impl ModelClientSession {
let wire_api = self.client.state.provider.wire_api;
match wire_api {
WireApi::Responses => {
if self.client.responses_websocket_enabled(model_info) {
if self.client.responses_websocket_enabled() {
match self
.stream_responses_websocket(
prompt,

View file

@ -23,7 +23,6 @@ fn test_model_client(session_source: SessionSource) -> ModelClient {
None,
false,
false,
false,
None,
)
}

View file

@ -53,7 +53,6 @@ use crate::terminal;
use crate::truncate::TruncationPolicy;
use crate::turn_metadata::TurnMetadataState;
use crate::util::error_or_panic;
use crate::ws_version_from_features;
use async_channel::Receiver;
use async_channel::Sender;
use chrono::Local;
@ -1807,7 +1806,6 @@ impl Session {
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
ws_version_from_features(config.as_ref()),
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Self::build_model_client_beta_features_header(config.as_ref()),
@ -6239,10 +6237,7 @@ async fn run_sampling_request(
// transient reconnect messages. In debug builds, keep full visibility for diagnosis.
let report_error = retries > 1
|| cfg!(debug_assertions)
|| !sess
.services
.model_client
.responses_websocket_enabled(&turn_context.model_info);
|| !sess.services.model_client.responses_websocket_enabled();
if report_error {
// Surface retry information to any UI/frontend so the
// user understands what is happening instead of staring

View file

@ -239,7 +239,6 @@ fn test_model_client_session() -> crate::client::ModelClientSession {
None,
false,
false,
false,
None,
)
.new_session()
@ -2513,7 +2512,6 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
ws_version_from_features(config.as_ref()),
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),
@ -3308,7 +3306,6 @@ pub(crate) async fn make_session_and_context_with_dynamic_tools_and_rx(
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
ws_version_from_features(config.as_ref()),
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),

View file

@ -184,9 +184,9 @@ pub enum Feature {
TuiAppServer,
/// Prevent idle system sleep while a turn is actively running.
PreventIdleSleep,
/// Use the Responses API WebSocket transport for OpenAI by default.
/// Legacy rollout flag for Responses API WebSocket transport experiments.
ResponsesWebsockets,
/// Enable Responses API websocket v2 mode.
/// Legacy rollout flag for Responses API WebSocket transport v2 experiments.
ResponsesWebsocketsV2,
}
@ -860,13 +860,13 @@ pub const FEATURES: &[FeatureSpec] = &[
FeatureSpec {
id: Feature::ResponsesWebsockets,
key: "responses_websockets",
stage: Stage::UnderDevelopment,
stage: Stage::Removed,
default_enabled: false,
},
FeatureSpec {
id: Feature::ResponsesWebsocketsV2,
key: "responses_websockets_v2",
stage: Stage::UnderDevelopment,
stage: Stage::Removed,
default_enabled: false,
},
];

View file

@ -162,7 +162,6 @@ pub(crate) use codex_shell_command::powershell;
pub use client::ModelClient;
pub use client::ModelClientSession;
pub use client::X_CODEX_TURN_METADATA_HEADER;
pub use client::ws_version_from_features;
pub use client_common::Prompt;
pub use client_common::REVIEW_PROMPT;
pub use client_common::ResponseEvent;

View file

@ -88,7 +88,6 @@ pub(crate) fn model_info_from_slug(slug: &str) -> ModelInfo {
effective_context_window_percent: 95,
experimental_supported_tools: Vec::new(),
input_modalities: default_input_modalities(),
prefer_websockets: false,
used_fallback_model_metadata: true, // this is the fallback model metadata
supports_search_tool: false,
}

View file

@ -153,11 +153,8 @@ impl TestCodexBuilder {
let base_url_clone = base_url.clone();
self.config_mutators.push(Box::new(move |config| {
config.model_provider.base_url = Some(base_url_clone);
config.model_provider.supports_websockets = true;
config.experimental_realtime_ws_model = Some("realtime-test-model".to_string());
config
.features
.enable(Feature::ResponsesWebsockets)
.expect("test config should allow feature update");
}));
Box::pin(self.build_with_home_and_base_url(base_url, home, /*resume_from*/ None)).await
}
@ -271,6 +268,9 @@ impl TestCodexBuilder {
) -> anyhow::Result<(Config, Arc<TempDir>)> {
let model_provider = ModelProviderInfo {
base_url: Some(base_url),
// Most core tests use SSE-only mock servers, so keep websocket transport off unless
// a test explicitly opts into websocket coverage.
supports_websockets: false,
..built_in_model_providers(/*openai_base_url*/ None)["openai"].clone()
};
let cwd = Arc::new(TempDir::new()?);

View file

@ -94,7 +94,6 @@ async fn responses_stream_includes_subagent_header_on_review() {
config.model_verbosity,
false,
false,
false,
None,
);
let mut client_session = client.new_session();
@ -208,7 +207,6 @@ async fn responses_stream_includes_subagent_header_on_other() {
config.model_verbosity,
false,
false,
false,
None,
);
let mut client_session = client.new_session();
@ -321,7 +319,6 @@ async fn responses_respects_model_info_overrides_from_config() {
config.model_verbosity,
false,
false,
false,
None,
);
let mut client_session = client.new_session();

View file

@ -717,6 +717,7 @@ async fn chatgpt_auth_sends_correct_request() {
let mut model_provider = built_in_model_providers(/* openai_base_url */ None)["openai"].clone();
model_provider.base_url = Some(format!("{}/api/codex", server.uri()));
model_provider.supports_websockets = false;
let mut builder = test_codex()
.with_auth(create_dummy_codex_auth())
.with_config(move |config| {
@ -791,6 +792,7 @@ async fn prefers_apikey_when_config_prefers_apikey_even_with_chatgpt_tokens() {
let model_provider = ModelProviderInfo {
base_url: Some(format!("{}/v1", server.uri())),
supports_websockets: false,
..built_in_model_providers(/* openai_base_url */ None)["openai"].clone()
};
@ -1832,7 +1834,6 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
config.model_verbosity,
false,
false,
false,
None,
);
let mut client_session = client.new_session();
@ -1968,6 +1969,7 @@ async fn token_count_includes_rate_limits_snapshot() {
let mut provider = built_in_model_providers(/* openai_base_url */ None)["openai"].clone();
provider.base_url = Some(format!("{}/v1", server.uri()));
provider.supports_websockets = false;
let mut builder = test_codex()
.with_auth(CodexAuth::from_api_key("test"))

View file

@ -8,7 +8,6 @@ use codex_core::ResponseEvent;
use codex_core::WireApi;
use codex_core::X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER;
use codex_core::features::Feature;
use codex_core::ws_version_from_features;
use codex_otel::SessionTelemetry;
use codex_otel::TelemetryAuthMode;
use codex_otel::metrics::MetricsClient;
@ -98,6 +97,28 @@ async fn responses_websocket_streams_request() {
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_streams_without_feature_flag_when_provider_supports_websockets() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![vec![
ev_response_created("resp-1"),
ev_completed("resp-1"),
]]])
.await;
let harness = websocket_harness_with_options(&server, false).await;
let mut client_session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
stream_until_complete(&mut client_session, &harness, &prompt).await;
assert_eq!(server.handshakes().len(), 1);
assert_eq!(server.single_connection().len(), 1);
server.shutdown().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_preconnect_reuses_connection() {
skip_if_no_network!();
@ -133,7 +154,7 @@ async fn responses_websocket_request_prewarm_reuses_connection() {
]])
.await;
let harness = websocket_harness_with_options(&server, false, false, true, true).await;
let harness = websocket_harness_with_options(&server, true).await;
let mut client_session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
client_session
@ -252,7 +273,7 @@ async fn responses_websocket_request_prewarm_is_reused_even_with_header_changes(
]])
.await;
let harness = websocket_harness_with_options(&server, false, false, true, true).await;
let harness = websocket_harness_with_options(&server, true).await;
let mut client_session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
client_session
@ -308,7 +329,7 @@ async fn responses_websocket_request_prewarm_is_reused_even_with_header_changes(
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_prewarm_uses_v2_when_model_prefers_websockets_and_feature_disabled() {
async fn responses_websocket_prewarm_uses_v2_when_provider_supports_websockets() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![vec![
@ -317,7 +338,7 @@ async fn responses_websocket_prewarm_uses_v2_when_model_prefers_websockets_and_f
]]])
.await;
let harness = websocket_harness_with_options(&server, false, false, false, true).await;
let harness = websocket_harness_with_options(&server, false).await;
let mut client_session = harness.client.new_session();
let prompt = prompt_with_input(vec![message_item("hello")]);
client_session
@ -374,7 +395,7 @@ async fn responses_websocket_preconnect_runs_when_only_v2_feature_enabled() {
]]])
.await;
let harness = websocket_harness_with_options(&server, false, false, true, false).await;
let harness = websocket_harness_with_options(&server, true).await;
let mut client_session = harness.client.new_session();
client_session
.preconnect_websocket(&harness.session_telemetry, &harness.model_info)
@ -404,7 +425,7 @@ async fn responses_websocket_preconnect_runs_when_only_v2_feature_enabled() {
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn responses_websocket_v2_requests_use_v2_when_model_prefers_websockets() {
async fn responses_websocket_v2_requests_use_v2_when_provider_supports_websockets() {
skip_if_no_network!();
let server = start_websocket_server(vec![vec![
@ -417,7 +438,7 @@ async fn responses_websocket_v2_requests_use_v2_when_model_prefers_websockets()
]])
.await;
let harness = websocket_harness_with_options(&server, false, false, true, true).await;
let harness = websocket_harness_with_options(&server, true).await;
let mut client_session = harness.client.new_session();
let prompt_one = prompt_with_input(vec![message_item("hello")]);
let prompt_two = prompt_with_input(vec![
@ -466,7 +487,7 @@ async fn responses_websocket_v2_incremental_requests_are_reused_across_turns() {
]])
.await;
let harness = websocket_harness_with_options(&server, false, false, true, true).await;
let harness = websocket_harness_with_options(&server, false).await;
let prompt_one = prompt_with_input(vec![message_item("hello")]);
let prompt_two = prompt_with_input(vec![
message_item("hello"),
@ -510,7 +531,7 @@ async fn responses_websocket_v2_wins_when_both_features_enabled() {
]])
.await;
let harness = websocket_harness_with_options(&server, false, true, true, false).await;
let harness = websocket_harness_with_options(&server, false).await;
let mut client_session = harness.client.new_session();
let prompt_one = prompt_with_input(vec![message_item("hello")]);
let prompt_two = prompt_with_input(vec![
@ -1534,69 +1555,39 @@ async fn websocket_harness_with_runtime_metrics(
server: &WebSocketTestServer,
runtime_metrics_enabled: bool,
) -> WebsocketTestHarness {
websocket_harness_with_options(server, runtime_metrics_enabled, true, false, false).await
websocket_harness_with_options(server, runtime_metrics_enabled).await
}
async fn websocket_harness_with_v2(
server: &WebSocketTestServer,
websocket_v2_enabled: bool,
runtime_metrics_enabled: bool,
) -> WebsocketTestHarness {
websocket_harness_with_options(server, false, true, websocket_v2_enabled, false).await
websocket_harness_with_options(server, runtime_metrics_enabled).await
}
async fn websocket_harness_with_options(
server: &WebSocketTestServer,
runtime_metrics_enabled: bool,
websocket_enabled: bool,
websocket_v2_enabled: bool,
prefer_websockets: bool,
) -> WebsocketTestHarness {
websocket_harness_with_provider_options(
websocket_provider(server),
runtime_metrics_enabled,
websocket_enabled,
websocket_v2_enabled,
prefer_websockets,
)
.await
websocket_harness_with_provider_options(websocket_provider(server), runtime_metrics_enabled)
.await
}
async fn websocket_harness_with_provider_options(
provider: ModelProviderInfo,
runtime_metrics_enabled: bool,
websocket_enabled: bool,
websocket_v2_enabled: bool,
prefer_websockets: bool,
) -> WebsocketTestHarness {
let codex_home = TempDir::new().unwrap();
let mut config = load_default_config_for_test(&codex_home).await;
config.model = Some(MODEL.to_string());
if websocket_enabled {
config
.features
.enable(Feature::ResponsesWebsockets)
.expect("test config should allow feature update");
} else {
config
.features
.disable(Feature::ResponsesWebsockets)
.expect("test config should allow feature update");
}
if runtime_metrics_enabled {
config
.features
.enable(Feature::RuntimeMetrics)
.expect("test config should allow feature update");
}
if websocket_v2_enabled {
config
.features
.enable(Feature::ResponsesWebsocketsV2)
.expect("test config should allow feature update");
}
let config = Arc::new(config);
let mut model_info = codex_core::test_support::construct_model_info_offline(MODEL, &config);
model_info.prefer_websockets = prefer_websockets;
let model_info = codex_core::test_support::construct_model_info_offline(MODEL, &config);
let conversation_id = ThreadId::new();
let auth_manager =
codex_core::test_support::auth_manager_from_auth(CodexAuth::from_api_key("Test API Key"));
@ -1627,7 +1618,6 @@ async fn websocket_harness_with_provider_options(
provider.clone(),
SessionSource::Exec,
config.model_verbosity,
ws_version_from_features(&config),
false,
runtime_metrics_enabled,
None,

View file

@ -96,6 +96,7 @@ fn non_openai_model_provider(server: &MockServer) -> ModelProviderInfo {
let mut provider = built_in_model_providers(/* openai_base_url */ None)["openai"].clone();
provider.name = "OpenAI (test)".into();
provider.base_url = Some(format!("{}/v1", server.uri()));
provider.supports_websockets = false;
provider
}

View file

@ -53,7 +53,6 @@ fn test_model_info(
visibility: ModelVisibility::List,
supported_in_api: true,
input_modalities,
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
priority: 1,
@ -849,7 +848,6 @@ async fn model_switch_to_smaller_model_updates_token_context_window() -> Result<
visibility: ModelVisibility::List,
supported_in_api: true,
input_modalities: default_input_modalities(),
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
priority: 1,

View file

@ -351,7 +351,6 @@ fn test_remote_model(slug: &str, priority: i32) -> ModelInfo {
effective_context_window_percent: 95,
experimental_supported_tools: Vec::new(),
input_modalities: default_input_modalities(),
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
}

View file

@ -659,7 +659,6 @@ async fn remote_model_friendly_personality_instructions_with_feature() -> anyhow
effective_context_window_percent: 95,
experimental_supported_tools: Vec::new(),
input_modalities: default_input_modalities(),
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
};
@ -775,7 +774,6 @@ async fn user_turn_personality_remote_model_template_includes_update_message() -
effective_context_window_percent: 95,
experimental_supported_tools: Vec::new(),
input_modalities: default_input_modalities(),
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
};

View file

@ -289,7 +289,6 @@ async fn remote_models_remote_model_uses_unified_exec() -> Result<()> {
visibility: ModelVisibility::List,
supported_in_api: true,
input_modalities: default_input_modalities(),
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
priority: 1,
@ -533,7 +532,6 @@ async fn remote_models_apply_remote_base_instructions() -> Result<()> {
visibility: ModelVisibility::List,
supported_in_api: true,
input_modalities: default_input_modalities(),
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
priority: 1,
@ -1001,7 +999,6 @@ fn test_remote_model_with_policy(
visibility,
supported_in_api: true,
input_modalities: default_input_modalities(),
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
priority,

View file

@ -419,7 +419,6 @@ async fn stdio_image_responses_are_sanitized_for_text_only_model() -> anyhow::Re
effective_context_window_percent: 95,
experimental_supported_tools: Vec::new(),
input_modalities: vec![InputModality::Text],
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
}],

View file

@ -64,7 +64,6 @@ fn test_model_info(
visibility,
supported_in_api: true,
input_modalities: default_input_modalities(),
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
priority: 1,

View file

@ -1270,7 +1270,6 @@ async fn view_image_tool_returns_unsupported_message_for_text_only_model() -> an
visibility: ModelVisibility::List,
supported_in_api: true,
input_modalities: vec![InputModality::Text],
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
priority: 1,

View file

@ -1,5 +1,4 @@
use anyhow::Result;
use codex_core::features::Feature;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::Op;
@ -45,10 +44,7 @@ async fn websocket_fallback_switches_to_http_on_upgrade_required_connect() -> Re
move |config| {
config.model_provider.base_url = Some(base_url);
config.model_provider.wire_api = codex_core::WireApi::Responses;
config
.features
.enable(Feature::ResponsesWebsockets)
.expect("test config should allow feature update");
config.model_provider.supports_websockets = true;
// If we don't treat 426 specially, the sampling loop would retry the WebSocket
// handshake before switching to the HTTP transport.
config.model_provider.stream_max_retries = Some(2);
@ -94,10 +90,7 @@ async fn websocket_fallback_switches_to_http_after_retries_exhausted() -> Result
move |config| {
config.model_provider.base_url = Some(base_url);
config.model_provider.wire_api = codex_core::WireApi::Responses;
config
.features
.enable(Feature::ResponsesWebsockets)
.expect("test config should allow feature update");
config.model_provider.supports_websockets = true;
config.model_provider.stream_max_retries = Some(2);
config.model_provider.request_max_retries = Some(0);
}
@ -142,10 +135,7 @@ async fn websocket_fallback_hides_first_websocket_retry_stream_error() -> Result
move |config| {
config.model_provider.base_url = Some(base_url);
config.model_provider.wire_api = codex_core::WireApi::Responses;
config
.features
.enable(Feature::ResponsesWebsockets)
.expect("test config should allow feature update");
config.model_provider.supports_websockets = true;
config.model_provider.stream_max_retries = Some(2);
config.model_provider.request_max_retries = Some(0);
}
@ -220,10 +210,7 @@ async fn websocket_fallback_is_sticky_across_turns() -> Result<()> {
move |config| {
config.model_provider.base_url = Some(base_url);
config.model_provider.wire_api = codex_core::WireApi::Responses;
config
.features
.enable(Feature::ResponsesWebsockets)
.expect("test config should allow feature update");
config.model_provider.supports_websockets = true;
config.model_provider.stream_max_retries = Some(2);
config.model_provider.request_max_retries = Some(0);
}

View file

@ -284,9 +284,6 @@ pub struct ModelInfo {
/// Input modalities accepted by the backend for this model.
#[serde(default = "default_input_modalities")]
pub input_modalities: Vec<InputModality>,
/// When true, this model should use websocket transport even when websocket features are off.
#[serde(default)]
pub prefer_websockets: bool,
/// Internal-only marker set by core when a model slug resolved to fallback metadata.
#[serde(default, skip_serializing, skip_deserializing)]
#[schemars(skip)]
@ -548,7 +545,6 @@ mod tests {
effective_context_window_percent: 95,
experimental_supported_tools: vec![],
input_modalities: default_input_modalities(),
prefer_websockets: false,
used_fallback_model_metadata: false,
supports_search_tool: false,
}
@ -751,8 +747,7 @@ mod tests {
"auto_compact_token_limit": null,
"effective_context_window_percent": 95,
"experimental_supported_tools": [],
"input_modalities": ["text", "image"],
"prefer_websockets": false
"input_modalities": ["text", "image"]
}))
.expect("deserialize model info");

View file

@ -3,6 +3,7 @@ module.exports = {
preset: "ts-jest/presets/default-esm",
testEnvironment: "node",
extensionsToTreatAsEsm: [".ts"],
setupFilesAfterEnv: ["<rootDir>/tests/setupCodexHome.ts"],
moduleNameMapper: {
"^(\\.{1,2}/.*)\\.js$": "$1",
},

View file

@ -1,9 +1,5 @@
import path from "node:path";
import { describe, expect, it } from "@jest/globals";
import { Codex } from "../src/codex";
import {
assistantMessage,
responseCompleted,
@ -13,8 +9,7 @@ import {
SseResponseBody,
startResponsesTestProxy,
} from "./responsesProxy";
const codexExecPath = path.join(process.cwd(), "..", "..", "codex-rs", "target", "debug", "codex");
import { createMockClient } from "./testCodex";
function* infiniteShellCall(): Generator<SseResponseBody> {
while (true) {
@ -28,9 +23,9 @@ describe("AbortSignal support", () => {
statusCode: 200,
responseBodies: infiniteShellCall(),
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
// Create an abort controller and abort it immediately
@ -40,6 +35,7 @@ describe("AbortSignal support", () => {
// The operation should fail because the signal is already aborted
await expect(thread.run("Hello, world!", { signal: controller.signal })).rejects.toThrow();
} finally {
cleanup();
await close();
}
});
@ -49,9 +45,9 @@ describe("AbortSignal support", () => {
statusCode: 200,
responseBodies: infiniteShellCall(),
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
// Create an abort controller and abort it immediately
@ -78,6 +74,7 @@ describe("AbortSignal support", () => {
expect(error).toBeDefined();
}
} finally {
cleanup();
await close();
}
});
@ -87,9 +84,9 @@ describe("AbortSignal support", () => {
statusCode: 200,
responseBodies: infiniteShellCall(),
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
const controller = new AbortController();
@ -103,6 +100,7 @@ describe("AbortSignal support", () => {
// The operation should fail
await expect(runPromise).rejects.toThrow();
} finally {
cleanup();
await close();
}
});
@ -112,9 +110,9 @@ describe("AbortSignal support", () => {
statusCode: 200,
responseBodies: infiniteShellCall(),
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
const controller = new AbortController();
@ -137,6 +135,7 @@ describe("AbortSignal support", () => {
})(),
).rejects.toThrow();
} finally {
cleanup();
await close();
}
});
@ -146,9 +145,9 @@ describe("AbortSignal support", () => {
statusCode: 200,
responseBodies: [sse(responseStarted(), assistantMessage("Hi!"), responseCompleted())],
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
const controller = new AbortController();
@ -159,6 +158,7 @@ describe("AbortSignal support", () => {
expect(result.finalResponse).toBe("Hi!");
expect(result.items).toHaveLength(1);
} finally {
cleanup();
await close();
}
});

View file

@ -93,4 +93,54 @@ describe("CodexExec", () => {
expect(imageIndex).toBeGreaterThan(-1);
expect(resumeIndex).toBeLessThan(imageIndex);
});
it("allows overriding the env passed to the Codex CLI", async () => {
const { CodexExec } = await import("../src/exec");
spawnMock.mockClear();
const child = new FakeChildProcess();
spawnMock.mockReturnValue(child as unknown as child_process.ChildProcess);
setImmediate(() => {
child.stdout.end();
child.stderr.end();
child.emit("exit", 0, null);
});
process.env.CODEX_ENV_SHOULD_NOT_LEAK = "leak";
try {
const exec = new CodexExec("codex", {
CODEX_HOME: "/tmp/codex-home",
CUSTOM_ENV: "custom",
});
for await (const _ of exec.run({
input: "custom env",
apiKey: "test",
baseUrl: "https://example.test",
})) {
// no-op
}
const commandArgs = spawnMock.mock.calls[0]?.[1] as string[] | undefined;
expect(commandArgs).toBeDefined();
const spawnOptions = spawnMock.mock.calls[0]?.[2] as child_process.SpawnOptions | undefined;
const spawnEnv = spawnOptions?.env as Record<string, string> | undefined;
expect(spawnEnv).toBeDefined();
if (!spawnEnv || !commandArgs) {
throw new Error("Spawn args missing");
}
expect(spawnEnv.CODEX_HOME).toBe("/tmp/codex-home");
expect(spawnEnv.CUSTOM_ENV).toBe("custom");
expect(spawnEnv.CODEX_ENV_SHOULD_NOT_LEAK).toBeUndefined();
expect(spawnEnv.OPENAI_BASE_URL).toBeUndefined();
expect(spawnEnv.CODEX_API_KEY).toBe("test");
expect(spawnEnv.CODEX_INTERNAL_ORIGINATOR_OVERRIDE).toBeDefined();
expect(commandArgs).toContain("--config");
expect(commandArgs).toContain(`openai_base_url=${JSON.stringify("https://example.test")}`);
} finally {
delete process.env.CODEX_ENV_SHOULD_NOT_LEAK;
}
});
});

View file

@ -5,8 +5,6 @@ import path from "node:path";
import { codexExecSpy } from "./codexExecSpy";
import { describe, expect, it } from "@jest/globals";
import { Codex } from "../src/codex";
import {
assistantMessage,
responseCompleted,
@ -16,8 +14,7 @@ import {
startResponsesTestProxy,
SseResponseBody,
} from "./responsesProxy";
const codexExecPath = path.join(process.cwd(), "..", "..", "codex-rs", "target", "debug", "codex");
import { createMockClient, createTestClient } from "./testCodex";
describe("Codex", () => {
it("returns thread events", async () => {
@ -25,10 +22,9 @@ describe("Codex", () => {
statusCode: 200,
responseBodies: [sse(responseStarted(), assistantMessage("Hi!"), responseCompleted())],
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
const result = await thread.run("Hello, world!");
@ -47,6 +43,7 @@ describe("Codex", () => {
});
expect(thread.id).toEqual(expect.any(String));
} finally {
cleanup();
await close();
}
});
@ -67,10 +64,9 @@ describe("Codex", () => {
),
],
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
await thread.run("first input");
await thread.run("second input");
@ -90,6 +86,7 @@ describe("Codex", () => {
)?.text;
expect(assistantText).toBe("First response");
} finally {
cleanup();
await close();
}
});
@ -110,10 +107,9 @@ describe("Codex", () => {
),
],
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
await thread.run("first input");
await thread.run("second input");
@ -134,6 +130,7 @@ describe("Codex", () => {
)?.text;
expect(assistantText).toBe("First response");
} finally {
cleanup();
await close();
}
});
@ -154,10 +151,9 @@ describe("Codex", () => {
),
],
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const originalThread = client.startThread();
await originalThread.run("first input");
@ -181,6 +177,7 @@ describe("Codex", () => {
)?.text;
expect(assistantText).toBe("First response");
} finally {
cleanup();
await close();
}
});
@ -198,10 +195,9 @@ describe("Codex", () => {
});
const { args: spawnArgs, restore } = codexExecSpy();
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread({
model: "gpt-test-1",
sandboxMode: "workspace-write",
@ -219,6 +215,7 @@ describe("Codex", () => {
expectPair(commandArgs, ["--sandbox", "workspace-write"]);
expectPair(commandArgs, ["--model", "gpt-test-1"]);
} finally {
cleanup();
restore();
await close();
}
@ -237,10 +234,9 @@ describe("Codex", () => {
});
const { args: spawnArgs, restore } = codexExecSpy();
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread({
modelReasoningEffort: "high",
});
@ -250,6 +246,7 @@ describe("Codex", () => {
expect(commandArgs).toBeDefined();
expectPair(commandArgs, ["--config", 'model_reasoning_effort="high"']);
} finally {
cleanup();
restore();
await close();
}
@ -268,10 +265,9 @@ describe("Codex", () => {
});
const { args: spawnArgs, restore } = codexExecSpy();
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread({
networkAccessEnabled: true,
});
@ -281,6 +277,7 @@ describe("Codex", () => {
expect(commandArgs).toBeDefined();
expectPair(commandArgs, ["--config", "sandbox_workspace_write.network_access=true"]);
} finally {
cleanup();
restore();
await close();
}
@ -299,10 +296,9 @@ describe("Codex", () => {
});
const { args: spawnArgs, restore } = codexExecSpy();
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread({
webSearchEnabled: true,
});
@ -312,6 +308,7 @@ describe("Codex", () => {
expect(commandArgs).toBeDefined();
expectPair(commandArgs, ["--config", 'web_search="live"']);
} finally {
cleanup();
restore();
await close();
}
@ -330,10 +327,9 @@ describe("Codex", () => {
});
const { args: spawnArgs, restore } = codexExecSpy();
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread({
webSearchMode: "cached",
});
@ -343,6 +339,7 @@ describe("Codex", () => {
expect(commandArgs).toBeDefined();
expectPair(commandArgs, ["--config", 'web_search="cached"']);
} finally {
cleanup();
restore();
await close();
}
@ -361,10 +358,9 @@ describe("Codex", () => {
});
const { args: spawnArgs, restore } = codexExecSpy();
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread({
webSearchEnabled: false,
});
@ -374,6 +370,7 @@ describe("Codex", () => {
expect(commandArgs).toBeDefined();
expectPair(commandArgs, ["--config", 'web_search="disabled"']);
} finally {
cleanup();
restore();
await close();
}
@ -392,10 +389,9 @@ describe("Codex", () => {
});
const { args: spawnArgs, restore } = codexExecSpy();
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread({
approvalPolicy: "on-request",
});
@ -405,6 +401,7 @@ describe("Codex", () => {
expect(commandArgs).toBeDefined();
expectPair(commandArgs, ["--config", 'approval_policy="on-request"']);
} finally {
cleanup();
restore();
await close();
}
@ -423,20 +420,18 @@ describe("Codex", () => {
});
const { args: spawnArgs, restore } = codexExecSpy();
const { client, cleanup } = createTestClient({
baseUrl: url,
apiKey: "test",
config: {
approval_policy: "never",
sandbox_workspace_write: { network_access: true },
retry_budget: 3,
tool_rules: { allow: ["git status", "git diff"] },
},
});
try {
const client = new Codex({
codexPathOverride: codexExecPath,
baseUrl: url,
apiKey: "test",
config: {
approval_policy: "never",
sandbox_workspace_write: { network_access: true },
retry_budget: 3,
tool_rules: { allow: ["git status", "git diff"] },
},
});
const thread = client.startThread();
await thread.run("apply config overrides");
@ -447,6 +442,7 @@ describe("Codex", () => {
expectPair(commandArgs, ["--config", "retry_budget=3"]);
expectPair(commandArgs, ["--config", 'tool_rules.allow=["git status", "git diff"]']);
} finally {
cleanup();
restore();
await close();
}
@ -465,15 +461,13 @@ describe("Codex", () => {
});
const { args: spawnArgs, restore } = codexExecSpy();
const { client, cleanup } = createTestClient({
baseUrl: url,
apiKey: "test",
config: { approval_policy: "never" },
});
try {
const client = new Codex({
codexPathOverride: codexExecPath,
baseUrl: url,
apiKey: "test",
config: { approval_policy: "never" },
});
const thread = client.startThread({ approvalPolicy: "on-request" });
await thread.run("override approval policy");
@ -485,56 +479,7 @@ describe("Codex", () => {
]);
expect(approvalPolicyOverrides.at(-1)).toBe('approval_policy="on-request"');
} finally {
restore();
await close();
}
});
it("allows overriding the env passed to the Codex CLI", async () => {
const { url, close } = await startResponsesTestProxy({
statusCode: 200,
responseBodies: [
sse(
responseStarted("response_1"),
assistantMessage("Custom env", "item_1"),
responseCompleted("response_1"),
),
],
});
const { args: spawnArgs, envs: spawnEnvs, restore } = codexExecSpy();
process.env.CODEX_ENV_SHOULD_NOT_LEAK = "leak";
try {
const client = new Codex({
codexPathOverride: codexExecPath,
baseUrl: url,
apiKey: "test",
env: { CUSTOM_ENV: "custom" },
});
const thread = client.startThread();
await thread.run("custom env");
const spawnEnv = spawnEnvs[0];
expect(spawnEnv).toBeDefined();
if (!spawnEnv) {
throw new Error("Spawn env missing");
}
const commandArgs = spawnArgs[0];
expect(commandArgs).toBeDefined();
if (!commandArgs) {
throw new Error("Command args missing");
}
expect(spawnEnv.CUSTOM_ENV).toBe("custom");
expect(spawnEnv.CODEX_ENV_SHOULD_NOT_LEAK).toBeUndefined();
expect(spawnEnv.OPENAI_BASE_URL).toBeUndefined();
expect(spawnEnv.CODEX_API_KEY).toBe("test");
expect(spawnEnv.CODEX_INTERNAL_ORIGINATOR_OVERRIDE).toBeDefined();
expect(commandArgs).toContain("--config");
expect(commandArgs).toContain(`openai_base_url=${JSON.stringify(url)}`);
} finally {
delete process.env.CODEX_ENV_SHOULD_NOT_LEAK;
cleanup();
restore();
await close();
}
@ -553,10 +498,9 @@ describe("Codex", () => {
});
const { args: spawnArgs, restore } = codexExecSpy();
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread({
additionalDirectories: ["../backend", "/tmp/shared"],
});
@ -577,6 +521,7 @@ describe("Codex", () => {
}
expect(addDirArgs).toEqual(["../backend", "/tmp/shared"]);
} finally {
cleanup();
restore();
await close();
}
@ -605,9 +550,9 @@ describe("Codex", () => {
additionalProperties: false,
} as const;
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const { client, cleanup } = createMockClient(url);
try {
const thread = client.startThread();
await thread.run("structured", { outputSchema: schema });
@ -634,6 +579,7 @@ describe("Codex", () => {
}
expect(fs.existsSync(schemaPath)).toBe(false);
} finally {
cleanup();
restore();
await close();
}
@ -649,10 +595,9 @@ describe("Codex", () => {
),
],
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
await thread.run([
{ type: "text", text: "Describe file changes" },
@ -664,6 +609,7 @@ describe("Codex", () => {
const lastUser = payload!.json.input.at(-1);
expect(lastUser?.content?.[0]?.text).toBe("Describe file changes\n\nFocus on impacted tests");
} finally {
cleanup();
await close();
}
});
@ -688,10 +634,9 @@ describe("Codex", () => {
imagesDirectoryEntries.forEach((image, index) => {
fs.writeFileSync(image, `image-${index}`);
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
await thread.run([
{ type: "text", text: "describe the images" },
@ -709,6 +654,7 @@ describe("Codex", () => {
}
expect(forwardedImages).toEqual(imagesDirectoryEntries);
} finally {
cleanup();
fs.rmSync(tempDir, { recursive: true, force: true });
restore();
await close();
@ -727,15 +673,13 @@ describe("Codex", () => {
});
const { args: spawnArgs, restore } = codexExecSpy();
const workingDirectory = fs.mkdtempSync(path.join(os.tmpdir(), "codex-working-dir-"));
const { client, cleanup } = createTestClient({
baseUrl: url,
apiKey: "test",
});
try {
const workingDirectory = fs.mkdtempSync(path.join(os.tmpdir(), "codex-working-dir-"));
const client = new Codex({
codexPathOverride: codexExecPath,
baseUrl: url,
apiKey: "test",
});
const thread = client.startThread({
workingDirectory,
skipGitRepoCheck: true,
@ -745,6 +689,8 @@ describe("Codex", () => {
const commandArgs = spawnArgs[0];
expectPair(commandArgs, ["--cd", workingDirectory]);
} finally {
cleanup();
fs.rmSync(workingDirectory, { recursive: true, force: true });
restore();
await close();
}
@ -761,15 +707,13 @@ describe("Codex", () => {
),
],
});
const workingDirectory = fs.mkdtempSync(path.join(os.tmpdir(), "codex-working-dir-"));
const { client, cleanup } = createTestClient({
baseUrl: url,
apiKey: "test",
});
try {
const workingDirectory = fs.mkdtempSync(path.join(os.tmpdir(), "codex-working-dir-"));
const client = new Codex({
codexPathOverride: codexExecPath,
baseUrl: url,
apiKey: "test",
});
const thread = client.startThread({
workingDirectory,
});
@ -777,6 +721,8 @@ describe("Codex", () => {
/Not inside a trusted directory/,
);
} finally {
cleanup();
fs.rmSync(workingDirectory, { recursive: true, force: true });
await close();
}
});
@ -786,10 +732,9 @@ describe("Codex", () => {
statusCode: 200,
responseBodies: [sse(responseStarted(), assistantMessage("Hi!"), responseCompleted())],
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
await thread.run("Hello, originator!");
@ -801,6 +746,7 @@ describe("Codex", () => {
expect(originatorHeader).toBe("codex_sdk_ts");
}
} finally {
cleanup();
await close();
}
});
@ -814,12 +760,13 @@ describe("Codex", () => {
}
})(),
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
await expect(thread.run("fail")).rejects.toThrow("stream disconnected before completion:");
} finally {
cleanup();
await close();
}
}, 10000); // TODO(pakrym): remove timeout

View file

@ -1,8 +1,5 @@
import path from "node:path";
import { describe, expect, it } from "@jest/globals";
import { Codex } from "../src/codex";
import { ThreadEvent } from "../src/index";
import {
@ -12,8 +9,7 @@ import {
sse,
startResponsesTestProxy,
} from "./responsesProxy";
const codexExecPath = path.join(process.cwd(), "..", "..", "codex-rs", "target", "debug", "codex");
import { createMockClient } from "./testCodex";
describe("Codex", () => {
it("returns thread events", async () => {
@ -21,10 +17,9 @@ describe("Codex", () => {
statusCode: 200,
responseBodies: [sse(responseStarted(), assistantMessage("Hi!"), responseCompleted())],
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
const result = await thread.runStreamed("Hello, world!");
@ -60,6 +55,7 @@ describe("Codex", () => {
]);
expect(thread.id).toEqual(expect.any(String));
} finally {
cleanup();
await close();
}
});
@ -80,10 +76,9 @@ describe("Codex", () => {
),
],
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
const first = await thread.runStreamed("first input");
await drainEvents(first.events);
@ -106,6 +101,7 @@ describe("Codex", () => {
)?.text;
expect(assistantText).toBe("First response");
} finally {
cleanup();
await close();
}
});
@ -126,10 +122,9 @@ describe("Codex", () => {
),
],
});
const { client, cleanup } = createMockClient(url);
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const originalThread = client.startThread();
const first = await originalThread.runStreamed("first input");
await drainEvents(first.events);
@ -154,6 +149,7 @@ describe("Codex", () => {
)?.text;
expect(assistantText).toBe("First response");
} finally {
cleanup();
await close();
}
});
@ -169,6 +165,7 @@ describe("Codex", () => {
),
],
});
const { client, cleanup } = createMockClient(url);
const schema = {
type: "object",
@ -180,8 +177,6 @@ describe("Codex", () => {
} as const;
try {
const client = new Codex({ codexPathOverride: codexExecPath, baseUrl: url, apiKey: "test" });
const thread = client.startThread();
const streamed = await thread.runStreamed("structured", { outputSchema: schema });
await drainEvents(streamed.events);
@ -198,6 +193,7 @@ describe("Codex", () => {
schema,
});
} finally {
cleanup();
await close();
}
});

View file

@ -0,0 +1,28 @@
import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import { afterEach, beforeEach } from "@jest/globals";
const originalCodexHome = process.env.CODEX_HOME;
let currentCodexHome: string | undefined;
beforeEach(async () => {
currentCodexHome = await fs.mkdtemp(path.join(os.tmpdir(), "codex-sdk-test-"));
process.env.CODEX_HOME = currentCodexHome;
});
afterEach(async () => {
const codexHomeToDelete = currentCodexHome;
currentCodexHome = undefined;
if (originalCodexHome === undefined) {
delete process.env.CODEX_HOME;
} else {
process.env.CODEX_HOME = originalCodexHome;
}
if (codexHomeToDelete) {
await fs.rm(codexHomeToDelete, { recursive: true, force: true });
}
});

View file

@ -0,0 +1,94 @@
import path from "node:path";
import { Codex } from "../src/codex";
import type { CodexConfigObject } from "../src/codexOptions";
export const codexExecPath = path.join(process.cwd(), "..", "..", "codex-rs", "target", "debug", "codex");
type CreateTestClientOptions = {
apiKey?: string;
baseUrl?: string;
config?: CodexConfigObject;
env?: Record<string, string>;
inheritEnv?: boolean;
};
export type TestClient = {
cleanup: () => void;
client: Codex;
};
export function createMockClient(url: string): TestClient {
return createTestClient({
config: {
model_provider: "mock",
model_providers: {
mock: {
name: "Mock provider for test",
base_url: url,
wire_api: "responses",
supports_websockets: false,
},
},
},
});
}
export function createTestClient(options: CreateTestClientOptions = {}): TestClient {
const env =
options.inheritEnv === false ? { ...options.env } : { ...getCurrentEnv(), ...options.env };
return {
cleanup: () => {},
client: new Codex({
codexPathOverride: codexExecPath,
baseUrl: options.baseUrl,
apiKey: options.apiKey,
config: mergeTestProviderConfig(options.baseUrl, options.config),
env,
}),
};
}
function mergeTestProviderConfig(
baseUrl: string | undefined,
config: CodexConfigObject | undefined,
): CodexConfigObject | undefined {
if (!baseUrl || hasExplicitProviderConfig(config)) {
return config;
}
// Built-in providers are merged before user config, so tests need a custom
// provider entry to force SSE against the local mock server.
return {
...config,
model_provider: "mock",
model_providers: {
mock: {
name: "Mock provider for test",
base_url: baseUrl,
wire_api: "responses",
supports_websockets: false,
},
},
};
}
function hasExplicitProviderConfig(config: CodexConfigObject | undefined): boolean {
return config?.model_provider !== undefined || config?.model_providers !== undefined;
}
function getCurrentEnv(): Record<string, string> {
const env: Record<string, string> = {};
for (const [key, value] of Object.entries(process.env)) {
if (key === "CODEX_INTERNAL_ORIGINATOR_OVERRIDE") {
continue;
}
if (value !== undefined) {
env[key] = value;
}
}
return env;
}