use std::collections::VecDeque; use std::time::Duration; use serde::Deserialize; use crate::error::ApiError; use crate::sse::SseParser; use crate::types::{MessageRequest, MessageResponse, StreamEvent}; const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; const ANTHROPIC_VERSION: &str = "2023-06-01"; const REQUEST_ID_HEADER: &str = "request-id"; const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); const DEFAULT_MAX_RETRIES: u32 = 2; #[derive(Debug, Clone)] pub struct AnthropicClient { http: reqwest::Client, api_key: String, auth_token: Option, base_url: String, max_retries: u32, initial_backoff: Duration, max_backoff: Duration, } impl AnthropicClient { #[must_use] pub fn new(api_key: impl Into) -> Self { Self { http: reqwest::Client::new(), api_key: api_key.into(), auth_token: None, base_url: DEFAULT_BASE_URL.to_string(), max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, } } pub fn from_env() -> Result { Ok(Self::new(read_api_key()?).with_base_url( std::env::var("ANTHROPIC_BASE_URL") .ok() .or_else(|| std::env::var("CLAUDE_CODE_API_BASE_URL").ok()) .unwrap_or_else(|| DEFAULT_BASE_URL.to_string()), )) } #[must_use] pub fn with_auth_token(mut self, auth_token: Option) -> Self { self.auth_token = auth_token.filter(|token| !token.is_empty()); self } #[must_use] pub fn with_base_url(mut self, base_url: impl Into) -> Self { self.base_url = base_url.into(); self } #[must_use] pub fn with_retry_policy( mut self, max_retries: u32, initial_backoff: Duration, max_backoff: Duration, ) -> Self { self.max_retries = max_retries; self.initial_backoff = initial_backoff; self.max_backoff = max_backoff; self } pub async fn send_message( &self, request: &MessageRequest, ) -> Result { let request = MessageRequest { stream: false, ..request.clone() }; let response = self.send_with_retry(&request).await?; let request_id = request_id_from_headers(response.headers()); let mut response = response .json::() .await .map_err(ApiError::from)?; if response.request_id.is_none() { response.request_id = request_id; } Ok(response) } pub async fn stream_message( &self, request: &MessageRequest, ) -> Result { let response = self .send_with_retry(&request.clone().with_streaming()) .await?; Ok(MessageStream { request_id: request_id_from_headers(response.headers()), response, parser: SseParser::new(), pending: VecDeque::new(), done: false, }) } async fn send_with_retry( &self, request: &MessageRequest, ) -> Result { let mut attempts = 0; let mut last_error: Option; loop { attempts += 1; match self.send_raw_request(request).await { Ok(response) => match expect_success(response).await { Ok(response) => return Ok(response), Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { last_error = Some(error); } Err(error) => return Err(error), }, Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { last_error = Some(error); } Err(error) => return Err(error), } if attempts > self.max_retries { break; } tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; } Err(ApiError::RetriesExhausted { attempts, last_error: Box::new(last_error.expect("retry loop must capture an error")), }) } async fn send_raw_request( &self, request: &MessageRequest, ) -> Result { let mut request_builder = self .http .post(format!( "{}/v1/messages", self.base_url.trim_end_matches('/') )) .header("x-api-key", &self.api_key) .header("anthropic-version", ANTHROPIC_VERSION) .header("content-type", "application/json"); if let Some(auth_token) = &self.auth_token { request_builder = request_builder.bearer_auth(auth_token); } request_builder .json(request) .send() .await .map_err(ApiError::from) } fn backoff_for_attempt(&self, attempt: u32) -> Result { let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { return Err(ApiError::BackoffOverflow { attempt, base_delay: self.initial_backoff, }); }; Ok(self .initial_backoff .checked_mul(multiplier) .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) } } fn read_api_key() -> Result { match std::env::var("ANTHROPIC_AUTH_TOKEN") { Ok(api_key) if !api_key.is_empty() => Ok(api_key), Ok(_) => Err(ApiError::MissingApiKey), Err(std::env::VarError::NotPresent) => match std::env::var("ANTHROPIC_API_KEY") { Ok(api_key) if !api_key.is_empty() => Ok(api_key), Ok(_) => Err(ApiError::MissingApiKey), Err(std::env::VarError::NotPresent) => Err(ApiError::MissingApiKey), Err(error) => Err(ApiError::from(error)), }, Err(error) => Err(ApiError::from(error)), } } fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { headers .get(REQUEST_ID_HEADER) .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) .and_then(|value| value.to_str().ok()) .map(ToOwned::to_owned) } #[derive(Debug)] pub struct MessageStream { request_id: Option, response: reqwest::Response, parser: SseParser, pending: VecDeque, done: bool, } impl MessageStream { #[must_use] pub fn request_id(&self) -> Option<&str> { self.request_id.as_deref() } pub async fn next_event(&mut self) -> Result, ApiError> { loop { if let Some(event) = self.pending.pop_front() { return Ok(Some(event)); } if self.done { let remaining = self.parser.finish()?; self.pending.extend(remaining); if let Some(event) = self.pending.pop_front() { return Ok(Some(event)); } return Ok(None); } match self.response.chunk().await? { Some(chunk) => { self.pending.extend(self.parser.push(&chunk)?); } None => { self.done = true; } } } } } async fn expect_success(response: reqwest::Response) -> Result { let status = response.status(); if status.is_success() { return Ok(response); } let body = response.text().await.unwrap_or_else(|_| String::new()); let parsed_error = serde_json::from_str::(&body).ok(); let retryable = is_retryable_status(status); Err(ApiError::Api { status, error_type: parsed_error .as_ref() .map(|error| error.error.error_type.clone()), message: parsed_error .as_ref() .map(|error| error.error.message.clone()), body, retryable, }) } const fn is_retryable_status(status: reqwest::StatusCode) -> bool { matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) } #[derive(Debug, Deserialize)] struct AnthropicErrorEnvelope { error: AnthropicErrorBody, } #[derive(Debug, Deserialize)] struct AnthropicErrorBody { #[serde(rename = "type")] error_type: String, message: String, } #[cfg(test)] mod tests { use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; use std::time::Duration; use crate::types::{ContentBlockDelta, MessageRequest}; #[test] fn read_api_key_requires_presence() { std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_API_KEY"); let error = super::read_api_key().expect_err("missing key should error"); assert!(matches!(error, crate::error::ApiError::MissingApiKey)); } #[test] fn read_api_key_requires_non_empty_value() { std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); std::env::remove_var("ANTHROPIC_API_KEY"); let error = super::read_api_key().expect_err("empty key should error"); assert!(matches!(error, crate::error::ApiError::MissingApiKey)); } #[test] fn read_api_key_prefers_auth_token() { std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); assert_eq!( super::read_api_key().expect("token should load"), "auth-token" ); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); std::env::remove_var("ANTHROPIC_API_KEY"); } #[test] fn message_request_stream_helper_sets_stream_true() { let request = MessageRequest { model: "claude-3-7-sonnet-latest".to_string(), max_tokens: 64, messages: vec![], system: None, tools: None, tool_choice: None, stream: false, }; assert!(request.with_streaming().stream); } #[test] fn backoff_doubles_until_maximum() { let client = super::AnthropicClient::new("test-key").with_retry_policy( 3, Duration::from_millis(10), Duration::from_millis(25), ); assert_eq!( client.backoff_for_attempt(1).expect("attempt 1"), Duration::from_millis(10) ); assert_eq!( client.backoff_for_attempt(2).expect("attempt 2"), Duration::from_millis(20) ); assert_eq!( client.backoff_for_attempt(3).expect("attempt 3"), Duration::from_millis(25) ); } #[test] fn retryable_statuses_are_detected() { assert!(super::is_retryable_status( reqwest::StatusCode::TOO_MANY_REQUESTS )); assert!(super::is_retryable_status( reqwest::StatusCode::INTERNAL_SERVER_ERROR )); assert!(!super::is_retryable_status( reqwest::StatusCode::UNAUTHORIZED )); } #[test] fn tool_delta_variant_round_trips() { let delta = ContentBlockDelta::InputJsonDelta { partial_json: "{\"city\":\"Paris\"}".to_string(), }; let encoded = serde_json::to_string(&delta).expect("delta should serialize"); let decoded: ContentBlockDelta = serde_json::from_str(&encoded).expect("delta should deserialize"); assert_eq!(decoded, delta); } #[test] fn request_id_uses_primary_or_fallback_header() { let mut headers = reqwest::header::HeaderMap::new(); headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header")); assert_eq!( super::request_id_from_headers(&headers).as_deref(), Some("req_primary") ); headers.clear(); headers.insert( ALT_REQUEST_ID_HEADER, "req_fallback".parse().expect("header"), ); assert_eq!( super::request_id_from_headers(&headers).as_deref(), Some("req_fallback") ); } }