Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions crates/rmcp/src/service/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
model::{
CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage,
ClientNotification, ClientRequest, ClientResult, CreateMessageRequest,
CreateMessageRequestParams, CreateMessageResult, ErrorData, ListRootsRequest,
CreateMessageRequestParams, CreateMessageResult, EmptyResult, ErrorData, ListRootsRequest,
ListRootsResult, LoggingMessageNotification, LoggingMessageNotificationParam,
ProgressNotification, ProgressNotificationParam, PromptListChangedNotification,
ProtocolVersion, ResourceListChangedNotification, ResourceUpdatedNotification,
Expand Down Expand Up @@ -147,22 +147,6 @@ where
)))
}

/// Helper function to expect a notification from the stream
async fn expect_notification<T>(
transport: &mut T,
context: &str,
) -> Result<ClientNotification, ServerInitializeError>
where
T: Transport<RoleServer>,
{
let msg = expect_next_message(transport, context).await?;
let msg_clone = msg.clone();
msg.into_notification()
.ok_or(ServerInitializeError::ExpectedInitializedNotification(
Some(msg_clone),
))
}

pub async fn serve_server_with_ct<S, T, E, A>(
service: S,
transport: T,
Expand Down Expand Up @@ -246,12 +230,41 @@ where
ServerInitializeError::transport::<T>(error, "sending initialize response")
})?;

