divisor.acestep.cpu_offload

 1# SPDX-License-Identifier:Apache-2.0
 2# adapted from https://github.com/ace-step/ACE-Step
 3
 4import torch
 5import functools
 6from typing import Callable, TypeVar
 7from divisor.registry import gfx_sync, empty_cache
 8
 9
10class CpuOffloader:
11    def __init__(self, model, device="cpu"):
12        self.model = model
13        self.original_device = device
14        self.original_dtype = model.dtype
15
16    def __enter__(self):
17        if not hasattr(self.model, "torchao_quantized"):
18            self.model.to(self.original_device, dtype=self.original_dtype)
19        return self.model
20
21    def __exit__(self, *args):
22        if not hasattr(self.model, "torchao_quantized"):
23            self.model.to("cpu")
24        gfx_sync
25        empty_cache
26
27
28T = TypeVar("T")
29
30
31def cpu_offload(model_attr: str):
32    def decorator(func: Callable[..., T]) -> Callable[..., T]:
33        @functools.wraps(func)
34        def wrapper(self, *args, **kwargs):
35            if not self.cpu_offload:
36                return func(self, *args, **kwargs)
37
38            # Get the device from the class
39            device = self.device
40            # Get the model from the class attribute
41            model = getattr(self, model_attr)
42
43            with CpuOffloader(model, device):
44                return func(self, *args, **kwargs)
45
46        return wrapper
47
48    return decorator
class CpuOffloader:
11class CpuOffloader:
12    def __init__(self, model, device="cpu"):
13        self.model = model
14        self.original_device = device
15        self.original_dtype = model.dtype
16
17    def __enter__(self):
18        if not hasattr(self.model, "torchao_quantized"):
19            self.model.to(self.original_device, dtype=self.original_dtype)
20        return self.model
21
22    def __exit__(self, *args):
23        if not hasattr(self.model, "torchao_quantized"):
24            self.model.to("cpu")
25        gfx_sync
26        empty_cache
CpuOffloader(model, device='cpu')
12    def __init__(self, model, device="cpu"):
13        self.model = model
14        self.original_device = device
15        self.original_dtype = model.dtype
model
original_device
original_dtype
def cpu_offload(model_attr: str):
32def cpu_offload(model_attr: str):
33    def decorator(func: Callable[..., T]) -> Callable[..., T]:
34        @functools.wraps(func)
35        def wrapper(self, *args, **kwargs):
36            if not self.cpu_offload:
37                return func(self, *args, **kwargs)
38
39            # Get the device from the class
40            device = self.device
41            # Get the model from the class attribute
42            model = getattr(self, model_attr)
43
44            with CpuOffloader(model, device):
45                return func(self, *args, **kwargs)
46
47        return wrapper
48
49    return decorator