From cbc0a83059a93a99d88e3b83c4c200546f31ed27 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 04:01:37 +0000 Subject: [PATCH 1/5] auto: save WIP progress from rcc session --- rust/crates/api/src/error.rs | 43 +- rust/crates/api/src/providers/anthropic.rs | 994 +++++++++++++++++++++ rust/crates/api/src/providers/mod.rs | 202 +++++ 3 files changed, 1218 insertions(+), 21 deletions(-) create mode 100644 rust/crates/api/src/providers/anthropic.rs create mode 100644 rust/crates/api/src/providers/mod.rs diff --git a/rust/crates/api/src/error.rs b/rust/crates/api/src/error.rs index 2c31691..7649889 100644 --- a/rust/crates/api/src/error.rs +++ b/rust/crates/api/src/error.rs @@ -4,7 +4,10 @@ use std::time::Duration; #[derive(Debug)] pub enum ApiError { - MissingApiKey, + MissingCredentials { + provider: &'static str, + env_vars: &'static [&'static str], + }, ExpiredOAuthToken, Auth(String), InvalidApiKeyEnv(VarError), @@ -30,13 +33,21 @@ pub enum ApiError { } impl ApiError { + #[must_use] + pub const fn missing_credentials( + provider: &'static str, + env_vars: &'static [&'static str], + ) -> Self { + Self::MissingCredentials { provider, env_vars } + } + #[must_use] pub fn is_retryable(&self) -> bool { match self { Self::Http(error) => error.is_connect() || error.is_timeout() || error.is_request(), Self::Api { retryable, .. } => *retryable, Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(), - Self::MissingApiKey + Self::MissingCredentials { .. } | Self::ExpiredOAuthToken | Self::Auth(_) | Self::InvalidApiKeyEnv(_) @@ -51,12 +62,11 @@ impl ApiError { impl Display for ApiError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::MissingApiKey => { - write!( - f, - "ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY is not set; export one before calling the Anthropic API" - ) - } + Self::MissingCredentials { provider, env_vars } => write!( + f, + "missing {provider} credentials; export {} before calling the {provider} API", + env_vars.join(" or ") + ), Self::ExpiredOAuthToken => { write!( f, @@ -65,10 +75,7 @@ impl Display for ApiError { } Self::Auth(message) => write!(f, "auth error: {message}"), Self::InvalidApiKeyEnv(error) => { - write!( - f, - "failed to read ANTHROPIC_AUTH_TOKEN / ANTHROPIC_API_KEY: {error}" - ) + write!(f, "failed to read credential environment variable: {error}") } Self::Http(error) => write!(f, "http error: {error}"), Self::Io(error) => write!(f, "io error: {error}"), @@ -81,20 +88,14 @@ impl Display for ApiError { .. } => match (error_type, message) { (Some(error_type), Some(message)) => { - write!( - f, - "anthropic api returned {status} ({error_type}): {message}" - ) + write!(f, "api returned {status} ({error_type}): {message}") } - _ => write!(f, "anthropic api returned {status}: {body}"), + _ => write!(f, "api returned {status}: {body}"), }, Self::RetriesExhausted { attempts, last_error, - } => write!( - f, - "anthropic api failed after {attempts} attempts: {last_error}" - ), + } => write!(f, "api failed after {attempts} attempts: {last_error}"), Self::InvalidSseFrame(message) => write!(f, "invalid sse frame: {message}"), Self::BackoffOverflow { attempt, diff --git a/rust/crates/api/src/providers/anthropic.rs b/rust/crates/api/src/providers/anthropic.rs new file mode 100644 index 0000000..4f6dd98 --- /dev/null +++ b/rust/crates/api/src/providers/anthropic.rs @@ -0,0 +1,994 @@ +use std::collections::VecDeque; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use runtime::{ + load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest, + OAuthTokenExchangeRequest, +}; +use serde::Deserialize; + +use crate::error::ApiError; +use crate::sse::SseParser; +use crate::types::{MessageRequest, MessageResponse, StreamEvent}; + +const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; +const ANTHROPIC_VERSION: &str = "2023-06-01"; +const REQUEST_ID_HEADER: &str = "request-id"; +const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; +const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); +const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); +const DEFAULT_MAX_RETRIES: u32 = 2; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AuthSource { + None, + ApiKey(String), + BearerToken(String), + ApiKeyAndBearer { + api_key: String, + bearer_token: String, + }, +} + +impl AuthSource { + pub fn from_env() -> Result { + let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?; + let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?; + match (api_key, auth_token) { + (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + (Some(api_key), None) => Ok(Self::ApiKey(api_key)), + (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), + (None, None) => Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])), + } + } + + #[must_use] + pub fn api_key(&self) -> Option<&str> { + match self { + Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key), + Self::None | Self::BearerToken(_) => None, + } + } + + #[must_use] + pub fn bearer_token(&self) -> Option<&str> { + match self { + Self::BearerToken(token) + | Self::ApiKeyAndBearer { + bearer_token: token, + .. + } => Some(token), + Self::None | Self::ApiKey(_) => None, + } + } + + #[must_use] + pub fn masked_authorization_header(&self) -> &'static str { + if self.bearer_token().is_some() { + "Bearer [REDACTED]" + } else { + "" + } + } + + pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(api_key) = self.api_key() { + request_builder = request_builder.header("x-api-key", api_key); + } + if let Some(token) = self.bearer_token() { + request_builder = request_builder.bearer_auth(token); + } + request_builder + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +pub struct OAuthTokenSet { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: Option, + #[serde(default)] + pub scopes: Vec, +} + +impl From for AuthSource { + fn from(value: OAuthTokenSet) -> Self { + Self::BearerToken(value.access_token) + } +} + +#[derive(Debug, Clone)] +pub struct AnthropicClient { + http: reqwest::Client, + auth: AuthSource, + base_url: String, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, +} + +impl AnthropicClient { + #[must_use] + pub fn new(api_key: impl Into) -> Self { + Self { + http: reqwest::Client::new(), + auth: AuthSource::ApiKey(api_key.into()), + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + #[must_use] + pub fn from_auth(auth: AuthSource) -> Self { + Self { + http: reqwest::Client::new(), + auth, + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + pub fn from_env() -> Result { + Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url())) + } + + #[must_use] + pub fn with_auth_source(mut self, auth: AuthSource) -> Self { + self.auth = auth; + self + } + + #[must_use] + pub fn with_auth_token(mut self, auth_token: Option) -> Self { + match ( + self.auth.api_key().map(ToOwned::to_owned), + auth_token.filter(|token| !token.is_empty()), + ) { + (Some(api_key), Some(bearer_token)) => { + self.auth = AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }; + } + (Some(api_key), None) => { + self.auth = AuthSource::ApiKey(api_key); + } + (None, Some(bearer_token)) => { + self.auth = AuthSource::BearerToken(bearer_token); + } + (None, None) => { + self.auth = AuthSource::None; + } + } + self + } + + #[must_use] + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + #[must_use] + pub fn with_retry_policy( + mut self, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, + ) -> Self { + self.max_retries = max_retries; + self.initial_backoff = initial_backoff; + self.max_backoff = max_backoff; + self + } + + #[must_use] + pub fn auth_source(&self) -> &AuthSource { + &self.auth + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + let request = MessageRequest { + stream: false, + ..request.clone() + }; + let response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(response.headers()); + let mut response = response + .json::() + .await + .map_err(ApiError::from)?; + if response.request_id.is_none() { + response.request_id = request_id; + } + Ok(response) + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + let response = self + .send_with_retry(&request.clone().with_streaming()) + .await?; + Ok(MessageStream { + request_id: request_id_from_headers(response.headers()), + response, + parser: SseParser::new(), + pending: VecDeque::new(), + done: false, + }) + } + + pub async fn exchange_oauth_code( + &self, + config: &OAuthConfig, + request: &OAuthTokenExchangeRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + + pub async fn refresh_oauth_token( + &self, + config: &OAuthConfig, + request: &OAuthRefreshRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + + async fn send_with_retry( + &self, + request: &MessageRequest, + ) -> Result { + let mut attempts = 0; + let mut last_error: Option; + + loop { + attempts += 1; + match self.send_raw_request(request).await { + Ok(response) => match expect_success(response).await { + Ok(response) => return Ok(response), + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + last_error = Some(error); + } + Err(error) => return Err(error), + }, + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + last_error = Some(error); + } + Err(error) => return Err(error), + } + + if attempts > self.max_retries { + break; + } + + tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; + } + + Err(ApiError::RetriesExhausted { + attempts, + last_error: Box::new(last_error.expect("retry loop must capture an error")), + }) + } + + async fn send_raw_request( + &self, + request: &MessageRequest, + ) -> Result { + let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/')); + let request_builder = self + .http + .post(&request_url) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("content-type", "application/json"); + let mut request_builder = self.auth.apply(request_builder); + + request_builder = request_builder.json(request); + request_builder.send().await.map_err(ApiError::from) + } + + fn backoff_for_attempt(&self, attempt: u32) -> Result { + let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { + return Err(ApiError::BackoffOverflow { + attempt, + base_delay: self.initial_backoff, + }); + }; + Ok(self + .initial_backoff + .checked_mul(multiplier) + .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) + } +} + +impl AuthSource { + pub fn from_env_or_saved() -> Result { + if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { + return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + Some(bearer_token) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + None => Ok(Self::ApiKey(api_key)), + }; + } + if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + return Ok(Self::BearerToken(bearer_token)); + } + match load_saved_oauth_token() { + Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => { + if token_set.refresh_token.is_some() { + Err(ApiError::Auth( + "saved OAuth token is expired; load runtime OAuth config to refresh it" + .to_string(), + )) + } else { + Err(ApiError::ExpiredOAuthToken) + } + } + Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)), + Ok(None) => Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])), + Err(error) => Err(error), + } + } +} + +#[must_use] +pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool { + token_set + .expires_at + .is_some_and(|expires_at| expires_at <= now_unix_timestamp()) +} + +pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result, ApiError> { + let Some(token_set) = load_saved_oauth_token()? else { + return Ok(None); + }; + resolve_saved_oauth_token_set(config, token_set).map(Some) +} + +pub fn resolve_startup_auth_source(load_oauth_config: F) -> Result +where + F: FnOnce() -> Result, ApiError>, +{ + if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { + return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }), + None => Ok(AuthSource::ApiKey(api_key)), + }; + } + if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + return Ok(AuthSource::BearerToken(bearer_token)); + } + + let Some(token_set) = load_saved_oauth_token()? else { + return Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])); + }; + if !oauth_token_is_expired(&token_set) { + return Ok(AuthSource::BearerToken(token_set.access_token)); + } + if token_set.refresh_token.is_none() { + return Err(ApiError::ExpiredOAuthToken); + } + + let Some(config) = load_oauth_config()? else { + return Err(ApiError::Auth( + "saved OAuth token is expired; runtime OAuth config is missing".to_string(), + )); + }; + Ok(AuthSource::from(resolve_saved_oauth_token_set( + &config, token_set, + )?)) +} + +fn resolve_saved_oauth_token_set( + config: &OAuthConfig, + token_set: OAuthTokenSet, +) -> Result { + if !oauth_token_is_expired(&token_set) { + return Ok(token_set); + } + let Some(refresh_token) = token_set.refresh_token.clone() else { + return Err(ApiError::ExpiredOAuthToken); + }; + let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url()); + let refreshed = client_runtime_block_on(async { + client + .refresh_oauth_token( + config, + &OAuthRefreshRequest::from_config( + config, + refresh_token, + Some(token_set.scopes.clone()), + ), + ) + .await + })?; + let resolved = OAuthTokenSet { + access_token: refreshed.access_token, + refresh_token: refreshed.refresh_token.or(token_set.refresh_token), + expires_at: refreshed.expires_at, + scopes: refreshed.scopes, + }; + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: resolved.access_token.clone(), + refresh_token: resolved.refresh_token.clone(), + expires_at: resolved.expires_at, + scopes: resolved.scopes.clone(), + }) + .map_err(ApiError::from)?; + Ok(resolved) +} + +fn client_runtime_block_on(future: F) -> Result +where + F: std::future::Future>, +{ + tokio::runtime::Runtime::new() + .map_err(ApiError::from)? + .block_on(future) +} + +fn load_saved_oauth_token() -> Result, ApiError> { + let token_set = load_oauth_credentials().map_err(ApiError::from)?; + Ok(token_set.map(|token_set| OAuthTokenSet { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + scopes: token_set.scopes, + })) +} + +fn now_unix_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + +fn read_env_non_empty(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), + Err(error) => Err(ApiError::from(error)), + } +} + +#[cfg(test)] +fn read_api_key() -> Result { + let auth = AuthSource::from_env_or_saved()?; + auth.api_key() + .or_else(|| auth.bearer_token()) + .map(ToOwned::to_owned) + .ok_or(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])) +} + +#[cfg(test)] +fn read_auth_token() -> Option { + read_env_non_empty("ANTHROPIC_AUTH_TOKEN") + .ok() + .and_then(std::convert::identity) +} + +#[must_use] +pub fn read_base_url() -> String { + std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()) +} + +fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { + headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) + .and_then(|value| value.to_str().ok()) + .map(ToOwned::to_owned) +} + +#[derive(Debug)] +pub struct MessageStream { + request_id: Option, + response: reqwest::Response, + parser: SseParser, + pending: VecDeque, + done: bool, +} + +impl MessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + self.request_id.as_deref() + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + loop { + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + + if self.done { + let remaining = self.parser.finish()?; + self.pending.extend(remaining); + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + return Ok(None); + } + + match self.response.chunk().await? { + Some(chunk) => { + self.pending.extend(self.parser.push(&chunk)?); + } + None => { + self.done = true; + } + } + } + } +} + +async fn expect_success(response: reqwest::Response) -> Result { + let status = response.status(); + if status.is_success() { + return Ok(response); + } + + let body = response.text().await.unwrap_or_else(|_| String::new()); + let parsed_error = serde_json::from_str::(&body).ok(); + let retryable = is_retryable_status(status); + + Err(ApiError::Api { + status, + error_type: parsed_error + .as_ref() + .map(|error| error.error.error_type.clone()), + message: parsed_error + .as_ref() + .map(|error| error.error.message.clone()), + body, + retryable, + }) +} + +const fn is_retryable_status(status: reqwest::StatusCode) -> bool { + matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) +} + +#[derive(Debug, Deserialize)] +struct AnthropicErrorEnvelope { + error: AnthropicErrorBody, +} + +#[derive(Debug, Deserialize)] +struct AnthropicErrorBody { + #[serde(rename = "type")] + error_type: String, + message: String, +} + +#[cfg(test)] +mod tests { + use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; + use std::io::{Read, Write}; + use std::net::TcpListener; + use std::sync::{Mutex, OnceLock}; + use std::thread; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig}; + + use super::{ + now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, + resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet, + }; + use crate::types::{ContentBlockDelta, MessageRequest}; + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .expect("env lock") + } + + fn temp_config_home() -> std::path::PathBuf { + std::env::temp_dir().join(format!( + "api-oauth-test-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )) + } + + fn sample_oauth_config(token_url: String) -> OAuthConfig { + OAuthConfig { + client_id: "runtime-client".to_string(), + authorize_url: "https://console.test/oauth/authorize".to_string(), + token_url, + callback_port: Some(4545), + manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), + scopes: vec!["org:read".to_string(), "user:write".to_string()], + } + } + + fn spawn_token_server(response_body: &'static str) -> String { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); + let address = listener.local_addr().expect("local addr"); + thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("accept connection"); + let mut buffer = [0_u8; 4096]; + let _ = stream.read(&mut buffer).expect("read request"); + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", + response_body.len(), + response_body + ); + stream + .write_all(response.as_bytes()) + .expect("write response"); + }); + format!("http://{address}/oauth/token") + } + + #[test] + fn read_api_key_requires_presence() { + let _guard = env_lock(); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + let error = super::read_api_key().expect_err("missing key should error"); + assert!(matches!(error, crate::error::ApiError::MissingCredentials { .. })); + } + + #[test] + fn read_api_key_requires_non_empty_value() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); + std::env::remove_var("ANTHROPIC_API_KEY"); + let error = super::read_api_key().expect_err("empty key should error"); + assert!(matches!(error, crate::error::ApiError::MissingCredentials { .. })); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + } + + #[test] + fn read_api_key_prefers_api_key_env() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + assert_eq!( + super::read_api_key().expect("api key should load"), + "legacy-key" + ); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn read_auth_token_reads_auth_token_env() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + } + + #[test] + fn oauth_token_maps_to_bearer_auth_source() { + let auth = AuthSource::from(OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(123), + scopes: vec!["scope:a".to_string()], + }); + assert_eq!(auth.bearer_token(), Some("access-token")); + assert_eq!(auth.api_key(), None); + } + + #[test] + fn auth_source_from_env_combines_api_key_and_bearer_token() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + let auth = AuthSource::from_env().expect("env auth"); + assert_eq!(auth.api_key(), Some("legacy-key")); + assert_eq!(auth.bearer_token(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn auth_source_from_saved_oauth_when_env_absent() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "saved-access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(now_unix_timestamp() + 300), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth credentials"); + + let auth = AuthSource::from_env_or_saved().expect("saved auth"); + assert_eq!(auth.bearer_token(), Some("saved-access-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn oauth_token_expiry_uses_expires_at_timestamp() { + assert!(oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(1), + scopes: Vec::new(), + })); + assert!(!oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(now_unix_timestamp() + 60), + scopes: Vec::new(), + })); + } + + #[test] + fn resolve_saved_oauth_token_refreshes_expired_credentials() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", + ); + let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) + .expect("resolve refreshed token") + .expect("token set present"); + assert_eq!(resolved.access_token, "refreshed-token"); + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.access_token, "refreshed-token"); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "saved-access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(now_unix_timestamp() + 300), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth credentials"); + + let auth = resolve_startup_auth_source(|| panic!("config should not be loaded")) + .expect("startup auth"); + assert_eq!(auth.bearer_token(), Some("saved-access-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let error = + resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error"); + assert!( + matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing")) + ); + + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.access_token, "expired-access-token"); + assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", + ); + let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) + .expect("resolve refreshed token") + .expect("token set present"); + assert_eq!(resolved.access_token, "refreshed-token"); + assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token")); + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); + } + + #[test] + fn message_request_stream_helper_sets_stream_true() { + let request = MessageRequest { + model: "claude-opus-4-6".to_string(), + max_tokens: 64, + messages: vec![], + system: None, + tools: None, + tool_choice: None, + stream: false, + }; + + assert!(request.with_streaming().stream); + } + + #[test] + fn backoff_doubles_until_maximum() { + let client = AnthropicClient::new("test-key").with_retry_policy( + 3, + Duration::from_millis(10), + Duration::from_millis(25), + ); + assert_eq!( + client.backoff_for_attempt(1).expect("attempt 1"), + Duration::from_millis(10) + ); + assert_eq!( + client.backoff_for_attempt(2).expect("attempt 2"), + Duration::from_millis(20) + ); + assert_eq!( + client.backoff_for_attempt(3).expect("attempt 3"), + Duration::from_millis(25) + ); + } + + #[test] + fn retryable_statuses_are_detected() { + assert!(super::is_retryable_status( + reqwest::StatusCode::TOO_MANY_REQUESTS + )); + assert!(super::is_retryable_status( + reqwest::StatusCode::INTERNAL_SERVER_ERROR + )); + assert!(!super::is_retryable_status( + reqwest::StatusCode::UNAUTHORIZED + )); + } + + #[test] + fn tool_delta_variant_round_trips() { + let delta = ContentBlockDelta::InputJsonDelta { + partial_json: "{\"city\":\"Paris\"}".to_string(), + }; + let encoded = serde_json::to_string(&delta).expect("delta should serialize"); + let decoded: ContentBlockDelta = + serde_json::from_str(&encoded).expect("delta should deserialize"); + assert_eq!(decoded, delta); + } + + #[test] + fn request_id_uses_primary_or_fallback_header() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header")); + assert_eq!( + super::request_id_from_headers(&headers).as_deref(), + Some("req_primary") + ); + + headers.clear(); + headers.insert( + ALT_REQUEST_ID_HEADER, + "req_fallback".parse().expect("header"), + ); + assert_eq!( + super::request_id_from_headers(&headers).as_deref(), + Some("req_fallback") + ); + } + + #[test] + fn auth_source_applies_headers() { + let auth = AuthSource::ApiKeyAndBearer { + api_key: "test-key".to_string(), + bearer_token: "proxy-token".to_string(), + }; + let request = auth + .apply(reqwest::Client::new().post("https://example.test")) + .build() + .expect("request build"); + let headers = request.headers(); + assert_eq!( + headers.get("x-api-key").and_then(|v| v.to_str().ok()), + Some("test-key") + ); + assert_eq!( + headers.get("authorization").and_then(|v| v.to_str().ok()), + Some("Bearer proxy-token") + ); + } +} diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs new file mode 100644 index 0000000..cf891cc --- /dev/null +++ b/rust/crates/api/src/providers/mod.rs @@ -0,0 +1,202 @@ +use std::future::Future; +use std::pin::Pin; + +use crate::error::ApiError; +use crate::types::{MessageRequest, MessageResponse}; + +pub mod anthropic; +pub mod openai_compat; + +pub type ProviderFuture<'a, T> = Pin> + Send + 'a>>; + +pub trait Provider { + type Stream; + + fn send_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, MessageResponse>; + + fn stream_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, Self::Stream>; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ProviderKind { + Anthropic, + Xai, + OpenAi, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ProviderMetadata { + pub provider: ProviderKind, + pub canonical_model: &'static str, + pub auth_env: &'static str, + pub base_url_env: &'static str, + pub default_base_url: &'static str, +} + +const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ + ( + "opus", + ProviderMetadata { + provider: ProviderKind::Anthropic, + canonical_model: "claude-opus-4-6", + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }, + ), + ( + "sonnet", + ProviderMetadata { + provider: ProviderKind::Anthropic, + canonical_model: "claude-sonnet-4-6", + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }, + ), + ( + "haiku", + ProviderMetadata { + provider: ProviderKind::Anthropic, + canonical_model: "claude-haiku-4-5-20251213", + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }, + ), + ( + "grok", + ProviderMetadata { + provider: ProviderKind::Xai, + canonical_model: "grok-3", + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-3", + ProviderMetadata { + provider: ProviderKind::Xai, + canonical_model: "grok-3", + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-mini", + ProviderMetadata { + provider: ProviderKind::Xai, + canonical_model: "grok-3-mini", + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-3-mini", + ProviderMetadata { + provider: ProviderKind::Xai, + canonical_model: "grok-3-mini", + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-2", + ProviderMetadata { + provider: ProviderKind::Xai, + canonical_model: "grok-2", + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), +]; + +#[must_use] +pub fn resolve_model_alias(model: &str) -> String { + let trimmed = model.trim(); + let lower = trimmed.to_ascii_lowercase(); + MODEL_REGISTRY + .iter() + .find_map(|(alias, metadata)| (*alias == lower).then_some(metadata.canonical_model)) + .map_or_else(|| trimmed.to_string(), ToOwned::to_owned) +} + +#[must_use] +pub fn metadata_for_model(model: &str) -> Option { + let canonical = resolve_model_alias(model); + if canonical.starts_with("claude") { + return Some(ProviderMetadata { + provider: ProviderKind::Anthropic, + canonical_model: Box::leak(canonical.into_boxed_str()), + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: anthropic::DEFAULT_BASE_URL, + }); + } + if canonical.starts_with("grok") { + return Some(ProviderMetadata { + provider: ProviderKind::Xai, + canonical_model: Box::leak(canonical.into_boxed_str()), + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }); + } + None +} + +#[must_use] +pub fn detect_provider_kind(model: &str) -> ProviderKind { + if let Some(metadata) = metadata_for_model(model) { + return metadata.provider; + } + if anthropic::has_auth_from_env_or_saved().unwrap_or(false) { + return ProviderKind::Anthropic; + } + if openai_compat::has_api_key("OPENAI_API_KEY") { + return ProviderKind::OpenAi; + } + if openai_compat::has_api_key("XAI_API_KEY") { + return ProviderKind::Xai; + } + ProviderKind::Anthropic +} + +#[must_use] +pub fn max_tokens_for_model(model: &str) -> u32 { + let canonical = resolve_model_alias(model); + if canonical.contains("opus") { + 32_000 + } else { + 64_000 + } +} + +#[cfg(test)] +mod tests { + use super::{detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind}; + + #[test] + fn resolves_grok_aliases() { + assert_eq!(resolve_model_alias("grok"), "grok-3"); + assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); + assert_eq!(resolve_model_alias("grok-2"), "grok-2"); + } + + #[test] + fn detects_provider_from_model_name_first() { + assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai); + assert_eq!(detect_provider_kind("claude-sonnet-4-6"), ProviderKind::Anthropic); + } + + #[test] + fn keeps_existing_max_token_heuristic() { + assert_eq!(max_tokens_for_model("opus"), 32_000); + assert_eq!(max_tokens_for_model("grok-3"), 64_000); + } +} From 2a0f4b677af854d7a20a430c367c753331615448 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 04:10:46 +0000 Subject: [PATCH 2/5] feat: provider abstraction layer + Grok API support --- rust/crates/api/src/client.rs | 1018 ++-------------- rust/crates/api/src/lib.rs | 10 +- rust/crates/api/src/providers/anthropic.rs | 58 +- rust/crates/api/src/providers/mod.rs | 44 +- .../crates/api/src/providers/openai_compat.rs | 1025 +++++++++++++++++ .../api/tests/openai_compat_integration.rs | 312 +++++ rust/crates/rusty-claude-cli/src/main.rs | 49 +- rust/crates/tools/src/lib.rs | 30 +- 8 files changed, 1547 insertions(+), 999 deletions(-) create mode 100644 rust/crates/api/src/providers/openai_compat.rs create mode 100644 rust/crates/api/tests/openai_compat_integration.rs diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 7ef7e83..467697e 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -1,994 +1,142 @@ -use std::collections::VecDeque; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; - -use runtime::{ - load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest, - OAuthTokenExchangeRequest, -}; -use serde::Deserialize; - use crate::error::ApiError; -use crate::sse::SseParser; +use crate::providers::anthropic::{self, AnthropicClient, AuthSource}; +use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig}; +use crate::providers::{self, Provider, ProviderKind}; use crate::types::{MessageRequest, MessageResponse, StreamEvent}; -const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; -const ANTHROPIC_VERSION: &str = "2023-06-01"; -const REQUEST_ID_HEADER: &str = "request-id"; -const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; -const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); -const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); -const DEFAULT_MAX_RETRIES: u32 = 2; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum AuthSource { - None, - ApiKey(String), - BearerToken(String), - ApiKeyAndBearer { - api_key: String, - bearer_token: String, - }, +async fn send_via_provider( + provider: &P, + request: &MessageRequest, +) -> Result { + provider.send_message(request).await } -impl AuthSource { - pub fn from_env() -> Result { - let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?; - let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?; - match (api_key, auth_token) { - (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer { - api_key, - bearer_token, - }), - (Some(api_key), None) => Ok(Self::ApiKey(api_key)), - (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), - (None, None) => Err(ApiError::MissingApiKey), - } - } - - #[must_use] - pub fn api_key(&self) -> Option<&str> { - match self { - Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key), - Self::None | Self::BearerToken(_) => None, - } - } - - #[must_use] - pub fn bearer_token(&self) -> Option<&str> { - match self { - Self::BearerToken(token) - | Self::ApiKeyAndBearer { - bearer_token: token, - .. - } => Some(token), - Self::None | Self::ApiKey(_) => None, - } - } - - #[must_use] - pub fn masked_authorization_header(&self) -> &'static str { - if self.bearer_token().is_some() { - "Bearer [REDACTED]" - } else { - "" - } - } - - pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { - if let Some(api_key) = self.api_key() { - request_builder = request_builder.header("x-api-key", api_key); - } - if let Some(token) = self.bearer_token() { - request_builder = request_builder.bearer_auth(token); - } - request_builder - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] -pub struct OAuthTokenSet { - pub access_token: String, - pub refresh_token: Option, - pub expires_at: Option, - #[serde(default)] - pub scopes: Vec, -} - -impl From for AuthSource { - fn from(value: OAuthTokenSet) -> Self { - Self::BearerToken(value.access_token) - } +async fn stream_via_provider( + provider: &P, + request: &MessageRequest, +) -> Result { + provider.stream_message(request).await } #[derive(Debug, Clone)] -pub struct AnthropicClient { - http: reqwest::Client, - auth: AuthSource, - base_url: String, - max_retries: u32, - initial_backoff: Duration, - max_backoff: Duration, +pub enum ProviderClient { + Anthropic(AnthropicClient), + Xai(OpenAiCompatClient), + OpenAi(OpenAiCompatClient), } -impl AnthropicClient { - #[must_use] - pub fn new(api_key: impl Into) -> Self { - Self { - http: reqwest::Client::new(), - auth: AuthSource::ApiKey(api_key.into()), - base_url: DEFAULT_BASE_URL.to_string(), - max_retries: DEFAULT_MAX_RETRIES, - initial_backoff: DEFAULT_INITIAL_BACKOFF, - max_backoff: DEFAULT_MAX_BACKOFF, +impl ProviderClient { + pub fn from_model(model: &str) -> Result { + Self::from_model_with_anthropic_auth(model, None) + } + + pub fn from_model_with_anthropic_auth( + model: &str, + anthropic_auth: Option, + ) -> Result { + let resolved_model = providers::resolve_model_alias(model); + match providers::detect_provider_kind(&resolved_model) { + ProviderKind::Anthropic => Ok(Self::Anthropic( + anthropic_auth + .map(AnthropicClient::from_auth) + .unwrap_or(AnthropicClient::from_env()?), + )), + ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env( + OpenAiCompatConfig::xai(), + )?)), + ProviderKind::OpenAi => Ok(Self::OpenAi(OpenAiCompatClient::from_env( + OpenAiCompatConfig::openai(), + )?)), } } #[must_use] - pub fn from_auth(auth: AuthSource) -> Self { - Self { - http: reqwest::Client::new(), - auth, - base_url: DEFAULT_BASE_URL.to_string(), - max_retries: DEFAULT_MAX_RETRIES, - initial_backoff: DEFAULT_INITIAL_BACKOFF, - max_backoff: DEFAULT_MAX_BACKOFF, + pub const fn provider_kind(&self) -> ProviderKind { + match self { + Self::Anthropic(_) => ProviderKind::Anthropic, + Self::Xai(_) => ProviderKind::Xai, + Self::OpenAi(_) => ProviderKind::OpenAi, } } - pub fn from_env() -> Result { - Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url())) - } - - #[must_use] - pub fn with_auth_source(mut self, auth: AuthSource) -> Self { - self.auth = auth; - self - } - - #[must_use] - pub fn with_auth_token(mut self, auth_token: Option) -> Self { - match ( - self.auth.api_key().map(ToOwned::to_owned), - auth_token.filter(|token| !token.is_empty()), - ) { - (Some(api_key), Some(bearer_token)) => { - self.auth = AuthSource::ApiKeyAndBearer { - api_key, - bearer_token, - }; - } - (Some(api_key), None) => { - self.auth = AuthSource::ApiKey(api_key); - } - (None, Some(bearer_token)) => { - self.auth = AuthSource::BearerToken(bearer_token); - } - (None, None) => { - self.auth = AuthSource::None; - } - } - self - } - - #[must_use] - pub fn with_base_url(mut self, base_url: impl Into) -> Self { - self.base_url = base_url.into(); - self - } - - #[must_use] - pub fn with_retry_policy( - mut self, - max_retries: u32, - initial_backoff: Duration, - max_backoff: Duration, - ) -> Self { - self.max_retries = max_retries; - self.initial_backoff = initial_backoff; - self.max_backoff = max_backoff; - self - } - - #[must_use] - pub fn auth_source(&self) -> &AuthSource { - &self.auth - } - pub async fn send_message( &self, request: &MessageRequest, ) -> Result { - let request = MessageRequest { - stream: false, - ..request.clone() - }; - let response = self.send_with_retry(&request).await?; - let request_id = request_id_from_headers(response.headers()); - let mut response = response - .json::() - .await - .map_err(ApiError::from)?; - if response.request_id.is_none() { - response.request_id = request_id; + match self { + Self::Anthropic(client) => send_via_provider(client, request).await, + Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await, } - Ok(response) } pub async fn stream_message( &self, request: &MessageRequest, ) -> Result { - let response = self - .send_with_retry(&request.clone().with_streaming()) - .await?; - Ok(MessageStream { - request_id: request_id_from_headers(response.headers()), - response, - parser: SseParser::new(), - pending: VecDeque::new(), - done: false, - }) - } - - pub async fn exchange_oauth_code( - &self, - config: &OAuthConfig, - request: &OAuthTokenExchangeRequest, - ) -> Result { - let response = self - .http - .post(&config.token_url) - .header("content-type", "application/x-www-form-urlencoded") - .form(&request.form_params()) - .send() - .await - .map_err(ApiError::from)?; - let response = expect_success(response).await?; - response - .json::() - .await - .map_err(ApiError::from) - } - - pub async fn refresh_oauth_token( - &self, - config: &OAuthConfig, - request: &OAuthRefreshRequest, - ) -> Result { - let response = self - .http - .post(&config.token_url) - .header("content-type", "application/x-www-form-urlencoded") - .form(&request.form_params()) - .send() - .await - .map_err(ApiError::from)?; - let response = expect_success(response).await?; - response - .json::() - .await - .map_err(ApiError::from) - } - - async fn send_with_retry( - &self, - request: &MessageRequest, - ) -> Result { - let mut attempts = 0; - let mut last_error: Option; - - loop { - attempts += 1; - match self.send_raw_request(request).await { - Ok(response) => match expect_success(response).await { - Ok(response) => return Ok(response), - Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { - last_error = Some(error); - } - Err(error) => return Err(error), - }, - Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { - last_error = Some(error); - } - Err(error) => return Err(error), - } - - if attempts > self.max_retries { - break; - } - - tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; - } - - Err(ApiError::RetriesExhausted { - attempts, - last_error: Box::new(last_error.expect("retry loop must capture an error")), - }) - } - - async fn send_raw_request( - &self, - request: &MessageRequest, - ) -> Result { - let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/')); - let request_builder = self - .http - .post(&request_url) - .header("anthropic-version", ANTHROPIC_VERSION) - .header("content-type", "application/json"); - let mut request_builder = self.auth.apply(request_builder); - - request_builder = request_builder.json(request); - request_builder.send().await.map_err(ApiError::from) - } - - fn backoff_for_attempt(&self, attempt: u32) -> Result { - let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { - return Err(ApiError::BackoffOverflow { - attempt, - base_delay: self.initial_backoff, - }); - }; - Ok(self - .initial_backoff - .checked_mul(multiplier) - .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) - } -} - -impl AuthSource { - pub fn from_env_or_saved() -> Result { - if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { - return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - Some(bearer_token) => Ok(Self::ApiKeyAndBearer { - api_key, - bearer_token, - }), - None => Ok(Self::ApiKey(api_key)), - }; - } - if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - return Ok(Self::BearerToken(bearer_token)); - } - match load_saved_oauth_token() { - Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => { - if token_set.refresh_token.is_some() { - Err(ApiError::Auth( - "saved OAuth token is expired; load runtime OAuth config to refresh it" - .to_string(), - )) - } else { - Err(ApiError::ExpiredOAuthToken) - } - } - Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)), - Ok(None) => Err(ApiError::MissingApiKey), - Err(error) => Err(error), + match self { + Self::Anthropic(client) => stream_via_provider(client, request) + .await + .map(MessageStream::Anthropic), + Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request) + .await + .map(MessageStream::OpenAiCompat), } } } -#[must_use] -pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool { - token_set - .expires_at - .is_some_and(|expires_at| expires_at <= now_unix_timestamp()) -} - -pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result, ApiError> { - let Some(token_set) = load_saved_oauth_token()? else { - return Ok(None); - }; - resolve_saved_oauth_token_set(config, token_set).map(Some) -} - -pub fn resolve_startup_auth_source(load_oauth_config: F) -> Result -where - F: FnOnce() -> Result, ApiError>, -{ - if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { - return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer { - api_key, - bearer_token, - }), - None => Ok(AuthSource::ApiKey(api_key)), - }; - } - if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - return Ok(AuthSource::BearerToken(bearer_token)); - } - - let Some(token_set) = load_saved_oauth_token()? else { - return Err(ApiError::MissingApiKey); - }; - if !oauth_token_is_expired(&token_set) { - return Ok(AuthSource::BearerToken(token_set.access_token)); - } - if token_set.refresh_token.is_none() { - return Err(ApiError::ExpiredOAuthToken); - } - - let Some(config) = load_oauth_config()? else { - return Err(ApiError::Auth( - "saved OAuth token is expired; runtime OAuth config is missing".to_string(), - )); - }; - Ok(AuthSource::from(resolve_saved_oauth_token_set( - &config, token_set, - )?)) -} - -fn resolve_saved_oauth_token_set( - config: &OAuthConfig, - token_set: OAuthTokenSet, -) -> Result { - if !oauth_token_is_expired(&token_set) { - return Ok(token_set); - } - let Some(refresh_token) = token_set.refresh_token.clone() else { - return Err(ApiError::ExpiredOAuthToken); - }; - let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url()); - let refreshed = client_runtime_block_on(async { - client - .refresh_oauth_token( - config, - &OAuthRefreshRequest::from_config( - config, - refresh_token, - Some(token_set.scopes.clone()), - ), - ) - .await - })?; - let resolved = OAuthTokenSet { - access_token: refreshed.access_token, - refresh_token: refreshed.refresh_token.or(token_set.refresh_token), - expires_at: refreshed.expires_at, - scopes: refreshed.scopes, - }; - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: resolved.access_token.clone(), - refresh_token: resolved.refresh_token.clone(), - expires_at: resolved.expires_at, - scopes: resolved.scopes.clone(), - }) - .map_err(ApiError::from)?; - Ok(resolved) -} - -fn client_runtime_block_on(future: F) -> Result -where - F: std::future::Future>, -{ - tokio::runtime::Runtime::new() - .map_err(ApiError::from)? - .block_on(future) -} - -fn load_saved_oauth_token() -> Result, ApiError> { - let token_set = load_oauth_credentials().map_err(ApiError::from)?; - Ok(token_set.map(|token_set| OAuthTokenSet { - access_token: token_set.access_token, - refresh_token: token_set.refresh_token, - expires_at: token_set.expires_at, - scopes: token_set.scopes, - })) -} - -fn now_unix_timestamp() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_or(0, |duration| duration.as_secs()) -} - -fn read_env_non_empty(key: &str) -> Result, ApiError> { - match std::env::var(key) { - Ok(value) if !value.is_empty() => Ok(Some(value)), - Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), - Err(error) => Err(ApiError::from(error)), - } -} - -#[cfg(test)] -fn read_api_key() -> Result { - let auth = AuthSource::from_env_or_saved()?; - auth.api_key() - .or_else(|| auth.bearer_token()) - .map(ToOwned::to_owned) - .ok_or(ApiError::MissingApiKey) -} - -#[cfg(test)] -fn read_auth_token() -> Option { - read_env_non_empty("ANTHROPIC_AUTH_TOKEN") - .ok() - .and_then(std::convert::identity) -} - -#[must_use] -pub fn read_base_url() -> String { - std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()) -} - -fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { - headers - .get(REQUEST_ID_HEADER) - .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) - .and_then(|value| value.to_str().ok()) - .map(ToOwned::to_owned) -} - #[derive(Debug)] -pub struct MessageStream { - request_id: Option, - response: reqwest::Response, - parser: SseParser, - pending: VecDeque, - done: bool, +pub enum MessageStream { + Anthropic(anthropic::MessageStream), + OpenAiCompat(openai_compat::MessageStream), } impl MessageStream { #[must_use] pub fn request_id(&self) -> Option<&str> { - self.request_id.as_deref() + match self { + Self::Anthropic(stream) => stream.request_id(), + Self::OpenAiCompat(stream) => stream.request_id(), + } } pub async fn next_event(&mut self) -> Result, ApiError> { - loop { - if let Some(event) = self.pending.pop_front() { - return Ok(Some(event)); - } - - if self.done { - let remaining = self.parser.finish()?; - self.pending.extend(remaining); - if let Some(event) = self.pending.pop_front() { - return Ok(Some(event)); - } - return Ok(None); - } - - match self.response.chunk().await? { - Some(chunk) => { - self.pending.extend(self.parser.push(&chunk)?); - } - None => { - self.done = true; - } - } + match self { + Self::Anthropic(stream) => stream.next_event().await, + Self::OpenAiCompat(stream) => stream.next_event().await, } } } -async fn expect_success(response: reqwest::Response) -> Result { - let status = response.status(); - if status.is_success() { - return Ok(response); - } - - let body = response.text().await.unwrap_or_else(|_| String::new()); - let parsed_error = serde_json::from_str::(&body).ok(); - let retryable = is_retryable_status(status); - - Err(ApiError::Api { - status, - error_type: parsed_error - .as_ref() - .map(|error| error.error.error_type.clone()), - message: parsed_error - .as_ref() - .map(|error| error.error.message.clone()), - body, - retryable, - }) +pub use anthropic::{ + oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, OAuthTokenSet, +}; +#[must_use] +pub fn read_base_url() -> String { + anthropic::read_base_url() } -const fn is_retryable_status(status: reqwest::StatusCode) -> bool { - matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) -} - -#[derive(Debug, Deserialize)] -struct AnthropicErrorEnvelope { - error: AnthropicErrorBody, -} - -#[derive(Debug, Deserialize)] -struct AnthropicErrorBody { - #[serde(rename = "type")] - error_type: String, - message: String, +#[must_use] +pub fn read_xai_base_url() -> String { + openai_compat::read_base_url(OpenAiCompatConfig::xai()) } #[cfg(test)] mod tests { - use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; - use std::io::{Read, Write}; - use std::net::TcpListener; - use std::sync::{Mutex, OnceLock}; - use std::thread; - use std::time::{Duration, SystemTime, UNIX_EPOCH}; + use crate::providers::{detect_provider_kind, resolve_model_alias, ProviderKind}; - use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig}; - - use crate::client::{ - now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, - resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet, - }; - use crate::types::{ContentBlockDelta, MessageRequest}; - - fn env_lock() -> std::sync::MutexGuard<'static, ()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) - .lock() - .expect("env lock") - } - - fn temp_config_home() -> std::path::PathBuf { - std::env::temp_dir().join(format!( - "api-oauth-test-{}-{}", - std::process::id(), - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("time") - .as_nanos() - )) - } - - fn sample_oauth_config(token_url: String) -> OAuthConfig { - OAuthConfig { - client_id: "runtime-client".to_string(), - authorize_url: "https://console.test/oauth/authorize".to_string(), - token_url, - callback_port: Some(4545), - manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), - scopes: vec!["org:read".to_string(), "user:write".to_string()], - } - } - - fn spawn_token_server(response_body: &'static str) -> String { - let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); - let address = listener.local_addr().expect("local addr"); - thread::spawn(move || { - let (mut stream, _) = listener.accept().expect("accept connection"); - let mut buffer = [0_u8; 4096]; - let _ = stream.read(&mut buffer).expect("read request"); - let response = format!( - "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", - response_body.len(), - response_body - ); - stream - .write_all(response.as_bytes()) - .expect("write response"); - }); - format!("http://{address}/oauth/token") + #[test] + fn resolves_existing_and_grok_aliases() { + assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6"); + assert_eq!(resolve_model_alias("grok"), "grok-3"); + assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); } #[test] - fn read_api_key_requires_presence() { - let _guard = env_lock(); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - let error = super::read_api_key().expect_err("missing key should error"); - assert!(matches!(error, crate::error::ApiError::MissingApiKey)); - } - - #[test] - fn read_api_key_requires_non_empty_value() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); - std::env::remove_var("ANTHROPIC_API_KEY"); - let error = super::read_api_key().expect_err("empty key should error"); - assert!(matches!(error, crate::error::ApiError::MissingApiKey)); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - } - - #[test] - fn read_api_key_prefers_api_key_env() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); - std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + fn provider_detection_prefers_model_family() { + assert_eq!(detect_provider_kind("grok-3"), ProviderKind::Xai); assert_eq!( - super::read_api_key().expect("api key should load"), - "legacy-key" - ); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - } - - #[test] - fn read_auth_token_reads_auth_token_env() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); - assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - } - - #[test] - fn oauth_token_maps_to_bearer_auth_source() { - let auth = AuthSource::from(OAuthTokenSet { - access_token: "access-token".to_string(), - refresh_token: Some("refresh".to_string()), - expires_at: Some(123), - scopes: vec!["scope:a".to_string()], - }); - assert_eq!(auth.bearer_token(), Some("access-token")); - assert_eq!(auth.api_key(), None); - } - - #[test] - fn auth_source_from_env_combines_api_key_and_bearer_token() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); - std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); - let auth = AuthSource::from_env().expect("env auth"); - assert_eq!(auth.api_key(), Some("legacy-key")); - assert_eq!(auth.bearer_token(), Some("auth-token")); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - } - - #[test] - fn auth_source_from_saved_oauth_when_env_absent() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "saved-access-token".to_string(), - refresh_token: Some("refresh".to_string()), - expires_at: Some(now_unix_timestamp() + 300), - scopes: vec!["scope:a".to_string()], - }) - .expect("save oauth credentials"); - - let auth = AuthSource::from_env_or_saved().expect("saved auth"); - assert_eq!(auth.bearer_token(), Some("saved-access-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn oauth_token_expiry_uses_expires_at_timestamp() { - assert!(oauth_token_is_expired(&OAuthTokenSet { - access_token: "access-token".to_string(), - refresh_token: None, - expires_at: Some(1), - scopes: Vec::new(), - })); - assert!(!oauth_token_is_expired(&OAuthTokenSet { - access_token: "access-token".to_string(), - refresh_token: None, - expires_at: Some(now_unix_timestamp() + 60), - scopes: Vec::new(), - })); - } - - #[test] - fn resolve_saved_oauth_token_refreshes_expired_credentials() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "expired-access-token".to_string(), - refresh_token: Some("refresh-token".to_string()), - expires_at: Some(1), - scopes: vec!["scope:a".to_string()], - }) - .expect("save expired oauth credentials"); - - let token_url = spawn_token_server( - "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", - ); - let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) - .expect("resolve refreshed token") - .expect("token set present"); - assert_eq!(resolved.access_token, "refreshed-token"); - let stored = runtime::load_oauth_credentials() - .expect("load stored credentials") - .expect("stored token set"); - assert_eq!(stored.access_token, "refreshed-token"); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "saved-access-token".to_string(), - refresh_token: Some("refresh".to_string()), - expires_at: Some(now_unix_timestamp() + 300), - scopes: vec!["scope:a".to_string()], - }) - .expect("save oauth credentials"); - - let auth = resolve_startup_auth_source(|| panic!("config should not be loaded")) - .expect("startup auth"); - assert_eq!(auth.bearer_token(), Some("saved-access-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "expired-access-token".to_string(), - refresh_token: Some("refresh-token".to_string()), - expires_at: Some(1), - scopes: vec!["scope:a".to_string()], - }) - .expect("save expired oauth credentials"); - - let error = - resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error"); - assert!( - matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing")) - ); - - let stored = runtime::load_oauth_credentials() - .expect("load stored credentials") - .expect("stored token set"); - assert_eq!(stored.access_token, "expired-access-token"); - assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "expired-access-token".to_string(), - refresh_token: Some("refresh-token".to_string()), - expires_at: Some(1), - scopes: vec!["scope:a".to_string()], - }) - .expect("save expired oauth credentials"); - - let token_url = spawn_token_server( - "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", - ); - let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) - .expect("resolve refreshed token") - .expect("token set present"); - assert_eq!(resolved.access_token, "refreshed-token"); - assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token")); - let stored = runtime::load_oauth_credentials() - .expect("load stored credentials") - .expect("stored token set"); - assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn message_request_stream_helper_sets_stream_true() { - let request = MessageRequest { - model: "claude-opus-4-6".to_string(), - max_tokens: 64, - messages: vec![], - system: None, - tools: None, - tool_choice: None, - stream: false, - }; - - assert!(request.with_streaming().stream); - } - - #[test] - fn backoff_doubles_until_maximum() { - let client = AnthropicClient::new("test-key").with_retry_policy( - 3, - Duration::from_millis(10), - Duration::from_millis(25), - ); - assert_eq!( - client.backoff_for_attempt(1).expect("attempt 1"), - Duration::from_millis(10) - ); - assert_eq!( - client.backoff_for_attempt(2).expect("attempt 2"), - Duration::from_millis(20) - ); - assert_eq!( - client.backoff_for_attempt(3).expect("attempt 3"), - Duration::from_millis(25) - ); - } - - #[test] - fn retryable_statuses_are_detected() { - assert!(super::is_retryable_status( - reqwest::StatusCode::TOO_MANY_REQUESTS - )); - assert!(super::is_retryable_status( - reqwest::StatusCode::INTERNAL_SERVER_ERROR - )); - assert!(!super::is_retryable_status( - reqwest::StatusCode::UNAUTHORIZED - )); - } - - #[test] - fn tool_delta_variant_round_trips() { - let delta = ContentBlockDelta::InputJsonDelta { - partial_json: "{\"city\":\"Paris\"}".to_string(), - }; - let encoded = serde_json::to_string(&delta).expect("delta should serialize"); - let decoded: ContentBlockDelta = - serde_json::from_str(&encoded).expect("delta should deserialize"); - assert_eq!(decoded, delta); - } - - #[test] - fn request_id_uses_primary_or_fallback_header() { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header")); - assert_eq!( - super::request_id_from_headers(&headers).as_deref(), - Some("req_primary") - ); - - headers.clear(); - headers.insert( - ALT_REQUEST_ID_HEADER, - "req_fallback".parse().expect("header"), - ); - assert_eq!( - super::request_id_from_headers(&headers).as_deref(), - Some("req_fallback") - ); - } - - #[test] - fn auth_source_applies_headers() { - let auth = AuthSource::ApiKeyAndBearer { - api_key: "test-key".to_string(), - bearer_token: "proxy-token".to_string(), - }; - let request = auth - .apply(reqwest::Client::new().post("https://example.test")) - .build() - .expect("request build"); - let headers = request.headers(); - assert_eq!( - headers.get("x-api-key").and_then(|v| v.to_str().ok()), - Some("test-key") - ); - assert_eq!( - headers.get("authorization").and_then(|v| v.to_str().ok()), - Some("Bearer proxy-token") + detect_provider_kind("claude-sonnet-4-6"), + ProviderKind::Anthropic ); } } diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index 4108187..7702fee 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -1,13 +1,19 @@ mod client; mod error; +mod providers; mod sse; mod types; pub use client::{ - oauth_token_is_expired, read_base_url, resolve_saved_oauth_token, resolve_startup_auth_source, - AnthropicClient, AuthSource, MessageStream, OAuthTokenSet, + oauth_token_is_expired, read_base_url, read_xai_base_url, resolve_saved_oauth_token, + resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient, }; pub use error::ApiError; +pub use providers::anthropic::{AnthropicClient, AuthSource}; +pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig}; +pub use providers::{ + detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind, +}; pub use sse::{parse_frame, SseParser}; pub use types::{ ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, diff --git a/rust/crates/api/src/providers/anthropic.rs b/rust/crates/api/src/providers/anthropic.rs index 4f6dd98..0883e60 100644 --- a/rust/crates/api/src/providers/anthropic.rs +++ b/rust/crates/api/src/providers/anthropic.rs @@ -8,10 +8,12 @@ use runtime::{ use serde::Deserialize; use crate::error::ApiError; + +use super::{Provider, ProviderFuture}; use crate::sse::SseParser; use crate::types::{MessageRequest, MessageResponse, StreamEvent}; -const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; +pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; const ANTHROPIC_VERSION: &str = "2023-06-01"; const REQUEST_ID_HEADER: &str = "request-id"; const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; @@ -41,7 +43,10 @@ impl AuthSource { }), (Some(api_key), None) => Ok(Self::ApiKey(api_key)), (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), - (None, None) => Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])), + (None, None) => Err(ApiError::missing_credentials( + "Anthropic", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )), } } @@ -362,7 +367,10 @@ impl AuthSource { } } Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)), - Ok(None) => Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])), + Ok(None) => Err(ApiError::missing_credentials( + "Anthropic", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )), Err(error) => Err(error), } } @@ -382,6 +390,12 @@ pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result Result { + Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some() + || read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some() + || load_saved_oauth_token()?.is_some()) +} + pub fn resolve_startup_auth_source(load_oauth_config: F) -> Result where F: FnOnce() -> Result, ApiError>, @@ -400,7 +414,10 @@ where } let Some(token_set) = load_saved_oauth_token()? else { - return Err(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])); + return Err(ApiError::missing_credentials( + "Anthropic", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )); }; if !oauth_token_is_expired(&token_set) { return Ok(AuthSource::BearerToken(token_set.access_token)); @@ -497,7 +514,10 @@ fn read_api_key() -> Result { auth.api_key() .or_else(|| auth.bearer_token()) .map(ToOwned::to_owned) - .ok_or(ApiError::missing_credentials("Anthropic", &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"])) + .ok_or(ApiError::missing_credentials( + "Anthropic", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )) } #[cfg(test)] @@ -520,6 +540,24 @@ fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse> { + Box::pin(async move { self.send_message(request).await }) + } + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream> { + Box::pin(async move { self.stream_message(request).await }) + } +} + #[derive(Debug)] pub struct MessageStream { request_id: Option, @@ -673,7 +711,10 @@ mod tests { std::env::remove_var("ANTHROPIC_API_KEY"); std::env::remove_var("CLAUDE_CONFIG_HOME"); let error = super::read_api_key().expect_err("missing key should error"); - assert!(matches!(error, crate::error::ApiError::MissingCredentials { .. })); + assert!(matches!( + error, + crate::error::ApiError::MissingCredentials { .. } + )); } #[test] @@ -682,7 +723,10 @@ mod tests { std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); std::env::remove_var("ANTHROPIC_API_KEY"); let error = super::read_api_key().expect_err("empty key should error"); - assert!(matches!(error, crate::error::ApiError::MissingCredentials { .. })); + assert!(matches!( + error, + crate::error::ApiError::MissingCredentials { .. } + )); std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); } diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs index cf891cc..d28febd 100644 --- a/rust/crates/api/src/providers/mod.rs +++ b/rust/crates/api/src/providers/mod.rs @@ -12,9 +12,15 @@ pub type ProviderFuture<'a, T> = Pin pub trait Provider { type Stream; - fn send_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, MessageResponse>; + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse>; - fn stream_message<'a>(&'a self, request: &'a MessageRequest) -> ProviderFuture<'a, Self::Stream>; + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream>; } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -27,7 +33,6 @@ pub enum ProviderKind { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ProviderMetadata { pub provider: ProviderKind, - pub canonical_model: &'static str, pub auth_env: &'static str, pub base_url_env: &'static str, pub default_base_url: &'static str, @@ -38,7 +43,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ "opus", ProviderMetadata { provider: ProviderKind::Anthropic, - canonical_model: "claude-opus-4-6", auth_env: "ANTHROPIC_API_KEY", base_url_env: "ANTHROPIC_BASE_URL", default_base_url: anthropic::DEFAULT_BASE_URL, @@ -48,7 +52,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ "sonnet", ProviderMetadata { provider: ProviderKind::Anthropic, - canonical_model: "claude-sonnet-4-6", auth_env: "ANTHROPIC_API_KEY", base_url_env: "ANTHROPIC_BASE_URL", default_base_url: anthropic::DEFAULT_BASE_URL, @@ -58,7 +61,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ "haiku", ProviderMetadata { provider: ProviderKind::Anthropic, - canonical_model: "claude-haiku-4-5-20251213", auth_env: "ANTHROPIC_API_KEY", base_url_env: "ANTHROPIC_BASE_URL", default_base_url: anthropic::DEFAULT_BASE_URL, @@ -68,7 +70,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ "grok", ProviderMetadata { provider: ProviderKind::Xai, - canonical_model: "grok-3", auth_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, @@ -78,7 +79,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ "grok-3", ProviderMetadata { provider: ProviderKind::Xai, - canonical_model: "grok-3", auth_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, @@ -88,7 +88,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ "grok-mini", ProviderMetadata { provider: ProviderKind::Xai, - canonical_model: "grok-3-mini", auth_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, @@ -98,7 +97,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ "grok-3-mini", ProviderMetadata { provider: ProviderKind::Xai, - canonical_model: "grok-3-mini", auth_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, @@ -108,7 +106,6 @@ const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ "grok-2", ProviderMetadata { provider: ProviderKind::Xai, - canonical_model: "grok-2", auth_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, @@ -122,7 +119,23 @@ pub fn resolve_model_alias(model: &str) -> String { let lower = trimmed.to_ascii_lowercase(); MODEL_REGISTRY .iter() - .find_map(|(alias, metadata)| (*alias == lower).then_some(metadata.canonical_model)) + .find_map(|(alias, metadata)| { + (*alias == lower).then_some(match metadata.provider { + ProviderKind::Anthropic => match *alias { + "opus" => "claude-opus-4-6", + "sonnet" => "claude-sonnet-4-6", + "haiku" => "claude-haiku-4-5-20251213", + _ => trimmed, + }, + ProviderKind::Xai => match *alias { + "grok" | "grok-3" => "grok-3", + "grok-mini" | "grok-3-mini" => "grok-3-mini", + "grok-2" => "grok-2", + _ => trimmed, + }, + ProviderKind::OpenAi => trimmed, + }) + }) .map_or_else(|| trimmed.to_string(), ToOwned::to_owned) } @@ -132,7 +145,6 @@ pub fn metadata_for_model(model: &str) -> Option { if canonical.starts_with("claude") { return Some(ProviderMetadata { provider: ProviderKind::Anthropic, - canonical_model: Box::leak(canonical.into_boxed_str()), auth_env: "ANTHROPIC_API_KEY", base_url_env: "ANTHROPIC_BASE_URL", default_base_url: anthropic::DEFAULT_BASE_URL, @@ -141,7 +153,6 @@ pub fn metadata_for_model(model: &str) -> Option { if canonical.starts_with("grok") { return Some(ProviderMetadata { provider: ProviderKind::Xai, - canonical_model: Box::leak(canonical.into_boxed_str()), auth_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, @@ -191,7 +202,10 @@ mod tests { #[test] fn detects_provider_from_model_name_first() { assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai); - assert_eq!(detect_provider_kind("claude-sonnet-4-6"), ProviderKind::Anthropic); + assert_eq!( + detect_provider_kind("claude-sonnet-4-6"), + ProviderKind::Anthropic + ); } #[test] diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs new file mode 100644 index 0000000..8a0fe9c --- /dev/null +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -0,0 +1,1025 @@ +use std::collections::{BTreeMap, VecDeque}; +use std::time::Duration; + +use serde::Deserialize; +use serde_json::{json, Value}; + +use crate::error::ApiError; +use crate::types::{ + ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, + InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest, + MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, + ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, +}; + +use super::{Provider, ProviderFuture}; + +pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; +pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; +const REQUEST_ID_HEADER: &str = "request-id"; +const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; +const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); +const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); +const DEFAULT_MAX_RETRIES: u32 = 2; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct OpenAiCompatConfig { + pub provider_name: &'static str, + pub api_key_env: &'static str, + pub base_url_env: &'static str, + pub default_base_url: &'static str, +} + +const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"]; +const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"]; + +impl OpenAiCompatConfig { + #[must_use] + pub const fn xai() -> Self { + Self { + provider_name: "xAI", + api_key_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: DEFAULT_XAI_BASE_URL, + } + } + + #[must_use] + pub const fn openai() -> Self { + Self { + provider_name: "OpenAI", + api_key_env: "OPENAI_API_KEY", + base_url_env: "OPENAI_BASE_URL", + default_base_url: DEFAULT_OPENAI_BASE_URL, + } + } + #[must_use] + pub fn credential_env_vars(self) -> &'static [&'static str] { + match self.provider_name { + "xAI" => XAI_ENV_VARS, + "OpenAI" => OPENAI_ENV_VARS, + _ => &[], + } + } +} + +#[derive(Debug, Clone)] +pub struct OpenAiCompatClient { + http: reqwest::Client, + api_key: String, + base_url: String, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, +} + +impl OpenAiCompatClient { + #[must_use] + pub fn new(api_key: impl Into, config: OpenAiCompatConfig) -> Self { + Self { + http: reqwest::Client::new(), + api_key: api_key.into(), + base_url: read_base_url(config), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + pub fn from_env(config: OpenAiCompatConfig) -> Result { + let Some(api_key) = read_env_non_empty(config.api_key_env)? else { + return Err(ApiError::missing_credentials( + config.provider_name, + config.credential_env_vars(), + )); + }; + Ok(Self::new(api_key, config)) + } + + #[must_use] + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + #[must_use] + pub fn with_retry_policy( + mut self, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, + ) -> Self { + self.max_retries = max_retries; + self.initial_backoff = initial_backoff; + self.max_backoff = max_backoff; + self + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + let request = MessageRequest { + stream: false, + ..request.clone() + }; + let response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(response.headers()); + let payload = response.json::().await?; + let mut normalized = normalize_response(&request.model, payload)?; + if normalized.request_id.is_none() { + normalized.request_id = request_id; + } + Ok(normalized) + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + let response = self + .send_with_retry(&request.clone().with_streaming()) + .await?; + Ok(MessageStream { + request_id: request_id_from_headers(response.headers()), + response, + parser: OpenAiSseParser::new(), + pending: VecDeque::new(), + done: false, + state: StreamState::new(request.model.clone()), + }) + } + + async fn send_with_retry( + &self, + request: &MessageRequest, + ) -> Result { + let mut attempts = 0; + + let last_error = loop { + attempts += 1; + let retryable_error = match self.send_raw_request(request).await { + Ok(response) => match expect_success(response).await { + Ok(response) => return Ok(response), + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, + Err(error) => return Err(error), + }, + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, + Err(error) => return Err(error), + }; + + if attempts > self.max_retries { + break retryable_error; + } + + tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; + }; + + Err(ApiError::RetriesExhausted { + attempts, + last_error: Box::new(last_error), + }) + } + + async fn send_raw_request( + &self, + request: &MessageRequest, + ) -> Result { + let request_url = format!("{}/chat/completions", self.base_url.trim_end_matches('/')); + self.http + .post(&request_url) + .header("content-type", "application/json") + .bearer_auth(&self.api_key) + .json(&build_chat_completion_request(request)) + .send() + .await + .map_err(ApiError::from) + } + + fn backoff_for_attempt(&self, attempt: u32) -> Result { + let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { + return Err(ApiError::BackoffOverflow { + attempt, + base_delay: self.initial_backoff, + }); + }; + Ok(self + .initial_backoff + .checked_mul(multiplier) + .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) + } +} + +impl Provider for OpenAiCompatClient { + type Stream = MessageStream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse> { + Box::pin(async move { self.send_message(request).await }) + } + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream> { + Box::pin(async move { self.stream_message(request).await }) + } +} + +#[derive(Debug)] +pub struct MessageStream { + request_id: Option, + response: reqwest::Response, + parser: OpenAiSseParser, + pending: VecDeque, + done: bool, + state: StreamState, +} + +impl MessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + self.request_id.as_deref() + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + loop { + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + + if self.done { + self.pending.extend(self.state.finish()?); + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + return Ok(None); + } + + match self.response.chunk().await? { + Some(chunk) => { + for parsed in self.parser.push(&chunk)? { + self.pending.extend(self.state.ingest_chunk(parsed)?); + } + } + None => { + self.done = true; + } + } + } + } +} + +#[derive(Debug, Default)] +struct OpenAiSseParser { + buffer: Vec, +} + +impl OpenAiSseParser { + fn new() -> Self { + Self::default() + } + + fn push(&mut self, chunk: &[u8]) -> Result, ApiError> { + self.buffer.extend_from_slice(chunk); + let mut events = Vec::new(); + + while let Some(frame) = next_sse_frame(&mut self.buffer) { + if let Some(event) = parse_sse_frame(&frame)? { + events.push(event); + } + } + + Ok(events) + } +} + +#[derive(Debug)] +struct StreamState { + model: String, + message_started: bool, + text_started: bool, + text_finished: bool, + finished: bool, + stop_reason: Option, + usage: Option, + tool_calls: BTreeMap, +} + +impl StreamState { + fn new(model: String) -> Self { + Self { + model, + message_started: false, + text_started: false, + text_finished: false, + finished: false, + stop_reason: None, + usage: None, + tool_calls: BTreeMap::new(), + } + } + + fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result, ApiError> { + let mut events = Vec::new(); + if !self.message_started { + self.message_started = true; + events.push(StreamEvent::MessageStart(MessageStartEvent { + message: MessageResponse { + id: chunk.id.clone(), + kind: "message".to_string(), + role: "assistant".to_string(), + content: Vec::new(), + model: chunk.model.clone().unwrap_or_else(|| self.model.clone()), + stop_reason: None, + stop_sequence: None, + usage: Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 0, + }, + request_id: None, + }, + })); + } + + if let Some(usage) = chunk.usage { + self.usage = Some(Usage { + input_tokens: usage.prompt_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: usage.completion_tokens, + }); + } + + for choice in chunk.choices { + if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) { + if !self.text_started { + self.text_started = true; + events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 0, + content_block: OutputContentBlock::Text { + text: String::new(), + }, + })); + } + events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 0, + delta: ContentBlockDelta::TextDelta { text: content }, + })); + } + + for tool_call in choice.delta.tool_calls { + let state = self.tool_calls.entry(tool_call.index).or_default(); + state.apply(tool_call); + let block_index = state.block_index(); + if !state.started { + if let Some(start_event) = state.start_event()? { + state.started = true; + events.push(StreamEvent::ContentBlockStart(start_event)); + } else { + continue; + } + } + if let Some(delta_event) = state.delta_event() { + events.push(StreamEvent::ContentBlockDelta(delta_event)); + } + if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: block_index, + })); + } + } + + if let Some(finish_reason) = choice.finish_reason { + self.stop_reason = Some(normalize_finish_reason(&finish_reason)); + if finish_reason == "tool_calls" { + for state in self.tool_calls.values_mut() { + if state.started && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: state.block_index(), + })); + } + } + } + } + } + + Ok(events) + } + + fn finish(&mut self) -> Result, ApiError> { + if self.finished { + return Ok(Vec::new()); + } + self.finished = true; + + let mut events = Vec::new(); + if self.text_started && !self.text_finished { + self.text_finished = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: 0, + })); + } + + for state in self.tool_calls.values_mut() { + if !state.started { + if let Some(start_event) = state.start_event()? { + state.started = true; + events.push(StreamEvent::ContentBlockStart(start_event)); + if let Some(delta_event) = state.delta_event() { + events.push(StreamEvent::ContentBlockDelta(delta_event)); + } + } + } + if state.started && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: state.block_index(), + })); + } + } + + if self.message_started { + events.push(StreamEvent::MessageDelta(MessageDeltaEvent { + delta: MessageDelta { + stop_reason: Some( + self.stop_reason + .clone() + .unwrap_or_else(|| "end_turn".to_string()), + ), + stop_sequence: None, + }, + usage: self.usage.clone().unwrap_or(Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 0, + }), + })); + events.push(StreamEvent::MessageStop(MessageStopEvent {})); + } + Ok(events) + } +} + +#[derive(Debug, Default)] +struct ToolCallState { + openai_index: u32, + id: Option, + name: Option, + arguments: String, + emitted_len: usize, + started: bool, + stopped: bool, +} + +impl ToolCallState { + fn apply(&mut self, tool_call: DeltaToolCall) { + self.openai_index = tool_call.index; + if let Some(id) = tool_call.id { + self.id = Some(id); + } + if let Some(name) = tool_call.function.name { + self.name = Some(name); + } + if let Some(arguments) = tool_call.function.arguments { + self.arguments.push_str(&arguments); + } + } + + const fn block_index(&self) -> u32 { + self.openai_index + 1 + } + + fn start_event(&self) -> Result, ApiError> { + let Some(name) = self.name.clone() else { + return Ok(None); + }; + let id = self + .id + .clone() + .unwrap_or_else(|| format!("tool_call_{}", self.openai_index)); + Ok(Some(ContentBlockStartEvent { + index: self.block_index(), + content_block: OutputContentBlock::ToolUse { + id, + name, + input: json!({}), + }, + })) + } + + fn delta_event(&mut self) -> Option { + if self.emitted_len >= self.arguments.len() { + return None; + } + let delta = self.arguments[self.emitted_len..].to_string(); + self.emitted_len = self.arguments.len(); + Some(ContentBlockDeltaEvent { + index: self.block_index(), + delta: ContentBlockDelta::InputJsonDelta { + partial_json: delta, + }, + }) + } +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionResponse { + id: String, + model: String, + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ChatChoice { + message: ChatMessage, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct ChatMessage { + role: String, + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Vec, +} + +#[derive(Debug, Deserialize)] +struct ResponseToolCall { + id: String, + function: ResponseToolFunction, +} + +#[derive(Debug, Deserialize)] +struct ResponseToolFunction { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct OpenAiUsage { + #[serde(default)] + prompt_tokens: u32, + #[serde(default)] + completion_tokens: u32, +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionChunk { + id: String, + #[serde(default)] + model: Option, + #[serde(default)] + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ChunkChoice { + delta: ChunkDelta, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Default, Deserialize)] +struct ChunkDelta { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Vec, +} + +#[derive(Debug, Deserialize)] +struct DeltaToolCall { + #[serde(default)] + index: u32, + #[serde(default)] + id: Option, + #[serde(default)] + function: DeltaFunction, +} + +#[derive(Debug, Default, Deserialize)] +struct DeltaFunction { + #[serde(default)] + name: Option, + #[serde(default)] + arguments: Option, +} + +#[derive(Debug, Deserialize)] +struct ErrorEnvelope { + error: ErrorBody, +} + +#[derive(Debug, Deserialize)] +struct ErrorBody { + #[serde(rename = "type")] + error_type: Option, + message: Option, +} + +fn build_chat_completion_request(request: &MessageRequest) -> Value { + let mut messages = Vec::new(); + if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) { + messages.push(json!({ + "role": "system", + "content": system, + })); + } + for message in &request.messages { + messages.extend(translate_message(message)); + } + + let mut payload = json!({ + "model": request.model, + "max_tokens": request.max_tokens, + "messages": messages, + "stream": request.stream, + }); + + if let Some(tools) = &request.tools { + payload["tools"] = + Value::Array(tools.iter().map(openai_tool_definition).collect::>()); + } + if let Some(tool_choice) = &request.tool_choice { + payload["tool_choice"] = openai_tool_choice(tool_choice); + } + + payload +} + +fn translate_message(message: &InputMessage) -> Vec { + match message.role.as_str() { + "assistant" => { + let mut text = String::new(); + let mut tool_calls = Vec::new(); + for block in &message.content { + match block { + InputContentBlock::Text { text: value } => text.push_str(value), + InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({ + "id": id, + "type": "function", + "function": { + "name": name, + "arguments": input.to_string(), + } + })), + InputContentBlock::ToolResult { .. } => {} + } + } + if text.is_empty() && tool_calls.is_empty() { + Vec::new() + } else { + vec![json!({ + "role": "assistant", + "content": (!text.is_empty()).then_some(text), + "tool_calls": tool_calls, + })] + } + } + _ => message + .content + .iter() + .filter_map(|block| match block { + InputContentBlock::Text { text } => Some(json!({ + "role": "user", + "content": text, + })), + InputContentBlock::ToolResult { + tool_use_id, + content, + is_error, + } => Some(json!({ + "role": "tool", + "tool_call_id": tool_use_id, + "content": flatten_tool_result_content(content), + "is_error": is_error, + })), + InputContentBlock::ToolUse { .. } => None, + }) + .collect(), + } +} + +fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String { + content + .iter() + .map(|block| match block { + ToolResultContentBlock::Text { text } => text.clone(), + ToolResultContentBlock::Json { value } => value.to_string(), + }) + .collect::>() + .join("\n") +} + +fn openai_tool_definition(tool: &ToolDefinition) -> Value { + json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + } + }) +} + +fn openai_tool_choice(tool_choice: &ToolChoice) -> Value { + match tool_choice { + ToolChoice::Auto => Value::String("auto".to_string()), + ToolChoice::Any => Value::String("required".to_string()), + ToolChoice::Tool { name } => json!({ + "type": "function", + "function": { "name": name }, + }), + } +} + +fn normalize_response( + model: &str, + response: ChatCompletionResponse, +) -> Result { + let choice = response + .choices + .into_iter() + .next() + .ok_or(ApiError::InvalidSseFrame( + "chat completion response missing choices", + ))?; + let mut content = Vec::new(); + if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) { + content.push(OutputContentBlock::Text { text }); + } + for tool_call in choice.message.tool_calls { + content.push(OutputContentBlock::ToolUse { + id: tool_call.id, + name: tool_call.function.name, + input: parse_tool_arguments(&tool_call.function.arguments), + }); + } + + Ok(MessageResponse { + id: response.id, + kind: "message".to_string(), + role: choice.message.role, + content, + model: response.model.if_empty_then(model.to_string()), + stop_reason: choice + .finish_reason + .map(|value| normalize_finish_reason(&value)), + stop_sequence: None, + usage: Usage { + input_tokens: response + .usage + .as_ref() + .map_or(0, |usage| usage.prompt_tokens), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: response + .usage + .as_ref() + .map_or(0, |usage| usage.completion_tokens), + }, + request_id: None, + }) +} + +fn parse_tool_arguments(arguments: &str) -> Value { + serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments })) +} + +fn next_sse_frame(buffer: &mut Vec) -> Option { + let separator = buffer + .windows(2) + .position(|window| window == b"\n\n") + .map(|position| (position, 2)) + .or_else(|| { + buffer + .windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|position| (position, 4)) + })?; + + let (position, separator_len) = separator; + let frame = buffer.drain(..position + separator_len).collect::>(); + let frame_len = frame.len().saturating_sub(separator_len); + Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned()) +} + +fn parse_sse_frame(frame: &str) -> Result, ApiError> { + let trimmed = frame.trim(); + if trimmed.is_empty() { + return Ok(None); + } + + let mut data_lines = Vec::new(); + for line in trimmed.lines() { + if line.starts_with(':') { + continue; + } + if let Some(data) = line.strip_prefix("data:") { + data_lines.push(data.trim_start()); + } + } + if data_lines.is_empty() { + return Ok(None); + } + let payload = data_lines.join("\n"); + if payload == "[DONE]" { + return Ok(None); + } + serde_json::from_str(&payload) + .map(Some) + .map_err(ApiError::from) +} + +fn read_env_non_empty(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), + Err(error) => Err(ApiError::from(error)), + } +} + +#[must_use] +pub fn has_api_key(key: &str) -> bool { + read_env_non_empty(key) + .ok() + .and_then(std::convert::identity) + .is_some() +} + +#[must_use] +pub fn read_base_url(config: OpenAiCompatConfig) -> String { + std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string()) +} + +fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { + headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) + .and_then(|value| value.to_str().ok()) + .map(ToOwned::to_owned) +} + +async fn expect_success(response: reqwest::Response) -> Result { + let status = response.status(); + if status.is_success() { + return Ok(response); + } + + let body = response.text().await.unwrap_or_default(); + let parsed_error = serde_json::from_str::(&body).ok(); + let retryable = is_retryable_status(status); + + Err(ApiError::Api { + status, + error_type: parsed_error + .as_ref() + .and_then(|error| error.error.error_type.clone()), + message: parsed_error + .as_ref() + .and_then(|error| error.error.message.clone()), + body, + retryable, + }) +} + +const fn is_retryable_status(status: reqwest::StatusCode) -> bool { + matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) +} + +fn normalize_finish_reason(value: &str) -> String { + match value { + "stop" => "end_turn", + "tool_calls" => "tool_use", + other => other, + } + .to_string() +} + +trait StringExt { + fn if_empty_then(self, fallback: String) -> String; +} + +impl StringExt for String { + fn if_empty_then(self, fallback: String) -> String { + if self.is_empty() { + fallback + } else { + self + } + } +} + +#[cfg(test)] +mod tests { + use super::{ + build_chat_completion_request, normalize_finish_reason, openai_tool_choice, + parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig, + }; + use crate::error::ApiError; + use crate::types::{ + InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition, + ToolResultContentBlock, + }; + use serde_json::json; + use std::sync::{Mutex, OnceLock}; + + #[test] + fn request_translation_uses_openai_compatible_shape() { + let payload = build_chat_completion_request(&MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![ + InputContentBlock::Text { + text: "hello".to_string(), + }, + InputContentBlock::ToolResult { + tool_use_id: "tool_1".to_string(), + content: vec![ToolResultContentBlock::Json { + value: json!({"ok": true}), + }], + is_error: false, + }, + ], + }], + system: Some("be helpful".to_string()), + tools: Some(vec![ToolDefinition { + name: "weather".to_string(), + description: Some("Get weather".to_string()), + input_schema: json!({"type": "object"}), + }]), + tool_choice: Some(ToolChoice::Auto), + stream: false, + }); + + assert_eq!(payload["messages"][0]["role"], json!("system")); + assert_eq!(payload["messages"][1]["role"], json!("user")); + assert_eq!(payload["messages"][2]["role"], json!("tool")); + assert_eq!(payload["tools"][0]["type"], json!("function")); + assert_eq!(payload["tool_choice"], json!("auto")); + } + + #[test] + fn tool_choice_translation_supports_required_function() { + assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required")); + assert_eq!( + openai_tool_choice(&ToolChoice::Tool { + name: "weather".to_string(), + }), + json!({"type": "function", "function": {"name": "weather"}}) + ); + } + + #[test] + fn parses_tool_arguments_fallback() { + assert_eq!( + parse_tool_arguments("{\"city\":\"Paris\"}"), + json!({"city": "Paris"}) + ); + assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"})); + } + + #[test] + fn missing_xai_api_key_is_provider_specific() { + let _lock = env_lock(); + std::env::remove_var("XAI_API_KEY"); + let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai()) + .expect_err("missing key should error"); + assert!(matches!( + error, + ApiError::MissingCredentials { + provider: "xAI", + .. + } + )); + } + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .expect("env lock") + } + + #[test] + fn normalizes_stop_reasons() { + assert_eq!(normalize_finish_reason("stop"), "end_turn"); + assert_eq!(normalize_finish_reason("tool_calls"), "tool_use"); + } +} diff --git a/rust/crates/api/tests/openai_compat_integration.rs b/rust/crates/api/tests/openai_compat_integration.rs new file mode 100644 index 0000000..b1b6a0a --- /dev/null +++ b/rust/crates/api/tests/openai_compat_integration.rs @@ -0,0 +1,312 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use api::{ + ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, + InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig, + OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition, +}; +use serde_json::json; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::sync::Mutex; + +#[tokio::test] +async fn send_message_uses_openai_compatible_endpoint_and_auth() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"chatcmpl_test\",", + "\"model\":\"grok-3\",", + "\"choices\":[{", + "\"message\":{\"role\":\"assistant\",\"content\":\"Hello from Grok\",\"tool_calls\":[]},", + "\"finish_reason\":\"stop\"", + "}],", + "\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":5}", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(server.base_url()); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.model, "grok-3"); + assert_eq!(response.total_tokens(), 16); + assert_eq!( + response.content, + vec![OutputContentBlock::Text { + text: "Hello from Grok".to_string(), + }] + ); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/chat/completions"); + assert_eq!( + request.headers.get("authorization").map(String::as_str), + Some("Bearer xai-test-key") + ); + let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body"); + assert_eq!(body["model"], json!("grok-3")); + assert_eq!(body["messages"][0]["role"], json!("system")); + assert_eq!(body["tools"][0]["type"], json!("function")); +} + +#[tokio::test] +async fn stream_message_normalizes_text_and_multiple_tool_calls() { + let state = Arc::new(Mutex::new(Vec::::new())); + let sse = concat!( + "data: {\"id\":\"chatcmpl_stream\",\"model\":\"grok-3\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n", + "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}},{\"index\":1,\"id\":\"call_2\",\"function\":{\"name\":\"clock\",\"arguments\":\"{\\\"zone\\\":\\\"UTC\\\"}\"}}]}}]}\n\n", + "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n", + "data: [DONE]\n\n" + ); + let server = spawn_server( + state.clone(), + vec![http_response_with_headers( + "200 OK", + "text/event-stream", + sse, + &[("x-request-id", "req_grok_stream")], + )], + ) + .await; + + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(server.base_url()); + let mut stream = client + .stream_message(&sample_request(false)) + .await + .expect("stream should start"); + + assert_eq!(stream.request_id(), Some("req_grok_stream")); + + let mut events = Vec::new(); + while let Some(event) = stream.next_event().await.expect("event should parse") { + events.push(event); + } + + assert!(matches!(events[0], StreamEvent::MessageStart(_))); + assert!(matches!( + events[1], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + content_block: OutputContentBlock::Text { .. }, + .. + }) + )); + assert!(matches!( + events[2], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + delta: ContentBlockDelta::TextDelta { .. }, + .. + }) + )); + assert!(matches!( + events[3], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 1, + content_block: OutputContentBlock::ToolUse { .. }, + }) + )); + assert!(matches!( + events[4], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 1, + delta: ContentBlockDelta::InputJsonDelta { .. }, + }) + )); + assert!(matches!( + events[5], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 2, + content_block: OutputContentBlock::ToolUse { .. }, + }) + )); + assert!(matches!( + events[6], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 2, + delta: ContentBlockDelta::InputJsonDelta { .. }, + }) + )); + assert!(matches!( + events[7], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 1 }) + )); + assert!(matches!( + events[8], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 2 }) + )); + assert!(matches!( + events[9], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 }) + )); + assert!(matches!(events[10], StreamEvent::MessageDelta(_))); + assert!(matches!(events[11], StreamEvent::MessageStop(_))); + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + assert_eq!(request.path, "/chat/completions"); + assert!(request.body.contains("\"stream\":true")); +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct CapturedRequest { + path: String, + headers: HashMap, + body: String, +} + +struct TestServer { + base_url: String, + join_handle: tokio::task::JoinHandle<()>, +} + +impl TestServer { + fn base_url(&self) -> String { + self.base_url.clone() + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.join_handle.abort(); + } +} + +async fn spawn_server( + state: Arc>>, + responses: Vec, +) -> TestServer { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let address = listener.local_addr().expect("listener addr"); + let join_handle = tokio::spawn(async move { + for response in responses { + let (mut socket, _) = listener.accept().await.expect("accept"); + let mut buffer = Vec::new(); + let mut header_end = None; + loop { + let mut chunk = [0_u8; 1024]; + let read = socket.read(&mut chunk).await.expect("read request"); + if read == 0 { + break; + } + buffer.extend_from_slice(&chunk[..read]); + if let Some(position) = find_header_end(&buffer) { + header_end = Some(position); + break; + } + } + + let header_end = header_end.expect("headers should exist"); + let (header_bytes, remaining) = buffer.split_at(header_end); + let header_text = String::from_utf8(header_bytes.to_vec()).expect("utf8 headers"); + let mut lines = header_text.split("\r\n"); + let request_line = lines.next().expect("request line"); + let path = request_line + .split_whitespace() + .nth(1) + .expect("path") + .to_string(); + let mut headers = HashMap::new(); + let mut content_length = 0_usize; + for line in lines { + if line.is_empty() { + continue; + } + let (name, value) = line.split_once(':').expect("header"); + let value = value.trim().to_string(); + if name.eq_ignore_ascii_case("content-length") { + content_length = value.parse().expect("content length"); + } + headers.insert(name.to_ascii_lowercase(), value); + } + + let mut body = remaining[4..].to_vec(); + while body.len() < content_length { + let mut chunk = vec![0_u8; content_length - body.len()]; + let read = socket.read(&mut chunk).await.expect("read body"); + if read == 0 { + break; + } + body.extend_from_slice(&chunk[..read]); + } + + state.lock().await.push(CapturedRequest { + path, + headers, + body: String::from_utf8(body).expect("utf8 body"), + }); + + socket + .write_all(response.as_bytes()) + .await + .expect("write response"); + } + }); + + TestServer { + base_url: format!("http://{address}"), + join_handle, + } +} + +fn find_header_end(bytes: &[u8]) -> Option { + bytes.windows(4).position(|window| window == b"\r\n\r\n") +} + +fn http_response(status: &str, content_type: &str, body: &str) -> String { + http_response_with_headers(status, content_type, body, &[]) +} + +fn http_response_with_headers( + status: &str, + content_type: &str, + body: &str, + headers: &[(&str, &str)], +) -> String { + let mut extra_headers = String::new(); + for (name, value) in headers { + use std::fmt::Write as _; + write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write"); + } + format!( + "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ) +} + +fn sample_request(stream: bool) -> MessageRequest { + MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "Say hello".to_string(), + }], + }], + system: Some("Use tools when needed".to_string()), + tools: Some(vec![ToolDefinition { + name: "weather".to_string(), + description: Some("Fetches weather".to_string()), + input_schema: json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + }), + }]), + tool_choice: Some(ToolChoice::Auto), + stream, + } +} diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 5f8a7a6..00ef7cd 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -12,8 +12,9 @@ use std::process::Command; use std::time::{SystemTime, UNIX_EPOCH}; use api::{ - resolve_startup_auth_source, AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, - InputMessage, MessageRequest, MessageResponse, OutputContentBlock, + detect_provider_kind, max_tokens_for_model, resolve_model_alias, resolve_startup_auth_source, + AnthropicClient, AuthSource, ContentBlockDelta, InputContentBlock, InputMessage, + MessageRequest, MessageResponse, OutputContentBlock, ProviderClient, ProviderKind, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; @@ -35,13 +36,6 @@ use serde_json::json; use tools::{execute_tool, mvp_tool_specs, ToolSpec}; const DEFAULT_MODEL: &str = "claude-opus-4-6"; -fn max_tokens_for_model(model: &str) -> u32 { - if model.contains("opus") { - 32_000 - } else { - 64_000 - } -} const DEFAULT_DATE: &str = "2026-03-31"; const DEFAULT_OAUTH_CALLBACK_PORT: u16 = 4545; const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -288,15 +282,6 @@ fn parse_args(args: &[String]) -> Result { } } -fn resolve_model_alias(model: &str) -> &str { - match model { - "opus" => "claude-opus-4-6", - "sonnet" => "claude-sonnet-4-6", - "haiku" => "claude-haiku-4-5-20251213", - _ => model, - } -} - fn normalize_allowed_tools(values: &[String]) -> Result, String> { if values.is_empty() { return Ok(None); @@ -980,7 +965,7 @@ struct LiveCli { allowed_tools: Option, permission_mode: PermissionMode, system_prompt: Vec, - runtime: ConversationRuntime, + runtime: ConversationRuntime, session: SessionHandle, } @@ -1920,11 +1905,11 @@ fn build_runtime( emit_output: bool, allowed_tools: Option, permission_mode: PermissionMode, -) -> Result, Box> +) -> Result, Box> { Ok(ConversationRuntime::new_with_features( session, - AnthropicRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?, + ProviderRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?, CliToolExecutor::new(allowed_tools, emit_output), permission_policy(permission_mode), system_prompt, @@ -1978,26 +1963,33 @@ impl runtime::PermissionPrompter for CliPermissionPrompter { } } -struct AnthropicRuntimeClient { +struct ProviderRuntimeClient { runtime: tokio::runtime::Runtime, - client: AnthropicClient, + client: ProviderClient, model: String, enable_tools: bool, emit_output: bool, allowed_tools: Option, } -impl AnthropicRuntimeClient { +impl ProviderRuntimeClient { fn new( model: String, enable_tools: bool, emit_output: bool, allowed_tools: Option, ) -> Result> { + let model = resolve_model_alias(&model).to_string(); + let client = match detect_provider_kind(&model) { + ProviderKind::Anthropic => ProviderClient::from_model_with_anthropic_auth( + &model, + Some(resolve_cli_auth_source()?), + )?, + ProviderKind::Xai | ProviderKind::OpenAi => ProviderClient::from_model(&model)?, + }; Ok(Self { runtime: tokio::runtime::Runtime::new()?, - client: AnthropicClient::from_auth(resolve_cli_auth_source()?) - .with_base_url(api::read_base_url()), + client, model, enable_tools, emit_output, @@ -2016,7 +2008,7 @@ fn resolve_cli_auth_source() -> Result> { })?) } -impl ApiClient for AnthropicRuntimeClient { +impl ApiClient for ProviderRuntimeClient { #[allow(clippy::too_many_lines)] fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { let message_request = MessageRequest { @@ -2911,6 +2903,9 @@ mod tests { assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6"); assert_eq!(resolve_model_alias("sonnet"), "claude-sonnet-4-6"); assert_eq!(resolve_model_alias("haiku"), "claude-haiku-4-5-20251213"); + assert_eq!(resolve_model_alias("grok"), "grok-3"); + assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); + assert_eq!(resolve_model_alias("grok-2"), "grok-2"); assert_eq!(resolve_model_alias("claude-opus"), "claude-opus"); } diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index 8dcd33d..6448ca0 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -4,9 +4,10 @@ use std::process::Command; use std::time::{Duration, Instant}; use api::{ - read_base_url, AnthropicClient, ContentBlockDelta, InputContentBlock, InputMessage, - MessageRequest, MessageResponse, OutputContentBlock, StreamEvent as ApiStreamEvent, ToolChoice, - ToolDefinition, ToolResultContentBlock, + detect_provider_kind, max_tokens_for_model, resolve_model_alias, ContentBlockDelta, + InputContentBlock, InputMessage, MessageRequest, MessageResponse, OutputContentBlock, + ProviderClient, ProviderKind, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, + ToolResultContentBlock, }; use reqwest::blocking::Client; use runtime::{ @@ -1459,14 +1460,14 @@ fn run_agent_job(job: &AgentJob) -> Result<(), String> { fn build_agent_runtime( job: &AgentJob, -) -> Result, String> { +) -> Result, String> { let model = job .manifest .model .clone() .unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string()); let allowed_tools = job.allowed_tools.clone(); - let api_client = AnthropicRuntimeClient::new(model, allowed_tools.clone())?; + let api_client = ProviderRuntimeClient::new(model, allowed_tools.clone())?; let tool_executor = SubagentToolExecutor::new(allowed_tools); Ok(ConversationRuntime::new( Session::new(), @@ -1635,18 +1636,21 @@ fn format_agent_terminal_output(status: &str, result: Option<&str>, error: Optio sections.join("") } -struct AnthropicRuntimeClient { +struct ProviderRuntimeClient { runtime: tokio::runtime::Runtime, - client: AnthropicClient, + client: ProviderClient, model: String, allowed_tools: BTreeSet, } -impl AnthropicRuntimeClient { +impl ProviderRuntimeClient { fn new(model: String, allowed_tools: BTreeSet) -> Result { - let client = AnthropicClient::from_env() - .map_err(|error| error.to_string())? - .with_base_url(read_base_url()); + let model = resolve_model_alias(&model).to_string(); + let client = match detect_provider_kind(&model) { + ProviderKind::Anthropic | ProviderKind::Xai | ProviderKind::OpenAi => { + ProviderClient::from_model(&model).map_err(|error| error.to_string())? + } + }; Ok(Self { runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?, client, @@ -1656,7 +1660,7 @@ impl AnthropicRuntimeClient { } } -impl ApiClient for AnthropicRuntimeClient { +impl ApiClient for ProviderRuntimeClient { fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools)) .into_iter() @@ -1668,7 +1672,7 @@ impl ApiClient for AnthropicRuntimeClient { .collect::>(); let message_request = MessageRequest { model: self.model.clone(), - max_tokens: 32_000, + max_tokens: max_tokens_for_model(&self.model), messages: convert_messages(&request.messages), system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")), tools: (!tools.is_empty()).then_some(tools), From 178934a9a0d0750b1926fa10289e1e45c3013d08 Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 04:20:15 +0000 Subject: [PATCH 3/5] feat: grok provider tests + cargo fmt --- rust/crates/rusty-claude-cli/src/main.rs | 110 +++++++++++++++++++---- rust/crates/tools/src/lib.rs | 101 ++++++++++++++++----- 2 files changed, 172 insertions(+), 39 deletions(-) diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index 00ef7cd..dee54d9 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -2046,7 +2046,7 @@ impl ApiClient for ProviderRuntimeClient { let renderer = TerminalRenderer::new(); let mut markdown_stream = MarkdownStreamState::default(); let mut events = Vec::new(); - let mut pending_tool: Option<(String, String, String)> = None; + let mut pending_tools: BTreeMap = BTreeMap::new(); let mut saw_stop = false; while let Some(event) = stream @@ -2057,15 +2057,23 @@ impl ApiClient for ProviderRuntimeClient { match event { ApiStreamEvent::MessageStart(start) => { for block in start.message.content { - push_output_block(block, out, &mut events, &mut pending_tool, true)?; + push_output_block( + block, + 0, + out, + &mut events, + &mut pending_tools, + true, + )?; } } ApiStreamEvent::ContentBlockStart(start) => { push_output_block( start.content_block, + start.index, out, &mut events, - &mut pending_tool, + &mut pending_tools, true, )?; } @@ -2081,18 +2089,18 @@ impl ApiClient for ProviderRuntimeClient { } } ContentBlockDelta::InputJsonDelta { partial_json } => { - if let Some((_, _, input)) = &mut pending_tool { + if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) { input.push_str(&partial_json); } } }, - ApiStreamEvent::ContentBlockStop(_) => { + ApiStreamEvent::ContentBlockStop(stop) => { if let Some(rendered) = markdown_stream.flush(&renderer) { write!(out, "{rendered}") .and_then(|()| out.flush()) .map_err(|error| RuntimeError::new(error.to_string()))?; } - if let Some((id, name, input)) = pending_tool.take() { + if let Some((id, name, input)) = pending_tools.remove(&stop.index) { // Display tool call now that input is fully accumulated writeln!(out, "\n{}", format_tool_call_start(&name, &input)) .and_then(|()| out.flush()) @@ -2556,9 +2564,10 @@ fn truncate_for_summary(value: &str, limit: usize) -> String { fn push_output_block( block: OutputContentBlock, + block_index: u32, out: &mut (impl Write + ?Sized), events: &mut Vec, - pending_tool: &mut Option<(String, String, String)>, + pending_tools: &mut BTreeMap, streaming_tool_input: bool, ) -> Result<(), RuntimeError> { match block { @@ -2583,7 +2592,7 @@ fn push_output_block( } else { input.to_string() }; - *pending_tool = Some((id, name, initial_input)); + pending_tools.insert(block_index, (id, name, initial_input)); } } Ok(()) @@ -2594,11 +2603,13 @@ fn response_to_events( out: &mut (impl Write + ?Sized), ) -> Result, RuntimeError> { let mut events = Vec::new(); - let mut pending_tool = None; + let mut pending_tools = BTreeMap::new(); - for block in response.content { - push_output_block(block, out, &mut events, &mut pending_tool, false)?; - if let Some((id, name, input)) = pending_tool.take() { + for (index, block) in response.content.into_iter().enumerate() { + let index = + u32::try_from(index).map_err(|_| RuntimeError::new("response block index overflow"))?; + push_output_block(block, index, out, &mut events, &mut pending_tools, false)?; + if let Some((id, name, input)) = pending_tools.remove(&index) { events.push(AssistantEvent::ToolUse { id, name, input }); } } @@ -2824,6 +2835,7 @@ mod tests { use api::{MessageResponse, OutputContentBlock, Usage}; use runtime::{AssistantEvent, ContentBlock, ConversationMessage, MessageRole, PermissionMode}; use serde_json::json; + use std::collections::BTreeMap; use std::path::PathBuf; #[test] @@ -3373,15 +3385,16 @@ mod tests { fn push_output_block_renders_markdown_text() { let mut out = Vec::new(); let mut events = Vec::new(); - let mut pending_tool = None; + let mut pending_tools = BTreeMap::new(); push_output_block( OutputContentBlock::Text { text: "# Heading".to_string(), }, + 0, &mut out, &mut events, - &mut pending_tool, + &mut pending_tools, false, ) .expect("text block should render"); @@ -3395,7 +3408,7 @@ mod tests { fn push_output_block_skips_empty_object_prefix_for_tool_streams() { let mut out = Vec::new(); let mut events = Vec::new(); - let mut pending_tool = None; + let mut pending_tools = BTreeMap::new(); push_output_block( OutputContentBlock::ToolUse { @@ -3403,20 +3416,83 @@ mod tests { name: "read_file".to_string(), input: json!({}), }, + 1, &mut out, &mut events, - &mut pending_tool, + &mut pending_tools, true, ) .expect("tool block should accumulate"); assert!(events.is_empty()); assert_eq!( - pending_tool, + pending_tools.remove(&1), Some(("tool-1".to_string(), "read_file".to_string(), String::new(),)) ); } + #[test] + fn pending_tools_preserve_multiple_streaming_tool_calls_by_index() { + let mut out = Vec::new(); + let mut events = Vec::new(); + let mut pending_tools = BTreeMap::new(); + + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }, + 1, + &mut out, + &mut events, + &mut pending_tools, + true, + ) + .expect("first tool should accumulate"); + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-2".to_string(), + name: "grep_search".to_string(), + input: json!({}), + }, + 2, + &mut out, + &mut events, + &mut pending_tools, + true, + ) + .expect("second tool should accumulate"); + + pending_tools + .get_mut(&1) + .expect("first tool pending") + .2 + .push_str("{\"path\":\"src/main.rs\"}"); + pending_tools + .get_mut(&2) + .expect("second tool pending") + .2 + .push_str("{\"pattern\":\"TODO\"}"); + + assert_eq!( + pending_tools.remove(&1), + Some(( + "tool-1".to_string(), + "read_file".to_string(), + "{\"path\":\"src/main.rs\"}".to_string(), + )) + ); + assert_eq!( + pending_tools.remove(&2), + Some(( + "tool-2".to_string(), + "grep_search".to_string(), + "{\"pattern\":\"TODO\"}".to_string(), + )) + ); + } + #[test] fn response_to_events_preserves_empty_object_json_input_outside_streaming() { let mut out = Vec::new(); diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index 6448ca0..63be324 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -4,10 +4,9 @@ use std::process::Command; use std::time::{Duration, Instant}; use api::{ - detect_provider_kind, max_tokens_for_model, resolve_model_alias, ContentBlockDelta, - InputContentBlock, InputMessage, MessageRequest, MessageResponse, OutputContentBlock, - ProviderClient, ProviderKind, StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, - ToolResultContentBlock, + max_tokens_for_model, resolve_model_alias, ContentBlockDelta, InputContentBlock, InputMessage, + MessageRequest, MessageResponse, OutputContentBlock, ProviderClient, + StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, }; use reqwest::blocking::Client; use runtime::{ @@ -1646,11 +1645,7 @@ struct ProviderRuntimeClient { impl ProviderRuntimeClient { fn new(model: String, allowed_tools: BTreeSet) -> Result { let model = resolve_model_alias(&model).to_string(); - let client = match detect_provider_kind(&model) { - ProviderKind::Anthropic | ProviderKind::Xai | ProviderKind::OpenAi => { - ProviderClient::from_model(&model).map_err(|error| error.to_string())? - } - }; + let client = ProviderClient::from_model(&model).map_err(|error| error.to_string())?; Ok(Self { runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?, client, @@ -1687,7 +1682,7 @@ impl ApiClient for ProviderRuntimeClient { .await .map_err(|error| RuntimeError::new(error.to_string()))?; let mut events = Vec::new(); - let mut pending_tool: Option<(String, String, String)> = None; + let mut pending_tools: BTreeMap = BTreeMap::new(); let mut saw_stop = false; while let Some(event) = stream @@ -1698,14 +1693,15 @@ impl ApiClient for ProviderRuntimeClient { match event { ApiStreamEvent::MessageStart(start) => { for block in start.message.content { - push_output_block(block, &mut events, &mut pending_tool, true); + push_output_block(block, 0, &mut events, &mut pending_tools, true); } } ApiStreamEvent::ContentBlockStart(start) => { push_output_block( start.content_block, + start.index, &mut events, - &mut pending_tool, + &mut pending_tools, true, ); } @@ -1716,13 +1712,13 @@ impl ApiClient for ProviderRuntimeClient { } } ContentBlockDelta::InputJsonDelta { partial_json } => { - if let Some((_, _, input)) = &mut pending_tool { + if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) { input.push_str(&partial_json); } } }, - ApiStreamEvent::ContentBlockStop(_) => { - if let Some((id, name, input)) = pending_tool.take() { + ApiStreamEvent::ContentBlockStop(stop) => { + if let Some((id, name, input)) = pending_tools.remove(&stop.index) { events.push(AssistantEvent::ToolUse { id, name, input }); } } @@ -1843,8 +1839,9 @@ fn convert_messages(messages: &[ConversationMessage]) -> Vec { fn push_output_block( block: OutputContentBlock, + block_index: u32, events: &mut Vec, - pending_tool: &mut Option<(String, String, String)>, + pending_tools: &mut BTreeMap, streaming_tool_input: bool, ) { match block { @@ -1862,18 +1859,19 @@ fn push_output_block( } else { input.to_string() }; - *pending_tool = Some((id, name, initial_input)); + pending_tools.insert(block_index, (id, name, initial_input)); } } } fn response_to_events(response: MessageResponse) -> Vec { let mut events = Vec::new(); - let mut pending_tool = None; + let mut pending_tools = BTreeMap::new(); - for block in response.content { - push_output_block(block, &mut events, &mut pending_tool, false); - if let Some((id, name, input)) = pending_tool.take() { + for (index, block) in response.content.into_iter().enumerate() { + let index = u32::try_from(index).expect("response block index overflow"); + push_output_block(block, index, &mut events, &mut pending_tools, false); + if let Some((id, name, input)) = pending_tools.remove(&index) { events.push(AssistantEvent::ToolUse { id, name, input }); } } @@ -2897,6 +2895,7 @@ fn parse_skill_description(contents: &str) -> Option { #[cfg(test)] mod tests { + use std::collections::BTreeMap; use std::collections::BTreeSet; use std::fs; use std::io::{Read, Write}; @@ -2909,8 +2908,9 @@ mod tests { use super::{ agent_permission_policy, allowed_tools_for_subagent, execute_agent_with_spawn, execute_tool, final_assistant_text, mvp_tool_specs, persist_agent_terminal_state, - AgentInput, AgentJob, SubagentToolExecutor, + push_output_block, AgentInput, AgentJob, SubagentToolExecutor, }; + use api::OutputContentBlock; use runtime::{ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, Session}; use serde_json::json; @@ -3125,6 +3125,63 @@ mod tests { assert!(error.contains("relative URL without a base") || error.contains("empty host")); } + #[test] + fn pending_tools_preserve_multiple_streaming_tool_calls_by_index() { + let mut events = Vec::new(); + let mut pending_tools = BTreeMap::new(); + + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }, + 1, + &mut events, + &mut pending_tools, + true, + ); + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-2".to_string(), + name: "grep_search".to_string(), + input: json!({}), + }, + 2, + &mut events, + &mut pending_tools, + true, + ); + + pending_tools + .get_mut(&1) + .expect("first tool pending") + .2 + .push_str("{\"path\":\"src/main.rs\"}"); + pending_tools + .get_mut(&2) + .expect("second tool pending") + .2 + .push_str("{\"pattern\":\"TODO\"}"); + + assert_eq!( + pending_tools.remove(&1), + Some(( + "tool-1".to_string(), + "read_file".to_string(), + "{\"path\":\"src/main.rs\"}".to_string(), + )) + ); + assert_eq!( + pending_tools.remove(&2), + Some(( + "tool-2".to_string(), + "grep_search".to_string(), + "{\"pattern\":\"TODO\"}".to_string(), + )) + ); + } + #[test] fn todo_write_persists_and_returns_previous_state() { let _guard = env_lock() From f477dde4a6c4ab76253c3d7b96d78183f0a94c0a Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 05:45:27 +0000 Subject: [PATCH 4/5] feat: provider tests + grok integration --- rust/crates/api/src/client.rs | 9 +- rust/crates/api/tests/client_integration.rs | 47 +++++++++- .../api/tests/openai_compat_integration.rs | 70 ++++++++++++++- .../api/tests/provider_client_integration.rs | 86 +++++++++++++++++++ rust/crates/runtime/src/conversation.rs | 12 +-- rust/crates/runtime/src/hooks.rs | 62 +++++++------ 6 files changed, 244 insertions(+), 42 deletions(-) create mode 100644 rust/crates/api/tests/provider_client_integration.rs diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 467697e..a4ac1c0 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -36,11 +36,10 @@ impl ProviderClient { ) -> Result { let resolved_model = providers::resolve_model_alias(model); match providers::detect_provider_kind(&resolved_model) { - ProviderKind::Anthropic => Ok(Self::Anthropic( - anthropic_auth - .map(AnthropicClient::from_auth) - .unwrap_or(AnthropicClient::from_env()?), - )), + ProviderKind::Anthropic => Ok(Self::Anthropic(match anthropic_auth { + Some(auth) => AnthropicClient::from_auth(auth), + None => AnthropicClient::from_env()?, + })), ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env( OpenAiCompatConfig::xai(), )?)), diff --git a/rust/crates/api/tests/client_integration.rs b/rust/crates/api/tests/client_integration.rs index c37fa99..b52f890 100644 --- a/rust/crates/api/tests/client_integration.rs +++ b/rust/crates/api/tests/client_integration.rs @@ -3,9 +3,9 @@ use std::sync::Arc; use std::time::Duration; use api::{ - AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, - InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock, - StreamEvent, ToolChoice, ToolDefinition, + AnthropicClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent, + ContentBlockStartEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, + OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, }; use serde_json::json; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -195,6 +195,47 @@ async fn retries_retryable_failures_before_succeeding() { assert_eq!(state.lock().await.len(), 2); } +#[tokio::test] +async fn provider_client_dispatches_anthropic_requests() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", + )], + ) + .await; + + let client = ProviderClient::from_model_with_anthropic_auth( + "claude-sonnet-4-6", + Some(AuthSource::ApiKey("test-key".to_string())), + ) + .expect("anthropic provider client should be constructed"); + let client = match client { + ProviderClient::Anthropic(client) => { + ProviderClient::Anthropic(client.with_base_url(server.base_url())) + } + other => panic!("expected anthropic provider, got {other:?}"), + }; + + let response = client + .send_message(&sample_request(false)) + .await + .expect("provider-dispatched request should succeed"); + + assert_eq!(response.total_tokens(), 5); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/v1/messages"); + assert_eq!( + request.headers.get("x-api-key").map(String::as_str), + Some("test-key") + ); +} + #[tokio::test] async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { let state = Arc::new(Mutex::new(Vec::::new())); diff --git a/rust/crates/api/tests/openai_compat_integration.rs b/rust/crates/api/tests/openai_compat_integration.rs index b1b6a0a..81a65f4 100644 --- a/rust/crates/api/tests/openai_compat_integration.rs +++ b/rust/crates/api/tests/openai_compat_integration.rs @@ -1,10 +1,12 @@ use std::collections::HashMap; +use std::ffi::OsString; use std::sync::Arc; +use std::sync::{Mutex as StdMutex, OnceLock}; use api::{ ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig, - OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition, + OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, }; use serde_json::json; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -158,6 +160,43 @@ async fn stream_message_normalizes_text_and_multiple_tool_calls() { assert!(request.body.contains("\"stream\":true")); } +#[tokio::test] +async fn provider_client_dispatches_xai_requests_from_env() { + let _lock = env_lock(); + let _api_key = ScopedEnvVar::set("XAI_API_KEY", "xai-test-key"); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"chatcmpl_provider\",\"model\":\"grok-3\",\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"Through provider client\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}", + )], + ) + .await; + let _base_url = ScopedEnvVar::set("XAI_BASE_URL", server.base_url()); + + let client = + ProviderClient::from_model("grok").expect("xAI provider client should be constructed"); + assert!(matches!(client, ProviderClient::Xai(_))); + + let response = client + .send_message(&sample_request(false)) + .await + .expect("provider-dispatched request should succeed"); + + assert_eq!(response.total_tokens(), 13); + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + assert_eq!(request.path, "/chat/completions"); + assert_eq!( + request.headers.get("authorization").map(String::as_str), + Some("Bearer xai-test-key") + ); +} + #[derive(Debug, Clone, PartialEq, Eq)] struct CapturedRequest { path: String, @@ -310,3 +349,32 @@ fn sample_request(stream: bool) -> MessageRequest { stream, } } + +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| StdMutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +struct ScopedEnvVar { + key: &'static str, + previous: Option, +} + +impl ScopedEnvVar { + fn set(key: &'static str, value: impl AsRef) -> Self { + let previous = std::env::var_os(key); + std::env::set_var(key, value); + Self { key, previous } + } +} + +impl Drop for ScopedEnvVar { + fn drop(&mut self) { + match &self.previous { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } +} diff --git a/rust/crates/api/tests/provider_client_integration.rs b/rust/crates/api/tests/provider_client_integration.rs new file mode 100644 index 0000000..204bf35 --- /dev/null +++ b/rust/crates/api/tests/provider_client_integration.rs @@ -0,0 +1,86 @@ +use std::ffi::OsString; +use std::sync::{Mutex, OnceLock}; + +use api::{read_xai_base_url, ApiError, AuthSource, ProviderClient, ProviderKind}; + +#[test] +fn provider_client_routes_grok_aliases_through_xai() { + let _lock = env_lock(); + let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", Some("xai-test-key")); + + let client = ProviderClient::from_model("grok-mini").expect("grok alias should resolve"); + + assert_eq!(client.provider_kind(), ProviderKind::Xai); +} + +#[test] +fn provider_client_reports_missing_xai_credentials_for_grok_models() { + let _lock = env_lock(); + let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", None); + + let error = ProviderClient::from_model("grok-3") + .expect_err("grok requests without XAI_API_KEY should fail fast"); + + match error { + ApiError::MissingCredentials { provider, env_vars } => { + assert_eq!(provider, "xAI"); + assert_eq!(env_vars, &["XAI_API_KEY"]); + } + other => panic!("expected missing xAI credentials, got {other:?}"), + } +} + +#[test] +fn provider_client_uses_explicit_anthropic_auth_without_env_lookup() { + let _lock = env_lock(); + let _anthropic_api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None); + let _anthropic_auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None); + + let client = ProviderClient::from_model_with_anthropic_auth( + "claude-sonnet-4-6", + Some(AuthSource::ApiKey("anthropic-test-key".to_string())), + ) + .expect("explicit anthropic auth should avoid env lookup"); + + assert_eq!(client.provider_kind(), ProviderKind::Anthropic); +} + +#[test] +fn read_xai_base_url_prefers_env_override() { + let _lock = env_lock(); + let _xai_base_url = EnvVarGuard::set("XAI_BASE_URL", Some("https://example.xai.test/v1")); + + assert_eq!(read_xai_base_url(), "https://example.xai.test/v1"); +} + +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +struct EnvVarGuard { + key: &'static str, + original: Option, +} + +impl EnvVarGuard { + fn set(key: &'static str, value: Option<&str>) -> Self { + let original = std::env::var_os(key); + match value { + Some(value) => std::env::set_var(key, value), + None => std::env::remove_var(key), + } + Self { key, original } + } +} + +impl Drop for EnvVarGuard { + fn drop(&mut self) { + match &self.original { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } +} diff --git a/rust/crates/runtime/src/conversation.rs b/rust/crates/runtime/src/conversation.rs index 4ffbabc..1abdce4 100644 --- a/rust/crates/runtime/src/conversation.rs +++ b/rust/crates/runtime/src/conversation.rs @@ -118,7 +118,7 @@ where tool_executor, permission_policy, system_prompt, - RuntimeFeatureConfig::default(), + &RuntimeFeatureConfig::default(), ) } @@ -129,7 +129,7 @@ where tool_executor: T, permission_policy: PermissionPolicy, system_prompt: Vec, - feature_config: RuntimeFeatureConfig, + feature_config: &RuntimeFeatureConfig, ) -> Self { let usage_tracker = UsageTracker::from_session(&session); Self { @@ -140,7 +140,7 @@ where system_prompt, max_iterations: usize::MAX, usage_tracker, - hook_runner: HookRunner::from_feature_config(&feature_config), + hook_runner: HookRunner::from_feature_config(feature_config), } } @@ -609,7 +609,7 @@ mod tests { }), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'blocked by hook'; exit 2")], Vec::new(), )), @@ -675,7 +675,7 @@ mod tests { StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), PermissionPolicy::new(PermissionMode::DangerFullAccess), vec!["system".to_string()], - RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + &RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( vec![shell_snippet("printf 'pre hook ran'")], vec![shell_snippet("printf 'post hook ran'")], )), @@ -697,7 +697,7 @@ mod tests { "post hook should preserve non-error result: {output:?}" ); assert!( - output.contains("4"), + output.contains('4'), "tool output missing value: {output:?}" ); assert!( diff --git a/rust/crates/runtime/src/hooks.rs b/rust/crates/runtime/src/hooks.rs index 36756a0..63ef9ff 100644 --- a/rust/crates/runtime/src/hooks.rs +++ b/rust/crates/runtime/src/hooks.rs @@ -51,6 +51,16 @@ pub struct HookRunner { config: RuntimeHookConfig, } +#[derive(Debug, Clone, Copy)] +struct HookCommandRequest<'a> { + event: HookEvent, + tool_name: &'a str, + tool_input: &'a str, + tool_output: Option<&'a str>, + is_error: bool, + payload: &'a str, +} + impl HookRunner { #[must_use] pub fn new(config: RuntimeHookConfig) -> Self { @@ -118,14 +128,16 @@ impl HookRunner { let mut messages = Vec::new(); for command in commands { - match self.run_command( + match Self::run_command( command, - event, - tool_name, - tool_input, - tool_output, - is_error, - &payload, + HookCommandRequest { + event, + tool_name, + tool_input, + tool_output, + is_error, + payload: &payload, + }, ) { HookCommandOutcome::Allow { message } => { if let Some(message) = message { @@ -149,29 +161,23 @@ impl HookRunner { HookRunResult::allow(messages) } - fn run_command( - &self, - command: &str, - event: HookEvent, - tool_name: &str, - tool_input: &str, - tool_output: Option<&str>, - is_error: bool, - payload: &str, - ) -> HookCommandOutcome { + fn run_command(command: &str, request: HookCommandRequest<'_>) -> HookCommandOutcome { let mut child = shell_command(command); child.stdin(std::process::Stdio::piped()); child.stdout(std::process::Stdio::piped()); child.stderr(std::process::Stdio::piped()); - child.env("HOOK_EVENT", event.as_str()); - child.env("HOOK_TOOL_NAME", tool_name); - child.env("HOOK_TOOL_INPUT", tool_input); - child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" }); - if let Some(tool_output) = tool_output { + child.env("HOOK_EVENT", request.event.as_str()); + child.env("HOOK_TOOL_NAME", request.tool_name); + child.env("HOOK_TOOL_INPUT", request.tool_input); + child.env( + "HOOK_TOOL_IS_ERROR", + if request.is_error { "1" } else { "0" }, + ); + if let Some(tool_output) = request.tool_output { child.env("HOOK_TOOL_OUTPUT", tool_output); } - match child.output_with_stdin(payload.as_bytes()) { + match child.output_with_stdin(request.payload.as_bytes()) { Ok(output) => { let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); @@ -189,16 +195,18 @@ impl HookRunner { }, None => HookCommandOutcome::Warn { message: format!( - "{} hook `{command}` terminated by signal while handling `{tool_name}`", - event.as_str() + "{} hook `{command}` terminated by signal while handling `{}`", + request.event.as_str(), + request.tool_name ), }, } } Err(error) => HookCommandOutcome::Warn { message: format!( - "{} hook `{command}` failed to start for `{tool_name}`: {error}", - event.as_str() + "{} hook `{command}` failed to start for `{}`: {error}", + request.event.as_str(), + request.tool_name ), }, } From 40008b65138d4760923cec68d1002805608b89ce Mon Sep 17 00:00:00 2001 From: Yeachan-Heo Date: Wed, 1 Apr 2026 06:00:48 +0000 Subject: [PATCH 5/5] wip: grok provider abstraction --- .../crates/api/src/providers/openai_compat.rs | 31 ++++++++++++++-- .../api/tests/openai_compat_integration.rs | 35 +++++++++++++++++++ rust/crates/rusty-claude-cli/src/main.rs | 3 +- 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs index 8a0fe9c..e8210ae 100644 --- a/rust/crates/api/src/providers/openai_compat.rs +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -185,7 +185,7 @@ impl OpenAiCompatClient { &self, request: &MessageRequest, ) -> Result { - let request_url = format!("{}/chat/completions", self.base_url.trim_end_matches('/')); + let request_url = chat_completions_endpoint(&self.base_url); self.http .post(&request_url) .header("content-type", "application/json") @@ -866,6 +866,15 @@ pub fn read_base_url(config: OpenAiCompatConfig) -> String { std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string()) } +fn chat_completions_endpoint(base_url: &str) -> String { + let trimmed = base_url.trim_end_matches('/'); + if trimmed.ends_with("/chat/completions") { + trimmed.to_string() + } else { + format!("{trimmed}/chat/completions") + } +} + fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { headers .get(REQUEST_ID_HEADER) @@ -927,8 +936,8 @@ impl StringExt for String { #[cfg(test)] mod tests { use super::{ - build_chat_completion_request, normalize_finish_reason, openai_tool_choice, - parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig, + build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason, + openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig, }; use crate::error::ApiError; use crate::types::{ @@ -1010,6 +1019,22 @@ mod tests { )); } + #[test] + fn endpoint_builder_accepts_base_urls_and_full_endpoints() { + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1"), + "https://api.x.ai/v1/chat/completions" + ); + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1/"), + "https://api.x.ai/v1/chat/completions" + ); + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1/chat/completions"), + "https://api.x.ai/v1/chat/completions" + ); + } + fn env_lock() -> std::sync::MutexGuard<'static, ()> { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| Mutex::new(())) diff --git a/rust/crates/api/tests/openai_compat_integration.rs b/rust/crates/api/tests/openai_compat_integration.rs index 81a65f4..b345b1f 100644 --- a/rust/crates/api/tests/openai_compat_integration.rs +++ b/rust/crates/api/tests/openai_compat_integration.rs @@ -62,6 +62,41 @@ async fn send_message_uses_openai_compatible_endpoint_and_auth() { assert_eq!(body["tools"][0]["type"], json!("function")); } +#[tokio::test] +async fn send_message_accepts_full_chat_completions_endpoint_override() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"chatcmpl_full_endpoint\",", + "\"model\":\"grok-3\",", + "\"choices\":[{", + "\"message\":{\"role\":\"assistant\",\"content\":\"Endpoint override works\",\"tool_calls\":[]},", + "\"finish_reason\":\"stop\"", + "}],", + "\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let endpoint_url = format!("{}/chat/completions", server.base_url()); + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(endpoint_url); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.total_tokens(), 10); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/chat/completions"); +} + #[tokio::test] async fn stream_message_normalizes_text_and_multiple_tool_calls() { let state = Arc::new(Mutex::new(Vec::::new())); diff --git a/rust/crates/rusty-claude-cli/src/main.rs b/rust/crates/rusty-claude-cli/src/main.rs index dee54d9..847f94f 100644 --- a/rust/crates/rusty-claude-cli/src/main.rs +++ b/rust/crates/rusty-claude-cli/src/main.rs @@ -1907,13 +1907,14 @@ fn build_runtime( permission_mode: PermissionMode, ) -> Result, Box> { + let feature_config = build_runtime_feature_config()?; Ok(ConversationRuntime::new_with_features( session, ProviderRuntimeClient::new(model, enable_tools, emit_output, allowed_tools.clone())?, CliToolExecutor::new(allowed_tools, emit_output), permission_policy(permission_mode), system_prompt, - build_runtime_feature_config()?, + &feature_config, )) }