fix: auto compaction threshold default 200k tokens

This commit is contained in:
Yeachan-Heo
2026-04-01 03:55:00 +00:00
commit 6b5331576e
202 changed files with 31071 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
[package]
name = "api"
version.workspace = true
edition.workspace = true
license.workspace = true
publish.workspace = true
[dependencies]
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
runtime = { path = "../runtime" }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] }
[lints]
workspace = true

View File

@@ -0,0 +1,994 @@
use std::collections::VecDeque;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use runtime::{
load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
OAuthTokenExchangeRequest,
};
use serde::Deserialize;
use crate::error::ApiError;
use crate::sse::SseParser;
use crate::types::{MessageRequest, MessageResponse, StreamEvent};
const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
const ANTHROPIC_VERSION: &str = "2023-06-01";
const REQUEST_ID_HEADER: &str = "request-id";
const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
const DEFAULT_MAX_RETRIES: u32 = 2;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AuthSource {
None,
ApiKey(String),
BearerToken(String),
ApiKeyAndBearer {
api_key: String,
bearer_token: String,
},
}
impl AuthSource {
pub fn from_env() -> Result<Self, ApiError> {
let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?;
let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?;
match (api_key, auth_token) {
(Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer {
api_key,
bearer_token,
}),
(Some(api_key), None) => Ok(Self::ApiKey(api_key)),
(None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
(None, None) => Err(ApiError::MissingApiKey),
}
}
#[must_use]
pub fn api_key(&self) -> Option<&str> {
match self {
Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
Self::None | Self::BearerToken(_) => None,
}
}
#[must_use]
pub fn bearer_token(&self) -> Option<&str> {
match self {
Self::BearerToken(token)
| Self::ApiKeyAndBearer {
bearer_token: token,
..
} => Some(token),
Self::None | Self::ApiKey(_) => None,
}
}
#[must_use]
pub fn masked_authorization_header(&self) -> &'static str {
if self.bearer_token().is_some() {
"Bearer [REDACTED]"
} else {
"<absent>"
}
}
pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
if let Some(api_key) = self.api_key() {
request_builder = request_builder.header("x-api-key", api_key);
}
if let Some(token) = self.bearer_token() {
request_builder = request_builder.bearer_auth(token);
}
request_builder
}
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
pub struct OAuthTokenSet {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_at: Option<u64>,
#[serde(default)]
pub scopes: Vec<String>,
}
impl From<OAuthTokenSet> for AuthSource {
fn from(value: OAuthTokenSet) -> Self {
Self::BearerToken(value.access_token)
}
}
#[derive(Debug, Clone)]
pub struct AnthropicClient {
http: reqwest::Client,
auth: AuthSource,
base_url: String,
max_retries: u32,
initial_backoff: Duration,
max_backoff: Duration,
}
impl AnthropicClient {
#[must_use]
pub fn new(api_key: impl Into<String>) -> Self {
Self {
http: reqwest::Client::new(),
auth: AuthSource::ApiKey(api_key.into()),
base_url: DEFAULT_BASE_URL.to_string(),
max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF,
}
}
#[must_use]
pub fn from_auth(auth: AuthSource) -> Self {
Self {
http: reqwest::Client::new(),
auth,
base_url: DEFAULT_BASE_URL.to_string(),
max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF,
}
}
pub fn from_env() -> Result<Self, ApiError> {
Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
}
#[must_use]
pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
self.auth = auth;
self
}
#[must_use]
pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
match (
self.auth.api_key().map(ToOwned::to_owned),
auth_token.filter(|token| !token.is_empty()),
) {
(Some(api_key), Some(bearer_token)) => {
self.auth = AuthSource::ApiKeyAndBearer {
api_key,
bearer_token,
};
}
(Some(api_key), None) => {
self.auth = AuthSource::ApiKey(api_key);
}
(None, Some(bearer_token)) => {
self.auth = AuthSource::BearerToken(bearer_token);
}
(None, None) => {
self.auth = AuthSource::None;
}
}
self
}
#[must_use]
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
#[must_use]
pub fn with_retry_policy(
mut self,
max_retries: u32,
initial_backoff: Duration,
max_backoff: Duration,
) -> Self {
self.max_retries = max_retries;
self.initial_backoff = initial_backoff;
self.max_backoff = max_backoff;
self
}
#[must_use]
pub fn auth_source(&self) -> &AuthSource {
&self.auth
}
pub async fn send_message(
&self,
request: &MessageRequest,
) -> Result<MessageResponse, ApiError> {
let request = MessageRequest {
stream: false,
..request.clone()
};
let response = self.send_with_retry(&request).await?;
let request_id = request_id_from_headers(response.headers());
let mut response = response
.json::<MessageResponse>()
.await
.map_err(ApiError::from)?;
if response.request_id.is_none() {
response.request_id = request_id;
}
Ok(response)
}
pub async fn stream_message(
&self,
request: &MessageRequest,
) -> Result<MessageStream, ApiError> {
let response = self
.send_with_retry(&request.clone().with_streaming())
.await?;
Ok(MessageStream {
request_id: request_id_from_headers(response.headers()),
response,
parser: SseParser::new(),
pending: VecDeque::new(),
done: false,
})
}
pub async fn exchange_oauth_code(
&self,
config: &OAuthConfig,
request: &OAuthTokenExchangeRequest,
) -> Result<OAuthTokenSet, ApiError> {
let response = self
.http
.post(&config.token_url)
.header("content-type", "application/x-www-form-urlencoded")
.form(&request.form_params())
.send()
.await
.map_err(ApiError::from)?;
let response = expect_success(response).await?;
response
.json::<OAuthTokenSet>()
.await
.map_err(ApiError::from)
}
pub async fn refresh_oauth_token(
&self,
config: &OAuthConfig,
request: &OAuthRefreshRequest,
) -> Result<OAuthTokenSet, ApiError> {
let response = self
.http
.post(&config.token_url)
.header("content-type", "application/x-www-form-urlencoded")
.form(&request.form_params())
.send()
.await
.map_err(ApiError::from)?;
let response = expect_success(response).await?;
response
.json::<OAuthTokenSet>()
.await
.map_err(ApiError::from)
}
async fn send_with_retry(
&self,
request: &MessageRequest,
) -> Result<reqwest::Response, ApiError> {
let mut attempts = 0;
let mut last_error: Option<ApiError>;
loop {
attempts += 1;
match self.send_raw_request(request).await {
Ok(response) => match expect_success(response).await {
Ok(response) => return Ok(response),
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
last_error = Some(error);
}
Err(error) => return Err(error),
},
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
last_error = Some(error);
}
Err(error) => return Err(error),
}
if attempts > self.max_retries {
break;
}
tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
}
Err(ApiError::RetriesExhausted {
attempts,
last_error: Box::new(last_error.expect("retry loop must capture an error")),
})
}
async fn send_raw_request(
&self,
request: &MessageRequest,
) -> Result<reqwest::Response, ApiError> {
let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
let request_builder = self
.http
.post(&request_url)
.header("anthropic-version", ANTHROPIC_VERSION)
.header("content-type", "application/json");
let mut request_builder = self.auth.apply(request_builder);
request_builder = request_builder.json(request);
request_builder.send().await.map_err(ApiError::from)
}
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
return Err(ApiError::BackoffOverflow {
attempt,
base_delay: self.initial_backoff,
});
};
Ok(self
.initial_backoff
.checked_mul(multiplier)
.map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
}
}
impl AuthSource {
pub fn from_env_or_saved() -> Result<Self, ApiError> {
if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
api_key,
bearer_token,
}),
None => Ok(Self::ApiKey(api_key)),
};
}
if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
return Ok(Self::BearerToken(bearer_token));
}
match load_saved_oauth_token() {
Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
if token_set.refresh_token.is_some() {
Err(ApiError::Auth(
"saved OAuth token is expired; load runtime OAuth config to refresh it"
.to_string(),
))
} else {
Err(ApiError::ExpiredOAuthToken)
}
}
Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
Ok(None) => Err(ApiError::MissingApiKey),
Err(error) => Err(error),
}
}
}
#[must_use]
pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
token_set
.expires_at
.is_some_and(|expires_at| expires_at <= now_unix_timestamp())
}
pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
let Some(token_set) = load_saved_oauth_token()? else {
return Ok(None);
};
resolve_saved_oauth_token_set(config, token_set).map(Some)
}
pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
where
F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
{
if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
api_key,
bearer_token,
}),
None => Ok(AuthSource::ApiKey(api_key)),
};
}
if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
return Ok(AuthSource::BearerToken(bearer_token));
}
let Some(token_set) = load_saved_oauth_token()? else {
return Err(ApiError::MissingApiKey);
};
if !oauth_token_is_expired(&token_set) {
return Ok(AuthSource::BearerToken(token_set.access_token));
}
if token_set.refresh_token.is_none() {
return Err(ApiError::ExpiredOAuthToken);
}
let Some(config) = load_oauth_config()? else {
return Err(ApiError::Auth(
"saved OAuth token is expired; runtime OAuth config is missing".to_string(),
));
};
Ok(AuthSource::from(resolve_saved_oauth_token_set(
&config, token_set,
)?))
}
fn resolve_saved_oauth_token_set(
config: &OAuthConfig,
token_set: OAuthTokenSet,
) -> Result<OAuthTokenSet, ApiError> {
if !oauth_token_is_expired(&token_set) {
return Ok(token_set);
}
let Some(refresh_token) = token_set.refresh_token.clone() else {
return Err(ApiError::ExpiredOAuthToken);
};
let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url());
let refreshed = client_runtime_block_on(async {
client
.refresh_oauth_token(
config,
&OAuthRefreshRequest::from_config(
config,
refresh_token,
Some(token_set.scopes.clone()),
),
)
.await
})?;
let resolved = OAuthTokenSet {
access_token: refreshed.access_token,
refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
expires_at: refreshed.expires_at,
scopes: refreshed.scopes,
};
save_oauth_credentials(&runtime::OAuthTokenSet {
access_token: resolved.access_token.clone(),
refresh_token: resolved.refresh_token.clone(),
expires_at: resolved.expires_at,
scopes: resolved.scopes.clone(),
})
.map_err(ApiError::from)?;
Ok(resolved)
}
fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
where
F: std::future::Future<Output = Result<T, ApiError>>,
{
tokio::runtime::Runtime::new()
.map_err(ApiError::from)?
.block_on(future)
}
fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
let token_set = load_oauth_credentials().map_err(ApiError::from)?;
Ok(token_set.map(|token_set| OAuthTokenSet {
access_token: token_set.access_token,
refresh_token: token_set.refresh_token,
expires_at: token_set.expires_at,
scopes: token_set.scopes,
}))
}
fn now_unix_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |duration| duration.as_secs())
}
fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
match std::env::var(key) {
Ok(value) if !value.is_empty() => Ok(Some(value)),
Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
Err(error) => Err(ApiError::from(error)),
}
}
#[cfg(test)]
fn read_api_key() -> Result<String, ApiError> {
let auth = AuthSource::from_env_or_saved()?;
auth.api_key()
.or_else(|| auth.bearer_token())
.map(ToOwned::to_owned)
.ok_or(ApiError::MissingApiKey)
}
#[cfg(test)]
fn read_auth_token() -> Option<String> {
read_env_non_empty("ANTHROPIC_AUTH_TOKEN")
.ok()
.and_then(std::convert::identity)
}
#[must_use]
pub fn read_base_url() -> String {
std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string())
}
fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
headers
.get(REQUEST_ID_HEADER)
.or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
.and_then(|value| value.to_str().ok())
.map(ToOwned::to_owned)
}
#[derive(Debug)]
pub struct MessageStream {
request_id: Option<String>,
response: reqwest::Response,
parser: SseParser,
pending: VecDeque<StreamEvent>,
done: bool,
}
impl MessageStream {
#[must_use]
pub fn request_id(&self) -> Option<&str> {
self.request_id.as_deref()
}
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
loop {
if let Some(event) = self.pending.pop_front() {
return Ok(Some(event));
}
if self.done {
let remaining = self.parser.finish()?;
self.pending.extend(remaining);
if let Some(event) = self.pending.pop_front() {
return Ok(Some(event));
}
return Ok(None);
}
match self.response.chunk().await? {
Some(chunk) => {
self.pending.extend(self.parser.push(&chunk)?);
}
None => {
self.done = true;
}
}
}
}
}
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
let status = response.status();
if status.is_success() {
return Ok(response);
}
let body = response.text().await.unwrap_or_else(|_| String::new());
let parsed_error = serde_json::from_str::<AnthropicErrorEnvelope>(&body).ok();
let retryable = is_retryable_status(status);
Err(ApiError::Api {
status,
error_type: parsed_error
.as_ref()
.map(|error| error.error.error_type.clone()),
message: parsed_error
.as_ref()
.map(|error| error.error.message.clone()),
body,
retryable,
})
}
const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
}
#[derive(Debug, Deserialize)]
struct AnthropicErrorEnvelope {
error: AnthropicErrorBody,
}
#[derive(Debug, Deserialize)]
struct AnthropicErrorBody {
#[serde(rename = "type")]
error_type: String,
message: String,
}
#[cfg(test)]
mod tests {
use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
use std::io::{Read, Write};
use std::net::TcpListener;
use std::sync::{Mutex, OnceLock};
use std::thread;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig};
use crate::client::{
now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet,
};
use crate::types::{ContentBlockDelta, MessageRequest};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.expect("env lock")
}
fn temp_config_home() -> std::path::PathBuf {
std::env::temp_dir().join(format!(
"api-oauth-test-{}-{}",
std::process::id(),
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("time")
.as_nanos()
))
}
fn sample_oauth_config(token_url: String) -> OAuthConfig {
OAuthConfig {
client_id: "runtime-client".to_string(),
authorize_url: "https://console.test/oauth/authorize".to_string(),
token_url,
callback_port: Some(4545),
manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
scopes: vec!["org:read".to_string(), "user:write".to_string()],
}
}
fn spawn_token_server(response_body: &'static str) -> String {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
let address = listener.local_addr().expect("local addr");
thread::spawn(move || {
let (mut stream, _) = listener.accept().expect("accept connection");
let mut buffer = [0_u8; 4096];
let _ = stream.read(&mut buffer).expect("read request");
let response = format!(
"HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
response_body.len(),
response_body
);
stream
.write_all(response.as_bytes())
.expect("write response");
});
format!("http://{address}/oauth/token")
}
#[test]
fn read_api_key_requires_presence() {
let _guard = env_lock();
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
std::env::remove_var("CLAUDE_CONFIG_HOME");
let error = super::read_api_key().expect_err("missing key should error");
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
}
#[test]
fn read_api_key_requires_non_empty_value() {
let _guard = env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
std::env::remove_var("ANTHROPIC_API_KEY");
let error = super::read_api_key().expect_err("empty key should error");
assert!(matches!(error, crate::error::ApiError::MissingApiKey));
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
}
#[test]
fn read_api_key_prefers_api_key_env() {
let _guard = env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
assert_eq!(
super::read_api_key().expect("api key should load"),
"legacy-key"
);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
}
#[test]
fn read_auth_token_reads_auth_token_env() {
let _guard = env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
}
#[test]
fn oauth_token_maps_to_bearer_auth_source() {
let auth = AuthSource::from(OAuthTokenSet {
access_token: "access-token".to_string(),
refresh_token: Some("refresh".to_string()),
expires_at: Some(123),
scopes: vec!["scope:a".to_string()],
});
assert_eq!(auth.bearer_token(), Some("access-token"));
assert_eq!(auth.api_key(), None);
}
#[test]
fn auth_source_from_env_combines_api_key_and_bearer_token() {
let _guard = env_lock();
std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
let auth = AuthSource::from_env().expect("env auth");
assert_eq!(auth.api_key(), Some("legacy-key"));
assert_eq!(auth.bearer_token(), Some("auth-token"));
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
}
#[test]
fn auth_source_from_saved_oauth_when_env_absent() {
let _guard = env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
save_oauth_credentials(&runtime::OAuthTokenSet {
access_token: "saved-access-token".to_string(),
refresh_token: Some("refresh".to_string()),
expires_at: Some(now_unix_timestamp() + 300),
scopes: vec!["scope:a".to_string()],
})
.expect("save oauth credentials");
let auth = AuthSource::from_env_or_saved().expect("saved auth");
assert_eq!(auth.bearer_token(), Some("saved-access-token"));
clear_oauth_credentials().expect("clear credentials");
std::env::remove_var("CLAUDE_CONFIG_HOME");
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
}
#[test]
fn oauth_token_expiry_uses_expires_at_timestamp() {
assert!(oauth_token_is_expired(&OAuthTokenSet {
access_token: "access-token".to_string(),
refresh_token: None,
expires_at: Some(1),
scopes: Vec::new(),
}));
assert!(!oauth_token_is_expired(&OAuthTokenSet {
access_token: "access-token".to_string(),
refresh_token: None,
expires_at: Some(now_unix_timestamp() + 60),
scopes: Vec::new(),
}));
}
#[test]
fn resolve_saved_oauth_token_refreshes_expired_credentials() {
let _guard = env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
save_oauth_credentials(&runtime::OAuthTokenSet {
access_token: "expired-access-token".to_string(),
refresh_token: Some("refresh-token".to_string()),
expires_at: Some(1),
scopes: vec!["scope:a".to_string()],
})
.expect("save expired oauth credentials");
let token_url = spawn_token_server(
"{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
);
let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
.expect("resolve refreshed token")
.expect("token set present");
assert_eq!(resolved.access_token, "refreshed-token");
let stored = runtime::load_oauth_credentials()
.expect("load stored credentials")
.expect("stored token set");
assert_eq!(stored.access_token, "refreshed-token");
clear_oauth_credentials().expect("clear credentials");
std::env::remove_var("CLAUDE_CONFIG_HOME");
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
}
#[test]
fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
let _guard = env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
save_oauth_credentials(&runtime::OAuthTokenSet {
access_token: "saved-access-token".to_string(),
refresh_token: Some("refresh".to_string()),
expires_at: Some(now_unix_timestamp() + 300),
scopes: vec!["scope:a".to_string()],
})
.expect("save oauth credentials");
let auth = resolve_startup_auth_source(|| panic!("config should not be loaded"))
.expect("startup auth");
assert_eq!(auth.bearer_token(), Some("saved-access-token"));
clear_oauth_credentials().expect("clear credentials");
std::env::remove_var("CLAUDE_CONFIG_HOME");
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
}
#[test]
fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
let _guard = env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
save_oauth_credentials(&runtime::OAuthTokenSet {
access_token: "expired-access-token".to_string(),
refresh_token: Some("refresh-token".to_string()),
expires_at: Some(1),
scopes: vec!["scope:a".to_string()],
})
.expect("save expired oauth credentials");
let error =
resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error");
assert!(
matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing"))
);
let stored = runtime::load_oauth_credentials()
.expect("load stored credentials")
.expect("stored token set");
assert_eq!(stored.access_token, "expired-access-token");
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
clear_oauth_credentials().expect("clear credentials");
std::env::remove_var("CLAUDE_CONFIG_HOME");
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
}
#[test]
fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
let _guard = env_lock();
let config_home = temp_config_home();
std::env::set_var("CLAUDE_CONFIG_HOME", &config_home);
std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
std::env::remove_var("ANTHROPIC_API_KEY");
save_oauth_credentials(&runtime::OAuthTokenSet {
access_token: "expired-access-token".to_string(),
refresh_token: Some("refresh-token".to_string()),
expires_at: Some(1),
scopes: vec!["scope:a".to_string()],
})
.expect("save expired oauth credentials");
let token_url = spawn_token_server(
"{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
);
let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
.expect("resolve refreshed token")
.expect("token set present");
assert_eq!(resolved.access_token, "refreshed-token");
assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token"));
let stored = runtime::load_oauth_credentials()
.expect("load stored credentials")
.expect("stored token set");
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
clear_oauth_credentials().expect("clear credentials");
std::env::remove_var("CLAUDE_CONFIG_HOME");
std::fs::remove_dir_all(config_home).expect("cleanup temp dir");
}
#[test]
fn message_request_stream_helper_sets_stream_true() {
let request = MessageRequest {
model: "claude-opus-4-6".to_string(),
max_tokens: 64,
messages: vec![],
system: None,
tools: None,
tool_choice: None,
stream: false,
};
assert!(request.with_streaming().stream);
}
#[test]
fn backoff_doubles_until_maximum() {
let client = AnthropicClient::new("test-key").with_retry_policy(
3,
Duration::from_millis(10),
Duration::from_millis(25),
);
assert_eq!(
client.backoff_for_attempt(1).expect("attempt 1"),
Duration::from_millis(10)
);
assert_eq!(
client.backoff_for_attempt(2).expect("attempt 2"),
Duration::from_millis(20)
);
assert_eq!(
client.backoff_for_attempt(3).expect("attempt 3"),
Duration::from_millis(25)
);
}
#[test]
fn retryable_statuses_are_detected() {
assert!(super::is_retryable_status(
reqwest::StatusCode::TOO_MANY_REQUESTS
));
assert!(super::is_retryable_status(
reqwest::StatusCode::INTERNAL_SERVER_ERROR
));
assert!(!super::is_retryable_status(
reqwest::StatusCode::UNAUTHORIZED
));
}
#[test]
fn tool_delta_variant_round_trips() {
let delta = ContentBlockDelta::InputJsonDelta {
partial_json: "{\"city\":\"Paris\"}".to_string(),
};
let encoded = serde_json::to_string(&delta).expect("delta should serialize");
let decoded: ContentBlockDelta =
serde_json::from_str(&encoded).expect("delta should deserialize");
assert_eq!(decoded, delta);
}
#[test]
fn request_id_uses_primary_or_fallback_header() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header"));
assert_eq!(
super::request_id_from_headers(&headers).as_deref(),
Some("req_primary")
);
headers.clear();
headers.insert(
ALT_REQUEST_ID_HEADER,
"req_fallback".parse().expect("header"),
);
assert_eq!(
super::request_id_from_headers(&headers).as_deref(),
Some("req_fallback")
);
}
#[test]
fn auth_source_applies_headers() {
let auth = AuthSource::ApiKeyAndBearer {
api_key: "test-key".to_string(),
bearer_token: "proxy-token".to_string(),
};
let request = auth
.apply(reqwest::Client::new().post("https://example.test"))
.build()
.expect("request build");
let headers = request.headers();
assert_eq!(
headers.get("x-api-key").and_then(|v| v.to_str().ok()),
Some("test-key")
);
assert_eq!(
headers.get("authorization").and_then(|v| v.to_str().ok()),
Some("Bearer proxy-token")
);
}
}

