diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index a8efec108..fc2ab5a04 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -177,11 +177,31 @@ jobs: steps: - uses: actions/checkout@v6 + - name: Install UBSan runtime (musl) + if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl' }} + shell: bash + run: | + set -euo pipefail + if command -v apt-get >/dev/null 2>&1; then + sudo apt-get update -y + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y libubsan1 + fi - uses: dtolnay/rust-toolchain@1.92 with: targets: ${{ matrix.target }} components: clippy + - if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}} + name: Use hermetic Cargo home (musl) + shell: bash + run: | + set -euo pipefail + cargo_home="${GITHUB_WORKSPACE}/.cargo-home" + mkdir -p "${cargo_home}/bin" + echo "CARGO_HOME=${cargo_home}" >> "$GITHUB_ENV" + echo "${cargo_home}/bin" >> "$GITHUB_PATH" + : > "${cargo_home}/config.toml" + - name: Compute lockfile hash id: lockhash working-directory: codex-rs @@ -202,6 +222,10 @@ jobs: ~/.cargo/registry/index/ ~/.cargo/registry/cache/ ~/.cargo/git/db/ + ${{ github.workspace }}/.cargo-home/bin/ + ${{ github.workspace }}/.cargo-home/registry/index/ + ${{ github.workspace }}/.cargo-home/registry/cache/ + ${{ github.workspace }}/.cargo-home/git/db/ key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ steps.lockhash.outputs.hash }}-${{ steps.lockhash.outputs.toolchain_hash }} restore-keys: | cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}- @@ -244,6 +268,14 @@ jobs: sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ steps.lockhash.outputs.hash }}- sccache-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}- + - if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}} + name: Disable sccache wrapper (musl) + shell: bash + run: | + set -euo pipefail + echo "RUSTC_WRAPPER=" >> "$GITHUB_ENV" + echo "RUSTC_WORKSPACE_WRAPPER=" >> "$GITHUB_ENV" + - if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}} name: Prepare APT cache directories (musl) shell: bash @@ -277,6 +309,58 @@ jobs: shell: bash run: bash "${GITHUB_WORKSPACE}/.github/scripts/install-musl-build-tools.sh" + - if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}} + name: Configure rustc UBSan wrapper (musl host) + shell: bash + run: | + set -euo pipefail + ubsan="" + if command -v ldconfig >/dev/null 2>&1; then + ubsan="$(ldconfig -p | grep -m1 'libubsan\.so\.1' | sed -E 's/.*=> (.*)$/\1/')" + fi + wrapper_root="${RUNNER_TEMP:-/tmp}" + wrapper="${wrapper_root}/rustc-ubsan-wrapper" + cat > "${wrapper}" <> "$GITHUB_ENV" + echo "RUSTC_WORKSPACE_WRAPPER=" >> "$GITHUB_ENV" + + - if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}} + name: Clear sanitizer flags (musl) + shell: bash + run: | + set -euo pipefail + # Clear global Rust flags so host/proc-macro builds don't pull in UBSan. + echo "RUSTFLAGS=" >> "$GITHUB_ENV" + echo "CARGO_ENCODED_RUSTFLAGS=" >> "$GITHUB_ENV" + echo "RUSTDOCFLAGS=" >> "$GITHUB_ENV" + # Override any runner-level Cargo config rustflags as well. + echo "CARGO_BUILD_RUSTFLAGS=" >> "$GITHUB_ENV" + echo "CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUSTFLAGS=" >> "$GITHUB_ENV" + echo "CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_RUSTFLAGS=" >> "$GITHUB_ENV" + echo "CARGO_TARGET_X86_64_UNKNOWN_LINUX_MUSL_RUSTFLAGS=" >> "$GITHUB_ENV" + echo "CARGO_TARGET_AARCH64_UNKNOWN_LINUX_MUSL_RUSTFLAGS=" >> "$GITHUB_ENV" + + sanitize_flags() { + local input="$1" + input="${input//-fsanitize=undefined/}" + input="${input//-fno-sanitize-recover=undefined/}" + input="${input//-fno-sanitize-trap=undefined/}" + echo "$input" + } + + cflags="$(sanitize_flags "${CFLAGS-}")" + cxxflags="$(sanitize_flags "${CXXFLAGS-}")" + echo "CFLAGS=${cflags}" >> "$GITHUB_ENV" + echo "CXXFLAGS=${cxxflags}" >> "$GITHUB_ENV" + - name: Install cargo-chef if: ${{ matrix.profile == 'release' }} uses: taiki-e/install-action@44c6d64aa62cd779e873306675c7a58e86d6d532 # v2 @@ -322,6 +406,10 @@ jobs: ~/.cargo/registry/index/ ~/.cargo/registry/cache/ ~/.cargo/git/db/ + ${{ github.workspace }}/.cargo-home/bin/ + ${{ github.workspace }}/.cargo-home/registry/index/ + ${{ github.workspace }}/.cargo-home/registry/cache/ + ${{ github.workspace }}/.cargo-home/git/db/ key: cargo-home-${{ matrix.runner }}-${{ matrix.target }}-${{ matrix.profile }}-${{ steps.lockhash.outputs.hash }}-${{ steps.lockhash.outputs.toolchain_hash }} - name: Save sccache cache (fallback) diff --git a/.github/workflows/rust-release.yml b/.github/workflows/rust-release.yml index e2f952ed0..7b8a4fe8a 100644 --- a/.github/workflows/rust-release.yml +++ b/.github/workflows/rust-release.yml @@ -21,7 +21,6 @@ jobs: steps: - uses: actions/checkout@v6 - uses: dtolnay/rust-toolchain@1.92 - - name: Validate tag matches Cargo.toml version shell: bash run: | @@ -90,10 +89,30 @@ jobs: steps: - uses: actions/checkout@v6 + - name: Install UBSan runtime (musl) + if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl' }} + shell: bash + run: | + set -euo pipefail + if command -v apt-get >/dev/null 2>&1; then + sudo apt-get update -y + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y libubsan1 + fi - uses: dtolnay/rust-toolchain@1.92 with: targets: ${{ matrix.target }} + - if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}} + name: Use hermetic Cargo home (musl) + shell: bash + run: | + set -euo pipefail + cargo_home="${GITHUB_WORKSPACE}/.cargo-home" + mkdir -p "${cargo_home}/bin" + echo "CARGO_HOME=${cargo_home}" >> "$GITHUB_ENV" + echo "${cargo_home}/bin" >> "$GITHUB_PATH" + : > "${cargo_home}/config.toml" + - uses: actions/cache@v5 with: path: | @@ -101,6 +120,10 @@ jobs: ~/.cargo/registry/index/ ~/.cargo/registry/cache/ ~/.cargo/git/db/ + ${{ github.workspace }}/.cargo-home/bin/ + ${{ github.workspace }}/.cargo-home/registry/index/ + ${{ github.workspace }}/.cargo-home/registry/cache/ + ${{ github.workspace }}/.cargo-home/git/db/ ${{ github.workspace }}/codex-rs/target/ key: cargo-${{ matrix.runner }}-${{ matrix.target }}-release-${{ hashFiles('**/Cargo.lock') }} @@ -116,6 +139,58 @@ jobs: TARGET: ${{ matrix.target }} run: bash "${GITHUB_WORKSPACE}/.github/scripts/install-musl-build-tools.sh" + - if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}} + name: Configure rustc UBSan wrapper (musl host) + shell: bash + run: | + set -euo pipefail + ubsan="" + if command -v ldconfig >/dev/null 2>&1; then + ubsan="$(ldconfig -p | grep -m1 'libubsan\.so\.1' | sed -E 's/.*=> (.*)$/\1/')" + fi + wrapper_root="${RUNNER_TEMP:-/tmp}" + wrapper="${wrapper_root}/rustc-ubsan-wrapper" + cat > "${wrapper}" <> "$GITHUB_ENV" + echo "RUSTC_WORKSPACE_WRAPPER=" >> "$GITHUB_ENV" + + - if: ${{ matrix.target == 'x86_64-unknown-linux-musl' || matrix.target == 'aarch64-unknown-linux-musl'}} + name: Clear sanitizer flags (musl) + shell: bash + run: | + set -euo pipefail + # Clear global Rust flags so host/proc-macro builds don't pull in UBSan. + echo "RUSTFLAGS=" >> "$GITHUB_ENV" + echo "CARGO_ENCODED_RUSTFLAGS=" >> "$GITHUB_ENV" + echo "RUSTDOCFLAGS=" >> "$GITHUB_ENV" + # Override any runner-level Cargo config rustflags as well. + echo "CARGO_BUILD_RUSTFLAGS=" >> "$GITHUB_ENV" + echo "CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUSTFLAGS=" >> "$GITHUB_ENV" + echo "CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_RUSTFLAGS=" >> "$GITHUB_ENV" + echo "CARGO_TARGET_X86_64_UNKNOWN_LINUX_MUSL_RUSTFLAGS=" >> "$GITHUB_ENV" + echo "CARGO_TARGET_AARCH64_UNKNOWN_LINUX_MUSL_RUSTFLAGS=" >> "$GITHUB_ENV" + + sanitize_flags() { + local input="$1" + input="${input//-fsanitize=undefined/}" + input="${input//-fno-sanitize-recover=undefined/}" + input="${input//-fno-sanitize-trap=undefined/}" + echo "$input" + } + + cflags="$(sanitize_flags "${CFLAGS-}")" + cxxflags="$(sanitize_flags "${CXXFLAGS-}")" + echo "CFLAGS=${cflags}" >> "$GITHUB_ENV" + echo "CXXFLAGS=${cxxflags}" >> "$GITHUB_ENV" + - name: Cargo build shell: bash run: | diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index e4fe78a3d..38a86f154 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -361,7 +361,7 @@ dependencies = [ "objc2-foundation", "parking_lot", "percent-encoding", - "windows-sys 0.52.0", + "windows-sys 0.60.2", "wl-clipboard-rs", "x11rb", ] @@ -602,6 +602,15 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -739,6 +748,9 @@ name = "bitflags" version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +dependencies = [ + "serde_core", +] [[package]] name = "block-buffer" @@ -1364,6 +1376,7 @@ dependencies = [ "codex-otel", "codex-protocol", "codex-rmcp-client", + "codex-state", "codex-utils-absolute-path", "codex-utils-cargo-bin", "codex-utils-pty", @@ -1828,6 +1841,23 @@ dependencies = [ "which", ] +[[package]] +name = "codex-state" +version = "0.0.0" +dependencies = [ + "anyhow", + "chrono", + "codex-otel", + "codex-protocol", + "pretty_assertions", + "serde", + "serde_json", + "sqlx", + "tokio", + "tracing", + "uuid", +] + [[package]] name = "codex-stdio-to-uds" version = "0.0.0" @@ -2111,6 +2141,12 @@ dependencies = [ "serde_core", ] +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + [[package]] name = "const_format" version = "0.2.35" @@ -2209,6 +2245,21 @@ dependencies = [ "libc", ] +[[package]] +name = "crc" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5eb8a2a1cd12ab0d987a5d5e825195d372001a4094a0376319d5a0ad71c1ba0d" +dependencies = [ + "crc-catalog", +] + +[[package]] +name = "crc-catalog" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" + [[package]] name = "crc32fast" version = "1.5.0" @@ -2252,6 +2303,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -2530,6 +2590,7 @@ version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" dependencies = [ + "const-oid", "pem-rfc7468", "zeroize", ] @@ -2628,6 +2689,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", + "const-oid", "crypto-common", "subtle", ] @@ -2775,6 +2837,9 @@ name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +dependencies = [ + "serde", +] [[package]] name = "ena" @@ -2908,7 +2973,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.60.2", ] [[package]] @@ -2917,6 +2982,17 @@ version = "3.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59" +[[package]] +name = "etcetera" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "136d1b5283a1ab77bd9257427ffd09d8667ced0570b6f938942bc7568ed5b943" +dependencies = [ + "cfg-if", + "home", + "windows-sys 0.48.0", +] + [[package]] name = "event-listener" version = "5.4.0" @@ -3008,7 +3084,7 @@ checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" dependencies = [ "cfg-if", "rustix 1.0.8", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3085,6 +3161,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "flume" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" +dependencies = [ + "futures-core", + "futures-sink", + "spin", +] + [[package]] name = "flume" version = "0.12.0" @@ -3233,6 +3320,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + [[package]] name = "futures-io" version = "0.3.31" @@ -3313,7 +3411,7 @@ dependencies = [ "libc", "log", "rustversion", - "windows-link 0.1.3", + "windows-link 0.2.0", "windows-result 0.3.4", ] @@ -3465,6 +3563,15 @@ dependencies = [ "foldhash 0.2.0", ] +[[package]] +name = "hashlink" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" +dependencies = [ + "hashbrown 0.15.4", +] + [[package]] name = "heck" version = "0.5.0" @@ -3668,7 +3775,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", + "webpki-roots 1.0.2", ] [[package]] @@ -4095,7 +4202,7 @@ checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4300,6 +4407,9 @@ name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +dependencies = [ + "spin", +] [[package]] name = "libc" @@ -4326,6 +4436,12 @@ dependencies = [ "windows-link 0.2.0", ] +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + [[package]] name = "libredox" version = "0.1.6" @@ -4334,6 +4450,18 @@ checksum = "4488594b9328dee448adb906d8b126d9b7deb7cf5c22161ee591610bb1be83c0" dependencies = [ "bitflags 2.10.0", "libc", + "redox_syscall", +] + +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", ] [[package]] @@ -4518,6 +4646,16 @@ dependencies = [ "wiremock", ] +[[package]] +name = "md-5" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +dependencies = [ + "cfg-if", + "digest", +] + [[package]] name = "md5" version = "0.8.0" @@ -4795,6 +4933,22 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-bigint-dig" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e661dda6640fad38e827a6d4a310ff4763082116fe217f279885c97f511bb0b7" +dependencies = [ + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand 0.8.5", + "smallvec", + "zeroize", +] + [[package]] name = "num-complex" version = "0.4.6" @@ -4848,6 +5002,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -5330,6 +5485,27 @@ dependencies = [ "futures-io", ] +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" version = "0.3.32" @@ -5673,7 +5849,7 @@ dependencies = [ "once_cell", "socket2 0.6.1", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.60.2", ] [[package]] @@ -5943,7 +6119,7 @@ checksum = "b28ee9e1e5d39264414b71f5c33e7fbb66b382c3fac456fe0daad39cf5509933" dependencies = [ "ahash", "const_format", - "flume", + "flume 0.12.0", "hex", "ipnet", "itertools 0.14.0", @@ -6001,7 +6177,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "def3d5d06d3ca3a2d2e4376cf93de0555cd9c7960f085bf77be9562f5c9ace8f" dependencies = [ "ahash", - "flume", + "flume 0.12.0", "itertools 0.14.0", "moka", "parking_lot", @@ -6293,7 +6469,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots", + "webpki-roots 1.0.2", ] [[package]] @@ -6364,6 +6540,26 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "rsa" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8573f03f5883dcaebdfcf4725caa1ecb9c15b2ef50c43a07b816e06799bb12d" +dependencies = [ + "const-oid", + "digest", + "num-bigint-dig", + "num-integer", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core 0.6.4", + "signature", + "spki", + "subtle", + "zeroize", +] + [[package]] name = "rustc-demangle" version = "0.1.25" @@ -6395,7 +6591,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.4.15", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -6408,7 +6604,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.9.4", - "windows-sys 0.52.0", + "windows-sys 0.60.2", ] [[package]] @@ -7113,6 +7309,16 @@ dependencies = [ "libc", ] +[[package]] +name = "signature" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" +dependencies = [ + "digest", + "rand_core 0.6.4", +] + [[package]] name = "simd-adler32" version = "0.3.7" @@ -7197,6 +7403,218 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + +[[package]] +name = "sqlx" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fefb893899429669dcdd979aff487bd78f4064e5e7907e4269081e0ef7d97dc" +dependencies = [ + "sqlx-core", + "sqlx-macros", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", +] + +[[package]] +name = "sqlx-core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee6798b1838b6a0f69c007c133b8df5866302197e404e8b6ee8ed3e3a5e68dc6" +dependencies = [ + "base64", + "bytes", + "chrono", + "crc", + "crossbeam-queue", + "either", + "event-listener", + "futures-core", + "futures-intrusive", + "futures-io", + "futures-util", + "hashbrown 0.15.4", + "hashlink", + "indexmap 2.12.0", + "log", + "memchr", + "once_cell", + "percent-encoding", + "rustls", + "serde", + "serde_json", + "sha2", + "smallvec", + "thiserror 2.0.17", + "time", + "tokio", + "tokio-stream", + "tracing", + "url", + "uuid", + "webpki-roots 0.26.11", +] + +[[package]] +name = "sqlx-macros" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2d452988ccaacfbf5e0bdbc348fb91d7c8af5bee192173ac3636b5fb6e6715d" +dependencies = [ + "proc-macro2", + "quote", + "sqlx-core", + "sqlx-macros-core", + "syn 2.0.104", +] + +[[package]] +name = "sqlx-macros-core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19a9c1841124ac5a61741f96e1d9e2ec77424bf323962dd894bdb93f37d5219b" +dependencies = [ + "dotenvy", + "either", + "heck", + "hex", + "once_cell", + "proc-macro2", + "quote", + "serde", + "serde_json", + "sha2", + "sqlx-core", + "sqlx-mysql", + "sqlx-postgres", + "sqlx-sqlite", + "syn 2.0.104", + "tokio", + "url", +] + +[[package]] +name = "sqlx-mysql" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa003f0038df784eb8fecbbac13affe3da23b45194bd57dba231c8f48199c526" +dependencies = [ + "atoi", + "base64", + "bitflags 2.10.0", + "byteorder", + "bytes", + "chrono", + "crc", + "digest", + "dotenvy", + "either", + "futures-channel", + "futures-core", + "futures-io", + "futures-util", + "generic-array", + "hex", + "hkdf", + "hmac", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "percent-encoding", + "rand 0.8.5", + "rsa", + "serde", + "sha1", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror 2.0.17", + "time", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-postgres" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db58fcd5a53cf07c184b154801ff91347e4c30d17a3562a635ff028ad5deda46" +dependencies = [ + "atoi", + "base64", + "bitflags 2.10.0", + "byteorder", + "chrono", + "crc", + "dotenvy", + "etcetera", + "futures-channel", + "futures-core", + "futures-util", + "hex", + "hkdf", + "hmac", + "home", + "itoa", + "log", + "md-5", + "memchr", + "once_cell", + "rand 0.8.5", + "serde", + "serde_json", + "sha2", + "smallvec", + "sqlx-core", + "stringprep", + "thiserror 2.0.17", + "time", + "tracing", + "uuid", + "whoami", +] + +[[package]] +name = "sqlx-sqlite" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea" +dependencies = [ + "atoi", + "chrono", + "flume 0.11.1", + "futures-channel", + "futures-core", + "futures-executor", + "futures-intrusive", + "futures-util", + "libsqlite3-sys", + "log", + "percent-encoding", + "serde", + "serde_urlencoded", + "sqlx-core", + "thiserror 2.0.17", + "time", + "tracing", + "url", + "uuid", +] + [[package]] name = "sse-stream" version = "0.2.1" @@ -7330,6 +7748,17 @@ dependencies = [ "precomputed-hash", ] +[[package]] +name = "stringprep" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", + "unicode-properties", +] + [[package]] name = "strsim" version = "0.10.0" @@ -7495,7 +7924,7 @@ dependencies = [ "getrandom 0.3.3", "once_cell", "rustix 1.0.8", - "windows-sys 0.52.0", + "windows-sys 0.61.1", ] [[package]] @@ -8290,6 +8719,12 @@ version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" +[[package]] +name = "unicode-bidi" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" + [[package]] name = "unicode-ident" version = "1.0.18" @@ -8302,6 +8737,21 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" +[[package]] +name = "unicode-normalization" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "unicode-properties" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d" + [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -8509,6 +8959,12 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "wasite" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -8708,6 +9164,15 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.2", +] + [[package]] name = "webpki-roots" version = "1.0.2" @@ -8734,6 +9199,16 @@ dependencies = [ "winsafe", ] +[[package]] +name = "whoami" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d" +dependencies = [ + "libredox", + "wasite", +] + [[package]] name = "widestring" version = "1.2.1" @@ -8777,7 +9252,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index a8ebb9be8..ae4eb768b 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -47,6 +47,7 @@ members = [ "utils/string", "codex-client", "codex-api", + "state", ] resolver = "2" @@ -91,6 +92,7 @@ codex-process-hardening = { path = "process-hardening" } codex-protocol = { path = "protocol" } codex-responses-api-proxy = { path = "responses-api-proxy" } codex-rmcp-client = { path = "rmcp-client" } +codex-state = { path = "state" } codex-stdio-to-uds = { path = "stdio-to-uds" } codex-tui = { path = "tui" } codex-utils-absolute-path = { path = "utils/absolute-path" } @@ -198,6 +200,7 @@ semver = "1.0" shlex = "1.3.0" similar = "2.7.0" socket2 = "0.6.1" +sqlx = { version = "0.8.6", default-features = false, features = ["chrono", "json", "macros", "migrate", "runtime-tokio-rustls", "sqlite", "time", "uuid"] } starlark = "0.13.0" strum = "0.27.2" strum_macros = "0.27.2" diff --git a/codex-rs/app-server/src/codex_message_processor.rs b/codex-rs/app-server/src/codex_message_processor.rs index e7acf3d29..c70f705a9 100644 --- a/codex-rs/app-server/src/codex_message_processor.rs +++ b/codex-rs/app-server/src/codex_message_processor.rs @@ -169,6 +169,7 @@ use codex_core::read_head_for_summary; use codex_core::read_session_meta_line; use codex_core::rollout_date_parts; use codex_core::sandboxing::SandboxPermissions; +use codex_core::state_db::{self}; use codex_core::windows_sandbox::WindowsSandboxLevelExt; use codex_feedback::CodexFeedback; use codex_login::ServerOptions as LoginServerOptions; @@ -1609,6 +1610,7 @@ impl CodexMessageProcessor { } async fn thread_archive(&mut self, request_id: RequestId, params: ThreadArchiveParams) { + // TODO(jif) mostly rewrite this using sqlite after phase 1 let thread_id = match ThreadId::from_string(¶ms.thread_id) { Ok(id) => id, Err(err) => { @@ -1658,6 +1660,7 @@ impl CodexMessageProcessor { } async fn thread_unarchive(&mut self, request_id: RequestId, params: ThreadUnarchiveParams) { + // TODO(jif) mostly rewrite this using sqlite after phase 1 let thread_id = match ThreadId::from_string(¶ms.thread_id) { Ok(id) => id, Err(err) => { @@ -1700,6 +1703,7 @@ impl CodexMessageProcessor { let rollout_path_display = archived_path.display().to_string(); let fallback_provider = self.config.model_provider_id.clone(); + let state_db_ctx = state_db::init_if_enabled(&self.config, None).await; let archived_folder = self .config .codex_home @@ -1778,6 +1782,11 @@ impl CodexMessageProcessor { message: format!("failed to unarchive thread: {err}"), data: None, })?; + if let Some(ctx) = state_db_ctx { + let _ = ctx + .mark_unarchived(thread_id, restored_path.as_path()) + .await; + } let summary = read_summary_from_rollout(restored_path.as_path(), fallback_provider.as_str()) .await @@ -2507,7 +2516,6 @@ impl CodexMessageProcessor { }; let fallback_provider = self.config.model_provider_id.as_str(); - match read_summary_from_rollout(&path, fallback_provider).await { Ok(summary) => { let response = GetConversationSummaryResponse { summary }; @@ -3530,8 +3538,13 @@ impl CodexMessageProcessor { }); } + let mut state_db_ctx = None; + // If the thread is active, request shutdown and wait briefly. if let Some(conversation) = self.thread_manager.remove_thread(&thread_id).await { + if let Some(ctx) = conversation.state_db() { + state_db_ctx = Some(ctx); + } info!("thread {thread_id} was active; shutting down"); // Request shutdown. match conversation.submit(Op::Shutdown).await { @@ -3558,14 +3571,24 @@ impl CodexMessageProcessor { } } + if state_db_ctx.is_none() { + state_db_ctx = state_db::init_if_enabled(&self.config, None).await; + } + // Move the rollout file to archived. - let result: std::io::Result<()> = async { + let result: std::io::Result<()> = async move { let archive_folder = self .config .codex_home .join(codex_core::ARCHIVED_SESSIONS_SUBDIR); tokio::fs::create_dir_all(&archive_folder).await?; - tokio::fs::rename(&canonical_rollout_path, &archive_folder.join(&file_name)).await?; + let archived_path = archive_folder.join(&file_name); + tokio::fs::rename(&canonical_rollout_path, &archived_path).await?; + if let Some(ctx) = state_db_ctx { + let _ = ctx + .mark_archived(thread_id, archived_path.as_path(), Utc::now()) + .await; + } Ok(()) } .await; diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index c186758cb..b336b296c 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -37,6 +37,7 @@ codex-keyring-store = { workspace = true } codex-otel = { workspace = true } codex-protocol = { workspace = true } codex-rmcp-client = { workspace = true } +codex-state = { workspace = true } codex-utils-absolute-path = { workspace = true } codex-utils-pty = { workspace = true } codex-utils-readiness = { workspace = true } diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index c11256c70..a756352da 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -204,6 +204,9 @@ "skill_mcp_dependency_install": { "type": "boolean" }, + "sqlite": { + "type": "boolean" + }, "steer": { "type": "boolean" }, @@ -1202,6 +1205,9 @@ "skill_mcp_dependency_install": { "type": "boolean" }, + "sqlite": { + "type": "boolean" + }, "steer": { "type": "boolean" }, diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index e9fbecfa9..e21eff974 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -154,6 +154,7 @@ use crate::protocol::WarningEvent; use crate::rollout::RolloutRecorder; use crate::rollout::RolloutRecorderParams; use crate::rollout::map_session_init_error; +use crate::rollout::metadata; use crate::shell; use crate::shell_snapshot::ShellSnapshot; use crate::skills::SkillError; @@ -165,6 +166,7 @@ use crate::skills::collect_explicit_skill_mentions; use crate::state::ActiveTurn; use crate::state::SessionServices; use crate::state::SessionState; +use crate::state_db; use crate::tasks::GhostSnapshotTask; use crate::tasks::ReviewTask; use crate::tasks::SessionTask; @@ -420,6 +422,10 @@ impl Codex { let state = self.session.state.lock().await; state.session_configuration.thread_config_snapshot() } + + pub(crate) fn state_db(&self) -> Option { + self.session.state_db() + } } /// Context for an initialized model agent @@ -699,6 +705,13 @@ impl Session { RolloutRecorderParams::resume(resumed_history.rollout_path.clone()), ), }; + let state_builder = match &initial_history { + InitialHistory::Resumed(resumed) => metadata::builder_from_items( + resumed.history.as_slice(), + resumed.rollout_path.as_path(), + ), + InitialHistory::New | InitialHistory::Forked(_) => None, + }; // Kick off independent async setup tasks in parallel to reduce startup latency. // @@ -707,11 +720,17 @@ impl Session { // - load history metadata let rollout_fut = async { if config.ephemeral { - Ok(None) + Ok::<_, anyhow::Error>((None, None)) } else { - RolloutRecorder::new(&config, rollout_params) - .await - .map(Some) + let state_db_ctx = state_db::init_if_enabled(&config, None).await; + let rollout_recorder = RolloutRecorder::new( + &config, + rollout_params, + state_db_ctx.clone(), + state_builder.clone(), + ) + .await?; + Ok((Some(rollout_recorder), state_db_ctx)) } }; @@ -731,14 +750,14 @@ impl Session { // Join all independent futures. let ( - rollout_recorder, + rollout_recorder_and_state_db, (history_log_id, history_entry_count), (auth, mcp_servers, auth_statuses), ) = tokio::join!(rollout_fut, history_meta_fut, auth_and_mcp_fut); - let rollout_recorder = rollout_recorder.map_err(|e| { + let (rollout_recorder, state_db_ctx) = rollout_recorder_and_state_db.map_err(|e| { error!("failed to initialize rollout recorder: {e:#}"); - anyhow::Error::from(e) + e })?; let rollout_path = rollout_recorder .as_ref() @@ -842,6 +861,7 @@ impl Session { tool_approvals: Mutex::new(ApprovalStore::default()), skills_manager, agent_control, + state_db: state_db_ctx.clone(), }; let sess = Arc::new(Session { @@ -914,6 +934,10 @@ impl Session { self.tx_event.clone() } + pub(crate) fn state_db(&self) -> Option { + self.services.state_db.clone() + } + /// Ensure all rollout writes are durably flushed. pub(crate) async fn flush_rollout(&self) { let recorder = { @@ -4580,6 +4604,7 @@ mod tests { tool_approvals: Mutex::new(ApprovalStore::default()), skills_manager, agent_control, + state_db: None, }; let turn_context = Session::make_turn_context( @@ -4691,6 +4716,7 @@ mod tests { tool_approvals: Mutex::new(ApprovalStore::default()), skills_manager, agent_control, + state_db: None, }; let turn_context = Arc::new(Session::make_turn_context( diff --git a/codex-rs/core/src/codex_thread.rs b/codex-rs/core/src/codex_thread.rs index 152679d14..fb8e466d7 100644 --- a/codex-rs/core/src/codex_thread.rs +++ b/codex-rs/core/src/codex_thread.rs @@ -12,6 +12,8 @@ use codex_protocol::protocol::SessionSource; use std::path::PathBuf; use tokio::sync::watch; +use crate::state_db::StateDbHandle; + #[derive(Clone, Debug)] pub struct ThreadConfigSnapshot { pub model: String, @@ -64,6 +66,10 @@ impl CodexThread { self.rollout_path.clone() } + pub fn state_db(&self) -> Option { + self.codex.state_db() + } + pub async fn config_snapshot(&self) -> ThreadConfigSnapshot { self.codex.thread_config_snapshot().await } diff --git a/codex-rs/core/src/features.rs b/codex-rs/core/src/features.rs index bd739e5a7..22371babd 100644 --- a/codex-rs/core/src/features.rs +++ b/codex-rs/core/src/features.rs @@ -101,6 +101,8 @@ pub enum Feature { RemoteModels, /// Experimental shell snapshotting. ShellSnapshot, + /// Persist rollout metadata to a local SQLite database. + Sqlite, /// Append additional AGENTS.md guidance to user instructions. ChildAgentsMd, /// Enforce UTF8 output in Powershell. @@ -377,6 +379,12 @@ pub const FEATURES: &[FeatureSpec] = &[ }, default_enabled: false, }, + FeatureSpec { + id: Feature::Sqlite, + key: "sqlite", + stage: Stage::UnderDevelopment, + default_enabled: false, + }, FeatureSpec { id: Feature::ChildAgentsMd, key: "child_agents_md", diff --git a/codex-rs/core/src/lib.rs b/codex-rs/core/src/lib.rs index 27fa255c8..eaf25d14e 100644 --- a/codex-rs/core/src/lib.rs +++ b/codex-rs/core/src/lib.rs @@ -91,6 +91,7 @@ pub mod shell; pub mod shell_snapshot; pub mod skills; pub mod spawn; +pub mod state_db; pub mod terminal; mod tools; pub mod turn_diff_tracker; diff --git a/codex-rs/core/src/rollout/list.rs b/codex-rs/core/src/rollout/list.rs index 9609f366b..7469bbeb0 100644 --- a/codex-rs/core/src/rollout/list.rs +++ b/codex-rs/core/src/rollout/list.rs @@ -19,7 +19,9 @@ use uuid::Uuid; use super::ARCHIVED_SESSIONS_SUBDIR; use super::SESSIONS_SUBDIR; use crate::protocol::EventMsg; +use crate::state_db; use codex_file_search as file_search; +use codex_protocol::ThreadId; use codex_protocol::protocol::RolloutItem; use codex_protocol::protocol::RolloutLine; use codex_protocol::protocol::SessionMetaLine; @@ -794,7 +796,7 @@ async fn collect_rollout_day_files( Ok(day_files) } -fn parse_timestamp_uuid_from_filename(name: &str) -> Option<(OffsetDateTime, Uuid)> { +pub(crate) fn parse_timestamp_uuid_from_filename(name: &str) -> Option<(OffsetDateTime, Uuid)> { // Expected: rollout-YYYY-MM-DDThh-mm-ss-.jsonl let core = name.strip_prefix("rollout-")?.strip_suffix(".jsonl")?; @@ -1093,11 +1095,39 @@ async fn find_thread_path_by_id_str_in_subdir( ) .map_err(|e| io::Error::other(format!("file search failed: {e}")))?; - Ok(results + let found = results .matches .into_iter() .next() - .map(|m| root.join(m.path))) + .map(|m| root.join(m.path)); + + // Checking if DB is at parity. + // TODO(jif): sqlite migration phase 1 + let archived_only = match subdir { + SESSIONS_SUBDIR => Some(false), + ARCHIVED_SESSIONS_SUBDIR => Some(true), + _ => None, + }; + let state_db_ctx = state_db::open_if_present(codex_home, "").await; + if let Some(state_db_ctx) = state_db_ctx.as_deref() + && let Ok(thread_id) = ThreadId::from_string(id_str) + { + let db_path = state_db::find_rollout_path_by_id( + Some(state_db_ctx), + thread_id, + archived_only, + "find_path_query", + ) + .await; + let canonical_path = found.as_deref(); + if db_path.as_deref() != canonical_path { + tracing::warn!( + "state db path mismatch for thread {thread_id:?}: canonical={canonical_path:?} db={db_path:?}" + ); + state_db::record_discrepancy("find_thread_path_by_id_str_in_subdir", "path_mismatch"); + } + } + Ok(found) } /// Locate a recorded thread rollout file by its UUID string using the existing diff --git a/codex-rs/core/src/rollout/metadata.rs b/codex-rs/core/src/rollout/metadata.rs new file mode 100644 index 000000000..0d6fedc27 --- /dev/null +++ b/codex-rs/core/src/rollout/metadata.rs @@ -0,0 +1,333 @@ +use crate::config::Config; +use crate::rollout; +use crate::rollout::list::parse_timestamp_uuid_from_filename; +use crate::rollout::recorder::RolloutRecorder; +use chrono::DateTime; +use chrono::NaiveDateTime; +use chrono::Timelike; +use chrono::Utc; +use codex_otel::OtelManager; +use codex_protocol::ThreadId; +use codex_protocol::protocol::AskForApproval; +use codex_protocol::protocol::RolloutItem; +use codex_protocol::protocol::SandboxPolicy; +use codex_protocol::protocol::SessionMetaLine; +use codex_protocol::protocol::SessionSource; +use codex_state::BackfillStats; +use codex_state::DB_ERROR_METRIC; +use codex_state::ExtractionOutcome; +use codex_state::ThreadMetadataBuilder; +use codex_state::apply_rollout_item; +use std::path::Path; +use std::path::PathBuf; +use tracing::warn; + +const ROLLOUT_PREFIX: &str = "rollout-"; +const ROLLOUT_SUFFIX: &str = ".jsonl"; + +pub(crate) fn builder_from_session_meta( + session_meta: &SessionMetaLine, + rollout_path: &Path, +) -> Option { + let created_at = parse_timestamp_to_utc(session_meta.meta.timestamp.as_str())?; + let mut builder = ThreadMetadataBuilder::new( + session_meta.meta.id, + rollout_path.to_path_buf(), + created_at, + session_meta.meta.source.clone(), + ); + builder.model_provider = session_meta.meta.model_provider.clone(); + builder.cwd = session_meta.meta.cwd.clone(); + builder.sandbox_policy = SandboxPolicy::ReadOnly; + builder.approval_mode = AskForApproval::OnRequest; + if let Some(git) = session_meta.git.as_ref() { + builder.git_sha = git.commit_hash.clone(); + builder.git_branch = git.branch.clone(); + builder.git_origin_url = git.repository_url.clone(); + } + Some(builder) +} + +pub(crate) fn builder_from_items( + items: &[RolloutItem], + rollout_path: &Path, +) -> Option { + if let Some(session_meta) = items.iter().find_map(|item| match item { + RolloutItem::SessionMeta(meta_line) => Some(meta_line), + RolloutItem::ResponseItem(_) + | RolloutItem::Compacted(_) + | RolloutItem::TurnContext(_) + | RolloutItem::EventMsg(_) => None, + }) && let Some(builder) = builder_from_session_meta(session_meta, rollout_path) + { + return Some(builder); + } + + let file_name = rollout_path.file_name()?.to_str()?; + if !file_name.starts_with(ROLLOUT_PREFIX) || !file_name.ends_with(ROLLOUT_SUFFIX) { + return None; + } + let (created_ts, uuid) = parse_timestamp_uuid_from_filename(file_name)?; + let created_at = + DateTime::::from_timestamp(created_ts.unix_timestamp(), 0)?.with_nanosecond(0)?; + let id = ThreadId::from_string(&uuid.to_string()).ok()?; + Some(ThreadMetadataBuilder::new( + id, + rollout_path.to_path_buf(), + created_at, + SessionSource::default(), + )) +} + +pub(crate) async fn extract_metadata_from_rollout( + rollout_path: &Path, + default_provider: &str, + otel: Option<&OtelManager>, +) -> anyhow::Result { + let (items, _thread_id, parse_errors) = + RolloutRecorder::load_rollout_items(rollout_path).await?; + if items.is_empty() { + return Err(anyhow::anyhow!( + "empty session file: {}", + rollout_path.display() + )); + } + let builder = builder_from_items(items.as_slice(), rollout_path).ok_or_else(|| { + anyhow::anyhow!( + "rollout missing metadata builder: {}", + rollout_path.display() + ) + })?; + let mut metadata = builder.build(default_provider); + for item in &items { + apply_rollout_item(&mut metadata, item, default_provider); + } + if let Some(updated_at) = file_modified_time_utc(rollout_path).await { + metadata.updated_at = updated_at; + } + if parse_errors > 0 + && let Some(otel) = otel + { + otel.counter( + DB_ERROR_METRIC, + parse_errors as i64, + &[("stage", "extract_metadata_from_rollout")], + ); + } + Ok(ExtractionOutcome { + metadata, + parse_errors, + }) +} + +pub(crate) async fn backfill_sessions( + runtime: &codex_state::StateRuntime, + config: &Config, + otel: Option<&OtelManager>, +) -> BackfillStats { + let sessions_root = config.codex_home.join(rollout::SESSIONS_SUBDIR); + let archived_root = config.codex_home.join(rollout::ARCHIVED_SESSIONS_SUBDIR); + let mut rollout_paths: Vec<(PathBuf, bool)> = Vec::new(); + for (root, archived) in [(sessions_root, false), (archived_root, true)] { + if !tokio::fs::try_exists(&root).await.unwrap_or(false) { + continue; + } + match collect_rollout_paths(&root).await { + Ok(paths) => { + rollout_paths.extend(paths.into_iter().map(|path| (path, archived))); + } + Err(err) => { + warn!( + "failed to collect rollout paths under {}: {err}", + root.display() + ); + } + } + } + let mut stats = BackfillStats { + scanned: 0, + upserted: 0, + failed: 0, + }; + for (path, archived) in rollout_paths { + stats.scanned = stats.scanned.saturating_add(1); + match extract_metadata_from_rollout(&path, config.model_provider_id.as_str(), otel).await { + Ok(outcome) => { + if outcome.parse_errors > 0 + && let Some(otel) = otel + { + otel.counter( + DB_ERROR_METRIC, + outcome.parse_errors as i64, + &[("stage", "backfill_sessions")], + ); + } + let mut metadata = outcome.metadata; + if archived && metadata.archived_at.is_none() { + let fallback_archived_at = metadata.updated_at; + metadata.archived_at = file_modified_time_utc(&path) + .await + .or(Some(fallback_archived_at)); + } + if let Err(err) = runtime.upsert_thread(&metadata).await { + stats.failed = stats.failed.saturating_add(1); + warn!("failed to upsert rollout {}: {err}", path.display()); + } else { + stats.upserted = stats.upserted.saturating_add(1); + } + } + Err(err) => { + stats.failed = stats.failed.saturating_add(1); + warn!("failed to extract rollout {}: {err}", path.display()); + } + } + } + stats +} + +async fn file_modified_time_utc(path: &Path) -> Option> { + let modified = tokio::fs::metadata(path).await.ok()?.modified().ok()?; + let updated_at: DateTime = modified.into(); + updated_at.with_nanosecond(0) +} + +fn parse_timestamp_to_utc(ts: &str) -> Option> { + const FILENAME_TS_FORMAT: &str = "%Y-%m-%dT%H-%M-%S"; + if let Ok(naive) = NaiveDateTime::parse_from_str(ts, FILENAME_TS_FORMAT) { + let dt = DateTime::::from_naive_utc_and_offset(naive, Utc); + return dt.with_nanosecond(0); + } + if let Ok(dt) = DateTime::parse_from_rfc3339(ts) { + return dt.with_timezone(&Utc).with_nanosecond(0); + } + None +} + +async fn collect_rollout_paths(root: &Path) -> std::io::Result> { + let mut stack = vec![root.to_path_buf()]; + let mut paths = Vec::new(); + while let Some(dir) = stack.pop() { + let mut read_dir = match tokio::fs::read_dir(&dir).await { + Ok(read_dir) => read_dir, + Err(err) => { + warn!("failed to read directory {}: {err}", dir.display()); + continue; + } + }; + while let Some(entry) = read_dir.next_entry().await? { + let path = entry.path(); + let file_type = entry.file_type().await?; + if file_type.is_dir() { + stack.push(path); + continue; + } + if !file_type.is_file() { + continue; + } + let file_name = entry.file_name(); + let Some(name) = file_name.to_str() else { + continue; + }; + if name.starts_with(ROLLOUT_PREFIX) && name.ends_with(ROLLOUT_SUFFIX) { + paths.push(path); + } + } + } + Ok(paths) +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::DateTime; + use chrono::NaiveDateTime; + use chrono::Timelike; + use chrono::Utc; + use codex_protocol::ThreadId; + use codex_protocol::protocol::CompactedItem; + use codex_protocol::protocol::RolloutItem; + use codex_protocol::protocol::RolloutLine; + use codex_protocol::protocol::SessionMeta; + use codex_protocol::protocol::SessionMetaLine; + use codex_protocol::protocol::SessionSource; + use codex_state::ThreadMetadataBuilder; + use pretty_assertions::assert_eq; + use std::fs::File; + use std::io::Write; + use tempfile::tempdir; + use uuid::Uuid; + + #[tokio::test] + async fn extract_metadata_from_rollout_uses_session_meta() { + let dir = tempdir().expect("tempdir"); + let uuid = Uuid::new_v4(); + let id = ThreadId::from_string(&uuid.to_string()).expect("thread id"); + let path = dir + .path() + .join(format!("rollout-2026-01-27T12-34-56-{uuid}.jsonl")); + + let session_meta = SessionMeta { + id, + forked_from_id: None, + timestamp: "2026-01-27T12:34:56Z".to_string(), + cwd: dir.path().to_path_buf(), + originator: "cli".to_string(), + cli_version: "0.0.0".to_string(), + source: SessionSource::default(), + model_provider: Some("openai".to_string()), + base_instructions: None, + }; + let session_meta_line = SessionMetaLine { + meta: session_meta, + git: None, + }; + let rollout_line = RolloutLine { + timestamp: "2026-01-27T12:34:56Z".to_string(), + item: RolloutItem::SessionMeta(session_meta_line.clone()), + }; + let json = serde_json::to_string(&rollout_line).expect("rollout json"); + let mut file = File::create(&path).expect("create rollout"); + writeln!(file, "{json}").expect("write rollout"); + + let outcome = extract_metadata_from_rollout(&path, "openai", None) + .await + .expect("extract"); + + let builder = + builder_from_session_meta(&session_meta_line, path.as_path()).expect("builder"); + let mut expected = builder.build("openai"); + apply_rollout_item(&mut expected, &rollout_line.item, "openai"); + expected.updated_at = file_modified_time_utc(&path).await.expect("mtime"); + + assert_eq!(outcome.metadata, expected); + assert_eq!(outcome.parse_errors, 0); + } + + #[test] + fn builder_from_items_falls_back_to_filename() { + let dir = tempdir().expect("tempdir"); + let uuid = Uuid::new_v4(); + let path = dir + .path() + .join(format!("rollout-2026-01-27T12-34-56-{uuid}.jsonl")); + let items = vec![RolloutItem::Compacted(CompactedItem { + message: "noop".to_string(), + replacement_history: None, + })]; + + let builder = builder_from_items(items.as_slice(), path.as_path()).expect("builder"); + let naive = NaiveDateTime::parse_from_str("2026-01-27T12-34-56", "%Y-%m-%dT%H-%M-%S") + .expect("timestamp"); + let created_at = DateTime::::from_naive_utc_and_offset(naive, Utc) + .with_nanosecond(0) + .expect("nanosecond"); + let expected = ThreadMetadataBuilder::new( + ThreadId::from_string(&uuid.to_string()).expect("thread id"), + path, + created_at, + SessionSource::default(), + ); + + assert_eq!(builder, expected); + } +} diff --git a/codex-rs/core/src/rollout/mod.rs b/codex-rs/core/src/rollout/mod.rs index fbddfecf3..cfc2d82d8 100644 --- a/codex-rs/core/src/rollout/mod.rs +++ b/codex-rs/core/src/rollout/mod.rs @@ -9,6 +9,7 @@ pub const INTERACTIVE_SESSION_SOURCES: &[SessionSource] = pub(crate) mod error; pub mod list; +pub(crate) mod metadata; pub(crate) mod policy; pub mod recorder; pub(crate) mod truncation; diff --git a/codex-rs/core/src/rollout/recorder.rs b/codex-rs/core/src/rollout/recorder.rs index 53425051c..cc3585054 100644 --- a/codex-rs/core/src/rollout/recorder.rs +++ b/codex-rs/core/src/rollout/recorder.rs @@ -28,11 +28,14 @@ use super::list::ThreadSortKey; use super::list::ThreadsPage; use super::list::get_threads; use super::list::get_threads_in_root; +use super::metadata; use super::policy::is_persisted_response_item; use crate::config::Config; use crate::default_client::originator; use crate::git_info::collect_git_info; use crate::path_utils; +use crate::state_db; +use crate::state_db::StateDbHandle; use codex_protocol::protocol::InitialHistory; use codex_protocol::protocol::ResumedHistory; use codex_protocol::protocol::RolloutItem; @@ -40,6 +43,7 @@ use codex_protocol::protocol::RolloutLine; use codex_protocol::protocol::SessionMeta; use codex_protocol::protocol::SessionMetaLine; use codex_protocol::protocol::SessionSource; +use codex_state::ThreadMetadataBuilder; /// Records all [`ResponseItem`]s for a session and flushes them to disk after /// every update. @@ -54,6 +58,7 @@ use codex_protocol::protocol::SessionSource; pub struct RolloutRecorder { tx: Sender, pub(crate) rollout_path: PathBuf, + state_db: Option, } #[derive(Clone)] @@ -111,7 +116,8 @@ impl RolloutRecorder { model_providers: Option<&[String]>, default_provider: &str, ) -> std::io::Result { - get_threads( + let stage = "list_threads"; + let page = get_threads( codex_home, page_size, cursor, @@ -120,7 +126,34 @@ impl RolloutRecorder { model_providers, default_provider, ) + .await?; + + // TODO(jif): drop after sqlite migration phase 1 + let state_db_ctx = state_db::open_if_present(codex_home, default_provider).await; + if let Some(db_ids) = state_db::list_thread_ids_db( + state_db_ctx.as_deref(), + codex_home, + page_size, + cursor, + sort_key, + allowed_sources, + model_providers, + false, + stage, + ) .await + { + if page.items.len() != db_ids.len() { + state_db::record_discrepancy(stage, "bad_len"); + return Ok(page); + } + for (id, item) in db_ids.iter().zip(page.items.iter()) { + if !item.path.display().to_string().contains(&id.to_string()) { + state_db::record_discrepancy(stage, "bad_id"); + } + } + } + Ok(page) } /// List archived threads (rollout files) under the archived sessions directory. @@ -133,8 +166,9 @@ impl RolloutRecorder { model_providers: Option<&[String]>, default_provider: &str, ) -> std::io::Result { + let stage = "list_archived_threads"; let root = codex_home.join(ARCHIVED_SESSIONS_SUBDIR); - get_threads_in_root( + let page = get_threads_in_root( root, page_size, cursor, @@ -146,7 +180,34 @@ impl RolloutRecorder { layout: ThreadListLayout::Flat, }, ) + .await?; + + // TODO(jif): drop after sqlite migration phase 1 + let state_db_ctx = state_db::open_if_present(codex_home, default_provider).await; + if let Some(db_ids) = state_db::list_thread_ids_db( + state_db_ctx.as_deref(), + codex_home, + page_size, + cursor, + sort_key, + allowed_sources, + model_providers, + true, + stage, + ) .await + { + if page.items.len() != db_ids.len() { + state_db::record_discrepancy(stage, "bad_len"); + return Ok(page); + } + for (id, item) in db_ids.iter().zip(page.items.iter()) { + if !item.path.display().to_string().contains(&id.to_string()) { + state_db::record_discrepancy(stage, "bad_id"); + } + } + } + Ok(page) } /// Find the newest recorded thread path, optionally filtering to a matching cwd. @@ -186,7 +247,12 @@ impl RolloutRecorder { /// Attempt to create a new [`RolloutRecorder`]. If the sessions directory /// cannot be created or the rollout file cannot be opened we return the /// error so the caller can decide whether to disable persistence. - pub async fn new(config: &Config, params: RolloutRecorderParams) -> std::io::Result { + pub async fn new( + config: &Config, + params: RolloutRecorderParams, + state_db_ctx: Option, + state_builder: Option, + ) -> std::io::Result { let (file, rollout_path, meta) = match params { RolloutRecorderParams::Create { conversation_id, @@ -246,9 +312,30 @@ impl RolloutRecorder { // Spawn a Tokio task that owns the file handle and performs async // writes. Using `tokio::fs::File` keeps everything on the async I/O // driver instead of blocking the runtime. - tokio::task::spawn(rollout_writer(file, rx, meta, cwd)); + tokio::task::spawn(rollout_writer( + file, + rx, + meta, + cwd, + rollout_path.clone(), + state_db_ctx.clone(), + state_builder, + config.model_provider_id.clone(), + )); - Ok(Self { tx, rollout_path }) + Ok(Self { + tx, + rollout_path, + state_db: state_db_ctx, + }) + } + + pub fn rollout_path(&self) -> &Path { + self.rollout_path.as_path() + } + + pub fn state_db(&self) -> Option { + self.state_db.clone() } pub(crate) async fn record_items(&self, items: &[RolloutItem]) -> std::io::Result<()> { @@ -281,7 +368,9 @@ impl RolloutRecorder { .map_err(|e| IoError::other(format!("failed waiting for rollout flush: {e}"))) } - pub async fn get_rollout_history(path: &Path) -> std::io::Result { + pub(crate) async fn load_rollout_items( + path: &Path, + ) -> std::io::Result<(Vec, Option, usize)> { info!("Resuming rollout from {path:?}"); let text = tokio::fs::read_to_string(path).await?; if text.trim().is_empty() { @@ -290,6 +379,7 @@ impl RolloutRecorder { let mut items: Vec = Vec::new(); let mut thread_id: Option = None; + let mut parse_errors = 0usize; for line in text.lines() { if line.trim().is_empty() { continue; @@ -298,6 +388,7 @@ impl RolloutRecorder { Ok(v) => v, Err(e) => { warn!("failed to parse line as JSON: {line:?}, error: {e}"); + parse_errors = parse_errors.saturating_add(1); continue; } }; @@ -328,15 +419,22 @@ impl RolloutRecorder { }, Err(e) => { warn!("failed to parse rollout line: {v:?}, error: {e}"); + parse_errors = parse_errors.saturating_add(1); } } } info!( - "Resumed rollout with {} items, thread ID: {:?}", + "Resumed rollout with {} items, thread ID: {:?}, parse errors: {}", items.len(), - thread_id + thread_id, + parse_errors, ); + Ok((items, thread_id, parse_errors)) + } + + pub async fn get_rollout_history(path: &Path) -> std::io::Result { + let (items, thread_id, _parse_errors) = Self::load_rollout_items(path).await?; let conversation_id = thread_id .ok_or_else(|| IoError::other("failed to parse thread ID from rollout file"))?; @@ -417,13 +515,21 @@ fn create_log_file(config: &Config, conversation_id: ThreadId) -> std::io::Resul }) } +#[allow(clippy::too_many_arguments)] async fn rollout_writer( file: tokio::fs::File, mut rx: mpsc::Receiver, mut meta: Option, cwd: std::path::PathBuf, + rollout_path: PathBuf, + state_db_ctx: Option, + mut state_builder: Option, + default_provider: String, ) -> std::io::Result<()> { let mut writer = JsonlWriter { file }; + if let Some(builder) = state_builder.as_mut() { + builder.rollout_path = rollout_path.clone(); + } // If we have a meta, collect git info asynchronously and write meta first if let Some(session_meta) = meta.take() { @@ -432,22 +538,50 @@ async fn rollout_writer( meta: session_meta, git: git_info, }; + if state_db_ctx.is_some() { + state_builder = + metadata::builder_from_session_meta(&session_meta_line, rollout_path.as_path()); + } // Write the SessionMeta as the first item in the file, wrapped in a rollout line - writer - .write_rollout_item(RolloutItem::SessionMeta(session_meta_line)) - .await?; + let rollout_item = RolloutItem::SessionMeta(session_meta_line); + writer.write_rollout_item(&rollout_item).await?; + state_db::reconcile_rollout( + state_db_ctx.as_deref(), + rollout_path.as_path(), + default_provider.as_str(), + state_builder.as_ref(), + std::slice::from_ref(&rollout_item), + ) + .await; } // Process rollout commands while let Some(cmd) = rx.recv().await { match cmd { RolloutCmd::AddItems(items) => { + let mut persisted_items = Vec::new(); for item in items { if is_persisted_response_item(&item) { - writer.write_rollout_item(item).await?; + writer.write_rollout_item(&item).await?; + persisted_items.push(item); } } + if persisted_items.is_empty() { + continue; + } + if let Some(builder) = state_builder.as_mut() { + builder.rollout_path = rollout_path.clone(); + } + state_db::apply_rollout_items( + state_db_ctx.as_deref(), + rollout_path.as_path(), + default_provider.as_str(), + state_builder.as_ref(), + persisted_items.as_slice(), + "rollout_writer", + ) + .await; } RolloutCmd::Flush { ack } => { // Ensure underlying file is flushed and then ack. @@ -470,8 +604,15 @@ struct JsonlWriter { file: tokio::fs::File, } +#[derive(serde::Serialize)] +struct RolloutLineRef<'a> { + timestamp: String, + #[serde(flatten)] + item: &'a RolloutItem, +} + impl JsonlWriter { - async fn write_rollout_item(&mut self, rollout_item: RolloutItem) -> std::io::Result<()> { + async fn write_rollout_item(&mut self, rollout_item: &RolloutItem) -> std::io::Result<()> { let timestamp_format: &[FormatItem] = format_description!( "[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z" ); @@ -479,7 +620,7 @@ impl JsonlWriter { .format(timestamp_format) .map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?; - let line = RolloutLine { + let line = RolloutLineRef { timestamp, item: rollout_item, }; diff --git a/codex-rs/core/src/state/service.rs b/codex-rs/core/src/state/service.rs index cd1f1c049..6559ec2e1 100644 --- a/codex-rs/core/src/state/service.rs +++ b/codex-rs/core/src/state/service.rs @@ -7,6 +7,7 @@ use crate::exec_policy::ExecPolicyManager; use crate::mcp_connection_manager::McpConnectionManager; use crate::models_manager::manager::ModelsManager; use crate::skills::SkillsManager; +use crate::state_db::StateDbHandle; use crate::tools::sandboxing::ApprovalStore; use crate::unified_exec::UnifiedExecProcessManager; use crate::user_notification::UserNotifier; @@ -30,4 +31,5 @@ pub(crate) struct SessionServices { pub(crate) tool_approvals: Mutex, pub(crate) skills_manager: Arc, pub(crate) agent_control: AgentControl, + pub(crate) state_db: Option, } diff --git a/codex-rs/core/src/state_db.rs b/codex-rs/core/src/state_db.rs new file mode 100644 index 000000000..d859328e1 --- /dev/null +++ b/codex-rs/core/src/state_db.rs @@ -0,0 +1,303 @@ +use crate::config::Config; +use crate::features::Feature; +use crate::rollout::list::Cursor; +use crate::rollout::list::ThreadSortKey; +use crate::rollout::metadata; +use chrono::DateTime; +use chrono::NaiveDateTime; +use chrono::Timelike; +use chrono::Utc; +use codex_otel::OtelManager; +use codex_protocol::ThreadId; +use codex_protocol::protocol::RolloutItem; +use codex_protocol::protocol::SessionSource; +use codex_state::DB_METRIC_BACKFILL; +use codex_state::STATE_DB_FILENAME; +use codex_state::ThreadMetadataBuilder; +use serde_json::Value; +use std::path::Path; +use std::path::PathBuf; +use std::sync::Arc; +use tracing::info; +use tracing::warn; +use uuid::Uuid; + +/// Core-facing handle to the optional SQLite-backed state runtime. +pub type StateDbHandle = Arc; + +/// Initialize the state runtime when the `sqlite` feature flag is enabled. +pub async fn init_if_enabled(config: &Config, otel: Option<&OtelManager>) -> Option { + let state_path = config.codex_home.join(STATE_DB_FILENAME); + if !config.features.enabled(Feature::Sqlite) { + // We delete the file on best effort basis to maintain retro-compatibility in the future. + let wal_path = state_path.with_extension("sqlite-wal"); + let shm_path = state_path.with_extension("sqlite-shm"); + for path in [state_path.as_path(), wal_path.as_path(), shm_path.as_path()] { + tokio::fs::remove_file(path).await.ok(); + } + return None; + } + let existed = tokio::fs::try_exists(&state_path).await.unwrap_or(false); + let runtime = match codex_state::StateRuntime::init( + config.codex_home.clone(), + config.model_provider_id.clone(), + otel.cloned(), + ) + .await + { + Ok(runtime) => runtime, + Err(err) => { + warn!( + "failed to initialize state runtime at {}: {err}", + config.codex_home.display() + ); + if let Some(otel) = otel { + otel.counter("codex.db.init", 1, &[("status", "init_error")]); + } + return None; + } + }; + if !existed { + let stats = metadata::backfill_sessions(runtime.as_ref(), config, otel).await; + info!( + "state db backfill scanned={}, upserted={}, failed={}", + stats.scanned, stats.upserted, stats.failed + ); + if let Some(otel) = otel { + otel.counter( + DB_METRIC_BACKFILL, + stats.upserted as i64, + &[("status", "upserted")], + ); + otel.counter( + DB_METRIC_BACKFILL, + stats.failed as i64, + &[("status", "failed")], + ); + } + } + Some(runtime) +} + +/// Open the state runtime when the SQLite file exists, without feature gating. +/// +/// This is used for parity checks during the SQLite migration phase. +pub async fn open_if_present(codex_home: &Path, default_provider: &str) -> Option { + let db_path = codex_home.join(STATE_DB_FILENAME); + if !tokio::fs::try_exists(&db_path).await.unwrap_or(false) { + return None; + } + let runtime = codex_state::StateRuntime::init( + codex_home.to_path_buf(), + default_provider.to_string(), + None, + ) + .await + .ok()?; + Some(runtime) +} + +fn cursor_to_anchor(cursor: Option<&Cursor>) -> Option { + let cursor = cursor?; + let value = serde_json::to_value(cursor).ok()?; + let cursor_str = value.as_str()?; + let (ts_str, id_str) = cursor_str.split_once('|')?; + if id_str.contains('|') { + return None; + } + let id = Uuid::parse_str(id_str).ok()?; + let ts = if let Ok(naive) = NaiveDateTime::parse_from_str(ts_str, "%Y-%m-%dT%H-%M-%S") { + DateTime::::from_naive_utc_and_offset(naive, Utc) + } else if let Ok(dt) = DateTime::parse_from_rfc3339(ts_str) { + dt.with_timezone(&Utc) + } else { + return None; + } + .with_nanosecond(0)?; + Some(codex_state::Anchor { ts, id }) +} + +/// List thread ids from SQLite for parity checks without rollout scanning. +#[allow(clippy::too_many_arguments)] +pub async fn list_thread_ids_db( + context: Option<&codex_state::StateRuntime>, + codex_home: &Path, + page_size: usize, + cursor: Option<&Cursor>, + sort_key: ThreadSortKey, + allowed_sources: &[SessionSource], + model_providers: Option<&[String]>, + archived_only: bool, + stage: &str, +) -> Option> { + let ctx = context?; + if ctx.codex_home() != codex_home { + warn!( + "state db codex_home mismatch: expected {}, got {}", + ctx.codex_home().display(), + codex_home.display() + ); + } + + let anchor = cursor_to_anchor(cursor); + let allowed_sources: Vec = allowed_sources + .iter() + .map(|value| match serde_json::to_value(value) { + Ok(Value::String(s)) => s, + Ok(other) => other.to_string(), + Err(_) => String::new(), + }) + .collect(); + let model_providers = model_providers.map(<[String]>::to_vec); + match ctx + .list_thread_ids( + page_size, + anchor.as_ref(), + match sort_key { + ThreadSortKey::CreatedAt => codex_state::SortKey::CreatedAt, + ThreadSortKey::UpdatedAt => codex_state::SortKey::UpdatedAt, + }, + allowed_sources.as_slice(), + model_providers.as_deref(), + archived_only, + ) + .await + { + Ok(ids) => Some(ids), + Err(err) => { + warn!("state db list_thread_ids failed during {stage}: {err}"); + None + } + } +} + +/// Look up the rollout path for a thread id using SQLite. +pub async fn find_rollout_path_by_id( + context: Option<&codex_state::StateRuntime>, + thread_id: ThreadId, + archived_only: Option, + stage: &str, +) -> Option { + let ctx = context?; + ctx.find_rollout_path_by_id(thread_id, archived_only) + .await + .unwrap_or_else(|err| { + warn!("state db find_rollout_path_by_id failed during {stage}: {err}"); + None + }) +} + +/// Reconcile rollout items into SQLite, falling back to scanning the rollout file. +pub async fn reconcile_rollout( + context: Option<&codex_state::StateRuntime>, + rollout_path: &Path, + default_provider: &str, + builder: Option<&ThreadMetadataBuilder>, + items: &[RolloutItem], +) { + let Some(ctx) = context else { + return; + }; + if builder.is_some() || !items.is_empty() { + apply_rollout_items( + Some(ctx), + rollout_path, + default_provider, + builder, + items, + "reconcile_rollout", + ) + .await; + return; + } + let outcome = + match metadata::extract_metadata_from_rollout(rollout_path, default_provider, None).await { + Ok(outcome) => outcome, + Err(err) => { + warn!( + "state db reconcile_rollout extraction failed {}: {err}", + rollout_path.display() + ); + return; + } + }; + if let Err(err) = ctx.upsert_thread(&outcome.metadata).await { + warn!( + "state db reconcile_rollout upsert failed {}: {err}", + rollout_path.display() + ); + } +} + +/// Apply rollout items incrementally to SQLite. +pub async fn apply_rollout_items( + context: Option<&codex_state::StateRuntime>, + rollout_path: &Path, + _default_provider: &str, + builder: Option<&ThreadMetadataBuilder>, + items: &[RolloutItem], + stage: &str, +) { + let Some(ctx) = context else { + return; + }; + let mut builder = match builder { + Some(builder) => builder.clone(), + None => match metadata::builder_from_items(items, rollout_path) { + Some(builder) => builder, + None => { + warn!( + "state db apply_rollout_items missing builder during {stage}: {}", + rollout_path.display() + ); + record_discrepancy(stage, "missing_builder"); + return; + } + }, + }; + builder.rollout_path = rollout_path.to_path_buf(); + if let Err(err) = ctx.apply_rollout_items(&builder, items, None).await { + warn!( + "state db apply_rollout_items failed during {stage} for {}: {err}", + rollout_path.display() + ); + } +} + +/// Record a state discrepancy metric with a stage and reason tag. +pub fn record_discrepancy(stage: &str, reason: &str) { + // We access the global metric because the call sites might not have access to the broader + // OtelManager. + if let Some(metric) = codex_otel::metrics::global() { + let _ = metric.counter( + "codex.db.discrepancy", + 1, + &[("stage", stage), ("reason", reason)], + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::rollout::list::parse_cursor; + use pretty_assertions::assert_eq; + + #[test] + fn cursor_to_anchor_normalizes_timestamp_format() { + let uuid = Uuid::new_v4(); + let ts_str = "2026-01-27T12-34-56"; + let token = format!("{ts_str}|{uuid}"); + let cursor = parse_cursor(token.as_str()).expect("cursor should parse"); + let anchor = cursor_to_anchor(Some(&cursor)).expect("anchor should parse"); + + let naive = + NaiveDateTime::parse_from_str(ts_str, "%Y-%m-%dT%H-%M-%S").expect("ts should parse"); + let expected_ts = DateTime::::from_naive_utc_and_offset(naive, Utc) + .with_nanosecond(0) + .expect("nanosecond"); + + assert_eq!(anchor.id, uuid); + assert_eq!(anchor.ts, expected_ts); + } +} diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index 9bf4ef8ef..e054070b3 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -65,6 +65,7 @@ mod shell_command; mod shell_serialization; mod shell_snapshot; mod skills; +mod sqlite_state; mod stream_error_allows_next_turn; mod stream_no_completed; mod text_encoding_fix; diff --git a/codex-rs/core/tests/suite/sqlite_state.rs b/codex-rs/core/tests/suite/sqlite_state.rs new file mode 100644 index 000000000..df6fa2ef7 --- /dev/null +++ b/codex-rs/core/tests/suite/sqlite_state.rs @@ -0,0 +1,199 @@ +use anyhow::Result; +use codex_core::features::Feature; +use codex_protocol::ThreadId; +use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::RolloutItem; +use codex_protocol::protocol::RolloutLine; +use codex_protocol::protocol::SessionMeta; +use codex_protocol::protocol::SessionMetaLine; +use codex_protocol::protocol::SessionSource; +use codex_protocol::protocol::UserMessageEvent; +use codex_state::STATE_DB_FILENAME; +use core_test_support::load_sse_fixture_with_id; +use core_test_support::responses::mount_sse_sequence; +use core_test_support::responses::start_mock_server; +use core_test_support::test_codex::test_codex; +use pretty_assertions::assert_eq; +use std::fs; +use tokio::time::Duration; +use uuid::Uuid; + +fn sse_completed(id: &str) -> String { + load_sse_fixture_with_id("../fixtures/completed_template.json", id) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn new_thread_is_recorded_in_state_db() -> Result<()> { + let server = start_mock_server().await; + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::Sqlite); + }); + let test = builder.build(&server).await?; + + let thread_id = test.session_configured.session_id; + let rollout_path = test.codex.rollout_path().expect("rollout path"); + let db_path = test.config.codex_home.join(STATE_DB_FILENAME); + + for _ in 0..100 { + if tokio::fs::try_exists(&db_path).await.unwrap_or(false) { + break; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + + let db = test.codex.state_db().expect("state db enabled"); + + let mut metadata = None; + for _ in 0..100 { + metadata = db.get_thread(thread_id).await?; + if metadata.is_some() { + break; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + + let metadata = metadata.expect("thread should exist in state db"); + assert_eq!(metadata.id, thread_id); + assert_eq!(metadata.rollout_path, rollout_path); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn backfill_scans_existing_rollouts() -> Result<()> { + let server = start_mock_server().await; + + let uuid = Uuid::now_v7(); + let thread_id = ThreadId::from_string(&uuid.to_string())?; + let rollout_rel_path = format!("sessions/2026/01/27/rollout-2026-01-27T12-00-00-{uuid}.jsonl"); + let rollout_rel_path_for_hook = rollout_rel_path.clone(); + + let mut builder = test_codex() + .with_pre_build_hook(move |codex_home| { + let rollout_path = codex_home.join(&rollout_rel_path_for_hook); + let parent = rollout_path + .parent() + .expect("rollout path should have parent"); + fs::create_dir_all(parent).expect("should create rollout directory"); + + let session_meta_line = SessionMetaLine { + meta: SessionMeta { + id: thread_id, + forked_from_id: None, + timestamp: "2026-01-27T12:00:00Z".to_string(), + cwd: codex_home.to_path_buf(), + originator: "test".to_string(), + cli_version: "test".to_string(), + source: SessionSource::default(), + model_provider: None, + base_instructions: None, + }, + git: None, + }; + + let lines = [ + RolloutLine { + timestamp: "2026-01-27T12:00:00Z".to_string(), + item: RolloutItem::SessionMeta(session_meta_line), + }, + RolloutLine { + timestamp: "2026-01-27T12:00:01Z".to_string(), + item: RolloutItem::EventMsg(EventMsg::UserMessage(UserMessageEvent { + message: "hello from backfill".to_string(), + images: None, + local_images: Vec::new(), + text_elements: Vec::new(), + })), + }, + ]; + + let jsonl = lines + .iter() + .map(|line| serde_json::to_string(line).expect("rollout line should serialize")) + .collect::>() + .join("\n"); + fs::write(&rollout_path, format!("{jsonl}\n")).expect("should write rollout file"); + }) + .with_config(|config| { + config.features.enable(Feature::Sqlite); + }); + + let test = builder.build(&server).await?; + + let db_path = test.config.codex_home.join(STATE_DB_FILENAME); + let rollout_path = test.config.codex_home.join(&rollout_rel_path); + let default_provider = test.config.model_provider_id.clone(); + + for _ in 0..20 { + if tokio::fs::try_exists(&db_path).await.unwrap_or(false) { + break; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + + let db = test.codex.state_db().expect("state db enabled"); + + let mut metadata = None; + for _ in 0..40 { + metadata = db.get_thread(thread_id).await?; + if metadata.is_some() { + break; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + + let metadata = metadata.expect("backfilled thread should exist in state db"); + assert_eq!(metadata.id, thread_id); + assert_eq!(metadata.rollout_path, rollout_path); + assert_eq!(metadata.model_provider, default_provider); + assert!(metadata.has_user_event); + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn user_messages_persist_in_state_db() -> Result<()> { + let server = start_mock_server().await; + mount_sse_sequence( + &server, + vec![sse_completed("resp-1"), sse_completed("resp-2")], + ) + .await; + + let mut builder = test_codex().with_config(|config| { + config.features.enable(Feature::Sqlite); + }); + let test = builder.build(&server).await?; + + let db_path = test.config.codex_home.join(STATE_DB_FILENAME); + for _ in 0..100 { + if tokio::fs::try_exists(&db_path).await.unwrap_or(false) { + break; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + + test.submit_turn("hello from sqlite").await?; + test.submit_turn("another message").await?; + + let db = test.codex.state_db().expect("state db enabled"); + let thread_id = test.session_configured.session_id; + + let mut metadata = None; + for _ in 0..100 { + metadata = db.get_thread(thread_id).await?; + if metadata + .as_ref() + .map(|entry| entry.has_user_event) + .unwrap_or(false) + { + break; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + + let metadata = metadata.expect("thread should exist in state db"); + assert!(metadata.has_user_event); + + Ok(()) +} diff --git a/codex-rs/otel/src/metrics/mod.rs b/codex-rs/otel/src/metrics/mod.rs index b13d5f917..30a7aa1cf 100644 --- a/codex-rs/otel/src/metrics/mod.rs +++ b/codex-rs/otel/src/metrics/mod.rs @@ -17,6 +17,6 @@ pub(crate) fn install_global(metrics: MetricsClient) { let _ = GLOBAL_METRICS.set(metrics); } -pub(crate) fn global() -> Option { +pub fn global() -> Option { GLOBAL_METRICS.get().cloned() } diff --git a/codex-rs/protocol/src/thread_id.rs b/codex-rs/protocol/src/thread_id.rs index 7b27db836..8d6d96eff 100644 --- a/codex-rs/protocol/src/thread_id.rs +++ b/codex-rs/protocol/src/thread_id.rs @@ -28,6 +28,28 @@ impl ThreadId { } } +impl TryFrom<&str> for ThreadId { + type Error = uuid::Error; + + fn try_from(value: &str) -> Result { + Self::from_string(value) + } +} + +impl TryFrom for ThreadId { + type Error = uuid::Error; + + fn try_from(value: String) -> Result { + Self::from_string(value.as_str()) + } +} + +impl From for String { + fn from(value: ThreadId) -> Self { + value.to_string() + } +} + impl Default for ThreadId { fn default() -> Self { Self::new() @@ -36,7 +58,7 @@ impl Default for ThreadId { impl Display for ThreadId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.uuid) + Display::fmt(&self.uuid, f) } } diff --git a/codex-rs/state/BUILD.bazel b/codex-rs/state/BUILD.bazel new file mode 100644 index 000000000..b1f793216 --- /dev/null +++ b/codex-rs/state/BUILD.bazel @@ -0,0 +1,7 @@ +load("//:defs.bzl", "codex_rust_crate") + +codex_rust_crate( + name = "state", + crate_name = "codex_state", + compile_data = glob(["migrations/**"]), +) diff --git a/codex-rs/state/Cargo.toml b/codex-rs/state/Cargo.toml new file mode 100644 index 000000000..810e250ca --- /dev/null +++ b/codex-rs/state/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "codex-state" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +anyhow = { workspace = true } +chrono = { workspace = true } +codex-otel = { workspace = true } +codex-protocol = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +sqlx = { workspace = true } +tokio = { workspace = true, features = ["fs", "io-util", "macros", "rt-multi-thread", "sync", "time"] } +tracing = { workspace = true } +uuid = { workspace = true } + +[dev-dependencies] +pretty_assertions = { workspace = true } + +[lints] +workspace = true diff --git a/codex-rs/state/migrations/0001_threads.sql b/codex-rs/state/migrations/0001_threads.sql new file mode 100644 index 000000000..7063ce11a --- /dev/null +++ b/codex-rs/state/migrations/0001_threads.sql @@ -0,0 +1,25 @@ +CREATE TABLE threads ( + id TEXT PRIMARY KEY, + rollout_path TEXT NOT NULL, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + source TEXT NOT NULL, + model_provider TEXT NOT NULL, + cwd TEXT NOT NULL, + title TEXT NOT NULL, + sandbox_policy TEXT NOT NULL, + approval_mode TEXT NOT NULL, + tokens_used INTEGER NOT NULL DEFAULT 0, + has_user_event INTEGER NOT NULL DEFAULT 0, + archived INTEGER NOT NULL DEFAULT 0, + archived_at INTEGER, + git_sha TEXT, + git_branch TEXT, + git_origin_url TEXT +); + +CREATE INDEX idx_threads_created_at ON threads(created_at DESC, id DESC); +CREATE INDEX idx_threads_updated_at ON threads(updated_at DESC, id DESC); +CREATE INDEX idx_threads_archived ON threads(archived); +CREATE INDEX idx_threads_source ON threads(source); +CREATE INDEX idx_threads_provider ON threads(model_provider); diff --git a/codex-rs/state/src/extract.rs b/codex-rs/state/src/extract.rs new file mode 100644 index 000000000..9c8ce2155 --- /dev/null +++ b/codex-rs/state/src/extract.rs @@ -0,0 +1,182 @@ +use crate::model::ThreadMetadata; +use codex_protocol::models::ContentItem; +use codex_protocol::models::ResponseItem; +use codex_protocol::models::is_local_image_close_tag_text; +use codex_protocol::models::is_local_image_open_tag_text; +use codex_protocol::protocol::EventMsg; +use codex_protocol::protocol::RolloutItem; +use codex_protocol::protocol::SessionMetaLine; +use codex_protocol::protocol::TurnContextItem; +use codex_protocol::protocol::USER_MESSAGE_BEGIN; +use serde::Serialize; +use serde_json::Value; + +/// Apply a rollout item to the metadata structure. +pub fn apply_rollout_item( + metadata: &mut ThreadMetadata, + item: &RolloutItem, + default_provider: &str, +) { + match item { + RolloutItem::SessionMeta(meta_line) => apply_session_meta_from_item(metadata, meta_line), + RolloutItem::TurnContext(turn_ctx) => apply_turn_context(metadata, turn_ctx), + RolloutItem::EventMsg(event) => apply_event_msg(metadata, event), + RolloutItem::ResponseItem(item) => apply_response_item(metadata, item), + RolloutItem::Compacted(_) => {} + } + if metadata.model_provider.is_empty() { + metadata.model_provider = default_provider.to_string(); + } +} + +fn apply_session_meta_from_item(metadata: &mut ThreadMetadata, meta_line: &SessionMetaLine) { + metadata.id = meta_line.meta.id; + metadata.source = enum_to_string(&meta_line.meta.source); + if let Some(provider) = meta_line.meta.model_provider.as_deref() { + metadata.model_provider = provider.to_string(); + } + if !meta_line.meta.cwd.as_os_str().is_empty() { + metadata.cwd = meta_line.meta.cwd.clone(); + } + if let Some(git) = meta_line.git.as_ref() { + metadata.git_sha = git.commit_hash.clone(); + metadata.git_branch = git.branch.clone(); + metadata.git_origin_url = git.repository_url.clone(); + } +} + +fn apply_turn_context(metadata: &mut ThreadMetadata, turn_ctx: &TurnContextItem) { + metadata.cwd = turn_ctx.cwd.clone(); + metadata.sandbox_policy = enum_to_string(&turn_ctx.sandbox_policy); + metadata.approval_mode = enum_to_string(&turn_ctx.approval_policy); +} + +fn apply_event_msg(metadata: &mut ThreadMetadata, event: &EventMsg) { + match event { + EventMsg::TokenCount(token_count) => { + if let Some(info) = token_count.info.as_ref() { + metadata.tokens_used = info.total_token_usage.total_tokens.max(0); + } + } + EventMsg::UserMessage(user) => { + metadata.has_user_event = true; + if metadata.title.is_empty() { + metadata.title = strip_user_message_prefix(user.message.as_str()).to_string(); + } + } + _ => {} + } +} + +fn apply_response_item(metadata: &mut ThreadMetadata, item: &ResponseItem) { + if let Some(text) = extract_user_message_text(item) { + metadata.has_user_event = true; + if metadata.title.is_empty() { + metadata.title = text; + } + } +} + +fn extract_user_message_text(item: &ResponseItem) -> Option { + let ResponseItem::Message { role, content, .. } = item else { + return None; + }; + if role != "user" { + return None; + } + let texts: Vec<&str> = content + .iter() + .filter_map(|content_item| match content_item { + ContentItem::InputText { text } => Some(text.as_str()), + ContentItem::InputImage { .. } | ContentItem::OutputText { .. } => None, + }) + .filter(|text| !is_local_image_open_tag_text(text) && !is_local_image_close_tag_text(text)) + .collect(); + if texts.is_empty() { + return None; + } + let joined = texts.join("\n"); + Some( + strip_user_message_prefix(joined.as_str()) + .trim() + .to_string(), + ) +} + +fn strip_user_message_prefix(text: &str) -> &str { + match text.find(USER_MESSAGE_BEGIN) { + Some(idx) => text[idx + USER_MESSAGE_BEGIN.len()..].trim(), + None => text.trim(), + } +} + +pub(crate) fn enum_to_string(value: &T) -> String { + match serde_json::to_value(value) { + Ok(Value::String(s)) => s, + Ok(other) => other.to_string(), + Err(_) => String::new(), + } +} + +#[cfg(test)] +mod tests { + use super::extract_user_message_text; + use crate::model::ThreadMetadata; + use chrono::DateTime; + use chrono::Utc; + use codex_protocol::ThreadId; + use codex_protocol::models::ContentItem; + use codex_protocol::models::ResponseItem; + use codex_protocol::protocol::USER_MESSAGE_BEGIN; + use pretty_assertions::assert_eq; + use std::path::PathBuf; + use uuid::Uuid; + + #[test] + fn extracts_user_message_text() { + let item = ResponseItem::Message { + id: None, + role: "user".to_string(), + content: vec![ + ContentItem::InputText { + text: format!(" {USER_MESSAGE_BEGIN}actual question"), + }, + ContentItem::InputImage { + image_url: "https://example.com/image.png".to_string(), + }, + ], + end_turn: None, + }; + let actual = extract_user_message_text(&item); + assert_eq!(actual.as_deref(), Some("actual question")); + } + + #[test] + fn diff_fields_detects_changes() { + let id = ThreadId::from_string(&Uuid::now_v7().to_string()).expect("thread id"); + let created_at = DateTime::::from_timestamp(1_735_689_600, 0).expect("timestamp"); + let base = ThreadMetadata { + id, + rollout_path: PathBuf::from("/tmp/a.jsonl"), + created_at, + updated_at: created_at, + source: "cli".to_string(), + model_provider: "openai".to_string(), + cwd: PathBuf::from("/tmp"), + title: "hello".to_string(), + sandbox_policy: "read-only".to_string(), + approval_mode: "on-request".to_string(), + tokens_used: 1, + has_user_event: false, + archived_at: None, + git_sha: None, + git_branch: None, + git_origin_url: None, + }; + let mut other = base.clone(); + other.tokens_used = 2; + other.title = "world".to_string(); + let diffs = base.diff_fields(&other); + assert_eq!(diffs, vec!["title", "tokens_used"]); + } +} diff --git a/codex-rs/state/src/lib.rs b/codex-rs/state/src/lib.rs new file mode 100644 index 000000000..67533a379 --- /dev/null +++ b/codex-rs/state/src/lib.rs @@ -0,0 +1,34 @@ +//! SQLite-backed state for rollout metadata. +//! +//! This crate is intentionally small and focused: it extracts rollout metadata +//! from JSONL rollouts and mirrors it into a local SQLite database. Backfill +//! orchestration and rollout scanning live in `codex-core`. + +mod extract; +mod migrations; +mod model; +mod paths; +mod runtime; + +/// Preferred entrypoint: owns configuration and metrics. +pub use runtime::StateRuntime; + +/// Low-level storage engine: useful for focused tests. +/// +/// Most consumers should prefer [`StateRuntime`]. +pub use extract::apply_rollout_item; +pub use model::Anchor; +pub use model::BackfillStats; +pub use model::ExtractionOutcome; +pub use model::SortKey; +pub use model::ThreadMetadata; +pub use model::ThreadMetadataBuilder; +pub use model::ThreadsPage; +pub use runtime::STATE_DB_FILENAME; + +/// Errors encountered during DB operations. Tags: [stage] +pub const DB_ERROR_METRIC: &str = "codex.db.error"; +/// Metrics on backfill process during first init of the db. Tags: [status] +pub const DB_METRIC_BACKFILL: &str = "codex.db.backfill"; +/// Metrics on errors during comparison between DB and rollout file. Tags: [stage] +pub const DB_METRIC_COMPARE_ERROR: &str = "codex.db.compare_error"; diff --git a/codex-rs/state/src/migrations.rs b/codex-rs/state/src/migrations.rs new file mode 100644 index 000000000..24b310224 --- /dev/null +++ b/codex-rs/state/src/migrations.rs @@ -0,0 +1,3 @@ +use sqlx::migrate::Migrator; + +pub(crate) static MIGRATOR: Migrator = sqlx::migrate!("./migrations"); diff --git a/codex-rs/state/src/model.rs b/codex-rs/state/src/model.rs new file mode 100644 index 000000000..7d475efff --- /dev/null +++ b/codex-rs/state/src/model.rs @@ -0,0 +1,352 @@ +use anyhow::Result; +use chrono::DateTime; +use chrono::Timelike; +use chrono::Utc; +use codex_protocol::ThreadId; +use codex_protocol::protocol::AskForApproval; +use codex_protocol::protocol::SandboxPolicy; +use codex_protocol::protocol::SessionSource; +use sqlx::Row; +use sqlx::sqlite::SqliteRow; +use std::path::PathBuf; +use uuid::Uuid; + +/// The sort key to use when listing threads. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SortKey { + /// Sort by the thread's creation timestamp. + CreatedAt, + /// Sort by the thread's last update timestamp. + UpdatedAt, +} + +/// A pagination anchor used for keyset pagination. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Anchor { + /// The timestamp component of the anchor. + pub ts: DateTime, + /// The UUID component of the anchor. + pub id: Uuid, +} + +/// A single page of thread metadata results. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ThreadsPage { + /// The thread metadata items in this page. + pub items: Vec, + /// The next anchor to use for pagination, if any. + pub next_anchor: Option, + /// The number of rows scanned to produce this page. + pub num_scanned_rows: usize, +} + +/// The outcome of extracting metadata from a rollout. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ExtractionOutcome { + /// The extracted thread metadata. + pub metadata: ThreadMetadata, + /// The number of rollout lines that failed to parse. + pub parse_errors: usize, +} + +/// Canonical thread metadata derived from rollout files. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ThreadMetadata { + /// The thread identifier. + pub id: ThreadId, + /// The absolute rollout path on disk. + pub rollout_path: PathBuf, + /// The creation timestamp. + pub created_at: DateTime, + /// The last update timestamp. + pub updated_at: DateTime, + /// The session source (stringified enum). + pub source: String, + /// The model provider identifier. + pub model_provider: String, + /// The working directory for the thread. + pub cwd: PathBuf, + /// A best-effort thread title. + pub title: String, + /// The sandbox policy (stringified enum). + pub sandbox_policy: String, + /// The approval mode (stringified enum). + pub approval_mode: String, + /// The last observed token usage. + pub tokens_used: i64, + /// Whether the thread has observed a user message. + pub has_user_event: bool, + /// The archive timestamp, if the thread is archived. + pub archived_at: Option>, + /// The git commit SHA, if known. + pub git_sha: Option, + /// The git branch name, if known. + pub git_branch: Option, + /// The git origin URL, if known. + pub git_origin_url: Option, +} + +/// Builder data required to construct [`ThreadMetadata`] without parsing filenames. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ThreadMetadataBuilder { + /// The thread identifier. + pub id: ThreadId, + /// The absolute rollout path on disk. + pub rollout_path: PathBuf, + /// The creation timestamp. + pub created_at: DateTime, + /// The last update timestamp, if known. + pub updated_at: Option>, + /// The session source. + pub source: SessionSource, + /// The model provider identifier, if known. + pub model_provider: Option, + /// The working directory for the thread. + pub cwd: PathBuf, + /// The sandbox policy. + pub sandbox_policy: SandboxPolicy, + /// The approval mode. + pub approval_mode: AskForApproval, + /// The archive timestamp, if the thread is archived. + pub archived_at: Option>, + /// The git commit SHA, if known. + pub git_sha: Option, + /// The git branch name, if known. + pub git_branch: Option, + /// The git origin URL, if known. + pub git_origin_url: Option, +} + +impl ThreadMetadataBuilder { + /// Create a new builder with required fields and sensible defaults. + pub fn new( + id: ThreadId, + rollout_path: PathBuf, + created_at: DateTime, + source: SessionSource, + ) -> Self { + Self { + id, + rollout_path, + created_at, + updated_at: None, + source, + model_provider: None, + cwd: PathBuf::new(), + sandbox_policy: SandboxPolicy::ReadOnly, + approval_mode: AskForApproval::OnRequest, + archived_at: None, + git_sha: None, + git_branch: None, + git_origin_url: None, + } + } + + /// Build canonical thread metadata, filling missing values from defaults. + pub fn build(&self, default_provider: &str) -> ThreadMetadata { + let source = crate::extract::enum_to_string(&self.source); + let sandbox_policy = crate::extract::enum_to_string(&self.sandbox_policy); + let approval_mode = crate::extract::enum_to_string(&self.approval_mode); + let created_at = canonicalize_datetime(self.created_at); + let updated_at = self + .updated_at + .map(canonicalize_datetime) + .unwrap_or(created_at); + ThreadMetadata { + id: self.id, + rollout_path: self.rollout_path.clone(), + created_at, + updated_at, + source, + model_provider: self + .model_provider + .clone() + .unwrap_or_else(|| default_provider.to_string()), + cwd: self.cwd.clone(), + title: String::new(), + sandbox_policy, + approval_mode, + tokens_used: 0, + has_user_event: false, + archived_at: self.archived_at.map(canonicalize_datetime), + git_sha: self.git_sha.clone(), + git_branch: self.git_branch.clone(), + git_origin_url: self.git_origin_url.clone(), + } + } +} + +impl ThreadMetadata { + /// Return the list of field names that differ between `self` and `other`. + pub fn diff_fields(&self, other: &Self) -> Vec<&'static str> { + let mut diffs = Vec::new(); + if self.id != other.id { + diffs.push("id"); + } + if self.rollout_path != other.rollout_path { + diffs.push("rollout_path"); + } + if self.created_at != other.created_at { + diffs.push("created_at"); + } + if self.updated_at != other.updated_at { + diffs.push("updated_at"); + } + if self.source != other.source { + diffs.push("source"); + } + if self.model_provider != other.model_provider { + diffs.push("model_provider"); + } + if self.cwd != other.cwd { + diffs.push("cwd"); + } + if self.title != other.title { + diffs.push("title"); + } + if self.sandbox_policy != other.sandbox_policy { + diffs.push("sandbox_policy"); + } + if self.approval_mode != other.approval_mode { + diffs.push("approval_mode"); + } + if self.tokens_used != other.tokens_used { + diffs.push("tokens_used"); + } + if self.has_user_event != other.has_user_event { + diffs.push("has_user_event"); + } + if self.archived_at != other.archived_at { + diffs.push("archived_at"); + } + if self.git_sha != other.git_sha { + diffs.push("git_sha"); + } + if self.git_branch != other.git_branch { + diffs.push("git_branch"); + } + if self.git_origin_url != other.git_origin_url { + diffs.push("git_origin_url"); + } + diffs + } +} + +fn canonicalize_datetime(dt: DateTime) -> DateTime { + dt.with_nanosecond(0).unwrap_or(dt) +} + +#[derive(Debug)] +pub(crate) struct ThreadRow { + id: String, + rollout_path: String, + created_at: i64, + updated_at: i64, + source: String, + model_provider: String, + cwd: String, + title: String, + sandbox_policy: String, + approval_mode: String, + tokens_used: i64, + has_user_event: bool, + archived_at: Option, + git_sha: Option, + git_branch: Option, + git_origin_url: Option, +} + +impl ThreadRow { + pub(crate) fn try_from_row(row: &SqliteRow) -> Result { + Ok(Self { + id: row.try_get("id")?, + rollout_path: row.try_get("rollout_path")?, + created_at: row.try_get("created_at")?, + updated_at: row.try_get("updated_at")?, + source: row.try_get("source")?, + model_provider: row.try_get("model_provider")?, + cwd: row.try_get("cwd")?, + title: row.try_get("title")?, + sandbox_policy: row.try_get("sandbox_policy")?, + approval_mode: row.try_get("approval_mode")?, + tokens_used: row.try_get("tokens_used")?, + has_user_event: row.try_get("has_user_event")?, + archived_at: row.try_get("archived_at")?, + git_sha: row.try_get("git_sha")?, + git_branch: row.try_get("git_branch")?, + git_origin_url: row.try_get("git_origin_url")?, + }) + } +} + +impl TryFrom for ThreadMetadata { + type Error = anyhow::Error; + + fn try_from(row: ThreadRow) -> std::result::Result { + let ThreadRow { + id, + rollout_path, + created_at, + updated_at, + source, + model_provider, + cwd, + title, + sandbox_policy, + approval_mode, + tokens_used, + has_user_event, + archived_at, + git_sha, + git_branch, + git_origin_url, + } = row; + Ok(Self { + id: ThreadId::try_from(id)?, + rollout_path: PathBuf::from(rollout_path), + created_at: epoch_seconds_to_datetime(created_at)?, + updated_at: epoch_seconds_to_datetime(updated_at)?, + source, + model_provider, + cwd: PathBuf::from(cwd), + title, + sandbox_policy, + approval_mode, + tokens_used, + has_user_event, + archived_at: archived_at.map(epoch_seconds_to_datetime).transpose()?, + git_sha, + git_branch, + git_origin_url, + }) + } +} + +pub(crate) fn anchor_from_item(item: &ThreadMetadata, sort_key: SortKey) -> Option { + let id = Uuid::parse_str(&item.id.to_string()).ok()?; + let ts = match sort_key { + SortKey::CreatedAt => item.created_at, + SortKey::UpdatedAt => item.updated_at, + }; + Some(Anchor { ts, id }) +} + +pub(crate) fn datetime_to_epoch_seconds(dt: DateTime) -> i64 { + dt.timestamp() +} + +pub(crate) fn epoch_seconds_to_datetime(secs: i64) -> Result> { + DateTime::::from_timestamp(secs, 0) + .ok_or_else(|| anyhow::anyhow!("invalid unix timestamp: {secs}")) +} + +/// Statistics about a backfill operation. +#[derive(Debug, Clone)] +pub struct BackfillStats { + /// The number of rollout files scanned. + pub scanned: usize, + /// The number of rows upserted successfully. + pub upserted: usize, + /// The number of rows that failed to upsert. + pub failed: usize, +} diff --git a/codex-rs/state/src/paths.rs b/codex-rs/state/src/paths.rs new file mode 100644 index 000000000..812374382 --- /dev/null +++ b/codex-rs/state/src/paths.rs @@ -0,0 +1,10 @@ +use chrono::DateTime; +use chrono::Timelike; +use chrono::Utc; +use std::path::Path; + +pub(crate) async fn file_modified_time_utc(path: &Path) -> Option> { + let modified = tokio::fs::metadata(path).await.ok()?.modified().ok()?; + let updated_at: DateTime = modified.into(); + Some(updated_at.with_nanosecond(0).unwrap_or(updated_at)) +} diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs new file mode 100644 index 000000000..f341ee3bb --- /dev/null +++ b/codex-rs/state/src/runtime.rs @@ -0,0 +1,458 @@ +use crate::DB_ERROR_METRIC; +use crate::SortKey; +use crate::ThreadMetadata; +use crate::ThreadMetadataBuilder; +use crate::ThreadsPage; +use crate::apply_rollout_item; +use crate::migrations::MIGRATOR; +use crate::model::ThreadRow; +use crate::model::anchor_from_item; +use crate::model::datetime_to_epoch_seconds; +use crate::paths::file_modified_time_utc; +use chrono::DateTime; +use chrono::Utc; +use codex_otel::OtelManager; +use codex_protocol::ThreadId; +use codex_protocol::protocol::RolloutItem; +use sqlx::QueryBuilder; +use sqlx::Row; +use sqlx::Sqlite; +use sqlx::SqlitePool; +use sqlx::sqlite::SqliteConnectOptions; +use sqlx::sqlite::SqliteJournalMode; +use sqlx::sqlite::SqlitePoolOptions; +use sqlx::sqlite::SqliteSynchronous; +use std::path::Path; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; +use tracing::warn; + +pub const STATE_DB_FILENAME: &str = "state.sqlite"; + +const METRIC_DB_INIT: &str = "codex.db.init"; + +#[derive(Clone)] +pub struct StateRuntime { + codex_home: PathBuf, + default_provider: String, + pool: Arc, +} + +impl StateRuntime { + /// Initialize the state runtime using the provided Codex home and default provider. + /// + /// This opens (and migrates) the SQLite database at `codex_home/state.sqlite`. + pub async fn init( + codex_home: PathBuf, + default_provider: String, + otel: Option, + ) -> anyhow::Result> { + tokio::fs::create_dir_all(&codex_home).await?; + let state_path = codex_home.join(STATE_DB_FILENAME); + let existed = tokio::fs::try_exists(&state_path).await.unwrap_or(false); + let pool = match open_sqlite(&state_path).await { + Ok(db) => Arc::new(db), + Err(err) => { + warn!("failed to open state db at {}: {err}", state_path.display()); + if let Some(otel) = otel.as_ref() { + otel.counter(METRIC_DB_INIT, 1, &[("status", "open_error")]); + } + return Err(err); + } + }; + if let Some(otel) = otel.as_ref() { + otel.counter(METRIC_DB_INIT, 1, &[("status", "opened")]); + } + let runtime = Arc::new(Self { + pool, + codex_home, + default_provider, + }); + if !existed && let Some(otel) = otel.as_ref() { + otel.counter(METRIC_DB_INIT, 1, &[("status", "created")]); + } + Ok(runtime) + } + + /// Return the configured Codex home directory for this runtime. + pub fn codex_home(&self) -> &Path { + self.codex_home.as_path() + } + + /// Load thread metadata by id using the underlying database. + pub async fn get_thread(&self, id: ThreadId) -> anyhow::Result> { + let row = sqlx::query( + r#" +SELECT + id, + rollout_path, + created_at, + updated_at, + source, + model_provider, + cwd, + title, + sandbox_policy, + approval_mode, + tokens_used, + has_user_event, + archived_at, + git_sha, + git_branch, + git_origin_url +FROM threads +WHERE id = ? + "#, + ) + .bind(id.to_string()) + .fetch_optional(self.pool.as_ref()) + .await?; + row.map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from)) + .transpose() + } + + /// Find a rollout path by thread id using the underlying database. + pub async fn find_rollout_path_by_id( + &self, + id: ThreadId, + archived_only: Option, + ) -> anyhow::Result> { + let mut builder = + QueryBuilder::::new("SELECT rollout_path FROM threads WHERE id = "); + builder.push_bind(id.to_string()); + match archived_only { + Some(true) => { + builder.push(" AND archived = 1"); + } + Some(false) => { + builder.push(" AND archived = 0"); + } + None => {} + } + let row = builder.build().fetch_optional(self.pool.as_ref()).await?; + Ok(row + .and_then(|r| r.try_get::("rollout_path").ok()) + .map(PathBuf::from)) + } + + /// List threads using the underlying database. + pub async fn list_threads( + &self, + page_size: usize, + anchor: Option<&crate::Anchor>, + sort_key: crate::SortKey, + allowed_sources: &[String], + model_providers: Option<&[String]>, + archived_only: bool, + ) -> anyhow::Result { + let limit = page_size.saturating_add(1); + + let mut builder = QueryBuilder::::new( + r#" +SELECT + id, + rollout_path, + created_at, + updated_at, + source, + model_provider, + cwd, + title, + sandbox_policy, + approval_mode, + tokens_used, + has_user_event, + archived_at, + git_sha, + git_branch, + git_origin_url +FROM threads + "#, + ); + push_thread_filters( + &mut builder, + archived_only, + allowed_sources, + model_providers, + anchor, + sort_key, + ); + push_thread_order_and_limit(&mut builder, sort_key, limit); + + let rows = builder.build().fetch_all(self.pool.as_ref()).await?; + let mut items = rows + .into_iter() + .map(|row| ThreadRow::try_from_row(&row).and_then(ThreadMetadata::try_from)) + .collect::, _>>()?; + let num_scanned_rows = items.len(); + let next_anchor = if items.len() > page_size { + items.pop(); + items + .last() + .and_then(|item| anchor_from_item(item, sort_key)) + } else { + None + }; + Ok(ThreadsPage { + items, + next_anchor, + num_scanned_rows, + }) + } + + /// List thread ids using the underlying database (no rollout scanning). + pub async fn list_thread_ids( + &self, + limit: usize, + anchor: Option<&crate::Anchor>, + sort_key: crate::SortKey, + allowed_sources: &[String], + model_providers: Option<&[String]>, + archived_only: bool, + ) -> anyhow::Result> { + let mut builder = QueryBuilder::::new("SELECT id FROM threads"); + push_thread_filters( + &mut builder, + archived_only, + allowed_sources, + model_providers, + anchor, + sort_key, + ); + push_thread_order_and_limit(&mut builder, sort_key, limit); + + let rows = builder.build().fetch_all(self.pool.as_ref()).await?; + rows.into_iter() + .map(|row| { + let id: String = row.try_get("id")?; + Ok(ThreadId::try_from(id)?) + }) + .collect() + } + + /// Insert or replace thread metadata directly. + pub async fn upsert_thread(&self, metadata: &crate::ThreadMetadata) -> anyhow::Result<()> { + sqlx::query( + r#" +INSERT INTO threads ( + id, + rollout_path, + created_at, + updated_at, + source, + model_provider, + cwd, + title, + sandbox_policy, + approval_mode, + tokens_used, + has_user_event, + archived, + archived_at, + git_sha, + git_branch, + git_origin_url +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +ON CONFLICT(id) DO UPDATE SET + rollout_path = excluded.rollout_path, + created_at = excluded.created_at, + updated_at = excluded.updated_at, + source = excluded.source, + model_provider = excluded.model_provider, + cwd = excluded.cwd, + title = excluded.title, + sandbox_policy = excluded.sandbox_policy, + approval_mode = excluded.approval_mode, + tokens_used = excluded.tokens_used, + has_user_event = excluded.has_user_event, + archived = excluded.archived, + archived_at = excluded.archived_at, + git_sha = excluded.git_sha, + git_branch = excluded.git_branch, + git_origin_url = excluded.git_origin_url + "#, + ) + .bind(metadata.id.to_string()) + .bind(metadata.rollout_path.display().to_string()) + .bind(datetime_to_epoch_seconds(metadata.created_at)) + .bind(datetime_to_epoch_seconds(metadata.updated_at)) + .bind(metadata.source.as_str()) + .bind(metadata.model_provider.as_str()) + .bind(metadata.cwd.display().to_string()) + .bind(metadata.title.as_str()) + .bind(metadata.sandbox_policy.as_str()) + .bind(metadata.approval_mode.as_str()) + .bind(metadata.tokens_used) + .bind(metadata.has_user_event) + .bind(metadata.archived_at.is_some()) + .bind(metadata.archived_at.map(datetime_to_epoch_seconds)) + .bind(metadata.git_sha.as_deref()) + .bind(metadata.git_branch.as_deref()) + .bind(metadata.git_origin_url.as_deref()) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + + /// Apply rollout items incrementally using the underlying database. + pub async fn apply_rollout_items( + &self, + builder: &ThreadMetadataBuilder, + items: &[RolloutItem], + otel: Option<&OtelManager>, + ) -> anyhow::Result<()> { + if items.is_empty() { + return Ok(()); + } + let mut metadata = self + .get_thread(builder.id) + .await? + .unwrap_or_else(|| builder.build(&self.default_provider)); + metadata.rollout_path = builder.rollout_path.clone(); + for item in items { + apply_rollout_item(&mut metadata, item, &self.default_provider); + } + if let Some(updated_at) = file_modified_time_utc(builder.rollout_path.as_path()).await { + metadata.updated_at = updated_at; + } + if let Err(err) = self.upsert_thread(&metadata).await { + if let Some(otel) = otel { + otel.counter(DB_ERROR_METRIC, 1, &[("stage", "apply_rollout_items")]); + } + return Err(err); + } + Ok(()) + } + + /// Mark a thread as archived using the underlying database. + pub async fn mark_archived( + &self, + thread_id: ThreadId, + rollout_path: &Path, + archived_at: DateTime, + ) -> anyhow::Result<()> { + let Some(mut metadata) = self.get_thread(thread_id).await? else { + return Ok(()); + }; + metadata.archived_at = Some(archived_at); + metadata.rollout_path = rollout_path.to_path_buf(); + if let Some(updated_at) = file_modified_time_utc(rollout_path).await { + metadata.updated_at = updated_at; + } + if metadata.id != thread_id { + warn!( + "thread id mismatch during archive: expected {thread_id}, got {}", + metadata.id + ); + } + self.upsert_thread(&metadata).await + } + + /// Mark a thread as unarchived using the underlying database. + pub async fn mark_unarchived( + &self, + thread_id: ThreadId, + rollout_path: &Path, + ) -> anyhow::Result<()> { + let Some(mut metadata) = self.get_thread(thread_id).await? else { + return Ok(()); + }; + metadata.archived_at = None; + metadata.rollout_path = rollout_path.to_path_buf(); + if let Some(updated_at) = file_modified_time_utc(rollout_path).await { + metadata.updated_at = updated_at; + } + if metadata.id != thread_id { + warn!( + "thread id mismatch during unarchive: expected {thread_id}, got {}", + metadata.id + ); + } + self.upsert_thread(&metadata).await + } +} + +async fn open_sqlite(path: &Path) -> anyhow::Result { + let options = SqliteConnectOptions::new() + .filename(path) + .create_if_missing(true) + .journal_mode(SqliteJournalMode::Wal) + .synchronous(SqliteSynchronous::Normal) + .busy_timeout(Duration::from_secs(5)); + let pool = SqlitePoolOptions::new() + .max_connections(5) + .connect_with(options) + .await?; + MIGRATOR.run(&pool).await?; + Ok(pool) +} + +fn push_thread_filters<'a>( + builder: &mut QueryBuilder<'a, Sqlite>, + archived_only: bool, + allowed_sources: &'a [String], + model_providers: Option<&'a [String]>, + anchor: Option<&crate::Anchor>, + sort_key: SortKey, +) { + builder.push(" WHERE 1 = 1"); + if archived_only { + builder.push(" AND archived = 1"); + } else { + builder.push(" AND archived = 0"); + } + builder.push(" AND has_user_event = 1"); + if !allowed_sources.is_empty() { + builder.push(" AND source IN ("); + let mut separated = builder.separated(", "); + for source in allowed_sources { + separated.push_bind(source); + } + separated.push_unseparated(")"); + } + if let Some(model_providers) = model_providers + && !model_providers.is_empty() + { + builder.push(" AND model_provider IN ("); + let mut separated = builder.separated(", "); + for provider in model_providers { + separated.push_bind(provider); + } + separated.push_unseparated(")"); + } + if let Some(anchor) = anchor { + let anchor_ts = datetime_to_epoch_seconds(anchor.ts); + let column = match sort_key { + SortKey::CreatedAt => "created_at", + SortKey::UpdatedAt => "updated_at", + }; + builder.push(" AND ("); + builder.push(column); + builder.push(" < "); + builder.push_bind(anchor_ts); + builder.push(" OR ("); + builder.push(column); + builder.push(" = "); + builder.push_bind(anchor_ts); + builder.push(" AND id < "); + builder.push_bind(anchor.id.to_string()); + builder.push("))"); + } +} + +fn push_thread_order_and_limit( + builder: &mut QueryBuilder<'_, Sqlite>, + sort_key: SortKey, + limit: usize, +) { + let order_column = match sort_key { + SortKey::CreatedAt => "created_at", + SortKey::UpdatedAt => "updated_at", + }; + builder.push(" ORDER BY "); + builder.push(order_column); + builder.push(" DESC, id DESC"); + builder.push(" LIMIT "); + builder.push_bind(limit as i64); +}