diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index 131e36ce2..966dc6534 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -557,6 +557,79 @@ fn load_output_schema(path: Option) -> Option { } } +#[derive(Debug, Clone, PartialEq, Eq)] +enum PromptDecodeError { + InvalidUtf8 { valid_up_to: usize }, + InvalidUtf16 { encoding: &'static str }, + UnsupportedBom { encoding: &'static str }, +} + +impl std::fmt::Display for PromptDecodeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PromptDecodeError::InvalidUtf8 { valid_up_to } => write!( + f, + "input is not valid UTF-8 (invalid byte at offset {valid_up_to}). Convert it to UTF-8 and retry (e.g., `iconv -f -t UTF-8 prompt.txt`)." + ), + PromptDecodeError::InvalidUtf16 { encoding } => write!( + f, + "input looked like {encoding} but could not be decoded. Convert it to UTF-8 and retry." + ), + PromptDecodeError::UnsupportedBom { encoding } => write!( + f, + "input appears to be {encoding}. Convert it to UTF-8 and retry." + ), + } + } +} + +fn decode_prompt_bytes(input: &[u8]) -> Result { + let input = input.strip_prefix(&[0xEF, 0xBB, 0xBF]).unwrap_or(input); + + if input.starts_with(&[0xFF, 0xFE, 0x00, 0x00]) { + return Err(PromptDecodeError::UnsupportedBom { + encoding: "UTF-32LE", + }); + } + + if input.starts_with(&[0x00, 0x00, 0xFE, 0xFF]) { + return Err(PromptDecodeError::UnsupportedBom { + encoding: "UTF-32BE", + }); + } + + if let Some(rest) = input.strip_prefix(&[0xFF, 0xFE]) { + return decode_utf16(rest, "UTF-16LE", u16::from_le_bytes); + } + + if let Some(rest) = input.strip_prefix(&[0xFE, 0xFF]) { + return decode_utf16(rest, "UTF-16BE", u16::from_be_bytes); + } + + std::str::from_utf8(input) + .map(str::to_string) + .map_err(|e| PromptDecodeError::InvalidUtf8 { + valid_up_to: e.valid_up_to(), + }) +} + +fn decode_utf16( + input: &[u8], + encoding: &'static str, + decode_unit: fn([u8; 2]) -> u16, +) -> Result { + if !input.len().is_multiple_of(2) { + return Err(PromptDecodeError::InvalidUtf16 { encoding }); + } + + let units: Vec = input + .chunks_exact(2) + .map(|chunk| decode_unit([chunk[0], chunk[1]])) + .collect(); + + String::from_utf16(&units).map_err(|_| PromptDecodeError::InvalidUtf16 { encoding }) +} + fn resolve_prompt(prompt_arg: Option) -> String { match prompt_arg { Some(p) if p != "-" => p, @@ -573,11 +646,22 @@ fn resolve_prompt(prompt_arg: Option) -> String { if !force_stdin { eprintln!("Reading prompt from stdin..."); } - let mut buffer = String::new(); - if let Err(e) = std::io::stdin().read_to_string(&mut buffer) { + + let mut bytes = Vec::new(); + if let Err(e) = std::io::stdin().read_to_end(&mut bytes) { eprintln!("Failed to read prompt from stdin: {e}"); std::process::exit(1); - } else if buffer.trim().is_empty() { + } + + let buffer = match decode_prompt_bytes(&bytes) { + Ok(s) => s, + Err(e) => { + eprintln!("Failed to read prompt from stdin: {e}"); + std::process::exit(1); + } + }; + + if buffer.trim().is_empty() { eprintln!("No prompt provided via stdin."); std::process::exit(1); } @@ -682,4 +766,79 @@ mod tests { assert_eq!(request, expected); } + + #[test] + fn decode_prompt_bytes_strips_utf8_bom() { + let input = [0xEF, 0xBB, 0xBF, b'h', b'i', b'\n']; + + let out = decode_prompt_bytes(&input).expect("decode utf-8 with BOM"); + + assert_eq!(out, "hi\n"); + } + + #[test] + fn decode_prompt_bytes_decodes_utf16le_bom() { + // UTF-16LE BOM + "hi\n" + let input = [0xFF, 0xFE, b'h', 0x00, b'i', 0x00, b'\n', 0x00]; + + let out = decode_prompt_bytes(&input).expect("decode utf-16le with BOM"); + + assert_eq!(out, "hi\n"); + } + + #[test] + fn decode_prompt_bytes_decodes_utf16be_bom() { + // UTF-16BE BOM + "hi\n" + let input = [0xFE, 0xFF, 0x00, b'h', 0x00, b'i', 0x00, b'\n']; + + let out = decode_prompt_bytes(&input).expect("decode utf-16be with BOM"); + + assert_eq!(out, "hi\n"); + } + + #[test] + fn decode_prompt_bytes_rejects_utf32le_bom() { + // UTF-32LE BOM + "hi\n" + let input = [ + 0xFF, 0xFE, 0x00, 0x00, b'h', 0x00, 0x00, 0x00, b'i', 0x00, 0x00, 0x00, b'\n', 0x00, + 0x00, 0x00, + ]; + + let err = decode_prompt_bytes(&input).expect_err("utf-32le should be rejected"); + + assert_eq!( + err, + PromptDecodeError::UnsupportedBom { + encoding: "UTF-32LE" + } + ); + } + + #[test] + fn decode_prompt_bytes_rejects_utf32be_bom() { + // UTF-32BE BOM + "hi\n" + let input = [ + 0x00, 0x00, 0xFE, 0xFF, 0x00, 0x00, 0x00, b'h', 0x00, 0x00, 0x00, b'i', 0x00, 0x00, + 0x00, b'\n', + ]; + + let err = decode_prompt_bytes(&input).expect_err("utf-32be should be rejected"); + + assert_eq!( + err, + PromptDecodeError::UnsupportedBom { + encoding: "UTF-32BE" + } + ); + } + + #[test] + fn decode_prompt_bytes_rejects_invalid_utf8() { + // Invalid UTF-8 sequence: 0xC3 0x28 + let input = [0xC3, 0x28]; + + let err = decode_prompt_bytes(&input).expect_err("invalid utf-8 should fail"); + + assert_eq!(err, PromptDecodeError::InvalidUtf8 { valid_up_to: 0 }); + } }