diff --git a/rust/crates/api/src/prompt_cache.rs b/rust/crates/api/src/prompt_cache.rs new file mode 100644 index 0000000..be7cb83 --- /dev/null +++ b/rust/crates/api/src/prompt_cache.rs @@ -0,0 +1,727 @@ +use std::fs; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; + +use crate::types::{MessageRequest, MessageResponse, Usage}; + +const DEFAULT_COMPLETION_TTL_SECS: u64 = 30; +const DEFAULT_PROMPT_TTL_SECS: u64 = 5 * 60; +const DEFAULT_BREAK_MIN_DROP: u32 = 2_000; +const MAX_SANITIZED_LENGTH: usize = 80; +const REQUEST_FINGERPRINT_VERSION: u32 = 1; +const REQUEST_FINGERPRINT_PREFIX: &str = "v1"; +const FNV_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325; +const FNV_PRIME: u64 = 0x0000_0100_0000_01b3; + +#[derive(Debug, Clone)] +pub struct PromptCacheConfig { + pub session_id: String, + pub completion_ttl: Duration, + pub prompt_ttl: Duration, + pub cache_break_min_drop: u32, +} + +impl PromptCacheConfig { + #[must_use] + pub fn new(session_id: impl Into) -> Self { + Self { + session_id: session_id.into(), + completion_ttl: Duration::from_secs(DEFAULT_COMPLETION_TTL_SECS), + prompt_ttl: Duration::from_secs(DEFAULT_PROMPT_TTL_SECS), + cache_break_min_drop: DEFAULT_BREAK_MIN_DROP, + } + } +} + +impl Default for PromptCacheConfig { + fn default() -> Self { + Self::new("default") + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct PromptCachePaths { + pub root: PathBuf, + pub session_dir: PathBuf, + pub completion_dir: PathBuf, + pub session_state_path: PathBuf, + pub stats_path: PathBuf, +} + +impl PromptCachePaths { + #[must_use] + pub fn for_session(session_id: &str) -> Self { + let root = base_cache_root(); + let session_dir = root.join(sanitize_path_segment(session_id)); + let completion_dir = session_dir.join("completions"); + Self { + root, + session_state_path: session_dir.join("session-state.json"), + stats_path: session_dir.join("stats.json"), + session_dir, + completion_dir, + } + } + + #[must_use] + pub fn completion_entry_path(&self, request_hash: &str) -> PathBuf { + self.completion_dir.join(format!("{request_hash}.json")) + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct PromptCacheStats { + pub tracked_requests: u64, + pub completion_cache_hits: u64, + pub completion_cache_misses: u64, + pub completion_cache_writes: u64, + pub expected_invalidations: u64, + pub unexpected_cache_breaks: u64, + pub total_cache_creation_input_tokens: u64, + pub total_cache_read_input_tokens: u64, + pub last_cache_creation_input_tokens: Option, + pub last_cache_read_input_tokens: Option, + pub last_request_hash: Option, + pub last_completion_cache_key: Option, + pub last_break_reason: Option, + pub last_cache_source: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct CacheBreakEvent { + pub unexpected: bool, + pub reason: String, + pub previous_cache_read_input_tokens: u32, + pub current_cache_read_input_tokens: u32, + pub token_drop: u32, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PromptCacheRecord { + pub cache_break: Option, + pub stats: PromptCacheStats, +} + +#[derive(Debug, Clone)] +pub struct PromptCache { + inner: Arc>, +} + +impl PromptCache { + #[must_use] + pub fn new(session_id: impl Into) -> Self { + Self::with_config(PromptCacheConfig::new(session_id)) + } + + #[must_use] + pub fn with_config(config: PromptCacheConfig) -> Self { + let paths = PromptCachePaths::for_session(&config.session_id); + let stats = read_json::(&paths.stats_path).unwrap_or_default(); + let previous = read_json::(&paths.session_state_path); + Self { + inner: Arc::new(Mutex::new(PromptCacheInner { + config, + paths, + stats, + previous, + })), + } + } + + #[must_use] + pub fn paths(&self) -> PromptCachePaths { + self.lock().paths.clone() + } + + #[must_use] + pub fn stats(&self) -> PromptCacheStats { + self.lock().stats.clone() + } + + #[must_use] + pub fn lookup_completion(&self, request: &MessageRequest) -> Option { + let request_hash = request_hash_hex(request); + let (paths, ttl) = { + let inner = self.lock(); + (inner.paths.clone(), inner.config.completion_ttl) + }; + let entry_path = paths.completion_entry_path(&request_hash); + let entry = read_json::(&entry_path); + let Some(entry) = entry else { + let mut inner = self.lock(); + inner.stats.completion_cache_misses += 1; + inner.stats.last_completion_cache_key = Some(request_hash); + persist_state(&inner); + return None; + }; + + if entry.fingerprint_version != current_fingerprint_version() { + let mut inner = self.lock(); + inner.stats.completion_cache_misses += 1; + inner.stats.last_completion_cache_key = Some(request_hash.clone()); + let _ = fs::remove_file(entry_path); + persist_state(&inner); + return None; + } + + let expired = now_unix_secs().saturating_sub(entry.cached_at_unix_secs) >= ttl.as_secs(); + let mut inner = self.lock(); + inner.stats.last_completion_cache_key = Some(request_hash.clone()); + if expired { + inner.stats.completion_cache_misses += 1; + let _ = fs::remove_file(entry_path); + persist_state(&inner); + return None; + } + + inner.stats.completion_cache_hits += 1; + apply_usage_to_stats( + &mut inner.stats, + &entry.response.usage, + &request_hash, + "completion-cache", + ); + inner.previous = Some(TrackedPromptState::from_usage( + request, + &entry.response.usage, + )); + persist_state(&inner); + Some(entry.response) + } + + #[must_use] + pub fn record_response( + &self, + request: &MessageRequest, + response: &MessageResponse, + ) -> PromptCacheRecord { + self.record_usage_internal(request, &response.usage, Some(response)) + } + + #[must_use] + pub fn record_usage(&self, request: &MessageRequest, usage: &Usage) -> PromptCacheRecord { + self.record_usage_internal(request, usage, None) + } + + fn record_usage_internal( + &self, + request: &MessageRequest, + usage: &Usage, + response: Option<&MessageResponse>, + ) -> PromptCacheRecord { + let request_hash = request_hash_hex(request); + let mut inner = self.lock(); + let previous = inner.previous.clone(); + let current = TrackedPromptState::from_usage(request, usage); + let cache_break = detect_cache_break(&inner.config, previous.as_ref(), ¤t); + + inner.stats.tracked_requests += 1; + apply_usage_to_stats(&mut inner.stats, usage, &request_hash, "api-response"); + if let Some(event) = &cache_break { + if event.unexpected { + inner.stats.unexpected_cache_breaks += 1; + } else { + inner.stats.expected_invalidations += 1; + } + inner.stats.last_break_reason = Some(event.reason.clone()); + } + + inner.previous = Some(current); + if let Some(response) = response { + write_completion_entry(&inner.paths, &request_hash, response); + inner.stats.completion_cache_writes += 1; + } + persist_state(&inner); + + PromptCacheRecord { + cache_break, + stats: inner.stats.clone(), + } + } + + fn lock(&self) -> std::sync::MutexGuard<'_, PromptCacheInner> { + self.inner + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } +} + +#[derive(Debug)] +struct PromptCacheInner { + config: PromptCacheConfig, + paths: PromptCachePaths, + stats: PromptCacheStats, + previous: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct CompletionCacheEntry { + cached_at_unix_secs: u64, + #[serde(default = "current_fingerprint_version")] + fingerprint_version: u32, + response: MessageResponse, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct TrackedPromptState { + observed_at_unix_secs: u64, + #[serde(default = "current_fingerprint_version")] + fingerprint_version: u32, + model_hash: u64, + system_hash: u64, + tools_hash: u64, + messages_hash: u64, + cache_read_input_tokens: u32, +} + +impl TrackedPromptState { + fn from_usage(request: &MessageRequest, usage: &Usage) -> Self { + let hashes = RequestFingerprints::from_request(request); + Self { + observed_at_unix_secs: now_unix_secs(), + fingerprint_version: current_fingerprint_version(), + model_hash: hashes.model, + system_hash: hashes.system, + tools_hash: hashes.tools, + messages_hash: hashes.messages, + cache_read_input_tokens: usage.cache_read_input_tokens, + } + } +} + +#[derive(Debug, Clone, Copy)] +struct RequestFingerprints { + model: u64, + system: u64, + tools: u64, + messages: u64, +} + +impl RequestFingerprints { + fn from_request(request: &MessageRequest) -> Self { + Self { + model: hash_serializable(&request.model), + system: hash_serializable(&request.system), + tools: hash_serializable(&request.tools), + messages: hash_serializable(&request.messages), + } + } +} + +fn detect_cache_break( + config: &PromptCacheConfig, + previous: Option<&TrackedPromptState>, + current: &TrackedPromptState, +) -> Option { + let previous = previous?; + if previous.fingerprint_version != current.fingerprint_version { + return Some(CacheBreakEvent { + unexpected: false, + reason: format!( + "fingerprint version changed (v{} -> v{})", + previous.fingerprint_version, current.fingerprint_version + ), + previous_cache_read_input_tokens: previous.cache_read_input_tokens, + current_cache_read_input_tokens: current.cache_read_input_tokens, + token_drop: previous + .cache_read_input_tokens + .saturating_sub(current.cache_read_input_tokens), + }); + } + let token_drop = previous + .cache_read_input_tokens + .saturating_sub(current.cache_read_input_tokens); + if token_drop < config.cache_break_min_drop { + return None; + } + + let mut reasons = Vec::new(); + if previous.model_hash != current.model_hash { + reasons.push("model changed"); + } + if previous.system_hash != current.system_hash { + reasons.push("system prompt changed"); + } + if previous.tools_hash != current.tools_hash { + reasons.push("tool definitions changed"); + } + if previous.messages_hash != current.messages_hash { + reasons.push("message payload changed"); + } + + let elapsed = current + .observed_at_unix_secs + .saturating_sub(previous.observed_at_unix_secs); + + let (unexpected, reason) = if reasons.is_empty() { + if elapsed > config.prompt_ttl.as_secs() { + ( + false, + format!("possible prompt cache TTL expiry after {elapsed}s"), + ) + } else { + ( + true, + "cache read tokens dropped while prompt fingerprint remained stable".to_string(), + ) + } + } else { + (false, reasons.join(", ")) + }; + + Some(CacheBreakEvent { + unexpected, + reason, + previous_cache_read_input_tokens: previous.cache_read_input_tokens, + current_cache_read_input_tokens: current.cache_read_input_tokens, + token_drop, + }) +} + +fn apply_usage_to_stats( + stats: &mut PromptCacheStats, + usage: &Usage, + request_hash: &str, + source: &str, +) { + stats.total_cache_creation_input_tokens += u64::from(usage.cache_creation_input_tokens); + stats.total_cache_read_input_tokens += u64::from(usage.cache_read_input_tokens); + stats.last_cache_creation_input_tokens = Some(usage.cache_creation_input_tokens); + stats.last_cache_read_input_tokens = Some(usage.cache_read_input_tokens); + stats.last_request_hash = Some(request_hash.to_string()); + stats.last_cache_source = Some(source.to_string()); +} + +fn persist_state(inner: &PromptCacheInner) { + let _ = ensure_cache_dirs(&inner.paths); + let _ = write_json(&inner.paths.stats_path, &inner.stats); + if let Some(previous) = &inner.previous { + let _ = write_json(&inner.paths.session_state_path, previous); + } +} + +fn write_completion_entry( + paths: &PromptCachePaths, + request_hash: &str, + response: &MessageResponse, +) { + let _ = ensure_cache_dirs(paths); + let entry = CompletionCacheEntry { + cached_at_unix_secs: now_unix_secs(), + fingerprint_version: current_fingerprint_version(), + response: response.clone(), + }; + let _ = write_json(&paths.completion_entry_path(request_hash), &entry); +} + +fn ensure_cache_dirs(paths: &PromptCachePaths) -> std::io::Result<()> { + fs::create_dir_all(&paths.completion_dir) +} + +fn write_json(path: &Path, value: &T) -> std::io::Result<()> { + let json = serde_json::to_vec_pretty(value) + .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?; + fs::write(path, json) +} + +fn read_json Deserialize<'de>>(path: &Path) -> Option { + let bytes = fs::read(path).ok()?; + serde_json::from_slice(&bytes).ok() +} + +fn request_hash_hex(request: &MessageRequest) -> String { + format!( + "{REQUEST_FINGERPRINT_PREFIX}-{:016x}", + hash_serializable(request) + ) +} + +fn hash_serializable(value: &T) -> u64 { + let json = serde_json::to_vec(value).unwrap_or_default(); + stable_hash_bytes(&json) +} + +fn sanitize_path_segment(value: &str) -> String { + let sanitized: String = value + .chars() + .map(|ch| if ch.is_ascii_alphanumeric() { ch } else { '-' }) + .collect(); + if sanitized.len() <= MAX_SANITIZED_LENGTH { + return sanitized; + } + let suffix = format!("-{:x}", hash_string(value)); + format!( + "{}{}", + &sanitized[..MAX_SANITIZED_LENGTH.saturating_sub(suffix.len())], + suffix + ) +} + +fn hash_string(value: &str) -> u64 { + stable_hash_bytes(value.as_bytes()) +} + +fn base_cache_root() -> PathBuf { + if let Some(config_home) = std::env::var_os("CLAUDE_CONFIG_HOME") { + return PathBuf::from(config_home) + .join("cache") + .join("prompt-cache"); + } + if let Some(home) = std::env::var_os("HOME") { + return PathBuf::from(home) + .join(".claude") + .join("cache") + .join("prompt-cache"); + } + std::env::temp_dir().join("claude-prompt-cache") +} + +fn now_unix_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + +const fn current_fingerprint_version() -> u32 { + REQUEST_FINGERPRINT_VERSION +} + +fn stable_hash_bytes(bytes: &[u8]) -> u64 { + let mut hash = FNV_OFFSET_BASIS; + for byte in bytes { + hash ^= u64::from(*byte); + hash = hash.wrapping_mul(FNV_PRIME); + } + hash +} + +#[cfg(test)] +mod tests { + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use super::{ + detect_cache_break, read_json, request_hash_hex, sanitize_path_segment, PromptCache, + PromptCacheConfig, PromptCachePaths, TrackedPromptState, REQUEST_FINGERPRINT_PREFIX, + }; + use crate::test_env_lock; + use crate::types::{InputMessage, MessageRequest, MessageResponse, OutputContentBlock, Usage}; + + #[test] + fn path_builder_sanitizes_session_identifier() { + let paths = PromptCachePaths::for_session("session:/with spaces"); + let session_dir = paths + .session_dir + .file_name() + .and_then(|value| value.to_str()) + .expect("session dir name"); + assert_eq!(session_dir, "session--with-spaces"); + assert!(paths.completion_dir.ends_with("completions")); + assert!(paths.stats_path.ends_with("stats.json")); + assert!(paths.session_state_path.ends_with("session-state.json")); + } + + #[test] + fn request_fingerprint_drives_unexpected_break_detection() { + let request = sample_request("same"); + let previous = TrackedPromptState::from_usage( + &request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 6_000, + output_tokens: 0, + }, + ); + let current = TrackedPromptState::from_usage( + &request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 1_000, + output_tokens: 0, + }, + ); + let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t) + .expect("break should be detected"); + assert!(event.unexpected); + assert!(event.reason.contains("stable")); + } + + #[test] + fn changed_prompt_marks_break_as_expected() { + let previous_request = sample_request("first"); + let current_request = sample_request("second"); + let previous = TrackedPromptState::from_usage( + &previous_request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 6_000, + output_tokens: 0, + }, + ); + let current = TrackedPromptState::from_usage( + ¤t_request, + &Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 1_000, + output_tokens: 0, + }, + ); + let event = detect_cache_break(&PromptCacheConfig::default(), Some(&previous), ¤t) + .expect("break should be detected"); + assert!(!event.unexpected); + assert!(event.reason.contains("message payload changed")); + } + + #[test] + fn completion_cache_round_trip_persists_recent_response() { + let _guard = test_env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "prompt-cache-test-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + let cache = PromptCache::new("unit-test-session"); + let request = sample_request("cache me"); + let response = sample_response(42, 12, "cached"); + + assert!(cache.lookup_completion(&request).is_none()); + let record = cache.record_response(&request, &response); + assert!(record.cache_break.is_none()); + + let cached = cache + .lookup_completion(&request) + .expect("cached response should load"); + assert_eq!(cached.content, response.content); + + let stats = cache.stats(); + assert_eq!(stats.completion_cache_hits, 1); + assert_eq!(stats.completion_cache_misses, 1); + assert_eq!(stats.completion_cache_writes, 1); + + let persisted = read_json::(&cache.paths().stats_path) + .expect("stats should persist"); + assert_eq!(persisted.completion_cache_hits, 1); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + } + + #[test] + fn distinct_requests_do_not_collide_in_completion_cache() { + let _guard = test_env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "prompt-cache-distinct-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + let cache = PromptCache::new("distinct-request-session"); + let first_request = sample_request("first"); + let second_request = sample_request("second"); + + let response = sample_response(42, 12, "cached"); + let _ = cache.record_response(&first_request, &response); + + assert!(cache.lookup_completion(&second_request).is_none()); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + } + + #[test] + fn expired_completion_entries_are_not_reused() { + let _guard = test_env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "prompt-cache-expired-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + let cache = PromptCache::with_config(PromptCacheConfig { + session_id: "expired-session".to_string(), + completion_ttl: Duration::ZERO, + ..PromptCacheConfig::default() + }); + let request = sample_request("expire me"); + let response = sample_response(7, 3, "stale"); + + let _ = cache.record_response(&request, &response); + + assert!(cache.lookup_completion(&request).is_none()); + let stats = cache.stats(); + assert_eq!(stats.completion_cache_hits, 0); + assert_eq!(stats.completion_cache_misses, 1); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + } + + #[test] + fn sanitize_path_caps_long_values() { + let long_value = "x".repeat(200); + let sanitized = sanitize_path_segment(&long_value); + assert!(sanitized.len() <= 80); + } + + #[test] + fn request_hashes_are_versioned_and_stable() { + let request = sample_request("stable"); + let first = request_hash_hex(&request); + let second = request_hash_hex(&request); + assert_eq!(first, second); + assert!(first.starts_with(REQUEST_FINGERPRINT_PREFIX)); + } + + fn sample_request(text: &str) -> MessageRequest { + MessageRequest { + model: "claude-3-7-sonnet-latest".to_string(), + max_tokens: 64, + messages: vec![InputMessage::user_text(text)], + system: Some("system".to_string()), + tools: None, + tool_choice: None, + stream: false, + } + } + + fn sample_response( + cache_read_input_tokens: u32, + output_tokens: u32, + text: &str, + ) -> MessageResponse { + MessageResponse { + id: "msg_test".to_string(), + kind: "message".to_string(), + role: "assistant".to_string(), + content: vec![OutputContentBlock::Text { + text: text.to_string(), + }], + model: "claude-3-7-sonnet-latest".to_string(), + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 10, + cache_creation_input_tokens: 5, + cache_read_input_tokens, + output_tokens, + }, + request_id: Some("req_test".to_string()), + } + } +} diff --git a/rust/crates/api/tests/client_integration.rs b/rust/crates/api/tests/client_integration.rs index bdccbdd..f235980 100644 --- a/rust/crates/api/tests/client_integration.rs +++ b/rust/crates/api/tests/client_integration.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; +use std::sync::{Mutex as StdMutex, OnceLock}; use std::time::Duration; use api::{ @@ -13,6 +14,13 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; use tokio::sync::Mutex; +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| StdMutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) +} + #[tokio::test] async fn send_message_posts_json_and_parses_response() { let state = Arc::new(Mutex::new(Vec::::new())); @@ -46,6 +54,8 @@ async fn send_message_posts_json_and_parses_response() { assert_eq!(response.id, "msg_test"); assert_eq!(response.total_tokens(), 16); assert_eq!(response.request_id.as_deref(), Some("req_body_123")); + assert_eq!(response.usage.cache_creation_input_tokens, 0); + assert_eq!(response.usage.cache_read_input_tokens, 0); assert_eq!( response.content, vec![OutputContentBlock::Text { @@ -198,11 +208,55 @@ async fn send_message_applies_request_profile_and_records_telemetry() { } #[tokio::test] +async fn send_message_parses_prompt_cache_token_usage_from_response() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"msg_cache_tokens\",", + "\"type\":\"message\",", + "\"role\":\"assistant\",", + "\"content\":[{\"type\":\"text\",\"text\":\"Cache tokens\"}],", + "\"model\":\"claude-3-7-sonnet-latest\",", + "\"stop_reason\":\"end_turn\",", + "\"stop_sequence\":null,", + "\"usage\":{\"input_tokens\":12,\"cache_creation_input_tokens\":321,\"cache_read_input_tokens\":654,\"output_tokens\":4}", + "}" + ); + let server = spawn_server( + state, + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let client = AnthropicClient::new("test-key").with_base_url(server.base_url()); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.usage.input_tokens, 12); + assert_eq!(response.usage.cache_creation_input_tokens, 321); + assert_eq!(response.usage.cache_read_input_tokens, 654); + assert_eq!(response.usage.output_tokens, 4); +} + +#[tokio::test] +#[allow(clippy::await_holding_lock)] async fn stream_message_parses_sse_events_with_tool_use() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "api-stream-cache-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); let state = Arc::new(Mutex::new(Vec::::new())); let sse = concat!( "event: message_start\n", - "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":13,\"cache_read_input_tokens\":21,\"output_tokens\":0}}}\n\n", "event: content_block_start\n", "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n", "event: content_block_delta\n", @@ -210,7 +264,7 @@ async fn stream_message_parses_sse_events_with_tool_use() { "event: content_block_stop\n", "data: {\"type\":\"content_block_stop\",\"index\":0}\n\n", "event: message_delta\n", - "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n", + "data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"cache_creation_input_tokens\":34,\"cache_read_input_tokens\":55,\"output_tokens\":1}}\n\n", "event: message_stop\n", "data: {\"type\":\"message_stop\"}\n\n", "data: [DONE]\n\n" @@ -228,7 +282,8 @@ async fn stream_message_parses_sse_events_with_tool_use() { let client = ApiClient::new("test-key") .with_auth_token(Some("proxy-token".to_string())) - .with_base_url(server.base_url()); + .with_base_url(server.base_url()) + .with_prompt_cache(PromptCache::new("stream-session")); let mut stream = client .stream_message(&sample_request(false)) .await @@ -282,6 +337,20 @@ async fn stream_message_parses_sse_events_with_tool_use() { let captured = state.lock().await; let request = captured.first().expect("server should capture request"); assert!(request.body.contains("\"stream\":true")); + + let cache_stats = client + .prompt_cache_stats() + .expect("prompt cache stats should exist"); + assert_eq!(cache_stats.tracked_requests, 1); + assert_eq!(cache_stats.last_cache_creation_input_tokens, Some(34)); + assert_eq!(cache_stats.last_cache_read_input_tokens, Some(55)); + assert_eq!( + cache_stats.last_cache_source.as_deref(), + Some("api-response") + ); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); } #[tokio::test] @@ -406,6 +475,121 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { } } +#[tokio::test] +#[allow(clippy::await_holding_lock)] +async fn send_message_reuses_recent_completion_cache_entries() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "api-prompt-cache-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_cached\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Cached once\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":4000,\"output_tokens\":2}}", + )], + ) + .await; + + let client = AnthropicClient::new("test-key") + .with_base_url(server.base_url()) + .with_prompt_cache(PromptCache::new("integration-session")); + + let first = client + .send_message(&sample_request(false)) + .await + .expect("first request should succeed"); + let second = client + .send_message(&sample_request(false)) + .await + .expect("second request should reuse cache"); + + assert_eq!(first.content, second.content); + assert_eq!(state.lock().await.len(), 1); + + let cache_stats = client + .prompt_cache_stats() + .expect("prompt cache stats should exist"); + assert_eq!(cache_stats.completion_cache_hits, 1); + assert_eq!(cache_stats.completion_cache_misses, 1); + assert_eq!(cache_stats.completion_cache_writes, 1); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); +} + +#[tokio::test] +#[allow(clippy::await_holding_lock)] +async fn send_message_tracks_unexpected_prompt_cache_breaks() { + let _guard = env_lock(); + let temp_root = std::env::temp_dir().join(format!( + "api-prompt-break-{}-{}", + std::process::id(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time") + .as_nanos() + )); + std::env::set_var("CLAUDE_CONFIG_HOME", &temp_root); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state, + vec![ + http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_one\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"One\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":5,\"cache_read_input_tokens\":6000,\"output_tokens\":2}}", + ), + http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_two\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Two\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"cache_creation_input_tokens\":0,\"cache_read_input_tokens\":1000,\"output_tokens\":2}}", + ), + ], + ) + .await; + + let request = sample_request(false); + let client = AnthropicClient::new("test-key") + .with_base_url(server.base_url()) + .with_prompt_cache(PromptCache::with_config(api::PromptCacheConfig { + session_id: "break-session".to_string(), + completion_ttl: Duration::from_secs(0), + ..api::PromptCacheConfig::default() + })); + + client + .send_message(&request) + .await + .expect("first response should succeed"); + client + .send_message(&request) + .await + .expect("second response should succeed"); + + let cache_stats = client + .prompt_cache_stats() + .expect("prompt cache stats should exist"); + assert_eq!(cache_stats.unexpected_cache_breaks, 1); + assert_eq!( + cache_stats.last_break_reason.as_deref(), + Some("cache read tokens dropped while prompt fingerprint remained stable") + ); + + std::fs::remove_dir_all(temp_root).expect("cleanup temp root"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); +} + #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY and network access"] async fn live_stream_smoke_test() { diff --git a/rust/crates/runtime/src/conversation.rs b/rust/crates/runtime/src/conversation.rs index cfb273b..13c227c 100644 --- a/rust/crates/runtime/src/conversation.rs +++ b/rust/crates/runtime/src/conversation.rs @@ -32,9 +32,19 @@ pub enum AssistantEvent { input: String, }, Usage(TokenUsage), + PromptCache(PromptCacheEvent), MessageStop, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PromptCacheEvent { + pub unexpected: bool, + pub reason: String, + pub previous_cache_read_input_tokens: u32, + pub current_cache_read_input_tokens: u32, + pub token_drop: u32, +} + pub trait ApiClient { fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError>; } @@ -91,6 +101,7 @@ impl std::error::Error for RuntimeError {} pub struct TurnSummary { pub assistant_messages: Vec, pub tool_results: Vec, + pub prompt_cache_events: Vec, pub iterations: usize, pub usage: TokenUsage, pub auto_compaction: Option, @@ -284,6 +295,7 @@ where let mut assistant_messages = Vec::new(); let mut tool_results = Vec::new(); + let mut prompt_cache_events = Vec::new(); let mut iterations = 0; loop { @@ -317,6 +329,7 @@ where if let Some(usage) = usage { self.usage_tracker.record(usage); } + prompt_cache_events.extend(turn_prompt_cache_events); let pending_tool_uses = assistant_message .blocks .iter() @@ -435,6 +448,7 @@ where Ok(TurnSummary { assistant_messages, tool_results, + prompt_cache_events, iterations, usage: self.usage_tracker.cumulative_usage(), auto_compaction, @@ -534,9 +548,17 @@ fn parse_auto_compaction_threshold(value: Option<&str>) -> u32 { fn build_assistant_message( events: Vec, -) -> Result<(ConversationMessage, Option), RuntimeError> { +) -> Result< + ( + ConversationMessage, + Option, + Vec, + ), + RuntimeError, +> { let mut text = String::new(); let mut blocks = Vec::new(); + let mut prompt_cache_events = Vec::new(); let mut finished = false; let mut usage = None; @@ -548,6 +570,7 @@ fn build_assistant_message( blocks.push(ContentBlock::ToolUse { id, name, input }); } AssistantEvent::Usage(value) => usage = Some(value), + AssistantEvent::PromptCache(event) => prompt_cache_events.push(event), AssistantEvent::MessageStop => { finished = true; } @@ -568,6 +591,7 @@ fn build_assistant_message( Ok(( ConversationMessage::assistant_with_usage(blocks, usage), usage, + prompt_cache_events, )) } @@ -700,6 +724,15 @@ mod tests { cache_creation_input_tokens: 1, cache_read_input_tokens: 3, }), + AssistantEvent::PromptCache(PromptCacheEvent { + unexpected: true, + reason: + "cache read tokens dropped while prompt fingerprint remained stable" + .to_string(), + previous_cache_read_input_tokens: 6_000, + current_cache_read_input_tokens: 1_000, + token_drop: 5_000, + }), AssistantEvent::MessageStop, ]) } @@ -753,6 +786,7 @@ mod tests { assert_eq!(summary.iterations, 2); assert_eq!(summary.assistant_messages.len(), 2); assert_eq!(summary.tool_results.len(), 1); + assert_eq!(summary.prompt_cache_events.len(), 1); assert_eq!(runtime.session().messages.len(), 4); assert_eq!(summary.usage.output_tokens, 10); assert_eq!(summary.auto_compaction, None); diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 71aef93..74a3292 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -1311,6 +1311,7 @@ impl LiveCli { })), "tool_uses": collect_tool_uses(&summary), "tool_results": collect_tool_results(&summary), + "prompt_cache_events": collect_prompt_cache_events(&summary), "usage": { "input_tokens": summary.usage.input_tokens, "output_tokens": summary.usage.output_tokens, @@ -3276,7 +3277,8 @@ impl AnthropicRuntimeClient { Ok(Self { runtime: tokio::runtime::Runtime::new()?, client: AnthropicClient::from_auth(resolve_cli_auth_source()?) - .with_base_url(api::read_base_url()), + .with_base_url(api::read_base_url()) + .with_prompt_cache(PromptCache::new(session_id)), model, enable_tools, emit_output, @@ -3413,6 +3415,8 @@ impl ApiClient for AnthropicRuntimeClient { } } + push_prompt_cache_record(&self.client, &mut events); + if !saw_stop && events.iter().any(|event| { matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) @@ -3437,7 +3441,9 @@ impl ApiClient for AnthropicRuntimeClient { }) .await .map_err(|error| RuntimeError::new(error.to_string()))?; - response_to_events(response, out) + let mut events = response_to_events(response, out)?; + push_prompt_cache_record(&self.client, &mut events); + Ok(events) }) } } @@ -3498,6 +3504,39 @@ fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec Vec { + summary + .prompt_cache_events + .iter() + .map(|event| { + json!({ + "unexpected": event.unexpected, + "reason": event.reason, + "previous_cache_read_input_tokens": event.previous_cache_read_input_tokens, + "current_cache_read_input_tokens": event.current_cache_read_input_tokens, + "token_drop": event.token_drop, + }) + }) + .collect() +} + +fn print_prompt_cache_events(summary: &runtime::TurnSummary) { + for event in &summary.prompt_cache_events { + let label = if event.unexpected { + "Prompt cache break" + } else { + "Prompt cache invalidation" + }; + println!( + "{label}: {} (cache read {} -> {}, drop {})", + event.reason, + event.previous_cache_read_input_tokens, + event.current_cache_read_input_tokens, + event.token_drop, + ); + } +} + fn slash_command_completion_candidates() -> Vec { slash_command_specs() .iter() @@ -3653,6 +3692,8 @@ fn first_visible_line(text: &str) -> &str { } fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String { + use std::fmt::Write as _; + let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")]; if let Some(task_id) = parsed .get("backgroundTaskId") @@ -3996,6 +4037,26 @@ fn response_to_events( Ok(events) } +fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec) { + if let Some(event) = client + .take_last_prompt_cache_record() + .and_then(prompt_cache_record_to_runtime_event) + { + events.push(AssistantEvent::PromptCache(event)); + } +} + +fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option { + let cache_break = record.cache_break?; + Some(PromptCacheEvent { + unexpected: cache_break.unexpected, + reason: cache_break.reason, + previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens, + current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens, + token_drop: cache_break.token_drop, + }) +} + struct CliToolExecutor { renderer: TerminalRenderer, emit_output: bool, diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index b4acf0a..0da62cd 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -1900,6 +1900,8 @@ impl ApiClient for ProviderRuntimeClient { } } + push_prompt_cache_record(&self.client, &mut events); + if !saw_stop && events.iter().any(|event| { matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) @@ -1924,7 +1926,9 @@ impl ApiClient for ProviderRuntimeClient { }) .await .map_err(|error| RuntimeError::new(error.to_string()))?; - Ok(response_to_events(response)) + let mut events = response_to_events(response); + push_prompt_cache_record(&self.client, &mut events); + Ok(events) }) } } @@ -2045,6 +2049,26 @@ fn response_to_events(response: MessageResponse) -> Vec { events } +fn push_prompt_cache_record(client: &AnthropicClient, events: &mut Vec) { + if let Some(event) = client + .take_last_prompt_cache_record() + .and_then(prompt_cache_record_to_runtime_event) + { + events.push(AssistantEvent::PromptCache(event)); + } +} + +fn prompt_cache_record_to_runtime_event(record: PromptCacheRecord) -> Option { + let cache_break = record.cache_break?; + Some(PromptCacheEvent { + unexpected: cache_break.unexpected, + reason: cache_break.reason, + previous_cache_read_input_tokens: cache_break.previous_cache_read_input_tokens, + current_cache_read_input_tokens: cache_break.current_cache_read_input_tokens, + token_drop: cache_break.token_drop, + }) +} + fn final_assistant_text(summary: &runtime::TurnSummary) -> String { summary .assistant_messages