mirror of
https://github.com/instructkr/claude-code.git
synced 2026-04-03 10:28:51 +03:00
feat: HTTP/SSE server crate with axum (session management, event streaming)
This commit is contained in:
163
rust/Cargo.lock
generated
163
rust/Cargo.lock
generated
@@ -28,12 +28,86 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-stream"
|
||||||
|
version = "0.3.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
|
||||||
|
dependencies = [
|
||||||
|
"async-stream-impl",
|
||||||
|
"futures-core",
|
||||||
|
"pin-project-lite",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-stream-impl"
|
||||||
|
version = "0.3.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "atomic-waker"
|
name = "atomic-waker"
|
||||||
version = "1.1.2"
|
version = "1.1.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum"
|
||||||
|
version = "0.8.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
|
||||||
|
dependencies = [
|
||||||
|
"axum-core",
|
||||||
|
"bytes",
|
||||||
|
"form_urlencoded",
|
||||||
|
"futures-util",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"http-body-util",
|
||||||
|
"hyper",
|
||||||
|
"hyper-util",
|
||||||
|
"itoa",
|
||||||
|
"matchit",
|
||||||
|
"memchr",
|
||||||
|
"mime",
|
||||||
|
"percent-encoding",
|
||||||
|
"pin-project-lite",
|
||||||
|
"serde_core",
|
||||||
|
"serde_json",
|
||||||
|
"serde_path_to_error",
|
||||||
|
"serde_urlencoded",
|
||||||
|
"sync_wrapper",
|
||||||
|
"tokio",
|
||||||
|
"tower",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum-core"
|
||||||
|
version = "0.5.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"futures-core",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"http-body-util",
|
||||||
|
"mime",
|
||||||
|
"pin-project-lite",
|
||||||
|
"sync_wrapper",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "base64"
|
name = "base64"
|
||||||
version = "0.22.1"
|
version = "0.22.1"
|
||||||
@@ -262,7 +336,7 @@ checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"rustix 1.1.4",
|
"rustix 1.1.4",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -318,6 +392,17 @@ version = "0.3.32"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
|
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-macro"
|
||||||
|
version = "0.3.32"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-sink"
|
name = "futures-sink"
|
||||||
version = "0.3.32"
|
version = "0.3.32"
|
||||||
@@ -338,6 +423,7 @@ checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"futures-core",
|
"futures-core",
|
||||||
"futures-io",
|
"futures-io",
|
||||||
|
"futures-macro",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
"memchr",
|
"memchr",
|
||||||
@@ -451,6 +537,12 @@ version = "1.10.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "httpdate"
|
||||||
|
version = "1.0.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hyper"
|
name = "hyper"
|
||||||
version = "1.9.0"
|
version = "1.9.0"
|
||||||
@@ -464,6 +556,7 @@ dependencies = [
|
|||||||
"http",
|
"http",
|
||||||
"http-body",
|
"http-body",
|
||||||
"httparse",
|
"httparse",
|
||||||
|
"httpdate",
|
||||||
"itoa",
|
"itoa",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
@@ -708,12 +801,24 @@ version = "0.1.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
|
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matchit"
|
||||||
|
version = "0.8.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "memchr"
|
name = "memchr"
|
||||||
version = "2.8.0"
|
version = "2.8.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
|
checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mime"
|
||||||
|
version = "0.3.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "miniz_oxide"
|
name = "miniz_oxide"
|
||||||
version = "0.8.9"
|
version = "0.8.9"
|
||||||
@@ -1091,12 +1196,14 @@ dependencies = [
|
|||||||
"sync_wrapper",
|
"sync_wrapper",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
|
"tokio-util",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"url",
|
"url",
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
"wasm-bindgen-futures",
|
"wasm-bindgen-futures",
|
||||||
|
"wasm-streams",
|
||||||
"web-sys",
|
"web-sys",
|
||||||
"webpki-roots",
|
"webpki-roots",
|
||||||
]
|
]
|
||||||
@@ -1145,7 +1252,7 @@ dependencies = [
|
|||||||
"errno",
|
"errno",
|
||||||
"libc",
|
"libc",
|
||||||
"linux-raw-sys 0.4.15",
|
"linux-raw-sys 0.4.15",
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1288,6 +1395,17 @@ dependencies = [
|
|||||||
"zmij",
|
"zmij",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_path_to_error"
|
||||||
|
version = "0.1.20"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457"
|
||||||
|
dependencies = [
|
||||||
|
"itoa",
|
||||||
|
"serde",
|
||||||
|
"serde_core",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "serde_urlencoded"
|
name = "serde_urlencoded"
|
||||||
version = "0.7.1"
|
version = "0.7.1"
|
||||||
@@ -1300,6 +1418,19 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "server"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"async-stream",
|
||||||
|
"axum",
|
||||||
|
"reqwest",
|
||||||
|
"runtime",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sha2"
|
name = "sha2"
|
||||||
version = "0.10.9"
|
version = "0.10.9"
|
||||||
@@ -1553,6 +1684,19 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-util"
|
||||||
|
version = "0.7.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"futures-core",
|
||||||
|
"futures-sink",
|
||||||
|
"pin-project-lite",
|
||||||
|
"tokio",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tools"
|
name = "tools"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
@@ -1579,6 +1723,7 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1617,6 +1762,7 @@ version = "0.1.44"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
|
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"log",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"tracing-core",
|
"tracing-core",
|
||||||
]
|
]
|
||||||
@@ -1791,6 +1937,19 @@ dependencies = [
|
|||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasm-streams"
|
||||||
|
version = "0.4.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65"
|
||||||
|
dependencies = [
|
||||||
|
"futures-util",
|
||||||
|
"js-sys",
|
||||||
|
"wasm-bindgen",
|
||||||
|
"wasm-bindgen-futures",
|
||||||
|
"web-sys",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "web-sys"
|
name = "web-sys"
|
||||||
version = "0.3.93"
|
version = "0.3.93"
|
||||||
|
|||||||
20
rust/crates/server/Cargo.toml
Normal file
20
rust/crates/server/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
[package]
|
||||||
|
name = "server"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
publish.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
async-stream = "0.3"
|
||||||
|
axum = "0.8"
|
||||||
|
runtime = { path = "../runtime" }
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json.workspace = true
|
||||||
|
tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync", "net", "time"] }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "stream"] }
|
||||||
|
|
||||||
|
[lints]
|
||||||
|
workspace = true
|
||||||
442
rust/crates/server/src/lib.rs
Normal file
442
rust/crates/server/src/lib.rs
Normal file
@@ -0,0 +1,442 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::convert::Infallible;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||||
|
|
||||||
|
use async_stream::stream;
|
||||||
|
use axum::extract::{Path, State};
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::response::sse::{Event, KeepAlive, Sse};
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use axum::routing::{get, post};
|
||||||
|
use axum::{Json, Router};
|
||||||
|
use runtime::{ConversationMessage, Session as RuntimeSession};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tokio::sync::{broadcast, RwLock};
|
||||||
|
|
||||||
|
pub type SessionId = String;
|
||||||
|
pub type SessionStore = Arc<RwLock<HashMap<SessionId, Session>>>;
|
||||||
|
|
||||||
|
const BROADCAST_CAPACITY: usize = 64;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppState {
|
||||||
|
sessions: SessionStore,
|
||||||
|
next_session_id: Arc<AtomicU64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppState {
|
||||||
|
#[must_use]
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sessions: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
next_session_id: Arc::new(AtomicU64::new(1)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn allocate_session_id(&self) -> SessionId {
|
||||||
|
let id = self.next_session_id.fetch_add(1, Ordering::Relaxed);
|
||||||
|
format!("session-{id}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for AppState {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Session {
|
||||||
|
pub id: SessionId,
|
||||||
|
pub created_at: u64,
|
||||||
|
pub conversation: RuntimeSession,
|
||||||
|
events: broadcast::Sender<SessionEvent>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Session {
|
||||||
|
fn new(id: SessionId) -> Self {
|
||||||
|
let (events, _) = broadcast::channel(BROADCAST_CAPACITY);
|
||||||
|
Self {
|
||||||
|
id,
|
||||||
|
created_at: unix_timestamp_millis(),
|
||||||
|
conversation: RuntimeSession::new(),
|
||||||
|
events,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
|
||||||
|
self.events.subscribe()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
enum SessionEvent {
|
||||||
|
Snapshot {
|
||||||
|
session_id: SessionId,
|
||||||
|
session: RuntimeSession,
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
session_id: SessionId,
|
||||||
|
message: ConversationMessage,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SessionEvent {
|
||||||
|
fn event_name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Self::Snapshot { .. } => "snapshot",
|
||||||
|
Self::Message { .. } => "message",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_sse_event(&self) -> Result<Event, serde_json::Error> {
|
||||||
|
Ok(Event::default()
|
||||||
|
.event(self.event_name())
|
||||||
|
.data(serde_json::to_string(self)?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct ErrorResponse {
|
||||||
|
error: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
type ApiError = (StatusCode, Json<ErrorResponse>);
|
||||||
|
type ApiResult<T> = Result<T, ApiError>;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub struct CreateSessionResponse {
|
||||||
|
pub session_id: SessionId,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub struct SessionSummary {
|
||||||
|
pub id: SessionId,
|
||||||
|
pub created_at: u64,
|
||||||
|
pub message_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub struct ListSessionsResponse {
|
||||||
|
pub sessions: Vec<SessionSummary>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub struct SessionDetailsResponse {
|
||||||
|
pub id: SessionId,
|
||||||
|
pub created_at: u64,
|
||||||
|
pub session: RuntimeSession,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub struct SendMessageRequest {
|
||||||
|
pub message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn app(state: AppState) -> Router {
|
||||||
|
Router::new()
|
||||||
|
.route("/sessions", post(create_session).get(list_sessions))
|
||||||
|
.route("/sessions/{id}", get(get_session))
|
||||||
|
.route("/sessions/{id}/events", get(stream_session_events))
|
||||||
|
.route("/sessions/{id}/message", post(send_message))
|
||||||
|
.with_state(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_session(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
) -> (StatusCode, Json<CreateSessionResponse>) {
|
||||||
|
let session_id = state.allocate_session_id();
|
||||||
|
let session = Session::new(session_id.clone());
|
||||||
|
|
||||||
|
state
|
||||||
|
.sessions
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(session_id.clone(), session);
|
||||||
|
|
||||||
|
(
|
||||||
|
StatusCode::CREATED,
|
||||||
|
Json(CreateSessionResponse { session_id }),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_sessions(State(state): State<AppState>) -> Json<ListSessionsResponse> {
|
||||||
|
let sessions = state.sessions.read().await;
|
||||||
|
let mut summaries = sessions
|
||||||
|
.values()
|
||||||
|
.map(|session| SessionSummary {
|
||||||
|
id: session.id.clone(),
|
||||||
|
created_at: session.created_at,
|
||||||
|
message_count: session.conversation.messages.len(),
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
summaries.sort_by(|left, right| left.id.cmp(&right.id));
|
||||||
|
|
||||||
|
Json(ListSessionsResponse {
|
||||||
|
sessions: summaries,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_session(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<SessionId>,
|
||||||
|
) -> ApiResult<Json<SessionDetailsResponse>> {
|
||||||
|
let sessions = state.sessions.read().await;
|
||||||
|
let session = sessions
|
||||||
|
.get(&id)
|
||||||
|
.ok_or_else(|| not_found(format!("session `{id}` not found")))?;
|
||||||
|
|
||||||
|
Ok(Json(SessionDetailsResponse {
|
||||||
|
id: session.id.clone(),
|
||||||
|
created_at: session.created_at,
|
||||||
|
session: session.conversation.clone(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_message(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<SessionId>,
|
||||||
|
Json(payload): Json<SendMessageRequest>,
|
||||||
|
) -> ApiResult<StatusCode> {
|
||||||
|
let message = ConversationMessage::user_text(payload.message);
|
||||||
|
let broadcaster = {
|
||||||
|
let mut sessions = state.sessions.write().await;
|
||||||
|
let session = sessions
|
||||||
|
.get_mut(&id)
|
||||||
|
.ok_or_else(|| not_found(format!("session `{id}` not found")))?;
|
||||||
|
session.conversation.messages.push(message.clone());
|
||||||
|
session.events.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let _ = broadcaster.send(SessionEvent::Message {
|
||||||
|
session_id: id,
|
||||||
|
message,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(StatusCode::NO_CONTENT)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn stream_session_events(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<SessionId>,
|
||||||
|
) -> ApiResult<impl IntoResponse> {
|
||||||
|
let (snapshot, mut receiver) = {
|
||||||
|
let sessions = state.sessions.read().await;
|
||||||
|
let session = sessions
|
||||||
|
.get(&id)
|
||||||
|
.ok_or_else(|| not_found(format!("session `{id}` not found")))?;
|
||||||
|
(
|
||||||
|
SessionEvent::Snapshot {
|
||||||
|
session_id: session.id.clone(),
|
||||||
|
session: session.conversation.clone(),
|
||||||
|
},
|
||||||
|
session.subscribe(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream = stream! {
|
||||||
|
if let Ok(event) = snapshot.to_sse_event() {
|
||||||
|
yield Ok::<Event, Infallible>(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match receiver.recv().await {
|
||||||
|
Ok(event) => {
|
||||||
|
if let Ok(sse_event) = event.to_sse_event() {
|
||||||
|
yield Ok::<Event, Infallible>(sse_event);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(broadcast::error::RecvError::Lagged(_)) => continue,
|
||||||
|
Err(broadcast::error::RecvError::Closed) => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unix_timestamp_millis() -> u64 {
|
||||||
|
SystemTime::now()
|
||||||
|
.duration_since(UNIX_EPOCH)
|
||||||
|
.expect("system time should be after epoch")
|
||||||
|
.as_millis() as u64
|
||||||
|
}
|
||||||
|
|
||||||
|
fn not_found(message: String) -> ApiError {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(ErrorResponse { error: message }),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{
|
||||||
|
app, AppState, CreateSessionResponse, ListSessionsResponse, SessionDetailsResponse,
|
||||||
|
};
|
||||||
|
use reqwest::Client;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
struct TestServer {
|
||||||
|
address: SocketAddr,
|
||||||
|
handle: JoinHandle<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestServer {
|
||||||
|
async fn spawn() -> Self {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0")
|
||||||
|
.await
|
||||||
|
.expect("test listener should bind");
|
||||||
|
let address = listener
|
||||||
|
.local_addr()
|
||||||
|
.expect("listener should report local address");
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
axum::serve(listener, app(AppState::default()))
|
||||||
|
.await
|
||||||
|
.expect("server should run");
|
||||||
|
});
|
||||||
|
|
||||||
|
Self { address, handle }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn url(&self, path: &str) -> String {
|
||||||
|
format!("http://{}{}", self.address, path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for TestServer {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.handle.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn create_session(client: &Client, server: &TestServer) -> CreateSessionResponse {
|
||||||
|
client
|
||||||
|
.post(server.url("/sessions"))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("create request should succeed")
|
||||||
|
.error_for_status()
|
||||||
|
.expect("create request should return success")
|
||||||
|
.json::<CreateSessionResponse>()
|
||||||
|
.await
|
||||||
|
.expect("create response should parse")
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn next_sse_frame(response: &mut reqwest::Response, buffer: &mut String) -> String {
|
||||||
|
loop {
|
||||||
|
if let Some(index) = buffer.find("\n\n") {
|
||||||
|
let frame = buffer[..index].to_string();
|
||||||
|
let remainder = buffer[index + 2..].to_string();
|
||||||
|
*buffer = remainder;
|
||||||
|
return frame;
|
||||||
|
}
|
||||||
|
|
||||||
|
let next_chunk = timeout(Duration::from_secs(5), response.chunk())
|
||||||
|
.await
|
||||||
|
.expect("SSE stream should yield within timeout")
|
||||||
|
.expect("SSE stream should remain readable")
|
||||||
|
.expect("SSE stream should stay open");
|
||||||
|
buffer.push_str(&String::from_utf8_lossy(&next_chunk));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn creates_and_lists_sessions() {
|
||||||
|
let server = TestServer::spawn().await;
|
||||||
|
let client = Client::new();
|
||||||
|
|
||||||
|
// given
|
||||||
|
let created = create_session(&client, &server).await;
|
||||||
|
|
||||||
|
// when
|
||||||
|
let sessions = client
|
||||||
|
.get(server.url("/sessions"))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("list request should succeed")
|
||||||
|
.error_for_status()
|
||||||
|
.expect("list request should return success")
|
||||||
|
.json::<ListSessionsResponse>()
|
||||||
|
.await
|
||||||
|
.expect("list response should parse");
|
||||||
|
let details = client
|
||||||
|
.get(server.url(&format!("/sessions/{}", created.session_id)))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("details request should succeed")
|
||||||
|
.error_for_status()
|
||||||
|
.expect("details request should return success")
|
||||||
|
.json::<SessionDetailsResponse>()
|
||||||
|
.await
|
||||||
|
.expect("details response should parse");
|
||||||
|
|
||||||
|
// then
|
||||||
|
assert_eq!(created.session_id, "session-1");
|
||||||
|
assert_eq!(sessions.sessions.len(), 1);
|
||||||
|
assert_eq!(sessions.sessions[0].id, created.session_id);
|
||||||
|
assert_eq!(sessions.sessions[0].message_count, 0);
|
||||||
|
assert_eq!(details.id, "session-1");
|
||||||
|
assert!(details.session.messages.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn streams_message_events_and_persists_message_flow() {
|
||||||
|
let server = TestServer::spawn().await;
|
||||||
|
let client = Client::new();
|
||||||
|
|
||||||
|
// given
|
||||||
|
let created = create_session(&client, &server).await;
|
||||||
|
let mut response = client
|
||||||
|
.get(server.url(&format!("/sessions/{}/events", created.session_id)))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("events request should succeed")
|
||||||
|
.error_for_status()
|
||||||
|
.expect("events request should return success");
|
||||||
|
let mut buffer = String::new();
|
||||||
|
let snapshot_frame = next_sse_frame(&mut response, &mut buffer).await;
|
||||||
|
|
||||||
|
// when
|
||||||
|
let send_status = client
|
||||||
|
.post(server.url(&format!("/sessions/{}/message", created.session_id)))
|
||||||
|
.json(&super::SendMessageRequest {
|
||||||
|
message: "hello from test".to_string(),
|
||||||
|
})
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("message request should succeed")
|
||||||
|
.status();
|
||||||
|
let message_frame = next_sse_frame(&mut response, &mut buffer).await;
|
||||||
|
let details = client
|
||||||
|
.get(server.url(&format!("/sessions/{}", created.session_id)))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.expect("details request should succeed")
|
||||||
|
.error_for_status()
|
||||||
|
.expect("details request should return success")
|
||||||
|
.json::<SessionDetailsResponse>()
|
||||||
|
.await
|
||||||
|
.expect("details response should parse");
|
||||||
|
|
||||||
|
// then
|
||||||
|
assert_eq!(send_status, reqwest::StatusCode::NO_CONTENT);
|
||||||
|
assert!(snapshot_frame.contains("event: snapshot"));
|
||||||
|
assert!(snapshot_frame.contains("\"session_id\":\"session-1\""));
|
||||||
|
assert!(message_frame.contains("event: message"));
|
||||||
|
assert!(message_frame.contains("hello from test"));
|
||||||
|
assert_eq!(details.session.messages.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
details.session.messages[0],
|
||||||
|
runtime::ConversationMessage::user_text("hello from test")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user