divisor.acestep.cpu_offload
1# SPDX-License-Identifier:Apache-2.0 2# adapted from https://github.com/ace-step/ACE-Step 3 4import torch 5import functools 6from typing import Callable, TypeVar 7from divisor.registry import gfx_sync, empty_cache 8 9 10class CpuOffloader: 11 def __init__(self, model, device="cpu"): 12 self.model = model 13 self.original_device = device 14 self.original_dtype = model.dtype 15 16 def __enter__(self): 17 if not hasattr(self.model, "torchao_quantized"): 18 self.model.to(self.original_device, dtype=self.original_dtype) 19 return self.model 20 21 def __exit__(self, *args): 22 if not hasattr(self.model, "torchao_quantized"): 23 self.model.to("cpu") 24 gfx_sync 25 empty_cache 26 27 28T = TypeVar("T") 29 30 31def cpu_offload(model_attr: str): 32 def decorator(func: Callable[..., T]) -> Callable[..., T]: 33 @functools.wraps(func) 34 def wrapper(self, *args, **kwargs): 35 if not self.cpu_offload: 36 return func(self, *args, **kwargs) 37 38 # Get the device from the class 39 device = self.device 40 # Get the model from the class attribute 41 model = getattr(self, model_attr) 42 43 with CpuOffloader(model, device): 44 return func(self, *args, **kwargs) 45 46 return wrapper 47 48 return decorator
class
CpuOffloader:
11class CpuOffloader: 12 def __init__(self, model, device="cpu"): 13 self.model = model 14 self.original_device = device 15 self.original_dtype = model.dtype 16 17 def __enter__(self): 18 if not hasattr(self.model, "torchao_quantized"): 19 self.model.to(self.original_device, dtype=self.original_dtype) 20 return self.model 21 22 def __exit__(self, *args): 23 if not hasattr(self.model, "torchao_quantized"): 24 self.model.to("cpu") 25 gfx_sync 26 empty_cache
def
cpu_offload(model_attr: str):
32def cpu_offload(model_attr: str): 33 def decorator(func: Callable[..., T]) -> Callable[..., T]: 34 @functools.wraps(func) 35 def wrapper(self, *args, **kwargs): 36 if not self.cpu_offload: 37 return func(self, *args, **kwargs) 38 39 # Get the device from the class 40 device = self.device 41 # Get the model from the class attribute 42 model = getattr(self, model_attr) 43 44 with CpuOffloader(model, device): 45 return func(self, *args, **kwargs) 46 47 return wrapper 48 49 return decorator