class OffloadingConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, spec: OffloadingSpec):
self.config = SchedulerOffloadConfig.from_spec(spec)
self.manager: OffloadingManager = spec.get_manager()
attention_groups: list[int] = []
for idx, _ in enumerate(spec.kv_cache_config.kv_cache_groups):
# currently treat all groups as full attention
attention_groups.append(idx)
self.lookup_groups = attention_groups
self._req_status: dict[ReqId, RequestOffloadState] = {}
self._current_batch_load_jobs: dict[int, TransferJob] = {}
self._current_batch_jobs_to_flush: set[int] = set()
# if GPU prefix caching is enabled,
# track loaded blocks to avoid redundant loads
self._blocks_being_loaded: set[OffloadKey] | None = (
set() if spec.vllm_config.cache_config.enable_prefix_caching else None
)
# Job ID counter shared by loads and stores.
self._job_counter: int = 0
self._jobs: dict[int, TransferJobStatus] = {}
# block_id -> pending store job_ids. Populated only for finished
# requests (running-request blocks are protected by their ref_cnt).
self._block_id_to_pending_jobs: dict[int, set[int]] = {}
def _generate_job_id(self) -> int:
job_id = self._job_counter
self._job_counter += 1
return job_id
def _maximal_prefix_lookup(
self, keys: Iterable[OffloadKey], req_context: ReqContext
) -> int | None:
"""Find the length of the maximal prefix of offloaded blocks."""
hit_count = 0
defer_lookup = False
for key in keys:
result = self.manager.lookup(key, req_context)
if result is None:
defer_lookup = True
# continue lookup to allow manager to kick-off async lookups
# for all blocks (until a miss is detected)
result = True
if not result:
break
hit_count += 1
return hit_count if not defer_lookup else None
def _sliding_window_lookup(
self,
keys: Sequence[OffloadKey],
sliding_window_size: int,
req_context: ReqContext,
) -> int | None:
"""Find the maximal ending position of consecutive offloaded blocks
within a sliding window."""
defer_lookup = False
consecutive_hits = 0
for idx in range(len(keys) - 1, -1, -1):
result = self.manager.lookup(keys[idx], req_context)
if result is None:
defer_lookup = True
# continue lookup to allow manager to kick-off async lookups
# for all blocks (until a hit is detected)
result = False
if not result:
consecutive_hits = 0
else:
consecutive_hits += 1
if consecutive_hits == sliding_window_size:
return idx + sliding_window_size if not defer_lookup else None
return consecutive_hits if not defer_lookup else None
def get_num_new_matched_tokens(
self, request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded beyond the
num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
A tuple with the following elements:
- The number of tokens that can be loaded beyond what is
already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if tokens will be loaded asynchronously
(between scheduler steps).
"""
if req_status := self._req_status.get(request.request_id):
# make sure block IDs are cleared
for group_state in req_status.group_states:
group_state.block_ids.clear()
else:
req_status = RequestOffloadState(config=self.config, req=request)
self._req_status[request.request_id] = req_status
req_status.update_offload_keys()
req_status.num_locally_computed_tokens = num_computed_tokens
for gs in req_status.group_states:
self.manager.touch(gs.offload_keys)
# Start with the full request size as the maximum loadable
max_hit_size_tokens: int = req_status.req.num_tokens
num_hit_tokens: int = 0
defer_lookup = False
delay_request = False
for group_idx in self.lookup_groups:
group_config: GroupOffloadConfig = self.config.kv_group_configs[group_idx]
offloaded_block_size = group_config.offloaded_block_size
offload_keys = req_status.group_states[group_idx].offload_keys
num_blocks = max_hit_size_tokens // offloaded_block_size
assert len(offload_keys) >= num_blocks
# Constrain to block-aligned boundary for this group
max_hit_size_tokens = num_blocks * offloaded_block_size
num_hit_tokens = max_hit_size_tokens - num_computed_tokens
if num_hit_tokens < offloaded_block_size:
# we can only load less than a block, better skip
return 0, False
start_block_idx = num_computed_tokens // offloaded_block_size
offload_keys = offload_keys[start_block_idx:num_blocks]
# Full attention relies on all previous KV cache blocks.
# Thus, we search for a maximal prefix of KV cache which are all cached.
block_hits = self._maximal_prefix_lookup(
offload_keys, req_status.req_context
)
if block_hits == 0:
return 0, False
if block_hits is None:
defer_lookup = True
else:
# Further constrain based on what's actually available by backend
max_hit_size_tokens = offloaded_block_size * (
start_block_idx + block_hits
)
num_hit_tokens = max_hit_size_tokens - num_computed_tokens
if num_hit_tokens < offloaded_block_size:
# we can only load less than a block, better skip
return 0, False
if (
block_hits
and self._blocks_being_loaded
and any(
key in self._blocks_being_loaded
for key in offload_keys[:block_hits]
)
):
# hit blocks are being loaded, delay request
delay_request = True
if defer_lookup:
logger.debug(
"Offloading manager delayed request %s as backend requested",
req_status.req.request_id,
)
return None, False
if delay_request:
logger.debug(
"Delaying request %s since some of its blocks are already being loaded",
req_status.req.request_id,
)
return None, False
logger.debug(
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
request.request_id,
num_hit_tokens,
num_computed_tokens,
)
return num_hit_tokens, True
def update_state_after_alloc(
self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
):
if num_external_tokens == 0:
return
req_status = self._req_status[request.request_id]
num_locally_computed_tokens = req_status.num_locally_computed_tokens
num_cached_tokens = num_locally_computed_tokens + num_external_tokens
params = req_status.req_context.kv_transfer_params
do_remote_decode = params is not None and params.get("do_remote_decode")
keys_to_load: list[OffloadKey] = []
dst_block_ids: list[int] = []
# per group
group_sizes: list[int] = []
block_indices: list[int] = []
for group_config, group_state, group_blocks in zip(
self.config.kv_group_configs,
req_status.group_states,
blocks.blocks,
):
gpu_block_size = group_config.gpu_block_size
offloaded_block_size = group_config.offloaded_block_size
offload_keys = group_state.offload_keys
num_gpu_blocks = cdiv(num_cached_tokens, gpu_block_size)
assert len(group_blocks) >= num_gpu_blocks
num_locally_computed_gpu_blocks = num_gpu_blocks
# Skip null placeholder blocks (used for sliding window or mamba padding).
for i, block in enumerate(group_blocks[:num_gpu_blocks]):
if not block.is_null and block.block_hash is None:
num_locally_computed_gpu_blocks = i
break
assert (
num_locally_computed_tokens
<= num_locally_computed_gpu_blocks * gpu_block_size
)
num_pending_gpu_blocks = num_gpu_blocks - num_locally_computed_gpu_blocks
num_blocks = cdiv(num_cached_tokens, offloaded_block_size)
assert len(offload_keys) >= num_blocks
if num_pending_gpu_blocks:
start_block_idx = (
num_locally_computed_gpu_blocks // self.config.block_size_factor
)
keys_to_load.extend(offload_keys[start_block_idx:num_blocks])
dst_block_ids.extend(
block.block_id
for block in group_blocks[
num_locally_computed_gpu_blocks:num_gpu_blocks
]
)
group_sizes.append(num_pending_gpu_blocks)
block_indices.append(num_locally_computed_gpu_blocks)
if not do_remote_decode:
# For P/D prefill requests (do_remote_decode=True), we do
# NOT skip saving the hit prefix, as we need to stream the
# entire KV cache so a remote decode node can consume it.
group_state.next_stored_block_idx = num_blocks
# Fence dst blocks against finished-request pending stores.
if (
self._block_id_to_pending_jobs
and not self._block_id_to_pending_jobs.keys().isdisjoint(dst_block_ids)
):
self._current_batch_jobs_to_flush.update(
jid
for bid in dst_block_ids
for jid in self._block_id_to_pending_jobs.get(bid, ())
)
src_spec = self.manager.prepare_load(keys_to_load, req_status.req_context)
dst_spec = GPULoadStoreSpec(
dst_block_ids, group_sizes=group_sizes, block_indices=block_indices
)
load_job_id = self._generate_job_id()
self._current_batch_load_jobs[load_job_id] = TransferJob(
req_id=request.request_id,
transfer_spec=(src_spec, dst_spec),
)
# a load can only be issued when no other jobs are pending.
assert not req_status.transfer_jobs
req_status.transfer_jobs.add(load_job_id)
self._jobs[load_job_id] = TransferJobStatus(
req_id=request.request_id,
pending_count=self.config.num_workers,
keys=set(keys_to_load),
is_store=False,
)
if self._blocks_being_loaded is not None:
self._blocks_being_loaded.update(keys_to_load)
def _build_store_jobs(
self,
scheduler_output: SchedulerOutput,
) -> dict[int, TransferJob]:
block_size_factor = self.config.block_size_factor
store_jobs: dict[int, TransferJob] = {}
# iterate over both new and cached requests
for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
req_status = self._req_status[req_id]
req_status.update_offload_keys()
req = req_status.req
if preempted:
for group_state in req_status.group_states:
group_state.block_ids.clear()
if new_block_id_groups:
req_status.update_block_id_groups(new_block_id_groups)
# Fence new blocks against in-flight stores.
if self._block_id_to_pending_jobs:
new_blocks_flat = [
bid for new_blocks in new_block_id_groups for bid in new_blocks
]
if not self._block_id_to_pending_jobs.keys().isdisjoint(
new_blocks_flat
):
self._current_batch_jobs_to_flush.update(
jid
for bid in new_blocks_flat
for jid in self._block_id_to_pending_jobs.get(bid, ())
)
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_tokens_after_batch = req.num_computed_tokens + num_scheduled_tokens
# with async scheduling, some tokens may be missing
num_offloadable_tokens = min(num_tokens_after_batch, req.num_tokens)
# Filter out blocks skipped due to sliding window attention / SSM
new_offload_keys: list[OffloadKey] = []
for group_config, group_state in zip(
self.config.kv_group_configs, req_status.group_states
):
num_blocks = num_offloadable_tokens // group_config.offloaded_block_size
start_block_idx = group_state.next_stored_block_idx
if num_blocks <= start_block_idx:
continue
offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
# For each block to offload, take the last corresponding GPU block.
# e.g. if block size factor is 3 and GPU block IDs are
# 1 5 6 7 2 4 9 3 8 then we'll take blocks 6 4 8.
# We will use these GPU blocks to determine if the block needs
# offloading, or (if the GPU block ID is 0) this block should
# be skipped due to sliding window attention / SSM.
# We know that if a block is skipped, then all the previous blocks
# are skipped as well. This is why we take the last of each block.
offload_block_ids = group_state.block_ids[
start_block_idx * block_size_factor
+ block_size_factor
- 1 : num_blocks * block_size_factor : block_size_factor
]
assert len(offload_keys) == len(offload_block_ids)
for offload_key, block_id in zip(offload_keys, offload_block_ids):
if block_id != 0:
new_offload_keys.append(offload_key)
if not new_offload_keys:
req_status.advance_stored_idx(num_offloadable_tokens)
continue
store_output = self.manager.prepare_store(
new_offload_keys, req_status.req_context
)
if store_output is None:
logger.warning("Request %s: cannot store blocks", req_id)
continue
if not store_output.keys_to_store:
req_status.advance_stored_idx(num_offloadable_tokens)
continue
for group_state in req_status.group_states:
self.manager.touch(group_state.offload_keys)
keys_to_store = set(store_output.keys_to_store)
group_sizes: list[int] = []
block_indices: list[int] = []
src_block_ids: list[int] = []
for group_config, group_state in zip(
self.config.kv_group_configs, req_status.group_states
):
num_blocks = num_offloadable_tokens // group_config.offloaded_block_size
start_block_idx = group_state.next_stored_block_idx
block_ids = group_state.block_ids
num_group_blocks = 0
start_gpu_block_idx: int | None = None
for idx, offload_key in enumerate(
group_state.offload_keys[start_block_idx:num_blocks]
):
if offload_key not in keys_to_store:
continue
offloaded_block_idx = start_block_idx + idx
gpu_block_idx = offloaded_block_idx * block_size_factor
num_group_blocks += block_size_factor
for i in range(block_size_factor):
block_id = block_ids[gpu_block_idx + i]
if block_id == 0:
# skipped blocks cannot appear after non-skipped blocks
assert start_gpu_block_idx is None
continue
elif start_gpu_block_idx is None:
start_gpu_block_idx = gpu_block_idx + i
src_block_ids.append(block_id)
group_sizes.append(num_group_blocks)
block_indices.append(start_gpu_block_idx or 0)
group_state.next_stored_block_idx = num_blocks
src_spec = GPULoadStoreSpec(
src_block_ids, group_sizes=group_sizes, block_indices=block_indices
)
dst_spec = store_output.store_spec
job_id = self._generate_job_id()
# a store can only be issued when no load is pending.
if req_status.transfer_jobs:
any_jid = next(iter(req_status.transfer_jobs))
assert self._jobs[any_jid].is_store
req_status.transfer_jobs.add(job_id)
self._jobs[job_id] = TransferJobStatus(
req_id=req_id,
pending_count=self.config.num_workers,
keys=set(keys_to_store),
is_store=True,
gpu_block_ids=src_block_ids,
)
store_jobs[job_id] = TransferJob(
req_id=req_id, transfer_spec=(src_spec, dst_spec)
)
logger.debug(
"Request %s offloading %s blocks upto %d tokens (job %d)",
req_id,
len(keys_to_store),
num_offloadable_tokens,
job_id,
)
return store_jobs
def build_connector_meta(
self, scheduler_output: SchedulerOutput
) -> KVConnectorMetadata:
for req_id in scheduler_output.preempted_req_ids or ():
req_status = self._req_status.get(req_id)
if req_status is None or not req_status.transfer_jobs:
continue
any_jid = next(iter(req_status.transfer_jobs))
assert self._jobs[any_jid].is_store
self._current_batch_jobs_to_flush.update(req_status.transfer_jobs)
meta = OffloadingConnectorMetadata(
load_jobs=self._current_batch_load_jobs,
store_jobs=self._build_store_jobs(scheduler_output),
jobs_to_flush=self._current_batch_jobs_to_flush,
)
self._current_batch_load_jobs = {}
self._current_batch_jobs_to_flush = set()
return meta
def update_connector_output(self, connector_output: KVConnectorOutput):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
meta = connector_output.kv_connector_worker_meta
if not isinstance(meta, OffloadingWorkerMetadata):
assert meta is None
meta = OffloadingWorkerMetadata()
for job_id, count in meta.completed_jobs.items():
assert count > 0
job_status = self._jobs[job_id]
job_status.pending_count -= count
if job_status.pending_count > 0:
continue
assert job_status.pending_count == 0
if job_status.is_store:
self.manager.complete_store(job_status.keys)
else:
self.manager.complete_load(job_status.keys)
if self._blocks_being_loaded:
self._blocks_being_loaded.difference_update(job_status.keys)
req_status = self._req_status[job_status.req_id]
if self._block_id_to_pending_jobs and req_status.req.is_finished():
for bid in job_status.gpu_block_ids or ():
pending = self._block_id_to_pending_jobs[bid]
pending.remove(job_id)
if not pending:
del self._block_id_to_pending_jobs[bid]
del self._jobs[job_id]
req_status.transfer_jobs.remove(job_id)
if not req_status.transfer_jobs and req_status.req.is_finished():
del self._req_status[job_status.req_id]
def request_finished(
self,
request: Request,
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called when a request has finished, before its blocks are freed.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
# TODO(orozery): possibly kickoff offload for last block
# which may have been deferred due to async scheduling
req_status = self._req_status.get(request.request_id)
if req_status is None:
return False, None
if not req_status.transfer_jobs:
del self._req_status[request.request_id]
return False, None
# Pending stores will outlive the request's block ownership.
# Register them so future block reuse triggers a flush.
for job_id in req_status.transfer_jobs:
job_status = self._jobs[job_id]
for bid in job_status.gpu_block_ids or ():
self._block_id_to_pending_jobs.setdefault(bid, set()).add(job_id)
return False, None
def take_events(self) -> Iterable[KVCacheEvent]:
"""Take the KV cache events from the connector.
Returns:
A list of KV cache events.
"""
for event in self.manager.take_events():
block_hashes = [get_offload_block_hash(key) for key in event.keys]
if event.removed:
yield BlockRemoved(block_hashes=block_hashes, medium=event.medium)
else:
yield BlockStored(
block_hashes=block_hashes,
parent_block_hash=None,
token_ids=[],
lora_id=None,
block_size=0,
medium=event.medium,
lora_name=None,
)
def shutdown(self) -> None:
self.manager.shutdown()