View File

@@ -0,0 +1,134 @@
use std::env::VarError;
use std::fmt::{Display, Formatter};
use std::time::Duration;
#[derive(Debug)]
pub enum ApiError {
MissingApiKey,
ExpiredOAuthToken,
Auth(String),
InvalidApiKeyEnv(VarError),
Http(reqwest::Error),
Io(std::io::Error),
Json(serde_json::Error),
Api {
status: reqwest::StatusCode,
error_type: Option<String>,
message: Option<String>,
body: String,
retryable: bool,
},
RetriesExhausted {
attempts: u32,
last_error: Box<ApiError>,
},
InvalidSseFrame(&'static str),
BackoffOverflow {
attempt: u32,
base_delay: Duration,
},
}
impl ApiError {
#[must_use]
pub fn is_retryable(&self) -> bool {
match self {
Self::Http(error) => error.is_connect() || error.is_timeout() || error.is_request(),
Self::Api { retryable, .. } => *retryable,
Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(),
Self::MissingApiKey
| Self::ExpiredOAuthToken
| Self::Auth(_)
| Self::InvalidApiKeyEnv(_)
| Self::Io(_)
| Self::Json(_)
| Self::InvalidSseFrame(_)
| Self::BackoffOverflow { .. } => false,
}
}
}
impl Display for ApiError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingApiKey => {
write!(
f,
"ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY is not set; export one before calling the Anthropic API"
)
}
Self::ExpiredOAuthToken => {
write!(
f,
"saved OAuth token is expired and no refresh token is available"
)
}
Self::Auth(message) => write!(f, "auth error: {message}"),
Self::InvalidApiKeyEnv(error) => {
write!(
f,
"failed to read ANTHROPIC_AUTH_TOKEN / ANTHROPIC_API_KEY: {error}"
)
}
Self::Http(error) => write!(f, "http error: {error}"),
Self::Io(error) => write!(f, "io error: {error}"),
Self::Json(error) => write!(f, "json error: {error}"),
Self::Api {
status,
error_type,
message,
body,
..
} => match (error_type, message) {
(Some(error_type), Some(message)) => {
write!(
f,
"anthropic api returned {status} ({error_type}): {message}"
)
}
_ => write!(f, "anthropic api returned {status}: {body}"),
},
Self::RetriesExhausted {
attempts,
last_error,
} => write!(
f,
"anthropic api failed after {attempts} attempts: {last_error}"
),
Self::InvalidSseFrame(message) => write!(f, "invalid sse frame: {message}"),
Self::BackoffOverflow {
attempt,
base_delay,
} => write!(
f,
"retry backoff overflowed on attempt {attempt} with base delay {base_delay:?}"
),
}
}
}
impl std::error::Error for ApiError {}
impl From<reqwest::Error> for ApiError {
fn from(value: reqwest::Error) -> Self {
Self::Http(value)
}
}
impl From<std::io::Error> for ApiError {
fn from(value: std::io::Error) -> Self {
Self::Io(value)
}
}
impl From<serde_json::Error> for ApiError {
fn from(value: serde_json::Error) -> Self {
Self::Json(value)
}
}
impl From<VarError> for ApiError {
fn from(value: VarError) -> Self {
Self::InvalidApiKeyEnv(value)
}
}

