Skip to content

Conversation

@jreiml
Copy link
Contributor

@jreiml jreiml commented Jan 10, 2026

What does this PR do?

Fixes the non_block=True behavior in ExternalZeroMQDistributedExecutor to properly implement the vLLM executor contract.

The previous implementation immediately called recv() and wrapped the result in an already-resolved Future. This PR adds a _DeferredZmqFuture that defers recv() until result() is called, allowing vLLM's EngineCore to overlap work (e.g., grammar bitmask computation for structured output) with remote model execution.

Related: #3934 added non_block parameter compatibility but didn't implement actual non-blocking behavior.

Checklist Before Starting

Test

The non_block=True code path is called by vLLM v1's EngineCore (see vllm/v1/engine/core.py). Existing CI tests that run vLLM v1 (tests/experimental/agent_loop, test_vllm_abort.py) exercise this code path through the full inference stack.

API and Usage Example

No API changes. Internal behavior change only.

Design & Code Changes

  • Add _DeferredZmqFuture class that stores sockets and defers recv() until result() is called
  • Add non_block and unique_reply_rank parameters to collective_rpc()
  • Add assertion enforcing max_concurrent_batches=1 (required for thread-safe ZMQ REQ/REP)

Checklist Before Submitting

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly implements the deferred execution for non-blocking calls in ExternalZeroMQDistributedExecutor by introducing _DeferredZmqFuture. This is a good fix that aligns with the vLLM executor contract and allows for overlapping computation. The addition of the assertion for max_concurrent_batches=1 is also a great defensive measure to ensure thread safety with ZMQ REQ/REP sockets. However, I've identified a critical security vulnerability. The implementation uses pickle.loads() to deserialize data received over the network. This is unsafe and can lead to remote code execution if the network is not completely secure. My review includes a comment with details on this issue.

try:
outputs = []
for socket in self._sockets:
outputs.append(pickle.loads(socket.recv()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-critical critical

The use of pickle.loads() on data received from a network socket introduces a critical security vulnerability. Deserializing data with pickle can lead to arbitrary code execution if the data is crafted maliciously. While this communication is likely between trusted internal workers, it's a significant security risk if the network is not completely isolated and secure. An attacker who can intercept or inject traffic on this ZMQ channel could compromise the worker process.

It is strongly recommended to replace pickle with a safer serialization format, such as JSON. If complex Python objects must be transferred, consider using a library that provides cryptographically signed serialization to ensure data integrity and authenticity.

Copy link
Contributor Author

@jreiml jreiml Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current change follows the existing logic. This would have to be done in another PR.

- Add `_DeferredZmqFuture` class that defers ZMQ `recv()` until `result()` is called
- This properly implements the vLLM executor contract for `non_block=True`, allowing EngineCore to overlap work (e.g., grammar bitmask computation for structured output) with remote model execution
- Add assertion to enforce `max_concurrent_batches=1`, required for thread-safe ZMQ REQ/REP operation
@jreiml jreiml force-pushed the vllm-deferred-zmq-future branch from 7447b1f to 013a7a9 Compare January 10, 2026 17:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant