mirror of
https://github.com/instructkr/claude-code.git
synced 2026-04-03 08:18:48 +03:00
feat: HTTP/SSE server crate with axum (session management, event streaming)
This commit is contained in:
163
rust/Cargo.lock
generated
163
rust/Cargo.lock
generated
@@ -28,12 +28,86 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
|
||||
dependencies = [
|
||||
"async-stream-impl",
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream-impl"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
||||
|
||||
[[package]]
|
||||
name = "axum"
|
||||
version = "0.8.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
|
||||
dependencies = [
|
||||
"axum-core",
|
||||
"bytes",
|
||||
"form_urlencoded",
|
||||
"futures-util",
|
||||
"http",
|
||||
"http-body",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"itoa",
|
||||
"matchit",
|
||||
"memchr",
|
||||
"mime",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"serde_core",
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-core"
|
||||
version = "0.5.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"http",
|
||||
"http-body",
|
||||
"http-body-util",
|
||||
"mime",
|
||||
"pin-project-lite",
|
||||
"sync_wrapper",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.22.1"
|
||||
@@ -262,7 +336,7 @@ checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"rustix 1.1.4",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -318,6 +392,17 @@ version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.32"
|
||||
@@ -338,6 +423,7 @@ checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"memchr",
|
||||
@@ -451,6 +537,12 @@ version = "1.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||
|
||||
[[package]]
|
||||
name = "httpdate"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.9.0"
|
||||
@@ -464,6 +556,7 @@ dependencies = [
|
||||
"http",
|
||||
"http-body",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"smallvec",
|
||||
@@ -708,12 +801,24 @@ version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
|
||||
|
||||
[[package]]
|
||||
name = "mime"
|
||||
version = "0.3.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.8.9"
|
||||
@@ -1091,12 +1196,14 @@ dependencies = [
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"tower",
|
||||
"tower-http",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"wasm-streams",
|
||||
"web-sys",
|
||||
"webpki-roots",
|
||||
]
|
||||
@@ -1145,7 +1252,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.4.15",
|
||||
"windows-sys 0.52.0",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1288,6 +1395,17 @@ dependencies = [
|
||||
"zmij",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_path_to_error"
|
||||
version = "0.1.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"serde",
|
||||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_urlencoded"
|
||||
version = "0.7.1"
|
||||
@@ -1300,6 +1418,19 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "server"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"axum",
|
||||
"reqwest",
|
||||
"runtime",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha2"
|
||||
version = "0.10.9"
|
||||
@@ -1553,6 +1684,19 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tools"
|
||||
version = "0.1.0"
|
||||
@@ -1579,6 +1723,7 @@ dependencies = [
|
||||
"tokio",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1617,6 +1762,7 @@ version = "0.1.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
|
||||
dependencies = [
|
||||
"log",
|
||||
"pin-project-lite",
|
||||
"tracing-core",
|
||||
]
|
||||
@@ -1791,6 +1937,19 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-streams"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "web-sys"
|
||||
version = "0.3.93"
|
||||
|
||||
20
rust/crates/server/Cargo.toml
Normal file
20
rust/crates/server/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "server"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
publish.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-stream = "0.3"
|
||||
axum = "0.8"
|
||||
runtime = { path = "../runtime" }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json.workspace = true
|
||||
tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync", "net", "time"] }
|
||||
|
||||
[dev-dependencies]
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "stream"] }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
442
rust/crates/server/src/lib.rs
Normal file
442
rust/crates/server/src/lib.rs
Normal file
@@ -0,0 +1,442 @@
|
||||
use std::collections::HashMap;
|
||||
use std::convert::Infallible;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use async_stream::stream;
|
||||
use axum::extract::{Path, State};
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::{get, post};
|
||||
use axum::{Json, Router};
|
||||
use runtime::{ConversationMessage, Session as RuntimeSession};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::{broadcast, RwLock};
|
||||
|
||||
pub type SessionId = String;
|
||||
pub type SessionStore = Arc<RwLock<HashMap<SessionId, Session>>>;
|
||||
|
||||
const BROADCAST_CAPACITY: usize = 64;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
sessions: SessionStore,
|
||||
next_session_id: Arc<AtomicU64>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
#[must_use]
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sessions: Arc::new(RwLock::new(HashMap::new())),
|
||||
next_session_id: Arc::new(AtomicU64::new(1)),
|
||||
}
|
||||
}
|
||||
|
||||
fn allocate_session_id(&self) -> SessionId {
|
||||
let id = self.next_session_id.fetch_add(1, Ordering::Relaxed);
|
||||
format!("session-{id}")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Session {
|
||||
pub id: SessionId,
|
||||
pub created_at: u64,
|
||||
pub conversation: RuntimeSession,
|
||||
events: broadcast::Sender<SessionEvent>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
fn new(id: SessionId) -> Self {
|
||||
let (events, _) = broadcast::channel(BROADCAST_CAPACITY);
|
||||
Self {
|
||||
id,
|
||||
created_at: unix_timestamp_millis(),
|
||||
conversation: RuntimeSession::new(),
|
||||
events,
|
||||
}
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
|
||||
self.events.subscribe()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum SessionEvent {
|
||||
Snapshot {
|
||||
session_id: SessionId,
|
||||
session: RuntimeSession,
|
||||
},
|
||||
Message {
|
||||
session_id: SessionId,
|
||||
message: ConversationMessage,
|
||||
},
|
||||
}
|
||||
|
||||
impl SessionEvent {
|
||||
fn event_name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Snapshot { .. } => "snapshot",
|
||||
Self::Message { .. } => "message",
|
||||
}
|
||||
}
|
||||
|
||||
fn to_sse_event(&self) -> Result<Event, serde_json::Error> {
|
||||
Ok(Event::default()
|
||||
.event(self.event_name())
|
||||
.data(serde_json::to_string(self)?))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ErrorResponse {
|
||||
error: String,
|
||||
}
|
||||
|
||||
type ApiError = (StatusCode, Json<ErrorResponse>);
|
||||
type ApiResult<T> = Result<T, ApiError>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct CreateSessionResponse {
|
||||
pub session_id: SessionId,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct SessionSummary {
|
||||
pub id: SessionId,
|
||||
pub created_at: u64,
|
||||
pub message_count: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct ListSessionsResponse {
|
||||
pub sessions: Vec<SessionSummary>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct SessionDetailsResponse {
|
||||
pub id: SessionId,
|
||||
pub created_at: u64,
|
||||
pub session: RuntimeSession,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct SendMessageRequest {
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn app(state: AppState) -> Router {
|
||||
Router::new()
|
||||
.route("/sessions", post(create_session).get(list_sessions))
|
||||
.route("/sessions/{id}", get(get_session))
|
||||
.route("/sessions/{id}/events", get(stream_session_events))
|
||||
.route("/sessions/{id}/message", post(send_message))
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
async fn create_session(
|
||||
State(state): State<AppState>,
|
||||
) -> (StatusCode, Json<CreateSessionResponse>) {
|
||||
let session_id = state.allocate_session_id();
|
||||
let session = Session::new(session_id.clone());
|
||||
|
||||
state
|
||||
.sessions
|
||||
.write()
|
||||
.await
|
||||
.insert(session_id.clone(), session);
|
||||
|
||||
(
|
||||
StatusCode::CREATED,
|
||||
Json(CreateSessionResponse { session_id }),
|
||||
)
|
||||
}
|
||||
|
||||
async fn list_sessions(State(state): State<AppState>) -> Json<ListSessionsResponse> {
|
||||
let sessions = state.sessions.read().await;
|
||||
let mut summaries = sessions
|
||||
.values()
|
||||
.map(|session| SessionSummary {
|
||||
id: session.id.clone(),
|
||||
created_at: session.created_at,
|
||||
message_count: session.conversation.messages.len(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
summaries.sort_by(|left, right| left.id.cmp(&right.id));
|
||||
|
||||
Json(ListSessionsResponse {
|
||||
sessions: summaries,
|
||||
})
|
||||
}
|
||||
|
||||
async fn get_session(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<SessionId>,
|
||||
) -> ApiResult<Json<SessionDetailsResponse>> {
|
||||
let sessions = state.sessions.read().await;
|
||||
let session = sessions
|
||||
.get(&id)
|
||||
.ok_or_else(|| not_found(format!("session `{id}` not found")))?;
|
||||
|
||||
Ok(Json(SessionDetailsResponse {
|
||||
id: session.id.clone(),
|
||||
created_at: session.created_at,
|
||||
session: session.conversation.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn send_message(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<SessionId>,
|
||||
Json(payload): Json<SendMessageRequest>,
|
||||
) -> ApiResult<StatusCode> {
|
||||
let message = ConversationMessage::user_text(payload.message);
|
||||
let broadcaster = {
|
||||
let mut sessions = state.sessions.write().await;
|
||||
let session = sessions
|
||||
.get_mut(&id)
|
||||
.ok_or_else(|| not_found(format!("session `{id}` not found")))?;
|
||||
session.conversation.messages.push(message.clone());
|
||||
session.events.clone()
|
||||
};
|
||||
|
||||
let _ = broadcaster.send(SessionEvent::Message {
|
||||
session_id: id,
|
||||
message,
|
||||
});
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
async fn stream_session_events(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<SessionId>,
|
||||
) -> ApiResult<impl IntoResponse> {
|
||||
let (snapshot, mut receiver) = {
|
||||
let sessions = state.sessions.read().await;
|
||||
let session = sessions
|
||||
.get(&id)
|
||||
.ok_or_else(|| not_found(format!("session `{id}` not found")))?;
|
||||
(
|
||||
SessionEvent::Snapshot {
|
||||
session_id: session.id.clone(),
|
||||
session: session.conversation.clone(),
|
||||
},
|
||||
session.subscribe(),
|
||||
)
|
||||
};
|
||||
|
||||
let stream = stream! {
|
||||
if let Ok(event) = snapshot.to_sse_event() {
|
||||
yield Ok::<Event, Infallible>(event);
|
||||
}
|
||||
|
||||
loop {
|
||||
match receiver.recv().await {
|
||||
Ok(event) => {
|
||||
if let Ok(sse_event) = event.to_sse_event() {
|
||||
yield Ok::<Event, Infallible>(sse_event);
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(_)) => continue,
|
||||
Err(broadcast::error::RecvError::Closed) => break,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))))
|
||||
}
|
||||
|
||||
fn unix_timestamp_millis() -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("system time should be after epoch")
|
||||
.as_millis() as u64
|
||||
}
|
||||
|
||||
fn not_found(message: String) -> ApiError {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(ErrorResponse { error: message }),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
app, AppState, CreateSessionResponse, ListSessionsResponse, SessionDetailsResponse,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::timeout;
|
||||
|
||||
struct TestServer {
|
||||
address: SocketAddr,
|
||||
handle: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl TestServer {
|
||||
async fn spawn() -> Self {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("test listener should bind");
|
||||
let address = listener
|
||||
.local_addr()
|
||||
.expect("listener should report local address");
|
||||
let handle = tokio::spawn(async move {
|
||||
axum::serve(listener, app(AppState::default()))
|
||||
.await
|
||||
.expect("server should run");
|
||||
});
|
||||
|
||||
Self { address, handle }
|
||||
}
|
||||
|
||||
fn url(&self, path: &str) -> String {
|
||||
format!("http://{}{}", self.address, path)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TestServer {
|
||||
fn drop(&mut self) {
|
||||
self.handle.abort();
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_session(client: &Client, server: &TestServer) -> CreateSessionResponse {
|
||||
client
|
||||
.post(server.url("/sessions"))
|
||||
.send()
|
||||
.await
|
||||
.expect("create request should succeed")
|
||||
.error_for_status()
|
||||
.expect("create request should return success")
|
||||
.json::<CreateSessionResponse>()
|
||||
.await
|
||||
.expect("create response should parse")
|
||||
}
|
||||
|
||||
async fn next_sse_frame(response: &mut reqwest::Response, buffer: &mut String) -> String {
|
||||
loop {
|
||||
if let Some(index) = buffer.find("\n\n") {
|
||||
let frame = buffer[..index].to_string();
|
||||
let remainder = buffer[index + 2..].to_string();
|
||||
*buffer = remainder;
|
||||
return frame;
|
||||
}
|
||||
|
||||
let next_chunk = timeout(Duration::from_secs(5), response.chunk())
|
||||
.await
|
||||
.expect("SSE stream should yield within timeout")
|
||||
.expect("SSE stream should remain readable")
|
||||
.expect("SSE stream should stay open");
|
||||
buffer.push_str(&String::from_utf8_lossy(&next_chunk));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn creates_and_lists_sessions() {
|
||||
let server = TestServer::spawn().await;
|
||||
let client = Client::new();
|
||||
|
||||
// given
|
||||
let created = create_session(&client, &server).await;
|
||||
|
||||
// when
|
||||
let sessions = client
|
||||
.get(server.url("/sessions"))
|
||||
.send()
|
||||
.await
|
||||
.expect("list request should succeed")
|
||||
.error_for_status()
|
||||
.expect("list request should return success")
|
||||
.json::<ListSessionsResponse>()
|
||||
.await
|
||||
.expect("list response should parse");
|
||||
let details = client
|
||||
.get(server.url(&format!("/sessions/{}", created.session_id)))
|
||||
.send()
|
||||
.await
|
||||
.expect("details request should succeed")
|
||||
.error_for_status()
|
||||
.expect("details request should return success")
|
||||
.json::<SessionDetailsResponse>()
|
||||
.await
|
||||
.expect("details response should parse");
|
||||
|
||||
// then
|
||||
assert_eq!(created.session_id, "session-1");
|
||||
assert_eq!(sessions.sessions.len(), 1);
|
||||
assert_eq!(sessions.sessions[0].id, created.session_id);
|
||||
assert_eq!(sessions.sessions[0].message_count, 0);
|
||||
assert_eq!(details.id, "session-1");
|
||||
assert!(details.session.messages.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn streams_message_events_and_persists_message_flow() {
|
||||
let server = TestServer::spawn().await;
|
||||
let client = Client::new();
|
||||
|
||||
// given
|
||||
let created = create_session(&client, &server).await;
|
||||
let mut response = client
|
||||
.get(server.url(&format!("/sessions/{}/events", created.session_id)))
|
||||
.send()
|
||||
.await
|
||||
.expect("events request should succeed")
|
||||
.error_for_status()
|
||||
.expect("events request should return success");
|
||||
let mut buffer = String::new();
|
||||
let snapshot_frame = next_sse_frame(&mut response, &mut buffer).await;
|
||||
|
||||
// when
|
||||
let send_status = client
|
||||
.post(server.url(&format!("/sessions/{}/message", created.session_id)))
|
||||
.json(&super::SendMessageRequest {
|
||||
message: "hello from test".to_string(),
|
||||
})
|
||||
.send()
|
||||
.await
|
||||
.expect("message request should succeed")
|
||||
.status();
|
||||
let message_frame = next_sse_frame(&mut response, &mut buffer).await;
|
||||
let details = client
|
||||
.get(server.url(&format!("/sessions/{}", created.session_id)))
|
||||
.send()
|
||||
.await
|
||||
.expect("details request should succeed")
|
||||
.error_for_status()
|
||||
.expect("details request should return success")
|
||||
.json::<SessionDetailsResponse>()
|
||||
.await
|
||||
.expect("details response should parse");
|
||||
|
||||
// then
|
||||
assert_eq!(send_status, reqwest::StatusCode::NO_CONTENT);
|
||||
assert!(snapshot_frame.contains("event: snapshot"));
|
||||
assert!(snapshot_frame.contains("\"session_id\":\"session-1\""));
|
||||
assert!(message_frame.contains("event: message"));
|
||||
assert!(message_frame.contains("hello from test"));
|
||||
assert_eq!(details.session.messages.len(), 1);
|
||||
assert_eq!(
|
||||
details.session.messages[0],
|
||||
runtime::ConversationMessage::user_text("hello from test")
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user