View File

@@ -0,0 +1,17 @@
mod client;
mod error;
mod sse;
mod types;
pub use client::{
oauth_token_is_expired, read_base_url, resolve_saved_oauth_token, resolve_startup_auth_source,
AnthropicClient, AuthSource, MessageStream, OAuthTokenSet,
};
pub use error::ApiError;
pub use sse::{parse_frame, SseParser};
pub use types::{
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
};

219
rust/crates/api/src/sse.rs Normal file
View File

@@ -0,0 +1,219 @@
use crate::error::ApiError;
use crate::types::StreamEvent;
#[derive(Debug, Default)]
pub struct SseParser {
buffer: Vec<u8>,
}
impl SseParser {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn push(&mut self, chunk: &[u8]) -> Result<Vec<StreamEvent>, ApiError> {
self.buffer.extend_from_slice(chunk);
let mut events = Vec::new();
while let Some(frame) = self.next_frame() {
if let Some(event) = parse_frame(&frame)? {
events.push(event);
}
}
Ok(events)
}
pub fn finish(&mut self) -> Result<Vec<StreamEvent>, ApiError> {
if self.buffer.is_empty() {
return Ok(Vec::new());
}
let trailing = std::mem::take(&mut self.buffer);
match parse_frame(&String::from_utf8_lossy(&trailing))? {
Some(event) => Ok(vec![event]),
None => Ok(Vec::new()),
}
}
fn next_frame(&mut self) -> Option<String> {
let separator = self
.buffer
.windows(2)
.position(|window| window == b"\n\n")
.map(|position| (position, 2))
.or_else(|| {
self.buffer
.windows(4)
.position(|window| window == b"\r\n\r\n")
.map(|position| (position, 4))
})?;
let (position, separator_len) = separator;
let frame = self
.buffer
.drain(..position + separator_len)
.collect::<Vec<_>>();
let frame_len = frame.len().saturating_sub(separator_len);
Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
}
}
pub fn parse_frame(frame: &str) -> Result<Option<StreamEvent>, ApiError> {
let trimmed = frame.trim();
if trimmed.is_empty() {
return Ok(None);
}
let mut data_lines = Vec::new();
let mut event_name: Option<&str> = None;
for line in trimmed.lines() {
if line.starts_with(':') {
continue;
}
if let Some(name) = line.strip_prefix("event:") {
event_name = Some(name.trim());
continue;
}
if let Some(data) = line.strip_prefix("data:") {
data_lines.push(data.trim_start());
}
}
if matches!(event_name, Some("ping")) {
return Ok(None);
}
if data_lines.is_empty() {
return Ok(None);
}
let payload = data_lines.join("\n");
if payload == "[DONE]" {
return Ok(None);
}
serde_json::from_str::<StreamEvent>(&payload)
.map(Some)
.map_err(ApiError::from)
}
#[cfg(test)]
mod tests {
use super::{parse_frame, SseParser};
use crate::types::{ContentBlockDelta, MessageDelta, OutputContentBlock, StreamEvent, Usage};
#[test]
fn parses_single_frame() {
let frame = concat!(
"event: content_block_start\n",
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"text\",\"text\":\"Hi\"}}\n\n"
);
let event = parse_frame(frame).expect("frame should parse");
assert_eq!(
event,
Some(StreamEvent::ContentBlockStart(
crate::types::ContentBlockStartEvent {
index: 0,
content_block: OutputContentBlock::Text {
text: "Hi".to_string(),
},
},
))
);
}
#[test]
fn parses_chunked_stream() {
let mut parser = SseParser::new();
let first = b"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hel";
let second = b"lo\"}}\n\n";
assert!(parser
.push(first)
.expect("first chunk should buffer")
.is_empty());
let events = parser.push(second).expect("second chunk should parse");
assert_eq!(
events,
vec![StreamEvent::ContentBlockDelta(
crate::types::ContentBlockDeltaEvent {
index: 0,
delta: ContentBlockDelta::TextDelta {
text: "Hello".to_string(),
},
}
)]
);
}
#[test]
fn ignores_ping_and_done() {
let mut parser = SseParser::new();
let payload = concat!(
": keepalive\n",
"event: ping\n",
"data: {\"type\":\"ping\"}\n\n",
"event: message_delta\n",
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":1,\"output_tokens\":2}}\n\n",
"event: message_stop\n",
"data: {\"type\":\"message_stop\"}\n\n",
"data: [DONE]\n\n"
);
let events = parser
.push(payload.as_bytes())
.expect("parser should succeed");
assert_eq!(
events,
vec![
StreamEvent::MessageDelta(crate::types::MessageDeltaEvent {
delta: MessageDelta {
stop_reason: Some("tool_use".to_string()),
stop_sequence: None,
},
usage: Usage {
input_tokens: 1,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
output_tokens: 2,
},
}),
StreamEvent::MessageStop(crate::types::MessageStopEvent {}),
]
);
}
#[test]
fn ignores_data_less_event_frames() {
let frame = "event: ping\n\n";
let event = parse_frame(frame).expect("frame without data should be ignored");
assert_eq!(event, None);
}
#[test]
fn parses_split_json_across_data_lines() {
let frame = concat!(
"event: content_block_delta\n",
"data: {\"type\":\"content_block_delta\",\"index\":0,\n",
"data: \"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n"
);
let event = parse_frame(frame).expect("frame should parse");
assert_eq!(
event,
Some(StreamEvent::ContentBlockDelta(
crate::types::ContentBlockDeltaEvent {
index: 0,
delta: ContentBlockDelta::TextDelta {
text: "Hello".to_string(),
},
}
))
);
}
}

