diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index fdd35ed3f..5905d150a 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -526,7 +526,7 @@ describe("StreamableHTTPClientTransport", () => { await transport["_startOrAuthSse"]({}); expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("CustomValue"); - (requestInit.headers as Headers).set("X-Custom-Header","SecondCustomValue"); + (requestInit.headers as Headers).set("X-Custom-Header", "SecondCustomValue"); await transport.send({ jsonrpc: "2.0", method: "test", params: {} } as JSONRPCMessage); expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("SecondCustomValue"); @@ -605,7 +605,7 @@ describe("StreamableHTTPClientTransport", () => { maxRetries: 1, maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity - } + } }); const errorSpy = jest.fn(); @@ -653,7 +653,7 @@ describe("StreamableHTTPClientTransport", () => { maxRetries: 1, maxReconnectionDelay: 1000, // Ensure it doesn't retry indefinitely reconnectionDelayGrowFactor: 1 // No exponential backoff for simplicity - } + } }); const errorSpy = jest.fn(); @@ -1001,4 +1001,69 @@ describe("StreamableHTTPClientTransport", () => { expect(global.fetch).not.toHaveBeenCalled(); }); }); + + describe("attemptSSE option", () => { + it("should not attempt SSE connection when attemptSSE is false", async () => { + const transport = new StreamableHTTPClientTransport( + new URL("http://localhost:1234/mcp"), + { attemptSSE: false } + ); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "initialize", + params: {}, + id: "test-id" + }; + + (global.fetch as jest.Mock).mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers(), + }); + + await transport.send(message); + + // Should only make one POST request, no GET request for SSE + expect(global.fetch).toHaveBeenCalledTimes(1); + expect(global.fetch).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + method: "POST", + }) + ); + }); + + it("should attempt SSE connection by default when attemptSSE is not specified", async () => { + const transport = new StreamableHTTPClientTransport( + new URL("http://localhost:1234/mcp") + ); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "notifications/initialized", + params: {} + }; + + (global.fetch as jest.Mock) + .mockResolvedValueOnce({ + ok: true, + status: 202, + headers: new Headers(), + }) + .mockResolvedValueOnce({ + ok: true, + status: 200, + headers: new Headers({ "content-type": "text/event-stream" }), + body: new ReadableStream(), + }); + + await transport.send(message); + + // Should make POST request and then GET request for SSE + expect(global.fetch).toHaveBeenCalledTimes(2); + expect(global.fetch).toHaveBeenNthCalledWith(1, expect.anything(), expect.objectContaining({ method: "POST" })); + expect(global.fetch).toHaveBeenNthCalledWith(2, expect.anything(), expect.objectContaining({ method: "GET" })); + }); + }); }); diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 12714ea44..fe356b744 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -114,6 +114,12 @@ export type StreamableHTTPClientTransportOptions = { * When not provided and connecting to a server that supports session IDs, the server will generate a new session ID. */ sessionId?: string; + + /** + * If false, do not attempt to open an SSE (Server-Sent Events) connection. + * Default is true (attempt SSE connection). + */ + attemptSSE?: boolean; }; /** @@ -131,6 +137,7 @@ export class StreamableHTTPClientTransport implements Transport { private _sessionId?: string; private _reconnectionOptions: StreamableHTTPReconnectionOptions; private _protocolVersion?: string; + private _attemptSSE: boolean; onclose?: () => void; onerror?: (error: Error) => void; @@ -147,6 +154,7 @@ export class StreamableHTTPClientTransport implements Transport { this._fetch = opts?.fetch; this._sessionId = opts?.sessionId; this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; + this._attemptSSE = opts?.attemptSSE !== false; } private async _authThenStart(): Promise { @@ -166,7 +174,10 @@ export class StreamableHTTPClientTransport implements Transport { throw new UnauthorizedError(); } - return await this._startOrAuthSse({ resumptionToken: undefined }); + if (this._attemptSSE) { + return await this._startOrAuthSse({ resumptionToken: undefined }); + } + return; } private async _commonHeaders(): Promise { @@ -196,6 +207,9 @@ export class StreamableHTTPClientTransport implements Transport { private async _startOrAuthSse(options: StartSSEOptions): Promise { const { resumptionToken } = options; + if (!this._attemptSSE) { + return; + } try { // Try to open an initial SSE stream with GET to listen for server messages // This is optional according to the spec - server may not support it @@ -411,7 +425,9 @@ export class StreamableHTTPClientTransport implements Transport { if (resumptionToken) { // If we have at last event ID, we need to reconnect the SSE stream - this._startOrAuthSse({ resumptionToken, replayMessageId: isJSONRPCRequest(message) ? message.id : undefined }).catch(err => this.onerror?.(err)); + if (this._attemptSSE) { + this._startOrAuthSse({ resumptionToken, replayMessageId: isJSONRPCRequest(message) ? message.id : undefined }).catch(err => this.onerror?.(err)); + } return; } @@ -459,7 +475,7 @@ export class StreamableHTTPClientTransport implements Transport { if (response.status === 202) { // if the accepted notification is initialized, we start the SSE stream // if it's supported by the server - if (isInitializedNotification(message)) { + if (isInitializedNotification(message) && this._attemptSSE) { // Start without a lastEventId since this is a fresh connection this._startOrAuthSse({ resumptionToken: undefined }).catch(err => this.onerror?.(err)); } diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index ddb274196..1c1f1abeb 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -420,7 +420,8 @@ async function connect(url?: string): Promise { transport = new StreamableHTTPClientTransport( new URL(serverUrl), { - sessionId: sessionId + sessionId: sessionId, + // attemptSSE: false, // Set to false to disable SSE connection attempts } );