merge: clawcode-issue-9408-api-sse-streaming into main

This commit is contained in:
Jobdori
2026-04-03 05:08:03 +09:00
2 changed files with 160 additions and 31 deletions

View File

@@ -67,6 +67,7 @@ impl OpenAiCompatConfig {
pub struct OpenAiCompatClient { pub struct OpenAiCompatClient {
http: reqwest::Client, http: reqwest::Client,
api_key: String, api_key: String,
config: OpenAiCompatConfig,
base_url: String, base_url: String,
max_retries: u32, max_retries: u32,
initial_backoff: Duration, initial_backoff: Duration,
@@ -74,11 +75,15 @@ pub struct OpenAiCompatClient {
} }
impl OpenAiCompatClient { impl OpenAiCompatClient {
const fn config(&self) -> OpenAiCompatConfig {
self.config
}
#[must_use] #[must_use]
pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self { pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
Self { Self {
http: reqwest::Client::new(), http: reqwest::Client::new(),
api_key: api_key.into(), api_key: api_key.into(),
config,
base_url: read_base_url(config), base_url: read_base_url(config),
max_retries: DEFAULT_MAX_RETRIES, max_retries: DEFAULT_MAX_RETRIES,
initial_backoff: DEFAULT_INITIAL_BACKOFF, initial_backoff: DEFAULT_INITIAL_BACKOFF,
@@ -190,7 +195,7 @@ impl OpenAiCompatClient {
.post(&request_url) .post(&request_url)
.header("content-type", "application/json") .header("content-type", "application/json")
.bearer_auth(&self.api_key) .bearer_auth(&self.api_key)
.json(&build_chat_completion_request(request)) .json(&build_chat_completion_request(request, self.config()))
.send() .send()
.await .await
.map_err(ApiError::from) .map_err(ApiError::from)
@@ -633,7 +638,7 @@ struct ErrorBody {
message: Option<String>, message: Option<String>,
} }
fn build_chat_completion_request(request: &MessageRequest) -> Value { fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatConfig) -> Value {
let mut messages = Vec::new(); let mut messages = Vec::new();
if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) { if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
messages.push(json!({ messages.push(json!({
@@ -652,6 +657,10 @@ fn build_chat_completion_request(request: &MessageRequest) -> Value {
"stream": request.stream, "stream": request.stream,
}); });
if request.stream && should_request_stream_usage(config) {
payload["stream_options"] = json!({ "include_usage": true });
}
if let Some(tools) = &request.tools { if let Some(tools) = &request.tools {
payload["tools"] = payload["tools"] =
Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>()); Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
@@ -749,6 +758,10 @@ fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
} }
} }
fn should_request_stream_usage(config: OpenAiCompatConfig) -> bool {
matches!(config.provider_name, "OpenAI")
}
fn normalize_response( fn normalize_response(
model: &str, model: &str,
response: ChatCompletionResponse, response: ChatCompletionResponse,
@@ -951,7 +964,8 @@ mod tests {
#[test] #[test]
fn request_translation_uses_openai_compatible_shape() { fn request_translation_uses_openai_compatible_shape() {
let payload = build_chat_completion_request(&MessageRequest { let payload = build_chat_completion_request(
&MessageRequest {
model: "grok-3".to_string(), model: "grok-3".to_string(),
max_tokens: 64, max_tokens: 64,
messages: vec![InputMessage { messages: vec![InputMessage {
@@ -977,7 +991,9 @@ mod tests {
}]), }]),
tool_choice: Some(ToolChoice::Auto), tool_choice: Some(ToolChoice::Auto),
stream: false, stream: false,
}); },
OpenAiCompatConfig::xai(),
);
assert_eq!(payload["messages"][0]["role"], json!("system")); assert_eq!(payload["messages"][0]["role"], json!("system"));
assert_eq!(payload["messages"][1]["role"], json!("user")); assert_eq!(payload["messages"][1]["role"], json!("user"));
@@ -986,6 +1002,42 @@ mod tests {
assert_eq!(payload["tool_choice"], json!("auto")); assert_eq!(payload["tool_choice"], json!("auto"));
} }
#[test]
fn openai_streaming_requests_include_usage_opt_in() {
let payload = build_chat_completion_request(
&MessageRequest {
model: "gpt-5".to_string(),
max_tokens: 64,
messages: vec![InputMessage::user_text("hello")],
system: None,
tools: None,
tool_choice: None,
stream: true,
},
OpenAiCompatConfig::openai(),
);
assert_eq!(payload["stream_options"], json!({"include_usage": true}));
}
#[test]
fn xai_streaming_requests_skip_openai_specific_usage_opt_in() {
let payload = build_chat_completion_request(
&MessageRequest {
model: "grok-3".to_string(),
max_tokens: 64,
messages: vec![InputMessage::user_text("hello")],
system: None,
tools: None,
tool_choice: None,
stream: true,
},
OpenAiCompatConfig::xai(),
);
assert!(payload.get("stream_options").is_none());
}
#[test] #[test]
fn tool_choice_translation_supports_required_function() { fn tool_choice_translation_supports_required_function() {
assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required")); assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required"));

View File

@@ -5,8 +5,9 @@ use std::sync::{Mutex as StdMutex, OnceLock};
use api::{ use api::{
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OpenAiCompatClient,
OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, OpenAiCompatConfig, OutputContentBlock, ProviderClient, StreamEvent, ToolChoice,
ToolDefinition,
}; };
use serde_json::json; use serde_json::json;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
@@ -195,6 +196,82 @@ async fn stream_message_normalizes_text_and_multiple_tool_calls() {
assert!(request.body.contains("\"stream\":true")); assert!(request.body.contains("\"stream\":true"));
} }
#[allow(clippy::await_holding_lock)]
#[tokio::test]
async fn openai_streaming_requests_opt_into_usage_chunks() {
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
let sse = concat!(
"data: {\"id\":\"chatcmpl_openai_stream\",\"model\":\"gpt-5\",\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n",
"data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n",
"data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}\n\n",
"data: [DONE]\n\n"
);
let server = spawn_server(
state.clone(),
vec![http_response_with_headers(
"200 OK",
"text/event-stream",
sse,
&[("x-request-id", "req_openai_stream")],
)],
)
.await;
let client = OpenAiCompatClient::new("openai-test-key", OpenAiCompatConfig::openai())
.with_base_url(server.base_url());
let mut stream = client
.stream_message(&sample_request(false))
.await
.expect("stream should start");
assert_eq!(stream.request_id(), Some("req_openai_stream"));
let mut events = Vec::new();
while let Some(event) = stream.next_event().await.expect("event should parse") {
events.push(event);
}
assert!(matches!(events[0], StreamEvent::MessageStart(_)));
assert!(matches!(
events[1],
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
content_block: OutputContentBlock::Text { .. },
..
})
));
assert!(matches!(
events[2],
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
delta: ContentBlockDelta::TextDelta { .. },
..
})
));
assert!(matches!(
events[3],
StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 })
));
assert!(matches!(
events[4],
StreamEvent::MessageDelta(MessageDeltaEvent { .. })
));
assert!(matches!(events[5], StreamEvent::MessageStop(_)));
match &events[4] {
StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => {
assert_eq!(usage.input_tokens, 9);
assert_eq!(usage.output_tokens, 4);
}
other => panic!("expected message delta, got {other:?}"),
}
let captured = state.lock().await;
let request = captured.first().expect("captured request");
assert_eq!(request.path, "/chat/completions");
let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
assert_eq!(body["stream"], json!(true));
assert_eq!(body["stream_options"], json!({"include_usage": true}));
}
#[allow(clippy::await_holding_lock)] #[allow(clippy::await_holding_lock)]
#[tokio::test] #[tokio::test]
async fn provider_client_dispatches_xai_requests_from_env() { async fn provider_client_dispatches_xai_requests_from_env() {