Skip to content

Conversation

sdavidbd
Copy link
Contributor

@sdavidbd sdavidbd commented Jun 8, 2025

🎯 Purpose

This PR implements the design proposed in RFC #19329.

It adds a mechanism to recover from KV load failures in vLLM’s KV connector path by:

  • Detecting failed KV block loads
  • Automatically rescheduling affected requests for recomputation from a valid prefix

The implementation supports both synchronous and asynchronous loading methods, improving robustness when using external systems for KV cache offload or transfer.


🧪 Test Plan

✅ Unit Tests

Coverage:

  • Recovery logic for sync and async KV load failures
  • Shared block scenarios (only applicable to sync loading)
  • Correct request rescheduling and computed-token count adaptation

Run:

pytest -v tests/v1/kv_connector/unit/test_output_aggreagator.py
pytest -v tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py

Results:

✔ test_output_aggreagator.py
  ├─ test_aggregate_workers_output
  └─ test_async_aggregate_workers_output

✔ test_kv_load_failure_recovery.py
  ├─ test_async_load_failure (3 cases)
  ├─ test_sync_load_failure (3 cases)
  ├─ test_sync_load_failure_with_shared_blocks (3 cases)
  └─ test_async_progressive_load_failure (2 cases)

📦 End-to-End Example

A full integration test is available at:

examples/offline_inference/kv_load_failure_recovery/

This test demonstrates the recovery mechanism using fault-injecting connectors that simulate failed KV loads.
Test flow:

  1. Prefill stage — Saves KV data to the local filesystem
  2. Decode stage 1 — Successful sync KV load (baseline)
  3. Decode stage 2 — Failed sync KV load
  4. Decode stage 3 — Failed async KV load

See the README.md for details.

Run:

cd examples/offline_inference/kv_load_failure_recovery
./run.sh

Results:

All outputs in simulated failure cases match the baseline decode.

...
INFO:root:Simulating failure to load all KV blocks for the first load request. Total blocks: 62
...
INFO 08-07 17:21:21 [scheduler.py:1037] Recovered from KV load failure: 2 request(s) rescheduled (1022 tokens affected).
...
INFO:root:Simulating failure to load all KV blocks for the first load request. Total blocks: 62
...
INFO 08-07 17:21:39 [scheduler.py:1037] Recovered from KV load failure: 1 request(s) rescheduled (0 tokens affected).
...
✅ Outputs match: recovery successful.

Copy link

github-actions bot commented Jun 8, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @sdavidbd, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello! Gemini or gemini-code-assist here, providing a summary of this pull request.

This PR introduces support for recovering from KV load failures within vLLM's v1 KV connector path. The primary goal is to enhance the robustness of vLLM when utilizing external systems for KV cache offload or transfer. The mechanism involves detecting when KV blocks fail to load and automatically rescheduling the affected requests for recomputation starting from a valid prefix. A self-contained example is included to demonstrate and test this recovery process by simulating load failures.

Highlights

  • KV Load Failure Recovery: Implements a mechanism to detect and recover from failures when loading KV cache blocks from an external source using the v1 KV connector.
  • Request Rescheduling: Requests that depend on failed-to-load KV blocks are automatically rescheduled. The scheduler identifies the point of failure and truncates the request's computed prefix, allowing it to be recomputed correctly.
  • Test Example Added: A new, self-contained example (examples/offline_inference/kv_load_failure_recovery/) is added to demonstrate the recovery process using a custom connector that simulates load failures.
  • Rogue Connector for Simulation: A RogueSharedStorageConnector is introduced in the test example to intentionally fail KV block loads for specific requests, allowing verification of the recovery logic.

Changelog

