diff --git a/crates/rmcp/src/service/server.rs b/crates/rmcp/src/service/server.rs index 666a7998..d011c225 100644 --- a/crates/rmcp/src/service/server.rs +++ b/crates/rmcp/src/service/server.rs @@ -16,7 +16,7 @@ use crate::{ model::{ CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, ClientNotification, ClientRequest, ClientResult, CreateMessageRequest, - CreateMessageRequestParams, CreateMessageResult, ErrorData, ListRootsRequest, + CreateMessageRequestParams, CreateMessageResult, EmptyResult, ErrorData, ListRootsRequest, ListRootsResult, LoggingMessageNotification, LoggingMessageNotificationParam, ProgressNotification, ProgressNotificationParam, PromptListChangedNotification, ProtocolVersion, ResourceListChangedNotification, ResourceUpdatedNotification, @@ -147,22 +147,6 @@ where ))) } -/// Helper function to expect a notification from the stream -async fn expect_notification( - transport: &mut T, - context: &str, -) -> Result -where - T: Transport, -{ - let msg = expect_next_message(transport, context).await?; - let msg_clone = msg.clone(); - msg.into_notification() - .ok_or(ServerInitializeError::ExpectedInitializedNotification( - Some(msg_clone), - )) -} - pub async fn serve_server_with_ct( service: S, transport: T, @@ -246,12 +230,41 @@ where ServerInitializeError::transport::(error, "sending initialize response") })?; - // Wait for initialize notification - let notification = expect_notification(&mut transport, "initialize notification").await?; - let ClientNotification::InitializedNotification(_) = notification else { - return Err(ServerInitializeError::ExpectedInitializedNotification( - Some(ClientJsonRpcMessage::notification(notification)), - )); + // Wait for initialized notification. The MCP spec permits logging/setLevel and ping + // before initialized; VS Code sends setLevel immediately after the initialize response. + let notification = loop { + let msg = expect_next_message(&mut transport, "initialize notification").await?; + match msg { + ClientJsonRpcMessage::Notification(n) + if matches!( + n.notification, + ClientNotification::InitializedNotification(_) + ) => + { + break n.notification; + } + ClientJsonRpcMessage::Request(req) + if matches!( + req.request, + ClientRequest::SetLevelRequest(_) | ClientRequest::PingRequest(_) + ) => + { + transport + .send(ServerJsonRpcMessage::response( + ServerResult::EmptyResult(EmptyResult {}), + req.id, + )) + .await + .map_err(|error| { + ServerInitializeError::transport::(error, "sending pre-init response") + })?; + } + other => { + return Err(ServerInitializeError::ExpectedInitializedNotification( + Some(other), + )); + } + } }; let context = NotificationContext { meta: notification.get_meta().clone(), diff --git a/crates/rmcp/tests/test_server_initialization.rs b/crates/rmcp/tests/test_server_initialization.rs new file mode 100644 index 00000000..c07501f0 --- /dev/null +++ b/crates/rmcp/tests/test_server_initialization.rs @@ -0,0 +1,160 @@ +// cargo test --features "client" --package rmcp -- server_init +#![cfg(feature = "client")] +mod common; + +use common::handlers::TestServer; +use rmcp::{ + ServiceExt, + model::{ClientJsonRpcMessage, ServerJsonRpcMessage, ServerResult}, + service::ServerInitializeError, + transport::{IntoTransport, Transport}, +}; + +fn msg(raw: &str) -> ClientJsonRpcMessage { + serde_json::from_str(raw).expect("invalid test message JSON") +} + +fn init_request() -> ClientJsonRpcMessage { + msg(r#"{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": { "name": "test-client", "version": "0.0.1" } + } + }"#) +} + +fn initialized_notification() -> ClientJsonRpcMessage { + msg(r#"{ "jsonrpc": "2.0", "method": "notifications/initialized" }"#) +} + +fn set_level_request(id: u64) -> ClientJsonRpcMessage { + msg(&format!( + r#"{{ "jsonrpc": "2.0", "id": {id}, "method": "logging/setLevel", "params": {{ "level": "info" }} }}"# + )) +} + +fn ping_request(id: u64) -> ClientJsonRpcMessage { + msg(&format!( + r#"{{ "jsonrpc": "2.0", "id": {id}, "method": "ping" }}"# + )) +} + +fn list_tools_request(id: u64) -> ClientJsonRpcMessage { + msg(&format!( + r#"{{ "jsonrpc": "2.0", "id": {id}, "method": "tools/list" }}"# + )) +} + +async fn do_initialize(client: &mut impl Transport) { + client.send(init_request()).await.unwrap(); + let _response = client.receive().await.unwrap(); +} + +// Server responds with EmptyResult to setLevel received before initialized. +#[tokio::test] +async fn server_init_set_level_response_is_empty_result() { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let _server = tokio::spawn(async move { TestServer::new().serve(server_transport).await }); + let mut client = IntoTransport::::into_transport(client_transport); + + do_initialize(&mut client).await; + client.send(set_level_request(2)).await.unwrap(); + + let response = client.receive().await.unwrap(); + assert!( + matches!( + response, + ServerJsonRpcMessage::Response(ref r) + if matches!(r.result, ServerResult::EmptyResult(_)) + ), + "expected EmptyResult for setLevel, got: {response:?}" + ); +} + +// Server initializes successfully when setLevel is sent before the initialized notification. +#[tokio::test] +async fn server_init_succeeds_after_set_level_before_initialized() { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let server_handle = + tokio::spawn(async move { TestServer::new().serve(server_transport).await }); + let mut client = IntoTransport::::into_transport(client_transport); + + do_initialize(&mut client).await; + client.send(set_level_request(2)).await.unwrap(); + let _response = client.receive().await.unwrap(); + client.send(initialized_notification()).await.unwrap(); + + let result = server_handle.await.unwrap(); + assert!( + result.is_ok(), + "server should initialize successfully after setLevel" + ); + result.unwrap().cancel().await.unwrap(); +} + +// Server responds with EmptyResult to ping received before initialized. +#[tokio::test] +async fn server_init_ping_response_is_empty_result() { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let _server = tokio::spawn(async move { TestServer::new().serve(server_transport).await }); + let mut client = IntoTransport::::into_transport(client_transport); + + do_initialize(&mut client).await; + client.send(ping_request(2)).await.unwrap(); + + let response = client.receive().await.unwrap(); + assert!( + matches!( + response, + ServerJsonRpcMessage::Response(ref r) + if matches!(r.result, ServerResult::EmptyResult(_)) + ), + "expected EmptyResult for ping, got: {response:?}" + ); +} + +// Server initializes successfully when ping is sent before the initialized notification. +#[tokio::test] +async fn server_init_succeeds_after_ping_before_initialized() { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let server_handle = + tokio::spawn(async move { TestServer::new().serve(server_transport).await }); + let mut client = IntoTransport::::into_transport(client_transport); + + do_initialize(&mut client).await; + client.send(ping_request(2)).await.unwrap(); + let _response = client.receive().await.unwrap(); + client.send(initialized_notification()).await.unwrap(); + + let result = server_handle.await.unwrap(); + assert!( + result.is_ok(), + "server should initialize successfully after ping" + ); + result.unwrap().cancel().await.unwrap(); +} + +// Server returns ExpectedInitializedNotification for any other message before initialized. +#[tokio::test] +async fn server_init_rejects_unexpected_message_before_initialized() { + let (server_transport, client_transport) = tokio::io::duplex(4096); + let server_handle = + tokio::spawn(async move { TestServer::new().serve(server_transport).await }); + let mut client = IntoTransport::::into_transport(client_transport); + + do_initialize(&mut client).await; + client.send(list_tools_request(2)).await.unwrap(); + + let result = server_handle.await.unwrap(); + assert!( + matches!( + result, + Err(ServerInitializeError::ExpectedInitializedNotification(_)) + ), + "expected ExpectedInitializedNotification error" + ); +}