From de589d47a58cbb65df4b89d6bc7a16ead312b577 Mon Sep 17 00:00:00 2001 From: YeonGyu-Kim Date: Thu, 2 Apr 2026 11:31:53 +0900 Subject: [PATCH] fix: restore anthropic request profile integration --- rust/crates/api/src/lib.rs | 5 + rust/crates/api/src/providers/anthropic.rs | 207 ++++++++++++++++++++- 2 files changed, 204 insertions(+), 8 deletions(-) diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index e48510f..cac2f53 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -1,5 +1,6 @@ mod client; mod error; +mod prompt_cache; mod providers; mod sse; mod types; @@ -9,6 +10,10 @@ pub use client::{ resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient, }; pub use error::ApiError; +pub use prompt_cache::{ + CacheBreakEvent, PromptCache, PromptCacheConfig, PromptCachePaths, PromptCacheRecord, + PromptCacheStats, +}; pub use providers::anthropic::{AnthropicClient, AnthropicClient as ApiClient, AuthSource}; pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig}; pub use providers::{ diff --git a/rust/crates/api/src/providers/anthropic.rs b/rust/crates/api/src/providers/anthropic.rs index 0ffcf59..ec04c3b 100644 --- a/rust/crates/api/src/providers/anthropic.rs +++ b/rust/crates/api/src/providers/anthropic.rs @@ -1,18 +1,22 @@ use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use runtime::format_usd; use runtime::{ load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest, OAuthTokenExchangeRequest, }; use serde::Deserialize; -use telemetry::SessionTracer; +use serde_json::{Map, Value}; +use telemetry::{AnalyticsEvent, AnthropicRequestProfile, ClientIdentity, SessionTracer}; use crate::error::ApiError; +use crate::prompt_cache::{PromptCache, PromptCacheRecord, PromptCacheStats}; use super::{Provider, ProviderFuture}; use crate::sse::SseParser; -use crate::types::{MessageRequest, MessageResponse, StreamEvent}; +use crate::types::{MessageDeltaEvent, MessageRequest, MessageResponse, StreamEvent, Usage}; pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; const ANTHROPIC_VERSION: &str = "2023-06-01"; @@ -114,6 +118,10 @@ pub struct AnthropicClient { max_retries: u32, initial_backoff: Duration, max_backoff: Duration, + request_profile: AnthropicRequestProfile, + session_tracer: Option, + prompt_cache: Option, + last_prompt_cache_record: Arc>>, } impl AnthropicClient { @@ -126,6 +134,10 @@ impl AnthropicClient { max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, + request_profile: AnthropicRequestProfile::default(), + session_tracer: None, + prompt_cache: None, + last_prompt_cache_record: Arc::new(Mutex::new(None)), } } @@ -138,6 +150,10 @@ impl AnthropicClient { max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, + request_profile: AnthropicRequestProfile::default(), + session_tracer: None, + prompt_cache: None, + last_prompt_cache_record: Arc::new(Mutex::new(None)), } } @@ -196,7 +212,66 @@ impl AnthropicClient { } #[must_use] - pub fn with_session_tracer(self, _session_tracer: SessionTracer) -> Self { + pub fn with_session_tracer(mut self, session_tracer: SessionTracer) -> Self { + self.session_tracer = Some(session_tracer); + self + } + + #[must_use] + pub fn with_client_identity(mut self, client_identity: ClientIdentity) -> Self { + self.request_profile.client_identity = client_identity; + self + } + + #[must_use] + pub fn with_beta(mut self, beta: impl Into) -> Self { + self.request_profile = self.request_profile.with_beta(beta); + self + } + + #[must_use] + pub fn with_extra_body_param(mut self, key: impl Into, value: Value) -> Self { + self.request_profile = self.request_profile.with_extra_body(key, value); + self + } + + #[must_use] + pub fn with_prompt_cache(mut self, prompt_cache: PromptCache) -> Self { + self.prompt_cache = Some(prompt_cache); + self + } + + #[must_use] + pub fn prompt_cache_stats(&self) -> Option { + self.prompt_cache.as_ref().map(PromptCache::stats) + } + + #[must_use] + pub fn request_profile(&self) -> &AnthropicRequestProfile { + &self.request_profile + } + + #[must_use] + pub fn session_tracer(&self) -> Option<&SessionTracer> { + self.session_tracer.as_ref() + } + + #[must_use] + pub fn prompt_cache(&self) -> Option<&PromptCache> { + self.prompt_cache.as_ref() + } + + #[must_use] + pub fn take_last_prompt_cache_record(&self) -> Option { + self.last_prompt_cache_record + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .take() + } + + #[must_use] + pub fn with_request_profile(mut self, request_profile: AnthropicRequestProfile) -> Self { + self.request_profile = request_profile; self } @@ -213,6 +288,13 @@ impl AnthropicClient { stream: false, ..request.clone() }; + + if let Some(prompt_cache) = &self.prompt_cache { + if let Some(response) = prompt_cache.lookup_completion(&request) { + return Ok(response); + } + } + let response = self.send_with_retry(&request).await?; let request_id = request_id_from_headers(response.headers()); let mut response = response @@ -222,6 +304,30 @@ impl AnthropicClient { if response.request_id.is_none() { response.request_id = request_id; } + + if let Some(prompt_cache) = &self.prompt_cache { + let record = prompt_cache.record_response(&request, &response); + self.store_last_prompt_cache_record(record); + } + if let Some(session_tracer) = &self.session_tracer { + session_tracer.record_analytics( + AnalyticsEvent::new("api", "message_usage") + .with_property( + "request_id", + response + .request_id + .clone() + .map_or(Value::Null, Value::String), + ) + .with_property("total_tokens", Value::from(response.total_tokens())) + .with_property( + "estimated_cost_usd", + Value::String(format_usd( + response.usage.estimated_cost_usd(&response.model).total_cost_usd(), + )), + ), + ); + } Ok(response) } @@ -238,6 +344,11 @@ impl AnthropicClient { parser: SseParser::new(), pending: VecDeque::new(), done: false, + request: request.clone(), + prompt_cache: self.prompt_cache.clone(), + latest_usage: None, + usage_recorded: false, + last_prompt_cache_record: Arc::clone(&self.last_prompt_cache_record), }) } @@ -290,18 +401,46 @@ impl AnthropicClient { loop { attempts += 1; + if let Some(session_tracer) = &self.session_tracer { + session_tracer.record_http_request_started( + attempts, + "POST", + "/v1/messages", + Map::new(), + ); + } match self.send_raw_request(request).await { Ok(response) => match expect_success(response).await { - Ok(response) => return Ok(response), + Ok(response) => { + if let Some(session_tracer) = &self.session_tracer { + session_tracer.record_http_request_succeeded( + attempts, + "POST", + "/v1/messages", + response.status().as_u16(), + request_id_from_headers(response.headers()), + Map::new(), + ); + } + return Ok(response); + } Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + self.record_request_failure(attempts, &error); last_error = Some(error); } - Err(error) => return Err(error), + Err(error) => { + self.record_request_failure(attempts, &error); + return Err(error); + } }, Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + self.record_request_failure(attempts, &error); last_error = Some(error); } - Err(error) => return Err(error), + Err(error) => { + self.record_request_failure(attempts, &error); + return Err(error); + } } if attempts > self.max_retries { @@ -325,14 +464,37 @@ impl AnthropicClient { let request_builder = self .http .post(&request_url) - .header("anthropic-version", ANTHROPIC_VERSION) .header("content-type", "application/json"); let mut request_builder = self.auth.apply(request_builder); + for (header_name, header_value) in self.request_profile.header_pairs() { + request_builder = request_builder.header(header_name, header_value); + } - request_builder = request_builder.json(request); + let request_body = self.request_profile.render_json_body(request)?; + request_builder = request_builder.json(&request_body); request_builder.send().await.map_err(ApiError::from) } + fn record_request_failure(&self, attempt: u32, error: &ApiError) { + if let Some(session_tracer) = &self.session_tracer { + session_tracer.record_http_request_failed( + attempt, + "POST", + "/v1/messages", + error.to_string(), + error.is_retryable(), + Map::new(), + ); + } + } + + fn store_last_prompt_cache_record(&self, record: PromptCacheRecord) { + *self + .last_prompt_cache_record + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record); + } + fn backoff_for_attempt(&self, attempt: u32) -> Result { let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { return Err(ApiError::BackoffOverflow { @@ -571,6 +733,11 @@ pub struct MessageStream { parser: SseParser, pending: VecDeque, done: bool, + request: MessageRequest, + prompt_cache: Option, + latest_usage: Option, + usage_recorded: bool, + last_prompt_cache_record: Arc>>, } impl MessageStream { @@ -582,6 +749,7 @@ impl MessageStream { pub async fn next_event(&mut self) -> Result, ApiError> { loop { if let Some(event) = self.pending.pop_front() { + self.observe_event(&event); return Ok(Some(event)); } @@ -604,6 +772,29 @@ impl MessageStream { } } } + + fn observe_event(&mut self, event: &StreamEvent) { + match event { + StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => { + self.latest_usage = Some(usage.clone()); + } + StreamEvent::MessageStop(_) => { + if !self.usage_recorded { + if let (Some(prompt_cache), Some(usage)) = + (&self.prompt_cache, self.latest_usage.as_ref()) + { + let record = prompt_cache.record_usage(&self.request, usage); + *self + .last_prompt_cache_record + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(record); + } + self.usage_recorded = true; + } + } + _ => {} + } + } } async fn expect_success(response: reqwest::Response) -> Result {