Click here to see the changelog
  • examples/offline_inference/kv_load_failure_recovery/README.md
    • New file: Adds a README explaining the purpose and structure of the KV load failure recovery test example.
  • examples/offline_inference/kv_load_failure_recovery/decode_example.py
    • New file: Adds the decode stage script for the test example.
    • Includes an --simulate-failure argument to switch between the standard SharedStorageConnector and the RogueSharedStorageConnector.
  • examples/offline_inference/kv_load_failure_recovery/prefill_example.py
    • New file: Adds the prefill stage script for the test example.
    • Generates prompts and saves the initial prefill output to output.txt for the decode stage.
  • examples/offline_inference/kv_load_failure_recovery/rogue_shared_storage_connector.py
    • New file: Defines RogueSharedStorageConnector, a subclass of SharedStorageConnector.
    • Overrides bind_connector_metadata to identify the first load request and store its block IDs as invalid_block_ids, effectively simulating a load failure for those blocks.
    • Implements get_block_ids_with_load_errors to return the set of simulated invalid block IDs.
  • examples/offline_inference/kv_load_failure_recovery/run.sh
    • New file: A helper script to clean up previous test runs, execute the prefill stage, run a normal decode, and then run a decode with simulated failure.
  • vllm/distributed/kv_transfer/kv_connector/v1/base.py
    • Adds a new abstract method get_block_ids_with_load_errors to KVConnectorBase_V1 to allow connectors to report failed block loads. Default implementation returns None.
  • vllm/v1/core/block_pool.py
    • Modifies the cache_full_blocks method to use >= instead of == when checking if num_cached_blocks is sufficient, potentially preventing redundant operations.
  • vllm/v1/core/sched/scheduler.py
    • Adds logic in update_from_output to check for invalid_block_ids reported by the model runner.
    • If invalid blocks are found, iterates through running requests to see if they use any of these blocks.
    • Requests using invalid blocks are marked for rescheduling by adjusting their num_computed_tokens to the last valid block boundary.
    • Adds logging to report the number of requests and tokens rescheduled due to KV load failure.
  • vllm/v1/core/single_type_kv_cache_manager.py
    • Adds a check in cache_blocks to return early if num_cached_blocks is already greater than or equal to num_full_blocks, similar to the change in block_pool.py.
  • vllm/v1/outputs.py
    • Adds invalid_block_ids: Optional[set[int]] field to the ModelRunnerOutput dataclass to carry information about failed block loads from the worker to the scheduler.
  • vllm/v1/worker/gpu_model_runner.py
    • Imports is_v1_kv_transfer_group.
    • In execute_model, checks if the KV transfer group is v1.
    • If it's a v1 group, calls get_block_ids_with_load_errors on the connector before clearing metadata.
    • Passes the retrieved invalid_block_ids to the ModelRunnerOutput.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added documentation Improvements or additions to documentation v1 labels Jun 8, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable feature for KV load failure recovery in vLLM, enhancing its robustness when dealing with external KV cache systems. The overall approach, including detecting failed blocks and rescheduling requests, is well-conceived. The addition of a self-contained example with a fault-injecting connector (RogueSharedStorageConnector) is excellent for testing and demonstrating the feature.

I've identified a few areas for improvement, primarily concerning potential correctness issues in the core scheduling logic and GPU model runner, as well as some minor suggestions for the example code.

Summary of Findings

  • Potential correctness issue in scheduler loop: The loop condition for checking invalid blocks in vllm/v1/core/sched/scheduler.py (line 735) compares a block index with a token count, which could lead to incorrect behavior. It's recommended to iterate explicitly over the blocks corresponding to the computed prefix.
  • Potential UnboundLocalError in GPU model runner: In vllm/v1/worker/gpu_model_runner.py, the invalid_block_ids variable might be uninitialized if is_v1_kv_transfer_group() is false, leading to an error when constructing ModelRunnerOutput. Initializing it to None beforehand is advised.
  • Example code improvements: Minor refactoring opportunities and type hint suggestions exist in the example files (decode_example.py, rogue_shared_storage_connector.py) to improve clarity and precision.

Merge Readiness

The pull request introduces a significant and useful feature for KV load failure recovery. The core idea and the testing example are well-implemented. However, there are a couple of high-severity issues in the core logic (scheduler.py and gpu_model_runner.py) that need to be addressed to ensure correctness and prevent potential runtime errors. Once these are resolved, and the minor suggestions for the example code are considered, the PR should be in good shape for merging. I am unable to approve pull requests, so please ensure other reviewers approve this code before merging.

@sdavidbd sdavidbd force-pushed the feature/kv-load-failure-recovery branch 2 times, most recently from 08bade8 to 20f0419 Compare June 9, 2025 12:47
@orozery
Copy link
Contributor

orozery commented Jun 10, 2025

@sdavidbd the problem I see here is in multi-worker case.
For example, if you use tensor-parallelism or pipeline parallelism.
Each worker may fail loading on different block ids.
Although each worker reports its own invalid_block_ids, the scheduler will only get the invalid_block_ids of the first-rank worker.
This is due to the implementation of MultiprocExecutor.execute_model which simply discards the ModelRunnerOutput of all but the first worker.

