mirror of
https://github.com/instructkr/claude-code.git
synced 2026-04-06 11:18:51 +03:00
Preserve usage accounting on OpenAI SSE streams
OpenAI chat-completions streams can emit a final usage chunk when the\nclient opts in, but the Rust transport was not requesting it. This\nkeeps provider config on the client and adds stream_options.include_usage\nonly for OpenAI streams so normalized message_delta usage reflects the\ntransport without changing xAI request bodies.\n\nConstraint: Keep xAI request bodies unchanged because provider-specific streaming knobs may differ\nRejected: Enable stream_options for every OpenAI-compatible provider | risks sending unsupported params to xAI-style endpoints\nConfidence: high\nScope-risk: narrow\nDirective: Keep provider-specific streaming flags tied to OpenAiCompatConfig instead of inferring provider behavior from URLs\nTested: cargo clippy -p api --tests -- -D warnings\nTested: cargo test -p api openai_streaming_requests -- --nocapture\nTested: cargo test -p api xai_streaming_requests_skip_openai_specific_usage_opt_in -- --nocapture\nTested: cargo test -p api request_translation_uses_openai_compatible_shape -- --nocapture\nTested: cargo test -p api stream_message_normalizes_text_and_multiple_tool_calls -- --exact --nocapture\nNot-tested: Live OpenAI or xAI network calls
This commit is contained in:
@@ -67,6 +67,7 @@ impl OpenAiCompatConfig {
|
|||||||
pub struct OpenAiCompatClient {
|
pub struct OpenAiCompatClient {
|
||||||
http: reqwest::Client,
|
http: reqwest::Client,
|
||||||
api_key: String,
|
api_key: String,
|
||||||
|
config: OpenAiCompatConfig,
|
||||||
base_url: String,
|
base_url: String,
|
||||||
max_retries: u32,
|
max_retries: u32,
|
||||||
initial_backoff: Duration,
|
initial_backoff: Duration,
|
||||||
@@ -74,11 +75,15 @@ pub struct OpenAiCompatClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiCompatClient {
|
impl OpenAiCompatClient {
|
||||||
|
const fn config(&self) -> OpenAiCompatConfig {
|
||||||
|
self.config
|
||||||
|
}
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
|
pub fn new(api_key: impl Into<String>, config: OpenAiCompatConfig) -> Self {
|
||||||
Self {
|
Self {
|
||||||
http: reqwest::Client::new(),
|
http: reqwest::Client::new(),
|
||||||
api_key: api_key.into(),
|
api_key: api_key.into(),
|
||||||
|
config,
|
||||||
base_url: read_base_url(config),
|
base_url: read_base_url(config),
|
||||||
max_retries: DEFAULT_MAX_RETRIES,
|
max_retries: DEFAULT_MAX_RETRIES,
|
||||||
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
initial_backoff: DEFAULT_INITIAL_BACKOFF,
|
||||||
@@ -190,7 +195,7 @@ impl OpenAiCompatClient {
|
|||||||
.post(&request_url)
|
.post(&request_url)
|
||||||
.header("content-type", "application/json")
|
.header("content-type", "application/json")
|
||||||
.bearer_auth(&self.api_key)
|
.bearer_auth(&self.api_key)
|
||||||
.json(&build_chat_completion_request(request))
|
.json(&build_chat_completion_request(request, self.config()))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(ApiError::from)
|
.map_err(ApiError::from)
|
||||||
@@ -633,7 +638,7 @@ struct ErrorBody {
|
|||||||
message: Option<String>,
|
message: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_chat_completion_request(request: &MessageRequest) -> Value {
|
fn build_chat_completion_request(request: &MessageRequest, config: OpenAiCompatConfig) -> Value {
|
||||||
let mut messages = Vec::new();
|
let mut messages = Vec::new();
|
||||||
if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
|
if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) {
|
||||||
messages.push(json!({
|
messages.push(json!({
|
||||||
@@ -652,6 +657,10 @@ fn build_chat_completion_request(request: &MessageRequest) -> Value {
|
|||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if request.stream && should_request_stream_usage(config) {
|
||||||
|
payload["stream_options"] = json!({ "include_usage": true });
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(tools) = &request.tools {
|
if let Some(tools) = &request.tools {
|
||||||
payload["tools"] =
|
payload["tools"] =
|
||||||
Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
|
Value::Array(tools.iter().map(openai_tool_definition).collect::<Vec<_>>());
|
||||||
@@ -749,6 +758,10 @@ fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn should_request_stream_usage(config: OpenAiCompatConfig) -> bool {
|
||||||
|
matches!(config.provider_name, "OpenAI")
|
||||||
|
}
|
||||||
|
|
||||||
fn normalize_response(
|
fn normalize_response(
|
||||||
model: &str,
|
model: &str,
|
||||||
response: ChatCompletionResponse,
|
response: ChatCompletionResponse,
|
||||||
@@ -951,33 +964,36 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn request_translation_uses_openai_compatible_shape() {
|
fn request_translation_uses_openai_compatible_shape() {
|
||||||
let payload = build_chat_completion_request(&MessageRequest {
|
let payload = build_chat_completion_request(
|
||||||
model: "grok-3".to_string(),
|
&MessageRequest {
|
||||||
max_tokens: 64,
|
model: "grok-3".to_string(),
|
||||||
messages: vec![InputMessage {
|
max_tokens: 64,
|
||||||
role: "user".to_string(),
|
messages: vec![InputMessage {
|
||||||
content: vec![
|
role: "user".to_string(),
|
||||||
InputContentBlock::Text {
|
content: vec![
|
||||||
text: "hello".to_string(),
|
InputContentBlock::Text {
|
||||||
},
|
text: "hello".to_string(),
|
||||||
InputContentBlock::ToolResult {
|
},
|
||||||
tool_use_id: "tool_1".to_string(),
|
InputContentBlock::ToolResult {
|
||||||
content: vec![ToolResultContentBlock::Json {
|
tool_use_id: "tool_1".to_string(),
|
||||||
value: json!({"ok": true}),
|
content: vec![ToolResultContentBlock::Json {
|
||||||
}],
|
value: json!({"ok": true}),
|
||||||
is_error: false,
|
}],
|
||||||
},
|
is_error: false,
|
||||||
],
|
},
|
||||||
}],
|
],
|
||||||
system: Some("be helpful".to_string()),
|
}],
|
||||||
tools: Some(vec![ToolDefinition {
|
system: Some("be helpful".to_string()),
|
||||||
name: "weather".to_string(),
|
tools: Some(vec![ToolDefinition {
|
||||||
description: Some("Get weather".to_string()),
|
name: "weather".to_string(),
|
||||||
input_schema: json!({"type": "object"}),
|
description: Some("Get weather".to_string()),
|
||||||
}]),
|
input_schema: json!({"type": "object"}),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
}]),
|
||||||
stream: false,
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
});
|
stream: false,
|
||||||
|
},
|
||||||
|
OpenAiCompatConfig::xai(),
|
||||||
|
);
|
||||||
|
|
||||||
assert_eq!(payload["messages"][0]["role"], json!("system"));
|
assert_eq!(payload["messages"][0]["role"], json!("system"));
|
||||||
assert_eq!(payload["messages"][1]["role"], json!("user"));
|
assert_eq!(payload["messages"][1]["role"], json!("user"));
|
||||||
@@ -986,6 +1002,42 @@ mod tests {
|
|||||||
assert_eq!(payload["tool_choice"], json!("auto"));
|
assert_eq!(payload["tool_choice"], json!("auto"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn openai_streaming_requests_include_usage_opt_in() {
|
||||||
|
let payload = build_chat_completion_request(
|
||||||
|
&MessageRequest {
|
||||||
|
model: "gpt-5".to_string(),
|
||||||
|
max_tokens: 64,
|
||||||
|
messages: vec![InputMessage::user_text("hello")],
|
||||||
|
system: None,
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
|
stream: true,
|
||||||
|
},
|
||||||
|
OpenAiCompatConfig::openai(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(payload["stream_options"], json!({"include_usage": true}));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn xai_streaming_requests_skip_openai_specific_usage_opt_in() {
|
||||||
|
let payload = build_chat_completion_request(
|
||||||
|
&MessageRequest {
|
||||||
|
model: "grok-3".to_string(),
|
||||||
|
max_tokens: 64,
|
||||||
|
messages: vec![InputMessage::user_text("hello")],
|
||||||
|
system: None,
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
|
stream: true,
|
||||||
|
},
|
||||||
|
OpenAiCompatConfig::xai(),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(payload.get("stream_options").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn tool_choice_translation_supports_required_function() {
|
fn tool_choice_translation_supports_required_function() {
|
||||||
assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required"));
|
assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required"));
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ use std::sync::{Mutex as StdMutex, OnceLock};
|
|||||||
|
|
||||||
use api::{
|
use api::{
|
||||||
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
|
ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent,
|
||||||
InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig,
|
InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OpenAiCompatClient,
|
||||||
OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition,
|
OpenAiCompatConfig, OutputContentBlock, ProviderClient, StreamEvent, ToolChoice,
|
||||||
|
ToolDefinition,
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
@@ -195,6 +196,82 @@ async fn stream_message_normalizes_text_and_multiple_tool_calls() {
|
|||||||
assert!(request.body.contains("\"stream\":true"));
|
assert!(request.body.contains("\"stream\":true"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::await_holding_lock)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn openai_streaming_requests_opt_into_usage_chunks() {
|
||||||
|
let state = Arc::new(Mutex::new(Vec::<CapturedRequest>::new()));
|
||||||
|
let sse = concat!(
|
||||||
|
"data: {\"id\":\"chatcmpl_openai_stream\",\"model\":\"gpt-5\",\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n",
|
||||||
|
"data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}]}\n\n",
|
||||||
|
"data: {\"id\":\"chatcmpl_openai_stream\",\"choices\":[],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}\n\n",
|
||||||
|
"data: [DONE]\n\n"
|
||||||
|
);
|
||||||
|
let server = spawn_server(
|
||||||
|
state.clone(),
|
||||||
|
vec![http_response_with_headers(
|
||||||
|
"200 OK",
|
||||||
|
"text/event-stream",
|
||||||
|
sse,
|
||||||
|
&[("x-request-id", "req_openai_stream")],
|
||||||
|
)],
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let client = OpenAiCompatClient::new("openai-test-key", OpenAiCompatConfig::openai())
|
||||||
|
.with_base_url(server.base_url());
|
||||||
|
let mut stream = client
|
||||||
|
.stream_message(&sample_request(false))
|
||||||
|
.await
|
||||||
|
.expect("stream should start");
|
||||||
|
|
||||||
|
assert_eq!(stream.request_id(), Some("req_openai_stream"));
|
||||||
|
|
||||||
|
let mut events = Vec::new();
|
||||||
|
while let Some(event) = stream.next_event().await.expect("event should parse") {
|
||||||
|
events.push(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(matches!(events[0], StreamEvent::MessageStart(_)));
|
||||||
|
assert!(matches!(
|
||||||
|
events[1],
|
||||||
|
StreamEvent::ContentBlockStart(ContentBlockStartEvent {
|
||||||
|
content_block: OutputContentBlock::Text { .. },
|
||||||
|
..
|
||||||
|
})
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
events[2],
|
||||||
|
StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent {
|
||||||
|
delta: ContentBlockDelta::TextDelta { .. },
|
||||||
|
..
|
||||||
|
})
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
events[3],
|
||||||
|
StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 })
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
events[4],
|
||||||
|
StreamEvent::MessageDelta(MessageDeltaEvent { .. })
|
||||||
|
));
|
||||||
|
assert!(matches!(events[5], StreamEvent::MessageStop(_)));
|
||||||
|
|
||||||
|
match &events[4] {
|
||||||
|
StreamEvent::MessageDelta(MessageDeltaEvent { usage, .. }) => {
|
||||||
|
assert_eq!(usage.input_tokens, 9);
|
||||||
|
assert_eq!(usage.output_tokens, 4);
|
||||||
|
}
|
||||||
|
other => panic!("expected message delta, got {other:?}"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let captured = state.lock().await;
|
||||||
|
let request = captured.first().expect("captured request");
|
||||||
|
assert_eq!(request.path, "/chat/completions");
|
||||||
|
let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body");
|
||||||
|
assert_eq!(body["stream"], json!(true));
|
||||||
|
assert_eq!(body["stream_options"], json!({"include_usage": true}));
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::await_holding_lock)]
|
#[allow(clippy::await_holding_lock)]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn provider_client_dispatches_xai_requests_from_env() {
|
async fn provider_client_dispatches_xai_requests_from_env() {
|
||||||
|
|||||||
Reference in New Issue
Block a user