class CpuGpuOffloadingHandler(OffloadingHandler):
def __init__(
self,
gpu_block_size: int,
cpu_block_size: int,
num_cpu_blocks: int,
gpu_caches: dict[str, torch.Tensor],
attn_backends: dict[str, type[AttentionBackend]],
):
assert cpu_block_size % gpu_block_size == 0
self.block_size_factor = cpu_block_size // gpu_block_size
# cuda streams for gpu->cpu and cpu->gpu
self.d2h_stream = torch.cuda.Stream()
self.h2d_stream = torch.cuda.Stream()
# job_id -> transfer cuda event
self.transfer_events: dict[int, torch.Event] = {}
# list of cuda events available for re-use
self.events_pool: list[torch.Event] = []
pin_memory = is_pin_memory_available()
# allocate cpu tensors
logger.info("Allocating %d CPU tensors...", len(gpu_caches))
self.gpu_tensors: list[torch.Tensor] = []
self.cpu_tensors: list[torch.Tensor] = []
self.kv_dim_before_num_blocks: list[bool] = []
for layer_name, gpu_tensor in gpu_caches.items():
self.gpu_tensors.append(gpu_tensor)
gpu_shape = gpu_tensor.shape
attn_backend = attn_backends[layer_name]
test_shape = attn_backend.get_kv_cache_shape(
num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256
)
if len(gpu_shape) != len(test_shape):
# cross-layers tensor
# shape is (num_blocks, ...)
assert len(gpu_shape) == len(test_shape) + 1
num_blocks_idx = 0
self.kv_dim_before_num_blocks.append(False)
elif test_shape[0] == 1234:
# shape is (num_blocks, ...)
num_blocks_idx = 0
self.kv_dim_before_num_blocks.append(False)
else:
# shape should be (2, num_blocks, ...)
assert test_shape[0] == 2
assert test_shape[1] == 1234
assert gpu_shape[0] == 2
num_blocks_idx = 1
self.kv_dim_before_num_blocks.append(True)
cpu_shape = list(gpu_shape)
cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
self.cpu_tensors.append(
torch.zeros(
cpu_shape,
dtype=gpu_tensor.dtype,
device="cpu",
pin_memory=pin_memory,
)
)
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
src_spec, dst_spec = spec
if isinstance(src_spec, CPULoadStoreSpec):
assert isinstance(dst_spec, GPULoadStoreSpec)
stream = self.h2d_stream
src_tensors = self.cpu_tensors
dst_tensors = self.gpu_tensors
src_block_size_factor = self.block_size_factor
dst_block_size_factor = 1
else:
assert isinstance(src_spec, GPULoadStoreSpec)
assert isinstance(dst_spec, CPULoadStoreSpec)
stream = self.d2h_stream
src_tensors = self.gpu_tensors
dst_tensors = self.cpu_tensors
src_block_size_factor = 1
dst_block_size_factor = self.block_size_factor
src_blocks = src_spec.block_ids
dst_blocks = dst_spec.block_ids
assert src_blocks.ndim == 1
assert dst_blocks.ndim == 1
src_sub_block_count = src_blocks.size * src_block_size_factor
dst_sub_block_count = dst_blocks.size * dst_block_size_factor
src_sub_blocks_to_skip = -dst_blocks.size % src_block_size_factor
assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip
src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64)
expand_block_ids(
src_blocks,
src_block_size_factor,
src_to_dst[:, 0],
skip_count=src_sub_blocks_to_skip,
)
expand_block_ids(dst_blocks, dst_block_size_factor, src_to_dst[:, 1])
src_to_dst_tensor = torch.from_numpy(src_to_dst)
event = self.events_pool.pop() if self.events_pool else torch.Event()
with torch.cuda.stream(stream):
for src_tensor, dst_tensor, kv_dim in zip(
src_tensors, dst_tensors, self.kv_dim_before_num_blocks
):
if kv_dim:
src_key_cache = src_tensor[0]
dst_key_cache = dst_tensor[0]
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor)
src_value_cache = src_tensor[1]
dst_value_cache = dst_tensor[1]
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor)
else:
ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
event.record(stream)
self.transfer_events[job_id] = event
# success
return True
def get_finished(self) -> list[TransferResult]:
results: list[TransferResult] = []
for job_id, event in self.transfer_events.items():
if event.query():
results.append((job_id, True))
self.events_pool.append(event)
for job_id, _ in results:
del self.transfer_events[job_id]
return results