In theory, the same problem holds for the existing fields ModelRunnerOutput.finished_sending and ModelRunnerOutput.finished_recving.
The current solution in the code for this is having the first-rank worker communicate with all other workers to aggregate input from all workers to its own output.
This is what was done in NixlConnector.get_finished.

However, I this it is better to have aggregation by the scheduler, and not by the workers.
And also, have one aggregation per all fields of ModelRunnerOutput.

I'm actually planning on opening up a PR in the upcoming days that will add a generic KVConnectorMetadata field to ModelRunnerOutput. This field will be used to transfer ALL required connector information from workers to the scheduler. This includes finished_sending and finished_recving, and potentially your invalid_block_ids.

With this approach, the get_block_ids_with_load_errors will become a scheduler-side connector call and will extract the invalid_block_ids from the abstract KVConnectorMetadata.

My motivation for this is my plan to add an offloading connector, and be able to report offloading status from workers to the scheduler.

cc @njhill

@sdavidbd
Copy link
Contributor Author

@sdavidbd the problem I see here is in multi-worker case. For example, if you use tensor-parallelism or pipeline parallelism. Each worker may fail loading on different block ids. Although each worker reports its own invalid_block_ids, the scheduler will only get the invalid_block_ids of the first-rank worker. This is due to the implementation of MultiprocExecutor.execute_model which simply discards the ModelRunnerOutput of all but the first worker.

In theory, the same problem holds for the existing fields ModelRunnerOutput.finished_sending and ModelRunnerOutput.finished_recving. The current solution in the code for this is having the first-rank worker communicate with all other workers to aggregate input from all workers to its own output. This is what was done in NixlConnector.get_finished.

However, I this it is better to have aggregation by the scheduler, and not by the workers. And also, have one aggregation per all fields of ModelRunnerOutput.

I'm actually planning on opening up a PR in the upcoming days that will add a generic KVConnectorMetadata field to ModelRunnerOutput. This field will be used to transfer ALL required connector information from workers to the scheduler. This includes finished_sending and finished_recving, and potentially your invalid_block_ids.

With this approach, the get_block_ids_with_load_errors will become a scheduler-side connector call and will extract the invalid_block_ids from the abstract KVConnectorMetadata.

My motivation for this is my plan to add an offloading connector, and be able to report offloading status from workers to the scheduler.

cc @njhill

@orozery Thanks a lot for the detailed feedback — you're absolutely right. I'm aware of this gap in the multi-worker setup, and I'm already working on adding aggregation logic for invalid_block_ids at the MultiprocExecutor level to address it.

I also really like your idea of encapsulating KV connector–related metadata in a dedicated KVConnectorOutput (or similar — probably best to distinguish it from the existing KVConnectorMetadata) field within ModelRunnerOutput. It seems like a clean and extensible way to surface connector-specific signals to the scheduler.

Looking forward to your upcoming PR — happy to align with it or adapt my changes accordingly.

@sdavidbd sdavidbd marked this pull request as draft June 11, 2025 12:17
@sdavidbd
Copy link
Contributor Author

Converting this PR to draft for now - I'm working on extending support for tensor-parallel (TP) setups and adding more unit tests to improve coverage and reliability. Will mark as ready for review once that's complete.

@njhill
Copy link
Member

njhill commented Jun 20, 2025

@sdavidbd we need generic failure handling logic... would it make more sense to just have the returned list of request ids that have finished loading reflect a failed/succeeded status (or equivalently include a mutually exclusive list of failed load req ids)?

The scheduler already knows the associated block ids and so would be able to recompute accordingly. This could then cover other kinds of failure too.

@sdavidbd
Copy link
Contributor Author

@sdavidbd we need generic failure handling logic... would it make more sense to just have the returned list of request ids that have finished loading reflect a failed/succeeded status (or equivalently include a mutually exclusive list of failed load req ids)?

The scheduler already knows the associated block ids and so would be able to recompute accordingly. This could then cover other kinds of failure too.

Thanks @nick-hill — appreciate the suggestion!

