mirror of
https://github.com/instructkr/claude-code.git
synced 2026-04-03 17:18:49 +03:00
1051 lines
32 KiB
Rust
1051 lines
32 KiB
Rust
use std::collections::{BTreeMap, VecDeque};
|
|
use std::time::Duration;
|
|
|
|
use serde::Deserialize;
|
|
use serde_json::{json, Value};
|
|
|
|
use crate::error::ApiError;
|
|
use crate::types::{
|
|
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
|
|
InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest,
|
|
MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent,
|
|
ToolChoice, ToolDefinition, ToolResultContentBlock, Usage,
|
|
};
|
|
|
|
use super::{Provider, ProviderFuture};
|
|
|
|
pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1";
|
|
pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
|
|
const REQUEST_ID_HEADER: &str = "request-id";
|
|
const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
|
|
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
|
|
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
|
|
const DEFAULT_MAX_RETRIES: u32 = 2;
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub struct OpenAiCompatConfig {
|
|
pub provider_name: &'static str,
|
|
pub api_key_env: &'static str,
|
|
pub base_url_env: &'static str,
|
|
pub default_base_url: &'static str,
|
|
}
|
|
|
|
const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"];
|
|
const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"];
|
|
|
|
impl OpenAiCompatConfig {
|
|
#[must_use]
|
|
pub const fn xai() -> Self {
|
|
Self {
|
|
provider_name: "xAI",
|
|
api_key_env: "XAI_API_KEY",
|
|
base_url_env: "XAI_BASE_URL",
|
|
default_base_url: DEFAULT_XAI_BASE_URL,
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
pub const fn openai() -> Self {
|
|
Self {
|
|
provider_name: "OpenAI",
|
|
api_key_env: "OPENAI_API_KEY",
|
|
base_url_env: "OPENAI_BASE_URL",
|
|
default_base_url: DEFAULT_OPENAI_BASE_URL,
|
|
}
|
|
}
|
|
#[must_use]
|
|
pub fn credential_env_vars(self) -> &'static [&'static str] {
|
|
match self.provider_name {
|
|
"xAI" => XAI_ENV_VARS,
|
|
"OpenAI" => OPENAI_ENV_VARS,
|
|
_ => &[],
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct OpenAiCompatClient {
|
|
http: reqwest::Client,
|
|
api_key: String,
|
|
base_url: String,
|
|
max_retries: u32,
|
|
initial_backoff: Duration,
|
|
max_backoff: Duration,
|
|
}
|
|
|
|
impl OpenAiCompatClient {
|
|
#[must_use]
|
|
pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
|
|
Self {
|
|
http: reqwest::Client::new(),
|
|
api_key: api_key.into(),
|
|
base_url: read_base_url(config),
|
|
max_retries: DEFAULT_MAX_RETRIES,
|
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
|
max_backoff: DEFAULT_MAX_BACKOFF,
|
|
}
|
|
}
|
|
|
|
pub fn from_env(config: OpenAiCompatConfig) -> Result<Self, ApiError> {
|
|
let Some(api_key) = read_env_non_empty(config.api_key_env)? else {
|
|
return Err(ApiError::missing_credentials(
|
|
config.provider_name,
|
|
config.credential_env_vars(),
|
|
));
|
|
};
|
|
Ok(Self::new(api_key, config))
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
|
|
self.base_url = base_url.into();
|
|
self
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn with_retry_policy(
|
|
mut self,
|
|
max_retries: u32,
|
|
initial_backoff: Duration,
|
|
max_backoff: Duration,
|
|
) -> Self {
|
|
self.max_retries = max_retries;
|
|
self.initial_backoff = initial_backoff;
|
|
self.max_backoff = max_backoff;
|
|
self
|
|
}
|
|
|
|
pub async fn send_message(
|
|
&self,
|
|
request: &MessageRequest,
|
|
) -> Result<MessageResponse, ApiError> {
|
|
let request = MessageRequest {
|
|
stream: false,
|
|
..request.clone()
|
|
};
|
|
let response = self.send_with_retry(&request).await?;
|
|
let request_id = request_id_from_headers(response.headers());
|
|
let payload = response.json::<ChatCompletionResponse>().await?;
|
|
let mut normalized = normalize_response(&request.model, payload)?;
|
|
if normalized.request_id.is_none() {
|
|
normalized.request_id = request_id;
|
|
}
|
|
Ok(normalized)
|
|
}
|
|
|
|
pub async fn stream_message(
|
|
&self,
|
|
request: &MessageRequest,
|
|
) -> Result<MessageStream, ApiError> {
|
|
let response = self
|
|
.send_with_retry(&request.clone().with_streaming())
|
|
.await?;
|
|
Ok(MessageStream {
|
|
request_id: request_id_from_headers(response.headers()),
|
|
response,
|
|
parser: OpenAiSseParser::new(),
|
|
pending: VecDeque::new(),
|
|
done: false,
|
|
state: StreamState::new(request.model.clone()),
|
|
})
|
|
}
|
|
|
|
async fn send_with_retry(
|
|
&self,
|
|
request: &MessageRequest,
|
|
) -> Result<reqwest::Response, ApiError> {
|
|
let mut attempts = 0;
|
|
|
|
let last_error = loop {
|
|
attempts += 1;
|
|
let retryable_error = match self.send_raw_request(request).await {
|
|
Ok(response) => match expect_success(response).await {
|
|
Ok(response) => return Ok(response),
|
|
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
|
|
Err(error) => return Err(error),
|
|
},
|
|
Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error,
|
|
Err(error) => return Err(error),
|
|
};
|
|
|
|
if attempts > self.max_retries {
|
|
break retryable_error;
|
|
}
|
|
|
|
tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
|
|
};
|
|
|
|
Err(ApiError::RetriesExhausted {
|
|
attempts,
|
|
last_error: Box::new(last_error),
|
|
})
|
|
}
|
|
|
|
async fn send_raw_request(
|
|
&self,
|
|
request: &MessageRequest,
|
|
) -> Result<reqwest::Response, ApiError> {
|
|
let request_url = chat_completions_endpoint(&self.base_url);
|
|
self.http
|
|
.post(&request_url)
|
|
.header("content-type", "application/json")
|
|
.bearer_auth(&self.api_key)
|
|
.json(&build_chat_completion_request(request))
|
|
.send()
|
|
.await
|
|
.map_err(ApiError::from)
|
|
}
|
|
|
|
fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
|
|
let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
|
|
return Err(ApiError::BackoffOverflow {
|
|
attempt,
|
|
base_delay: self.initial_backoff,
|
|
});
|
|
};
|
|
Ok(self
|
|
.initial_backoff
|
|
.checked_mul(multiplier)
|
|
.map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
|
|
}
|
|
}
|
|
|
|
impl Provider for OpenAiCompatClient {
|
|
type Stream = MessageStream;
|
|
|
|
fn send_message<'a>(
|
|
&'a self,
|
|
request: &'a MessageRequest,
|
|
) -> ProviderFuture<'a, MessageResponse> {
|
|
Box::pin(async move { self.send_message(request).await })
|
|
}
|
|
|
|
fn stream_message<'a>(
|
|
&'a self,
|
|
request: &'a MessageRequest,
|
|
) -> ProviderFuture<'a, Self::Stream> {
|
|
Box::pin(async move { self.stream_message(request).await })
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct MessageStream {
|
|
request_id: Option<String>,
|
|
response: reqwest::Response,
|
|
parser: OpenAiSseParser,
|
|
pending: VecDeque<StreamEvent>,
|
|
done: bool,
|
|
state: StreamState,
|
|
}
|
|
|
|
impl MessageStream {
|
|
#[must_use]
|
|
pub fn request_id(&self) -> Option<&str> {
|
|
self.request_id.as_deref()
|
|
}
|
|
|
|
pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
|
|
loop {
|
|
if let Some(event) = self.pending.pop_front() {
|
|
return Ok(Some(event));
|
|
}
|
|
|
|
if self.done {
|
|
self.pending.extend(self.state.finish()?);
|
|
if let Some(event) = self.pending.pop_front() {
|
|
return Ok(Some(event));
|
|
}
|
|
return Ok(None);
|
|
}
|
|
|
|
match self.response.chunk().await? {
|
|
Some(chunk) => {
|
|
for parsed in self.parser.push(&chunk)? {
|
|
self.pending.extend(self.state.ingest_chunk(parsed)?);
|
|
}
|
|
}
|
|
None => {
|
|
self.done = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Default)]
|
|
struct OpenAiSseParser {
|
|
buffer: Vec<u8>,
|
|
}
|
|
|
|
impl OpenAiSseParser {
|
|
fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
fn push(&mut self, chunk: &[u8]) -> Result<Vec<ChatCompletionChunk>, ApiError> {
|
|
self.buffer.extend_from_slice(chunk);
|
|
let mut events = Vec::new();
|
|
|
|
while let Some(frame) = next_sse_frame(&mut self.buffer) {
|
|
if let Some(event) = parse_sse_frame(&frame)? {
|
|
events.push(event);
|
|
}
|
|
}
|
|
|
|
Ok(events)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct StreamState {
|
|
model: String,
|
|
message_started: bool,
|
|
text_started: bool,
|
|
text_finished: bool,
|
|
finished: bool,
|
|
stop_reason: Option<String>,
|
|
usage: Option<Usage>,
|
|
tool_calls: BTreeMap<u32, ToolCallState>,
|
|
}
|
|
|
|
impl StreamState {
|
|
fn new(model: String) -> Self {
|
|
Self {
|
|
model,
|
|
message_started: false,
|
|
text_started: false,
|
|
text_finished: false,
|
|
finished: false,
|
|
stop_reason: None,
|
|
usage: None,
|
|
tool_calls: BTreeMap::new(),
|
|
}
|
|
}
|
|
|
|
fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result<Vec<StreamEvent>, ApiError> {
|
|
let mut events = Vec::new();
|
|
if !self.message_started {
|
|
self.message_started = true;
|
|
events.push(StreamEvent::MessageStart(MessageStartEvent {
|
|
message: MessageResponse {
|
|
id: chunk.id.clone(),
|
|
kind: "message".to_string(),
|
|
role: "assistant".to_string(),
|
|
content: Vec::new(),
|
|
model: chunk.model.clone().unwrap_or_else(|| self.model.clone()),
|
|
stop_reason: None,
|
|
stop_sequence: None,
|
|
usage: Usage {
|
|
input_tokens: 0,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
output_tokens: 0,
|
|
},
|
|
request_id: None,
|
|
},
|
|
}));
|
|
}
|
|
|
|
if let Some(usage) = chunk.usage {
|
|
self.usage = Some(Usage {
|
|
input_tokens: usage.prompt_tokens,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
output_tokens: usage.completion_tokens,
|
|
});
|
|
}
|
|
|
|
for choice in chunk.choices {
|
|
if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) {
|
|
if !self.text_started {
|
|
self.text_started = true;
|
|
events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent {
|
|
index: 0,
|
|
content_block: OutputContentBlock::Text {
|
|
text: String::new(),
|
|
},
|
|
}));
|
|
}
|
|
events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
|
|
index: 0,
|
|
delta: ContentBlockDelta::TextDelta { text: content },
|
|
}));
|
|
}
|
|
|
|
for tool_call in choice.delta.tool_calls {
|
|
let state = self.tool_calls.entry(tool_call.index).or_default();
|
|
state.apply(tool_call);
|
|
let block_index = state.block_index();
|
|
if !state.started {
|
|
if let Some(start_event) = state.start_event()? {
|
|
state.started = true;
|
|
events.push(StreamEvent::ContentBlockStart(start_event));
|
|
} else {
|
|
continue;
|
|
}
|
|
}
|
|
if let Some(delta_event) = state.delta_event() {
|
|
events.push(StreamEvent::ContentBlockDelta(delta_event));
|
|
}
|
|
if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped {
|
|
state.stopped = true;
|
|
events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
|
|
index: block_index,
|
|
}));
|
|
}
|
|
}
|
|
|
|
if let Some(finish_reason) = choice.finish_reason {
|
|
self.stop_reason = Some(normalize_finish_reason(&finish_reason));
|
|
if finish_reason == "tool_calls" {
|
|
for state in self.tool_calls.values_mut() {
|
|
if state.started && !state.stopped {
|
|
state.stopped = true;
|
|
events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
|
|
index: state.block_index(),
|
|
}));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(events)
|
|
}
|
|
|
|
fn finish(&mut self) -> Result<Vec<StreamEvent>, ApiError> {
|
|
if self.finished {
|
|
return Ok(Vec::new());
|
|
}
|
|
self.finished = true;
|
|
|
|
let mut events = Vec::new();
|
|
if self.text_started && !self.text_finished {
|
|
self.text_finished = true;
|
|
events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
|
|
index: 0,
|
|
}));
|
|
}
|
|
|
|
for state in self.tool_calls.values_mut() {
|
|
if !state.started {
|
|
if let Some(start_event) = state.start_event()? {
|
|
state.started = true;
|
|
events.push(StreamEvent::ContentBlockStart(start_event));
|
|
if let Some(delta_event) = state.delta_event() {
|
|
events.push(StreamEvent::ContentBlockDelta(delta_event));
|
|
}
|
|
}
|
|
}
|
|
if state.started && !state.stopped {
|
|
state.stopped = true;
|
|
events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent {
|
|
index: state.block_index(),
|
|
}));
|
|
}
|
|
}
|
|
|
|
if self.message_started {
|
|
events.push(StreamEvent::MessageDelta(MessageDeltaEvent {
|
|
delta: MessageDelta {
|
|
stop_reason: Some(
|
|
self.stop_reason
|
|
.clone()
|
|
.unwrap_or_else(|| "end_turn".to_string()),
|
|
),
|
|
stop_sequence: None,
|
|
},
|
|
usage: self.usage.clone().unwrap_or(Usage {
|
|
input_tokens: 0,
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
output_tokens: 0,
|
|
}),
|
|
}));
|
|
events.push(StreamEvent::MessageStop(MessageStopEvent {}));
|
|
}
|
|
Ok(events)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Default)]
|
|
struct ToolCallState {
|
|
openai_index: u32,
|
|
id: Option<String>,
|
|
name: Option<String>,
|
|
arguments: String,
|
|
emitted_len: usize,
|
|
started: bool,
|
|
stopped: bool,
|
|
}
|
|
|
|
impl ToolCallState {
|
|
fn apply(&mut self, tool_call: DeltaToolCall) {
|
|
self.openai_index = tool_call.index;
|
|
if let Some(id) = tool_call.id {
|
|
self.id = Some(id);
|
|
}
|
|
if let Some(name) = tool_call.function.name {
|
|
self.name = Some(name);
|
|
}
|
|
if let Some(arguments) = tool_call.function.arguments {
|
|
self.arguments.push_str(&arguments);
|
|
}
|
|
}
|
|
|
|
const fn block_index(&self) -> u32 {
|
|
self.openai_index + 1
|
|
}
|
|
|
|
fn start_event(&self) -> Result<Option<ContentBlockStartEvent>, ApiError> {
|
|
let Some(name) = self.name.clone() else {
|
|
return Ok(None);
|
|
};
|
|
let id = self
|
|
.id
|
|
.clone()
|
|
.unwrap_or_else(|| format!("tool_call_{}", self.openai_index));
|
|
Ok(Some(ContentBlockStartEvent {
|
|
index: self.block_index(),
|
|
content_block: OutputContentBlock::ToolUse {
|
|
id,
|
|
name,
|
|
input: json!({}),
|
|
},
|
|
}))
|
|
}
|
|
|
|
fn delta_event(&mut self) -> Option<ContentBlockDeltaEvent> {
|
|
if self.emitted_len >= self.arguments.len() {
|
|
return None;
|
|
}
|
|
let delta = self.arguments[self.emitted_len..].to_string();
|
|
self.emitted_len = self.arguments.len();
|
|
Some(ContentBlockDeltaEvent {
|
|
index: self.block_index(),
|
|
delta: ContentBlockDelta::InputJsonDelta {
|
|
partial_json: delta,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ChatCompletionResponse {
|
|
id: String,
|
|
model: String,
|
|
choices: Vec<ChatChoice>,
|
|
#[serde(default)]
|
|
usage: Option<OpenAiUsage>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ChatChoice {
|
|
message: ChatMessage,
|
|
#[serde(default)]
|
|
finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ChatMessage {
|
|
role: String,
|
|
#[serde(default)]
|
|
content: Option<String>,
|
|
#[serde(default)]
|
|
tool_calls: Vec<ResponseToolCall>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ResponseToolCall {
|
|
id: String,
|
|
function: ResponseToolFunction,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ResponseToolFunction {
|
|
name: String,
|
|
arguments: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OpenAiUsage {
|
|
#[serde(default)]
|
|
prompt_tokens: u32,
|
|
#[serde(default)]
|
|
completion_tokens: u32,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ChatCompletionChunk {
|
|
id: String,
|
|
#[serde(default)]
|
|
model: Option<String>,
|
|
#[serde(default)]
|
|
choices: Vec<ChunkChoice>,
|
|
#[serde(default)]
|
|
usage: Option<OpenAiUsage>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ChunkChoice {
|
|
delta: ChunkDelta,
|
|
#[serde(default)]
|
|
finish_reason: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Default, Deserialize)]
|
|
struct ChunkDelta {
|
|
#[serde(default)]
|
|
content: Option<String>,
|
|
#[serde(default)]
|
|
tool_calls: Vec<DeltaToolCall>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct DeltaToolCall {
|
|
#[serde(default)]
|
|
index: u32,
|
|
#[serde(default)]
|
|
id: Option<String>,
|
|
#[serde(default)]
|
|
function: DeltaFunction,
|
|
}
|
|
|
|
#[derive(Debug, Default, Deserialize)]
|
|
struct DeltaFunction {
|
|
#[serde(default)]
|
|
name: Option<String>,
|
|
#[serde(default)]
|
|
arguments: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ErrorEnvelope {
|
|
error: ErrorBody,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct ErrorBody {
|
|
#[serde(rename = "type")]
|
|
error_type: Option<String>,
|
|
message: Option<String>,
|
|
}
|
|
|
|
fn build_chat_completion_request(request: &MessageRequest) -> Value {
|
|
let mut messages = Vec::new();
|
|
if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
|
|
messages.push(json!({
|
|
"role": "system",
|
|
"content": system,
|
|
}));
|
|
}
|
|
for message in &request.messages {
|
|
messages.extend(translate_message(message));
|
|
}
|
|
|
|
let mut payload = json!({
|
|
"model": request.model,
|
|
"max_tokens": request.max_tokens,
|
|
"messages": messages,
|
|
"stream": request.stream,
|
|
});
|
|
|
|
if let Some(tools) = &request.tools {
|
|
payload["tools"] =
|
|
Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
|
|
}
|
|
if let Some(tool_choice) = &request.tool_choice {
|
|
payload["tool_choice"] = openai_tool_choice(tool_choice);
|
|
}
|
|
|
|
payload
|
|
}
|
|
|
|
fn translate_message(message: &InputMessage) -> Vec<Value> {
|
|
match message.role.as_str() {
|
|
"assistant" => {
|
|
let mut text = String::new();
|
|
let mut tool_calls = Vec::new();
|
|
for block in &message.content {
|
|
match block {
|
|
InputContentBlock::Text { text: value } => text.push_str(value),
|
|
InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({
|
|
"id": id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": name,
|
|
"arguments": input.to_string(),
|
|
}
|
|
})),
|
|
InputContentBlock::ToolResult { .. } => {}
|
|
}
|
|
}
|
|
if text.is_empty() && tool_calls.is_empty() {
|
|
Vec::new()
|
|
} else {
|
|
vec![json!({
|
|
"role": "assistant",
|
|
"content": (!text.is_empty()).then_some(text),
|
|
"tool_calls": tool_calls,
|
|
})]
|
|
}
|
|
}
|
|
_ => message
|
|
.content
|
|
.iter()
|
|
.filter_map(|block| match block {
|
|
InputContentBlock::Text { text } => Some(json!({
|
|
"role": "user",
|
|
"content": text,
|
|
})),
|
|
InputContentBlock::ToolResult {
|
|
tool_use_id,
|
|
content,
|
|
is_error,
|
|
} => Some(json!({
|
|
"role": "tool",
|
|
"tool_call_id": tool_use_id,
|
|
"content": flatten_tool_result_content(content),
|
|
"is_error": is_error,
|
|
})),
|
|
InputContentBlock::ToolUse { .. } => None,
|
|
})
|
|
.collect(),
|
|
}
|
|
}
|
|
|
|
fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String {
|
|
content
|
|
.iter()
|
|
.map(|block| match block {
|
|
ToolResultContentBlock::Text { text } => text.clone(),
|
|
ToolResultContentBlock::Json { value } => value.to_string(),
|
|
})
|
|
.collect::<Vec<_>>()
|
|
.join("\n")
|
|
}
|
|
|
|
fn openai_tool_definition(tool: &ToolDefinition) -> Value {
|
|
json!({
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"parameters": tool.input_schema,
|
|
}
|
|
})
|
|
}
|
|
|
|
fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
|
|
match tool_choice {
|
|
ToolChoice::Auto => Value::String("auto".to_string()),
|
|
ToolChoice::Any => Value::String("required".to_string()),
|
|
ToolChoice::Tool { name } => json!({
|
|
"type": "function",
|
|
"function": { "name": name },
|
|
}),
|
|
}
|
|
}
|
|
|
|
fn normalize_response(
|
|
model: &str,
|
|
response: ChatCompletionResponse,
|
|
) -> Result<MessageResponse, ApiError> {
|
|
let choice = response
|
|
.choices
|
|
.into_iter()
|
|
.next()
|
|
.ok_or(ApiError::InvalidSseFrame(
|
|
"chat completion response missing choices",
|
|
))?;
|
|
let mut content = Vec::new();
|
|
if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) {
|
|
content.push(OutputContentBlock::Text { text });
|
|
}
|
|
for tool_call in choice.message.tool_calls {
|
|
content.push(OutputContentBlock::ToolUse {
|
|
id: tool_call.id,
|
|
name: tool_call.function.name,
|
|
input: parse_tool_arguments(&tool_call.function.arguments),
|
|
});
|
|
}
|
|
|
|
Ok(MessageResponse {
|
|
id: response.id,
|
|
kind: "message".to_string(),
|
|
role: choice.message.role,
|
|
content,
|
|
model: response.model.if_empty_then(model.to_string()),
|
|
stop_reason: choice
|
|
.finish_reason
|
|
.map(|value| normalize_finish_reason(&value)),
|
|
stop_sequence: None,
|
|
usage: Usage {
|
|
input_tokens: response
|
|
.usage
|
|
.as_ref()
|
|
.map_or(0, |usage| usage.prompt_tokens),
|
|
cache_creation_input_tokens: 0,
|
|
cache_read_input_tokens: 0,
|
|
output_tokens: response
|
|
.usage
|
|
.as_ref()
|
|
.map_or(0, |usage| usage.completion_tokens),
|
|
},
|
|
request_id: None,
|
|
})
|
|
}
|
|
|
|
fn parse_tool_arguments(arguments: &str) -> Value {
|
|
serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments }))
|
|
}
|
|
|
|
fn next_sse_frame(buffer: &mut Vec<u8>) -> Option<String> {
|
|
let separator = buffer
|
|
.windows(2)
|
|
.position(|window| window == b"\n\n")
|
|
.map(|position| (position, 2))
|
|
.or_else(|| {
|
|
buffer
|
|
.windows(4)
|
|
.position(|window| window == b"\r\n\r\n")
|
|
.map(|position| (position, 4))
|
|
})?;
|
|
|
|
let (position, separator_len) = separator;
|
|
let frame = buffer.drain(..position + separator_len).collect::<Vec<_>>();
|
|
let frame_len = frame.len().saturating_sub(separator_len);
|
|
Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned())
|
|
}
|
|
|
|
fn parse_sse_frame(frame: &str) -> Result<Option<ChatCompletionChunk>, ApiError> {
|
|
let trimmed = frame.trim();
|
|
if trimmed.is_empty() {
|
|
return Ok(None);
|
|
}
|
|
|
|
let mut data_lines = Vec::new();
|
|
for line in trimmed.lines() {
|
|
if line.starts_with(':') {
|
|
continue;
|
|
}
|
|
if let Some(data) = line.strip_prefix("data:") {
|
|
data_lines.push(data.trim_start());
|
|
}
|
|
}
|
|
if data_lines.is_empty() {
|
|
return Ok(None);
|
|
}
|
|
let payload = data_lines.join("\n");
|
|
if payload == "[DONE]" {
|
|
return Ok(None);
|
|
}
|
|
serde_json::from_str(&payload)
|
|
.map(Some)
|
|
.map_err(ApiError::from)
|
|
}
|
|
|
|
fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
|
|
match std::env::var(key) {
|
|
Ok(value) if !value.is_empty() => Ok(Some(value)),
|
|
Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
|
|
Err(error) => Err(ApiError::from(error)),
|
|
}
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn has_api_key(key: &str) -> bool {
|
|
read_env_non_empty(key)
|
|
.ok()
|
|
.and_then(std::convert::identity)
|
|
.is_some()
|
|
}
|
|
|
|
#[must_use]
|
|
pub fn read_base_url(config: OpenAiCompatConfig) -> String {
|
|
std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string())
|
|
}
|
|
|
|
fn chat_completions_endpoint(base_url: &str) -> String {
|
|
let trimmed = base_url.trim_end_matches('/');
|
|
if trimmed.ends_with("/chat/completions") {
|
|
trimmed.to_string()
|
|
} else {
|
|
format!("{trimmed}/chat/completions")
|
|
}
|
|
}
|
|
|
|
fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
|
|
headers
|
|
.get(REQUEST_ID_HEADER)
|
|
.or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
|
|
.and_then(|value| value.to_str().ok())
|
|
.map(ToOwned::to_owned)
|
|
}
|
|
|
|
async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
|
|
let status = response.status();
|
|
if status.is_success() {
|
|
return Ok(response);
|
|
}
|
|
|
|
let body = response.text().await.unwrap_or_default();
|
|
let parsed_error = serde_json::from_str::<ErrorEnvelope>(&body).ok();
|
|
let retryable = is_retryable_status(status);
|
|
|
|
Err(ApiError::Api {
|
|
status,
|
|
error_type: parsed_error
|
|
.as_ref()
|
|
.and_then(|error| error.error.error_type.clone()),
|
|
message: parsed_error
|
|
.as_ref()
|
|
.and_then(|error| error.error.message.clone()),
|
|
body,
|
|
retryable,
|
|
})
|
|
}
|
|
|
|
const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
|
|
matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
|
|
}
|
|
|
|
fn normalize_finish_reason(value: &str) -> String {
|
|
match value {
|
|
"stop" => "end_turn",
|
|
"tool_calls" => "tool_use",
|
|
other => other,
|
|
}
|
|
.to_string()
|
|
}
|
|
|
|
trait StringExt {
|
|
fn if_empty_then(self, fallback: String) -> String;
|
|
}
|
|
|
|
impl StringExt for String {
|
|
fn if_empty_then(self, fallback: String) -> String {
|
|
if self.is_empty() {
|
|
fallback
|
|
} else {
|
|
self
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::{
|
|
build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason,
|
|
openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig,
|
|
};
|
|
use crate::error::ApiError;
|
|
use crate::types::{
|
|
InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition,
|
|
ToolResultContentBlock,
|
|
};
|
|
use serde_json::json;
|
|
use std::sync::{Mutex, OnceLock};
|
|
|
|
#[test]
|
|
fn request_translation_uses_openai_compatible_shape() {
|
|
let payload = build_chat_completion_request(&MessageRequest {
|
|
model: "grok-3".to_string(),
|
|
max_tokens: 64,
|
|
messages: vec![InputMessage {
|
|
role: "user".to_string(),
|
|
content: vec![
|
|
InputContentBlock::Text {
|
|
text: "hello".to_string(),
|
|
},
|
|
InputContentBlock::ToolResult {
|
|
tool_use_id: "tool_1".to_string(),
|
|
content: vec![ToolResultContentBlock::Json {
|
|
value: json!({"ok": true}),
|
|
}],
|
|
is_error: false,
|
|
},
|
|
],
|
|
}],
|
|
system: Some("be helpful".to_string()),
|
|
tools: Some(vec![ToolDefinition {
|
|
name: "weather".to_string(),
|
|
description: Some("Get weather".to_string()),
|
|
input_schema: json!({"type": "object"}),
|
|
}]),
|
|
tool_choice: Some(ToolChoice::Auto),
|
|
stream: false,
|
|
});
|
|
|
|
assert_eq!(payload["messages"][0]["role"], json!("system"));
|
|
assert_eq!(payload["messages"][1]["role"], json!("user"));
|
|
assert_eq!(payload["messages"][2]["role"], json!("tool"));
|
|
assert_eq!(payload["tools"][0]["type"], json!("function"));
|
|
assert_eq!(payload["tool_choice"], json!("auto"));
|
|
}
|
|
|
|
#[test]
|
|
fn tool_choice_translation_supports_required_function() {
|
|
assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required"));
|
|
assert_eq!(
|
|
openai_tool_choice(&ToolChoice::Tool {
|
|
name: "weather".to_string(),
|
|
}),
|
|
json!({"type": "function", "function": {"name": "weather"}})
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn parses_tool_arguments_fallback() {
|
|
assert_eq!(
|
|
parse_tool_arguments("{\"city\":\"Paris\"}"),
|
|
json!({"city": "Paris"})
|
|
);
|
|
assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"}));
|
|
}
|
|
|
|
#[test]
|
|
fn missing_xai_api_key_is_provider_specific() {
|
|
let _lock = env_lock();
|
|
std::env::remove_var("XAI_API_KEY");
|
|
let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai())
|
|
.expect_err("missing key should error");
|
|
assert!(matches!(
|
|
error,
|
|
ApiError::MissingCredentials {
|
|
provider: "xAI",
|
|
..
|
|
}
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn endpoint_builder_accepts_base_urls_and_full_endpoints() {
|
|
assert_eq!(
|
|
chat_completions_endpoint("https://api.x.ai/v1"),
|
|
"https://api.x.ai/v1/chat/completions"
|
|
);
|
|
assert_eq!(
|
|
chat_completions_endpoint("https://api.x.ai/v1/"),
|
|
"https://api.x.ai/v1/chat/completions"
|
|
);
|
|
assert_eq!(
|
|
chat_completions_endpoint("https://api.x.ai/v1/chat/completions"),
|
|
"https://api.x.ai/v1/chat/completions"
|
|
);
|
|
}
|
|
|
|
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
|
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
|
LOCK.get_or_init(|| Mutex::new(()))
|
|
.lock()
|
|
.expect("env lock")
|
|
}
|
|
|
|
#[test]
|
|
fn normalizes_stop_reasons() {
|
|
assert_eq!(normalize_finish_reason("stop"), "end_turn");
|
|
assert_eq!(normalize_finish_reason("tool_calls"), "tool_use");
|
|
}
|
|
}
|