View File

@@ -0,0 +1,212 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MessageRequest {
pub model: String,
pub max_tokens: u32,
pub messages: Vec<InputMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub stream: bool,
}
impl MessageRequest {
#[must_use]
pub fn with_streaming(mut self) -> Self {
self.stream = true;
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct InputMessage {
pub role: String,
pub content: Vec<InputContentBlock>,
}
impl InputMessage {
#[must_use]
pub fn user_text(text: impl Into<String>) -> Self {
Self {
role: "user".to_string(),
content: vec![InputContentBlock::Text { text: text.into() }],
}
}
#[must_use]
pub fn user_tool_result(
tool_use_id: impl Into<String>,
content: impl Into<String>,
is_error: bool,
) -> Self {
Self {
role: "user".to_string(),
content: vec![InputContentBlock::ToolResult {
tool_use_id: tool_use_id.into(),
content: vec![ToolResultContentBlock::Text {
text: content.into(),
}],
is_error,
}],
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum InputContentBlock {
Text {
text: String,
},
ToolUse {
id: String,
name: String,
input: Value,
},
ToolResult {
tool_use_id: String,
content: Vec<ToolResultContentBlock>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
is_error: bool,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolResultContentBlock {
Text { text: String },
Json { value: Value },
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub input_schema: Value,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolChoice {
Auto,
Any,
Tool { name: String },
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MessageResponse {
pub id: String,
#[serde(rename = "type")]
pub kind: String,
pub role: String,
pub content: Vec<OutputContentBlock>,
pub model: String,
#[serde(default)]
pub stop_reason: Option<String>,
#[serde(default)]
pub stop_sequence: Option<String>,
pub usage: Usage,
#[serde(default)]
pub request_id: Option<String>,
}
impl MessageResponse {
#[must_use]
pub fn total_tokens(&self) -> u32 {
self.usage.total_tokens()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum OutputContentBlock {
Text {
text: String,
},
ToolUse {
id: String,
name: String,
input: Value,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Usage {
pub input_tokens: u32,
#[serde(default)]
pub cache_creation_input_tokens: u32,
#[serde(default)]
pub cache_read_input_tokens: u32,
pub output_tokens: u32,
}
impl Usage {
#[must_use]
pub const fn total_tokens(&self) -> u32 {
self.input_tokens + self.output_tokens
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MessageStartEvent {
pub message: MessageResponse,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MessageDeltaEvent {
pub delta: MessageDelta,
pub usage: Usage,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct MessageDelta {
#[serde(default)]
pub stop_reason: Option<String>,
#[serde(default)]
pub stop_sequence: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ContentBlockStartEvent {
pub index: u32,
pub content_block: OutputContentBlock,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ContentBlockDeltaEvent {
pub index: u32,
pub delta: ContentBlockDelta,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ContentBlockDelta {
TextDelta { text: String },
InputJsonDelta { partial_json: String },
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ContentBlockStopEvent {
pub index: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct MessageStopEvent {}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum StreamEvent {
MessageStart(MessageStartEvent),
MessageDelta(MessageDeltaEvent),
ContentBlockStart(ContentBlockStartEvent),
ContentBlockDelta(ContentBlockDeltaEvent),
ContentBlockStop(ContentBlockStopEvent),
MessageStop(MessageStopEvent),
}

View File

@@ -0,0 +1,443 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use api::{
AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent,
InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock,
StreamEvent, ToolChoice, ToolDefinition,
};
use serde_json::json;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::Mutex;
#[tokio::test]
async fn send_message_posts_json_and_parses_response() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let body = concat!(
"{",
"\"id\":\"msg_test\",",
"\"type\":\"message\",",
"\"role\":\"assistant\",",
"\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],",
"\"model\":\"claude-3-7-sonnet-latest\",",
"\"stop_reason\":\"end_turn\",",
"\"stop_sequence\":null,",
"\"usage\":{\"input_tokens\":12,\"output_tokens\":4},",
"\"request_id\":\"req_body_123\"",
"}"
);
let server = spawn_server(
state.clone(),
vec![http_response("200 OK", "application/json", body)],
)
.await;
let client = AnthropicClient::new("test-key")
.with_auth_token(Some("proxy-token".to_string()))
.with_base_url(server.base_url());
let response = client
.send_message(&sample_request(false))
.await
.expect("request should succeed");
assert_eq!(response.id, "msg_test");
assert_eq!(response.total_tokens(), 16);
assert_eq!(response.request_id.as_deref(), Some("req_body_123"));
assert_eq!(
response.content,
vec![OutputContentBlock::Text {
text: "Hello from Claude".to_string(),
}]
);
let captured = state.lock().await;
let request = captured.first().expect("server should capture request");
assert_eq!(request.method, "POST");
assert_eq!(request.path, "/v1/messages");
assert_eq!(
request.headers.get("x-api-key").map(String::as_str),
Some("test-key")
);
assert_eq!(
request.headers.get("authorization").map(String::as_str),
Some("Bearer proxy-token")
);
let body: serde_json::Value =
serde_json::from_str(&request.body).expect("request body should be json");
assert_eq!(
body.get("model").and_then(serde_json::Value::as_str),
Some("claude-3-7-sonnet-latest")
);
assert!(body.get("stream").is_none());
assert_eq!(body["tools"][0]["name"], json!("get_weather"));
assert_eq!(body["tool_choice"]["type"], json!("auto"));
}
#[tokio::test]
async fn stream_message_parses_sse_events_with_tool_use() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let sse = concat!(
"event: message_start\n",
"data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n",
"event: content_block_start\n",
"data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n",
"event: content_block_delta\n",
"data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"input_json_delta\",\"partial_json\":\"{\\\"city\\\":\\\"Paris\\\"}\"}}\n\n",
"event: content_block_stop\n",
"data: {\"type\":\"content_block_stop\",\"index\":0}\n\n",
"event: message_delta\n",
"data: {\"type\":\"message_delta\",\"delta\":{\"stop_reason\":\"tool_use\",\"stop_sequence\":null},\"usage\":{\"input_tokens\":8,\"output_tokens\":1}}\n\n",
"event: message_stop\n",
"data: {\"type\":\"message_stop\"}\n\n",
"data: [DONE]\n\n"
);
let server = spawn_server(
state.clone(),
vec![http_response_with_headers(
"200 OK",
"text/event-stream",
sse,
&[("request-id", "req_stream_456")],
)],
)
.await;
let client = AnthropicClient::new("test-key")
.with_auth_token(Some("proxy-token".to_string()))
.with_base_url(server.base_url());
let mut stream = client
.stream_message(&sample_request(false))
.await
.expect("stream should start");
assert_eq!(stream.request_id(), Some("req_stream_456"));
let mut events = Vec::new();
while let Some(event) = stream
.next_event()
.await
.expect("stream event should parse")
{
events.push(event);
}
assert_eq!(events.len(), 6);
assert!(matches!(events[0], StreamEvent::MessageStart(_)));
assert!(matches!(
events[1],
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
content_block: OutputContentBlock::ToolUse { .. },
..
})
));
assert!(matches!(
events[2],
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
delta: ContentBlockDelta::InputJsonDelta { .. },
..
})
));
assert!(matches!(events[3], StreamEvent::ContentBlockStop(_)));
assert!(matches!(
events[4],
StreamEvent::MessageDelta(MessageDeltaEvent { .. })
));
assert!(matches!(events[5], StreamEvent::MessageStop(_)));
match &events[1] {
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
content_block: OutputContentBlock::ToolUse { name, input, .. },
..
}) => {
assert_eq!(name, "get_weather");
assert_eq!(input, &json!({}));
}
other => panic!("expected tool_use block, got {other:?}"),
}
let captured = state.lock().await;
let request = captured.first().expect("server should capture request");
assert!(request.body.contains("\"stream\":true"));
}
#[tokio::test]
async fn retries_retryable_failures_before_succeeding() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![
http_response(
"429 Too Many Requests",
"application/json",
"{\"type\":\"error\",\"error\":{\"type\":\"rate_limit_error\",\"message\":\"slow down\"}}",
),
http_response(
"200 OK",
"application/json",
"{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}",
),
],
)
.await;
let client = AnthropicClient::new("test-key")
.with_base_url(server.base_url())
.with_retry_policy(2, Duration::from_millis(1), Duration::from_millis(2));
let response = client
.send_message(&sample_request(false))
.await
.expect("retry should eventually succeed");
assert_eq!(response.total_tokens(), 5);
assert_eq!(state.lock().await.len(), 2);
}
#[tokio::test]
async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let server = spawn_server(
state.clone(),
vec![
http_response(
"503 Service Unavailable",
"application/json",
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"busy\"}}",
),
http_response(
"503 Service Unavailable",
"application/json",
"{\"type\":\"error\",\"error\":{\"type\":\"overloaded_error\",\"message\":\"still busy\"}}",
),
],
)
.await;
let client = AnthropicClient::new("test-key")
.with_base_url(server.base_url())
.with_retry_policy(1, Duration::from_millis(1), Duration::from_millis(2));
let error = client
.send_message(&sample_request(false))
.await
.expect_err("persistent 503 should fail");
match error {
ApiError::RetriesExhausted {
attempts,
last_error,
} => {
assert_eq!(attempts, 2);
assert!(matches!(
*last_error,
ApiError::Api {
status: reqwest::StatusCode::SERVICE_UNAVAILABLE,
retryable: true,
..
}
));
}
other => panic!("expected retries exhausted, got {other:?}"),
}
}
#[tokio::test]
#[ignore = "requires ANTHROPIC_API_KEY and network access"]
async fn live_stream_smoke_test() {
let client = AnthropicClient::from_env().expect("ANTHROPIC_API_KEY must be set");
let mut stream = client
.stream_message(&MessageRequest {
model: std::env::var("ANTHROPIC_MODEL")
.unwrap_or_else(|_| "claude-3-7-sonnet-latest".to_string()),
max_tokens: 32,
messages: vec![InputMessage::user_text(
"Reply with exactly: hello from rust",
)],
system: None,
tools: None,
tool_choice: None,
stream: false,
})
.await
.expect("live stream should start");
while let Some(_event) = stream
.next_event()
.await
.expect("live stream should yield events")
{}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct CapturedRequest {
method: String,
path: String,
headers: HashMap<String, String>,
body: String,
}
struct TestServer {
base_url: String,
join_handle: tokio::task::JoinHandle<()>,
}
impl TestServer {
fn base_url(&self) -> String {
self.base_url.clone()
}
}
impl Drop for TestServer {
fn drop(&mut self) {
self.join_handle.abort();
}
}
async fn spawn_server(
state: Arc<Mutex<Vec<CapturedRequest>>>,
responses: Vec<String>,
) -> TestServer {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("listener should bind");
let address = listener
.local_addr()
.expect("listener should have local addr");
let join_handle = tokio::spawn(async move {
for response in responses {
let (mut socket, _) = listener.accept().await.expect("server should accept");
let mut buffer = Vec::new();
let mut header_end = None;
loop {
let mut chunk = [0_u8; 1024];
let read = socket
.read(&mut chunk)
.await
.expect("request read should succeed");
if read == 0 {
break;
}
buffer.extend_from_slice(&chunk[..read]);
if let Some(position) = find_header_end(&buffer) {
header_end = Some(position);
break;
}
}
let header_end = header_end.expect("request should include headers");
let (header_bytes, remaining) = buffer.split_at(header_end);
let header_text =
String::from_utf8(header_bytes.to_vec()).expect("headers should be utf8");
let mut lines = header_text.split("\r\n");
let request_line = lines.next().expect("request line should exist");
let mut parts = request_line.split_whitespace();
let method = parts.next().expect("method should exist").to_string();
let path = parts.next().expect("path should exist").to_string();
let mut headers = HashMap::new();
let mut content_length = 0_usize;
for line in lines {
if line.is_empty() {
continue;
}
let (name, value) = line.split_once(':').expect("header should have colon");
let value = value.trim().to_string();
if name.eq_ignore_ascii_case("content-length") {
content_length = value.parse().expect("content length should parse");
}
headers.insert(name.to_ascii_lowercase(), value);
}
let mut body = remaining[4..].to_vec();
while body.len() < content_length {
let mut chunk = vec![0_u8; content_length - body.len()];
let read = socket
.read(&mut chunk)
.await
.expect("body read should succeed");
if read == 0 {
break;
}
body.extend_from_slice(&chunk[..read]);
}
state.lock().await.push(CapturedRequest {
method,
path,
headers,
body: String::from_utf8(body).expect("body should be utf8"),
});
socket
.write_all(response.as_bytes())
.await
.expect("response write should succeed");
}
});
TestServer {
base_url: format!("http://{address}"),
join_handle,
}
}
fn find_header_end(bytes: &[u8]) -> Option<usize> {
bytes.windows(4).position(|window| window == b"\r\n\r\n")
}
fn http_response(status: &str, content_type: &str, body: &str) -> String {
http_response_with_headers(status, content_type, body, &[])
}
fn http_response_with_headers(
status: &str,
content_type: &str,
body: &str,
headers: &[(&str, &str)],
) -> String {
let mut extra_headers = String::new();
for (name, value) in headers {
use std::fmt::Write as _;
write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write should succeed");
}
format!(
"HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}",
body.len()
)
}
fn sample_request(stream: bool) -> MessageRequest {
MessageRequest {
model: "claude-3-7-sonnet-latest".to_string(),
max_tokens: 64,
messages: vec![InputMessage {
role: "user".to_string(),
content: vec![
InputContentBlock::Text {
text: "Say hello".to_string(),
},
InputContentBlock::ToolResult {
tool_use_id: "toolu_prev".to_string(),
content: vec![api::ToolResultContentBlock::Json {
value: json!({"forecast": "sunny"}),
}],
is_error: false,
},
],
}],
system: Some("Use tools when needed".to_string()),
tools: Some(vec![ToolDefinition {
name: "get_weather".to_string(),
description: Some("Fetches the weather".to_string()),
input_schema: json!({
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"]
}),
}]),
tool_choice: Some(ToolChoice::Auto),
stream,
}
}