I actually think that returning the list of failed block IDs is already a fairly generic and extensible approach, for a few reasons:

  1. Connector visibility: The connector is inherently aware of which blocks failed to load, but it may not have enough context to map failures back to specific request IDs. We’d prefer not to require connectors to track per-request mappings, as this would increase implementation complexity and reduce flexibility.
  2. One-to-many dependency: Multiple requests in the same batch may share the same KV blocks. In case of a block load failure, all dependent requests should be rescheduled. Reporting failed blocks allows the scheduler to identify the affected requests based on its existing request-to-blocks mapping.
  3. Granularity of recovery: Treating a request as failed based on any load error may be too coarse. A request might attempt to load 100 blocks and fail only on the last one. In such cases, it’s more efficient to recompute only the failed blocks rather than discarding all previously loaded ones.
  4. Future enhancement of KV cache reuse: Currently, KV reuse in vLLM is limited to a prefix of the prompt. We're planning to enhance this to support reuse of arbitrary subsets of the prompt blocks. This is particularly important when offloading KV cache to external backends with their own eviction policies or when sporadic load failures occur. By allowing connectors to report the exact set of failed blocks, we lay the groundwork for fine-grained recovery in the future.

@njhill
Copy link
Member

njhill commented Jul 3, 2025

Thanks @sdavidbd, and sorry for the delay getting back to this.

I guess I'm still not convinced that block ids would be "simpler" - since we already pass back ids of finished requests, just having these flagged as failed or succeeded seems simpler to me, and would be straightfoward on the scheduler side for example to just fall back to locally prefilling any failed requests.

Since errors would presumably be quite rare, the granularity optimization seems like it might be unnecessary.

One thing I have been wondering about related to this though, is how we should we handle cases where concurrent requests want to load overlapping prefixes. If there is an async kv load in already flight for a subset of the blocks, the request should ideally just subscribe / wait behind that (in addition to loading additional blocks in parallel, if needed). I don't think that's how it currently works.

@@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be part of the test suite rather than under examples?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sdavidbd also wondering about this. Do we have equivalent of this also covered in the tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do have similar logic covered in the unit tests — specifically async/sync loading and shared blocks in the sync case. That said, the unit tests focus on the recovery logic in isolation, while this example serves as an integration test: it exercises the full end-to-end flow and verifies overall correctness.

@NickLucche
Copy link
Collaborator

Thanks for the work @sdavidbd !

re: @orozery

I'm actually planning on opening up a PR in the upcoming days that will add a generic KVConnectorMetadata field to ModelRunnerOutput.

This is actually independent from having the reduce carried out in the scheduler though.
We could still have rank0 be in charge of aggregating data from the other ranks, so that the actual logic remains where the data is produced. Perhaps I am just missing the advantage of scheduler-side reduction.
But we definitely need to aggregate as you pointed out.

having these (request_id) flagged as failed or succeeded seems simpler to me

I agree here I think it would make the aggregation above simpler and would move less data around the ranks. I believe we can overload the single communication step where we already exchange finished_req_ids.

@sdavidbd
Copy link
Contributor Author

sdavidbd commented Jul 3, 2025

Thanks @njhill - really appreciate the follow-up.

You're absolutely right that today errors are relatively rare. However, as KV cache offloading gains traction - especially with the rise of disaggregated storage solutions - we do expect failures to become more common. Failures could stem from transient disconnections or external eviction policies, and designing for them upfront helps ensure resilience.

Looking forward, KV cache offloading is expected to play a key role in enabling KV reuse across various scenarios: preempted requests, multi-turn conversations, long shared document queries, and more. In these cases, externally computed tokens are loaded as part of running prefill requests and can potentially be shared by multiple requests in the same batch.

Critically, the connector typically does not have full visibility into all requests that might share these blocks - specifically, requests beyond the first will treat the blocks as if they were computed locally rather than loaded externally. Reporting block-level failures allows the scheduler, which does maintain a global view of request-to-block mappings, to identify and reschedule only the affected requests efficiently.

Moreover, the future enhancement I mentioned - supporting reuse of arbitrary prompt blocks rather than only the prefix - is an important and feasible optimization we are actively working on. The block-level failure reporting design directly aligns with and enables this finer-grained reuse and recovery path.