// Wait for initialize notification
let notification = expect_notification(&mut transport, "initialize notification").await?;
let ClientNotification::InitializedNotification(_) = notification else {
return Err(ServerInitializeError::ExpectedInitializedNotification(
Some(ClientJsonRpcMessage::notification(notification)),
));
// Wait for initialized notification. The MCP spec permits logging/setLevel and ping
// before initialized; VS Code sends setLevel immediately after the initialize response.
let notification = loop {
let msg = expect_next_message(&mut transport, "initialize notification").await?;
match msg {
ClientJsonRpcMessage::Notification(n)
if matches!(
n.notification,
ClientNotification::InitializedNotification(_)
) =>
{
break n.notification;
}
ClientJsonRpcMessage::Request(req)
if matches!(
req.request,
ClientRequest::SetLevelRequest(_) | ClientRequest::PingRequest(_)
) =>
{
transport
.send(ServerJsonRpcMessage::response(
ServerResult::EmptyResult(EmptyResult {}),
req.id,
))
.await
.map_err(|error| {
ServerInitializeError::transport::<T>(error, "sending pre-init response")
})?;
}
other => {
return Err(ServerInitializeError::ExpectedInitializedNotification(
Some(other),
));
}
}
};
let context = NotificationContext {
meta: notification.get_meta().clone(),
Expand Down
160 changes: 160 additions & 0 deletions crates/rmcp/tests/test_server_initialization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// cargo test --features "client" --package rmcp -- server_init
#![cfg(feature = "client")]
mod common;

use common::handlers::TestServer;
use rmcp::{
ServiceExt,
model::{ClientJsonRpcMessage, ServerJsonRpcMessage, ServerResult},
service::ServerInitializeError,
transport::{IntoTransport, Transport},
};

fn msg(raw: &str) -> ClientJsonRpcMessage {
serde_json::from_str(raw).expect("invalid test message JSON")
}

fn init_request() -> ClientJsonRpcMessage {
msg(r#"{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-11-25",
"capabilities": {},
"clientInfo": { "name": "test-client", "version": "0.0.1" }
}
}"#)
}

fn initialized_notification() -> ClientJsonRpcMessage {
msg(r#"{ "jsonrpc": "2.0", "method": "notifications/initialized" }"#)
}

fn set_level_request(id: u64) -> ClientJsonRpcMessage {
msg(&format!(
r#"{{ "jsonrpc": "2.0", "id": {id}, "method": "logging/setLevel", "params": {{ "level": "info" }} }}"#
))
}

fn ping_request(id: u64) -> ClientJsonRpcMessage {
msg(&format!(
r#"{{ "jsonrpc": "2.0", "id": {id}, "method": "ping" }}"#
))
}

fn list_tools_request(id: u64) -> ClientJsonRpcMessage {
msg(&format!(
r#"{{ "jsonrpc": "2.0", "id": {id}, "method": "tools/list" }}"#
))
}

async fn do_initialize(client: &mut impl Transport<rmcp::RoleClient>) {
client.send(init_request()).await.unwrap();
let _response = client.receive().await.unwrap();
}

// Server responds with EmptyResult to setLevel received before initialized.
#[tokio::test]
async fn server_init_set_level_response_is_empty_result() {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let _server = tokio::spawn(async move { TestServer::new().serve(server_transport).await });
let mut client = IntoTransport::<rmcp::RoleClient, _, _>::into_transport(client_transport);

do_initialize(&mut client).await;
client.send(set_level_request(2)).await.unwrap();

let response = client.receive().await.unwrap();
assert!(
matches!(
response,
ServerJsonRpcMessage::Response(ref r)
if matches!(r.result, ServerResult::EmptyResult(_))
),
"expected EmptyResult for setLevel, got: {response:?}"
);
}

// Server initializes successfully when setLevel is sent before the initialized notification.
#[tokio::test]
async fn server_init_succeeds_after_set_level_before_initialized() {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let server_handle =
tokio::spawn(async move { TestServer::new().serve(server_transport).await });
let mut client = IntoTransport::<rmcp::RoleClient, _, _>::into_transport(client_transport);

do_initialize(&mut client).await;
client.send(set_level_request(2)).await.unwrap();
let _response = client.receive().await.unwrap();
client.send(initialized_notification()).await.unwrap();

let result = server_handle.await.unwrap();
assert!(
result.is_ok(),
"server should initialize successfully after setLevel"
);
result.unwrap().cancel().await.unwrap();
}

// Server responds with EmptyResult to ping received before initialized.
#[tokio::test]
async fn server_init_ping_response_is_empty_result() {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let _server = tokio::spawn(async move { TestServer::new().serve(server_transport).await });
let mut client = IntoTransport::<rmcp::RoleClient, _, _>::into_transport(client_transport);

do_initialize(&mut client).await;
client.send(ping_request(2)).await.unwrap();

let response = client.receive().await.unwrap();
assert!(
matches!(
response,
ServerJsonRpcMessage::Response(ref r)
if matches!(r.result, ServerResult::EmptyResult(_))
),
"expected EmptyResult for ping, got: {response:?}"
);
}

// Server initializes successfully when ping is sent before the initialized notification.
#[tokio::test]
async fn server_init_succeeds_after_ping_before_initialized() {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let server_handle =
tokio::spawn(async move { TestServer::new().serve(server_transport).await });
let mut client = IntoTransport::<rmcp::RoleClient, _, _>::into_transport(client_transport);

do_initialize(&mut client).await;
client.send(ping_request(2)).await.unwrap();
let _response = client.receive().await.unwrap();
client.send(initialized_notification()).await.unwrap();

let result = server_handle.await.unwrap();
assert!(
result.is_ok(),
"server should initialize successfully after ping"
);
result.unwrap().cancel().await.unwrap();
}

// Server returns ExpectedInitializedNotification for any other message before initialized.
#[tokio::test]
async fn server_init_rejects_unexpected_message_before_initialized() {
let (server_transport, client_transport) = tokio::io::duplex(4096);
let server_handle =
tokio::spawn(async move { TestServer::new().serve(server_transport).await });
let mut client = IntoTransport::<rmcp::RoleClient, _, _>::into_transport(client_transport);

do_initialize(&mut client).await;
client.send(list_tools_request(2)).await.unwrap();

let result = server_handle.await.unwrap();
assert!(
matches!(
result,
Err(ServerInitializeError::ExpectedInitializedNotification(_))
),
"expected ExpectedInitializedNotification error"
);
}