use std::collections::{BTreeMap, VecDeque}; use std::time::Duration; use serde::Deserialize; use serde_json::{json, Value}; use crate::error::ApiError; use crate::types::{ ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest, MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, }; use super::{Provider, ProviderFuture}; pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; const REQUEST_ID_HEADER: &str = "request-id"; const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); const DEFAULT_MAX_RETRIES: u32 = 2; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct OpenAiCompatConfig { pub provider_name: &'static str, pub api_key_env: &'static str, pub base_url_env: &'static str, pub default_base_url: &'static str, } const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"]; const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"]; impl OpenAiCompatConfig { #[must_use] pub const fn xai() -> Self { Self { provider_name: "xAI", api_key_env: "XAI_API_KEY", base_url_env: "XAI_BASE_URL", default_base_url: DEFAULT_XAI_BASE_URL, } } #[must_use] pub const fn openai() -> Self { Self { provider_name: "OpenAI", api_key_env: "OPENAI_API_KEY", base_url_env: "OPENAI_BASE_URL", default_base_url: DEFAULT_OPENAI_BASE_URL, } } #[must_use] pub fn credential_env_vars(self) -> &'static [&'static str] { match self.provider_name { "xAI" => XAI_ENV_VARS, "OpenAI" => OPENAI_ENV_VARS, _ => &[], } } } #[derive(Debug, Clone)] pub struct OpenAiCompatClient { http: reqwest::Client, api_key: String, base_url: String, max_retries: u32, initial_backoff: Duration, max_backoff: Duration, } impl OpenAiCompatClient { #[must_use] pub fn new(api_key: impl Into, config: OpenAiCompatConfig) -> Self { Self { http: reqwest::Client::new(), api_key: api_key.into(), base_url: read_base_url(config), max_retries: DEFAULT_MAX_RETRIES, initial_backoff: DEFAULT_INITIAL_BACKOFF, max_backoff: DEFAULT_MAX_BACKOFF, } } pub fn from_env(config: OpenAiCompatConfig) -> Result { let Some(api_key) = read_env_non_empty(config.api_key_env)? else { return Err(ApiError::missing_credentials( config.provider_name, config.credential_env_vars(), )); }; Ok(Self::new(api_key, config)) } #[must_use] pub fn with_base_url(mut self, base_url: impl Into) -> Self { self.base_url = base_url.into(); self } #[must_use] pub fn with_retry_policy( mut self, max_retries: u32, initial_backoff: Duration, max_backoff: Duration, ) -> Self { self.max_retries = max_retries; self.initial_backoff = initial_backoff; self.max_backoff = max_backoff; self } pub async fn send_message( &self, request: &MessageRequest, ) -> Result { let request = MessageRequest { stream: false, ..request.clone() }; let response = self.send_with_retry(&request).await?; let request_id = request_id_from_headers(response.headers()); let payload = response.json::().await?; let mut normalized = normalize_response(&request.model, payload)?; if normalized.request_id.is_none() { normalized.request_id = request_id; } Ok(normalized) } pub async fn stream_message( &self, request: &MessageRequest, ) -> Result { let response = self .send_with_retry(&request.clone().with_streaming()) .await?; Ok(MessageStream { request_id: request_id_from_headers(response.headers()), response, parser: OpenAiSseParser::new(), pending: VecDeque::new(), done: false, state: StreamState::new(request.model.clone()), }) } async fn send_with_retry( &self, request: &MessageRequest, ) -> Result { let mut attempts = 0; let last_error = loop { attempts += 1; let retryable_error = match self.send_raw_request(request).await { Ok(response) => match expect_success(response).await { Ok(response) => return Ok(response), Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, Err(error) => return Err(error), }, Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, Err(error) => return Err(error), }; if attempts > self.max_retries { break retryable_error; } tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; }; Err(ApiError::RetriesExhausted { attempts, last_error: Box::new(last_error), }) } async fn send_raw_request( &self, request: &MessageRequest, ) -> Result { let request_url = chat_completions_endpoint(&self.base_url); self.http .post(&request_url) .header("content-type", "application/json") .bearer_auth(&self.api_key) .json(&build_chat_completion_request(request)) .send() .await .map_err(ApiError::from) } fn backoff_for_attempt(&self, attempt: u32) -> Result { let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { return Err(ApiError::BackoffOverflow { attempt, base_delay: self.initial_backoff, }); }; Ok(self .initial_backoff .checked_mul(multiplier) .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) } } impl Provider for OpenAiCompatClient { type Stream = MessageStream; fn send_message<'a>( &'a self, request: &'a MessageRequest, ) -> ProviderFuture<'a, MessageResponse> { Box::pin(async move { self.send_message(request).await }) } fn stream_message<'a>( &'a self, request: &'a MessageRequest, ) -> ProviderFuture<'a, Self::Stream> { Box::pin(async move { self.stream_message(request).await }) } } #[derive(Debug)] pub struct MessageStream { request_id: Option, response: reqwest::Response, parser: OpenAiSseParser, pending: VecDeque, done: bool, state: StreamState, } impl MessageStream { #[must_use] pub fn request_id(&self) -> Option<&str> { self.request_id.as_deref() } pub async fn next_event(&mut self) -> Result, ApiError> { loop { if let Some(event) = self.pending.pop_front() { return Ok(Some(event)); } if self.done { self.pending.extend(self.state.finish()?); if let Some(event) = self.pending.pop_front() { return Ok(Some(event)); } return Ok(None); } match self.response.chunk().await? { Some(chunk) => { for parsed in self.parser.push(&chunk)? { self.pending.extend(self.state.ingest_chunk(parsed)?); } } None => { self.done = true; } } } } } #[derive(Debug, Default)] struct OpenAiSseParser { buffer: Vec, } impl OpenAiSseParser { fn new() -> Self { Self::default() } fn push(&mut self, chunk: &[u8]) -> Result, ApiError> { self.buffer.extend_from_slice(chunk); let mut events = Vec::new(); while let Some(frame) = next_sse_frame(&mut self.buffer) { if let Some(event) = parse_sse_frame(&frame)? { events.push(event); } } Ok(events) } } #[derive(Debug)] struct StreamState { model: String, message_started: bool, text_started: bool, text_finished: bool, finished: bool, stop_reason: Option, usage: Option, tool_calls: BTreeMap, } impl StreamState { fn new(model: String) -> Self { Self { model, message_started: false, text_started: false, text_finished: false, finished: false, stop_reason: None, usage: None, tool_calls: BTreeMap::new(), } } fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result, ApiError> { let mut events = Vec::new(); if !self.message_started { self.message_started = true; events.push(StreamEvent::MessageStart(MessageStartEvent { message: MessageResponse { id: chunk.id.clone(), kind: "message".to_string(), role: "assistant".to_string(), content: Vec::new(), model: chunk.model.clone().unwrap_or_else(|| self.model.clone()), stop_reason: None, stop_sequence: None, usage: Usage { input_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 0, output_tokens: 0, }, request_id: None, }, })); } if let Some(usage) = chunk.usage { self.usage = Some(Usage { input_tokens: usage.prompt_tokens, cache_creation_input_tokens: 0, cache_read_input_tokens: 0, output_tokens: usage.completion_tokens, }); } for choice in chunk.choices { if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) { if !self.text_started { self.text_started = true; events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { index: 0, content_block: OutputContentBlock::Text { text: String::new(), }, })); } events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { index: 0, delta: ContentBlockDelta::TextDelta { text: content }, })); } for tool_call in choice.delta.tool_calls { let state = self.tool_calls.entry(tool_call.index).or_default(); state.apply(tool_call); let block_index = state.block_index(); if !state.started { if let Some(start_event) = state.start_event()? { state.started = true; events.push(StreamEvent::ContentBlockStart(start_event)); } else { continue; } } if let Some(delta_event) = state.delta_event() { events.push(StreamEvent::ContentBlockDelta(delta_event)); } if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped { state.stopped = true; events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: block_index, })); } } if let Some(finish_reason) = choice.finish_reason { self.stop_reason = Some(normalize_finish_reason(&finish_reason)); if finish_reason == "tool_calls" { for state in self.tool_calls.values_mut() { if state.started && !state.stopped { state.stopped = true; events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: state.block_index(), })); } } } } } Ok(events) } fn finish(&mut self) -> Result, ApiError> { if self.finished { return Ok(Vec::new()); } self.finished = true; let mut events = Vec::new(); if self.text_started && !self.text_finished { self.text_finished = true; events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0, })); } for state in self.tool_calls.values_mut() { if !state.started { if let Some(start_event) = state.start_event()? { state.started = true; events.push(StreamEvent::ContentBlockStart(start_event)); if let Some(delta_event) = state.delta_event() { events.push(StreamEvent::ContentBlockDelta(delta_event)); } } } if state.started && !state.stopped { state.stopped = true; events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: state.block_index(), })); } } if self.message_started { events.push(StreamEvent::MessageDelta(MessageDeltaEvent { delta: MessageDelta { stop_reason: Some( self.stop_reason .clone() .unwrap_or_else(|| "end_turn".to_string()), ), stop_sequence: None, }, usage: self.usage.clone().unwrap_or(Usage { input_tokens: 0, cache_creation_input_tokens: 0, cache_read_input_tokens: 0, output_tokens: 0, }), })); events.push(StreamEvent::MessageStop(MessageStopEvent {})); } Ok(events) } } #[derive(Debug, Default)] struct ToolCallState { openai_index: u32, id: Option, name: Option, arguments: String, emitted_len: usize, started: bool, stopped: bool, } impl ToolCallState { fn apply(&mut self, tool_call: DeltaToolCall) { self.openai_index = tool_call.index; if let Some(id) = tool_call.id { self.id = Some(id); } if let Some(name) = tool_call.function.name { self.name = Some(name); } if let Some(arguments) = tool_call.function.arguments { self.arguments.push_str(&arguments); } } const fn block_index(&self) -> u32 { self.openai_index + 1 } fn start_event(&self) -> Result, ApiError> { let Some(name) = self.name.clone() else { return Ok(None); }; let id = self .id .clone() .unwrap_or_else(|| format!("tool_call_{}", self.openai_index)); Ok(Some(ContentBlockStartEvent { index: self.block_index(), content_block: OutputContentBlock::ToolUse { id, name, input: json!({}), }, })) } fn delta_event(&mut self) -> Option { if self.emitted_len >= self.arguments.len() { return None; } let delta = self.arguments[self.emitted_len..].to_string(); self.emitted_len = self.arguments.len(); Some(ContentBlockDeltaEvent { index: self.block_index(), delta: ContentBlockDelta::InputJsonDelta { partial_json: delta, }, }) } } #[derive(Debug, Deserialize)] struct ChatCompletionResponse { id: String, model: String, choices: Vec, #[serde(default)] usage: Option, } #[derive(Debug, Deserialize)] struct ChatChoice { message: ChatMessage, #[serde(default)] finish_reason: Option, } #[derive(Debug, Deserialize)] struct ChatMessage { role: String, #[serde(default)] content: Option, #[serde(default)] tool_calls: Vec, } #[derive(Debug, Deserialize)] struct ResponseToolCall { id: String, function: ResponseToolFunction, } #[derive(Debug, Deserialize)] struct ResponseToolFunction { name: String, arguments: String, } #[derive(Debug, Deserialize)] struct OpenAiUsage { #[serde(default)] prompt_tokens: u32, #[serde(default)] completion_tokens: u32, } #[derive(Debug, Deserialize)] struct ChatCompletionChunk { id: String, #[serde(default)] model: Option, #[serde(default)] choices: Vec, #[serde(default)] usage: Option, } #[derive(Debug, Deserialize)] struct ChunkChoice { delta: ChunkDelta, #[serde(default)] finish_reason: Option, } #[derive(Debug, Default, Deserialize)] struct ChunkDelta { #[serde(default)] content: Option, #[serde(default)] tool_calls: Vec, } #[derive(Debug, Deserialize)] struct DeltaToolCall { #[serde(default)] index: u32, #[serde(default)] id: Option, #[serde(default)] function: DeltaFunction, } #[derive(Debug, Default, Deserialize)] struct DeltaFunction { #[serde(default)] name: Option, #[serde(default)] arguments: Option, } #[derive(Debug, Deserialize)] struct ErrorEnvelope { error: ErrorBody, } #[derive(Debug, Deserialize)] struct ErrorBody { #[serde(rename = "type")] error_type: Option, message: Option, } fn build_chat_completion_request(request: &MessageRequest) -> Value { let mut messages = Vec::new(); if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) { messages.push(json!({ "role": "system", "content": system, })); } for message in &request.messages { messages.extend(translate_message(message)); } let mut payload = json!({ "model": request.model, "max_tokens": request.max_tokens, "messages": messages, "stream": request.stream, }); if let Some(tools) = &request.tools { payload["tools"] = Value::Array(tools.iter().map(openai_tool_definition).collect::>()); } if let Some(tool_choice) = &request.tool_choice { payload["tool_choice"] = openai_tool_choice(tool_choice); } payload } fn translate_message(message: &InputMessage) -> Vec { match message.role.as_str() { "assistant" => { let mut text = String::new(); let mut tool_calls = Vec::new(); for block in &message.content { match block { InputContentBlock::Text { text: value } => text.push_str(value), InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({ "id": id, "type": "function", "function": { "name": name, "arguments": input.to_string(), } })), InputContentBlock::ToolResult { .. } => {} } } if text.is_empty() && tool_calls.is_empty() { Vec::new() } else { vec![json!({ "role": "assistant", "content": (!text.is_empty()).then_some(text), "tool_calls": tool_calls, })] } } _ => message .content .iter() .filter_map(|block| match block { InputContentBlock::Text { text } => Some(json!({ "role": "user", "content": text, })), InputContentBlock::ToolResult { tool_use_id, content, is_error, } => Some(json!({ "role": "tool", "tool_call_id": tool_use_id, "content": flatten_tool_result_content(content), "is_error": is_error, })), InputContentBlock::ToolUse { .. } => None, }) .collect(), } } fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String { content .iter() .map(|block| match block { ToolResultContentBlock::Text { text } => text.clone(), ToolResultContentBlock::Json { value } => value.to_string(), }) .collect::>() .join("\n") } fn openai_tool_definition(tool: &ToolDefinition) -> Value { json!({ "type": "function", "function": { "name": tool.name, "description": tool.description, "parameters": tool.input_schema, } }) } fn openai_tool_choice(tool_choice: &ToolChoice) -> Value { match tool_choice { ToolChoice::Auto => Value::String("auto".to_string()), ToolChoice::Any => Value::String("required".to_string()), ToolChoice::Tool { name } => json!({ "type": "function", "function": { "name": name }, }), } } fn normalize_response( model: &str, response: ChatCompletionResponse, ) -> Result { let choice = response .choices .into_iter() .next() .ok_or(ApiError::InvalidSseFrame( "chat completion response missing choices", ))?; let mut content = Vec::new(); if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) { content.push(OutputContentBlock::Text { text }); } for tool_call in choice.message.tool_calls { content.push(OutputContentBlock::ToolUse { id: tool_call.id, name: tool_call.function.name, input: parse_tool_arguments(&tool_call.function.arguments), }); } Ok(MessageResponse { id: response.id, kind: "message".to_string(), role: choice.message.role, content, model: response.model.if_empty_then(model.to_string()), stop_reason: choice .finish_reason .map(|value| normalize_finish_reason(&value)), stop_sequence: None, usage: Usage { input_tokens: response .usage .as_ref() .map_or(0, |usage| usage.prompt_tokens), cache_creation_input_tokens: 0, cache_read_input_tokens: 0, output_tokens: response .usage .as_ref() .map_or(0, |usage| usage.completion_tokens), }, request_id: None, }) } fn parse_tool_arguments(arguments: &str) -> Value { serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments })) } fn next_sse_frame(buffer: &mut Vec) -> Option { let separator = buffer .windows(2) .position(|window| window == b"\n\n") .map(|position| (position, 2)) .or_else(|| { buffer .windows(4) .position(|window| window == b"\r\n\r\n") .map(|position| (position, 4)) })?; let (position, separator_len) = separator; let frame = buffer.drain(..position + separator_len).collect::>(); let frame_len = frame.len().saturating_sub(separator_len); Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned()) } fn parse_sse_frame(frame: &str) -> Result, ApiError> { let trimmed = frame.trim(); if trimmed.is_empty() { return Ok(None); } let mut data_lines = Vec::new(); for line in trimmed.lines() { if line.starts_with(':') { continue; } if let Some(data) = line.strip_prefix("data:") { data_lines.push(data.trim_start()); } } if data_lines.is_empty() { return Ok(None); } let payload = data_lines.join("\n"); if payload == "[DONE]" { return Ok(None); } serde_json::from_str(&payload) .map(Some) .map_err(ApiError::from) } fn read_env_non_empty(key: &str) -> Result, ApiError> { match std::env::var(key) { Ok(value) if !value.is_empty() => Ok(Some(value)), Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), Err(error) => Err(ApiError::from(error)), } } #[must_use] pub fn has_api_key(key: &str) -> bool { read_env_non_empty(key) .ok() .and_then(std::convert::identity) .is_some() } #[must_use] pub fn read_base_url(config: OpenAiCompatConfig) -> String { std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string()) } fn chat_completions_endpoint(base_url: &str) -> String { let trimmed = base_url.trim_end_matches('/'); if trimmed.ends_with("/chat/completions") { trimmed.to_string() } else { format!("{trimmed}/chat/completions") } } fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { headers .get(REQUEST_ID_HEADER) .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) .and_then(|value| value.to_str().ok()) .map(ToOwned::to_owned) } async fn expect_success(response: reqwest::Response) -> Result { let status = response.status(); if status.is_success() { return Ok(response); } let body = response.text().await.unwrap_or_default(); let parsed_error = serde_json::from_str::(&body).ok(); let retryable = is_retryable_status(status); Err(ApiError::Api { status, error_type: parsed_error .as_ref() .and_then(|error| error.error.error_type.clone()), message: parsed_error .as_ref() .and_then(|error| error.error.message.clone()), body, retryable, }) } const fn is_retryable_status(status: reqwest::StatusCode) -> bool { matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) } fn normalize_finish_reason(value: &str) -> String { match value { "stop" => "end_turn", "tool_calls" => "tool_use", other => other, } .to_string() } trait StringExt { fn if_empty_then(self, fallback: String) -> String; } impl StringExt for String { fn if_empty_then(self, fallback: String) -> String { if self.is_empty() { fallback } else { self } } } #[cfg(test)] mod tests { use super::{ build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason, openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig, }; use crate::error::ApiError; use crate::types::{ InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition, ToolResultContentBlock, }; use serde_json::json; use std::sync::{Mutex, OnceLock}; #[test] fn request_translation_uses_openai_compatible_shape() { let payload = build_chat_completion_request(&MessageRequest { model: "grok-3".to_string(), max_tokens: 64, messages: vec![InputMessage { role: "user".to_string(), content: vec![ InputContentBlock::Text { text: "hello".to_string(), }, InputContentBlock::ToolResult { tool_use_id: "tool_1".to_string(), content: vec![ToolResultContentBlock::Json { value: json!({"ok": true}), }], is_error: false, }, ], }], system: Some("be helpful".to_string()), tools: Some(vec![ToolDefinition { name: "weather".to_string(), description: Some("Get weather".to_string()), input_schema: json!({"type": "object"}), }]), tool_choice: Some(ToolChoice::Auto), stream: false, }); assert_eq!(payload["messages"][0]["role"], json!("system")); assert_eq!(payload["messages"][1]["role"], json!("user")); assert_eq!(payload["messages"][2]["role"], json!("tool")); assert_eq!(payload["tools"][0]["type"], json!("function")); assert_eq!(payload["tool_choice"], json!("auto")); } #[test] fn tool_choice_translation_supports_required_function() { assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required")); assert_eq!( openai_tool_choice(&ToolChoice::Tool { name: "weather".to_string(), }), json!({"type": "function", "function": {"name": "weather"}}) ); } #[test] fn parses_tool_arguments_fallback() { assert_eq!( parse_tool_arguments("{\"city\":\"Paris\"}"), json!({"city": "Paris"}) ); assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"})); } #[test] fn missing_xai_api_key_is_provider_specific() { let _lock = env_lock(); std::env::remove_var("XAI_API_KEY"); let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai()) .expect_err("missing key should error"); assert!(matches!( error, ApiError::MissingCredentials { provider: "xAI", .. } )); } #[test] fn endpoint_builder_accepts_base_urls_and_full_endpoints() { assert_eq!( chat_completions_endpoint("https://api.x.ai/v1"), "https://api.x.ai/v1/chat/completions" ); assert_eq!( chat_completions_endpoint("https://api.x.ai/v1/"), "https://api.x.ai/v1/chat/completions" ); assert_eq!( chat_completions_endpoint("https://api.x.ai/v1/chat/completions"), "https://api.x.ai/v1/chat/completions" ); } fn env_lock() -> std::sync::MutexGuard<'static, ()> { static LOCK: OnceLock> = OnceLock::new(); LOCK.get_or_init(|| Mutex::new(())) .lock() .expect("env lock") } #[test] fn normalizes_stop_reasons() { assert_eq!(normalize_finish_reason("stop"), "end_turn"); assert_eq!(normalize_finish_reason("tool_calls"), "tool_use"); } }