Regarding your point about overlapping async KV loads: I agree - currently, overlapping blocks are duplicated locally instead of being shared. While this gap is certainly addressable, it also raises a broader question:
Why do disaggregated decode requests need to fully wait for remotely computed KV blocks to finish loading before starting at all? Couldn't we schedule the requests immediately and load the KV blocks asynchronously, layer by layer, during the forward pass?
This approach could resolve the duplication issue and allow us to overlap block loading with computation, further optimizing end-to-end latency.

Thanks again for the constructive discussion - happy to further align on this if you'd like!

@njhill
Copy link
Member

njhill commented Jul 5, 2025

Thanks @sdavidbd. Maybe it's better to have the connector API be block-based? (see discussion in other PR #19555 (comment)).

Couldn't we schedule the requests immediately and load the KV blocks asynchronously, layer by layer, during the forward pass?

This is already supported by the connector API but there may be a trade-off since we need to load at least the first layer before the forward pass starts, which will add latency to every other request in the batch. It might also slow down subsequent layer if the loading takes longer than computation of the prior layer. There might also be greater overhead associated with multiple per-layer transfers. But it's something we could experiment with. We want to at least support async for the initial handshake.

I guess we need to think through the sync and async loading cases with respect to block sharing. You're right that the sharing happens automatically/implicitly in the sync case right now, it would be good to have that work for the async case too.

@sdavidbd sdavidbd force-pushed the feature/kv-load-failure-recovery branch from 20f0419 to cf89f4d Compare July 13, 2025 15:02
David Ben-David added 17 commits September 30, 2025 12:28
Signed-off-by: David Ben-David <[email protected]>
Signed-off-by: David Ben-David <[email protected]>
Signed-off-by: David Ben-David <[email protected]>
Connectors now consistently return a set of failed block IDs (empty if none).
This removes the need for Optional handling, simplifying the code without
introducing any performance impact.

Signed-off-by: David Ben-David <[email protected]>
Signed-off-by: David Ben-David <[email protected]>
@sdavidbd sdavidbd force-pushed the feature/kv-load-failure-recovery branch from 3af28dc to 46dcace Compare September 30, 2025 11:58
@sdavidbd
Copy link
Contributor Author

@njhill I’ve simplified the handling here by removing generator rollback from this PR. Rollback of the generator state in the case of sync-KV-load failures can be handled in a follow-up PR, along with unifying and refactoring the rollback mechanism more broadly.

This keeps the logic here focused on correcting state alignment for discarded output tokens. It also directly addresses your comments around overhead, clarity, and duplication without overcomplicating this change.

@sdavidbd
Copy link
Contributor Author

The v1-test-others failure appears to originate from changes introduced in PR #25896. Build reference: https://buildkite.com/vllm/ci/builds/32924/steps/canvas?sid=019999e7-f4f7-4f46-b184-91f105f6cee4.

@njhill
Copy link
Member

njhill commented Sep 30, 2025

Thanks @sdavidbd that does seem much cleaner.

Could you confirm the implications of not handling the generator rollback?

Seeded outputs may not be consistent in the face of kv load failure/recovery, would that only apply to the sync kv-load-failure case rather than async? (sorry I know I could figure this out but thought I would save time by asking!)

@sdavidbd
Copy link
Contributor Author

Thanks @sdavidbd that does seem much cleaner.

Could you confirm the implications of not handling the generator rollback?

Seeded outputs may not be consistent in the face of kv load failure/recovery, would that only apply to the sync kv-load-failure case rather than async? (sorry I know I could figure this out but thought I would save time by asking!)

@njhill Yes, exactly — it would only apply to the sync kv-load-failure case. In the async case, there isn’t a cached state for the request yet.

@njhill
Copy link
Member

njhill commented Sep 30, 2025

ok lets get this merged!

Would be good to address the generator thing as a follow-on.

@simon-mo simon-mo merged commit 9a9f48d into vllm-project:main Sep 30, 2025
45 of 48 checks passed
pdasigi pushed a commit to pdasigi/vllm that referenced this pull request Oct 2, 2025
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: David Ben-David <[email protected]>
Co-authored-by: David Ben-David <[email protected]>
Signed-off-by: yewentao256 <[email protected]>
tomeras91 pushed a commit to tomeras91/vllm that referenced this pull request Oct 6, 2025
Signed-off-by: David Ben-David <[email protected]>
Co-authored-by: David Ben-David <[email protected]>
Signed-off-by: Tomer Asida <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation kv-connector ready ONLY add when PR is ready to merge/full CI is needed suppress-bc-linter v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.