divisor.acestep.apg_guidance

  1# SPDX-License-Identifier:Apache-2.0
  2# code from https://github.com/ace-step/ACE-Step
  3
  4import torch
  5
  6
  7class MomentumBuffer:
  8    def __init__(self, momentum: float = -0.75):
  9        self.momentum = momentum
 10        self.running_average = 0
 11
 12    def update(self, update_value: torch.Tensor):
 13        new_average = self.momentum * self.running_average
 14        self.running_average = update_value + new_average
 15
 16
 17def project(
 18    v0: torch.Tensor,  # [B, C, H, W]
 19    v1: torch.Tensor,  # [B, C, H, W]
 20    dims=[-1, -2],
 21):
 22    dtype = v0.dtype
 23    device_type = v0.device.type
 24    if device_type == "mps":
 25        v0, v1 = v0.cpu(), v1.cpu()
 26
 27    v0, v1 = v0.double(), v1.double()
 28    v1 = torch.nn.functional.normalize(v1, dim=dims)
 29    v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
 30    v0_orthogonal = v0 - v0_parallel
 31    return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type)
 32
 33
 34def apg_forward(
 35    pred_cond: torch.Tensor,  # [B, C, H, W]
 36    pred_uncond: torch.Tensor,  # [B, C, H, W]
 37    guidance_scale: float,
 38    momentum_buffer: MomentumBuffer = None,
 39    eta: float = 0.0,
 40    norm_threshold: float = 2.5,
 41    dims=[-1, -2],
 42):
 43    diff = pred_cond - pred_uncond
 44    if momentum_buffer is not None:
 45        momentum_buffer.update(diff)
 46        diff = momentum_buffer.running_average
 47
 48    if norm_threshold > 0:
 49        ones = torch.ones_like(diff)
 50        diff_norm = diff.norm(p=2, dim=dims, keepdim=True)
 51        scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
 52        diff = diff * scale_factor
 53
 54    diff_parallel, diff_orthogonal = project(diff, pred_cond, dims)
 55    normalized_update = diff_orthogonal + eta * diff_parallel
 56    pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
 57    return pred_guided
 58
 59
 60def cfg_forward(cond_output, uncond_output, cfg_strength):
 61    return uncond_output + cfg_strength * (cond_output - uncond_output)
 62
 63
 64def cfg_double_condition_forward(
 65    cond_output,
 66    uncond_output,
 67    only_text_cond_output,
 68    guidance_scale_text,
 69    guidance_scale_lyric,
 70):
 71    return (1 - guidance_scale_text) * uncond_output + (guidance_scale_text - guidance_scale_lyric) * only_text_cond_output + guidance_scale_lyric * cond_output
 72
 73
 74def optimized_scale(positive_flat, negative_flat):
 75    # Calculate dot production
 76    dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
 77
 78    # Squared norm of uncondition
 79    squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
 80
 81    # st_star = v_cond^T * v_uncond / ||v_uncond||^2
 82    st_star = dot_product / squared_norm
 83
 84    return st_star
 85
 86
 87def cfg_zero_star(
 88    noise_pred_with_cond,
 89    noise_pred_uncond,
 90    guidance_scale,
 91    i,
 92    zero_steps=1,
 93    use_zero_init=True,
 94):
 95    bsz = noise_pred_with_cond.shape[0]
 96    positive_flat = noise_pred_with_cond.view(bsz, -1)
 97    negative_flat = noise_pred_uncond.view(bsz, -1)
 98    alpha = optimized_scale(positive_flat, negative_flat)
 99    alpha = alpha.view(bsz, 1, 1, 1)
100    if (i <= zero_steps) and use_zero_init:
101        noise_pred = noise_pred_with_cond * 0.0
102    else:
103        noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_with_cond - noise_pred_uncond * alpha)
104    return noise_pred
class MomentumBuffer:
 8class MomentumBuffer:
 9    def __init__(self, momentum: float = -0.75):
10        self.momentum = momentum
11        self.running_average = 0
12
13    def update(self, update_value: torch.Tensor):
14        new_average = self.momentum * self.running_average
15        self.running_average = update_value + new_average
MomentumBuffer(momentum: float = -0.75)
 9    def __init__(self, momentum: float = -0.75):
10        self.momentum = momentum
11        self.running_average = 0
momentum
running_average
def update(self, update_value: torch.Tensor):
13    def update(self, update_value: torch.Tensor):
14        new_average = self.momentum * self.running_average
15        self.running_average = update_value + new_average
def project(v0: torch.Tensor, v1: torch.Tensor, dims=[-1, -2]):
18def project(
19    v0: torch.Tensor,  # [B, C, H, W]
20    v1: torch.Tensor,  # [B, C, H, W]
21    dims=[-1, -2],
22):
23    dtype = v0.dtype
24    device_type = v0.device.type
25    if device_type == "mps":
26        v0, v1 = v0.cpu(), v1.cpu()
27
28    v0, v1 = v0.double(), v1.double()
29    v1 = torch.nn.functional.normalize(v1, dim=dims)
30    v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
31    v0_orthogonal = v0 - v0_parallel
32    return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type)
def apg_forward( pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, momentum_buffer: MomentumBuffer = None, eta: float = 0.0, norm_threshold: float = 2.5, dims=[-1, -2]):
35def apg_forward(
36    pred_cond: torch.Tensor,  # [B, C, H, W]
37    pred_uncond: torch.Tensor,  # [B, C, H, W]
38    guidance_scale: float,
39    momentum_buffer: MomentumBuffer = None,
40    eta: float = 0.0,
41    norm_threshold: float = 2.5,
42    dims=[-1, -2],
43):
44    diff = pred_cond - pred_uncond
45    if momentum_buffer is not None:
46        momentum_buffer.update(diff)
47        diff = momentum_buffer.running_average
48
49    if norm_threshold > 0:
50        ones = torch.ones_like(diff)
51        diff_norm = diff.norm(p=2, dim=dims, keepdim=True)
52        scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
53        diff = diff * scale_factor
54
55    diff_parallel, diff_orthogonal = project(diff, pred_cond, dims)
56    normalized_update = diff_orthogonal + eta * diff_parallel
57    pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
58    return pred_guided
def cfg_forward(cond_output, uncond_output, cfg_strength):
61def cfg_forward(cond_output, uncond_output, cfg_strength):
62    return uncond_output + cfg_strength * (cond_output - uncond_output)
def cfg_double_condition_forward( cond_output, uncond_output, only_text_cond_output, guidance_scale_text, guidance_scale_lyric):
65def cfg_double_condition_forward(
66    cond_output,
67    uncond_output,
68    only_text_cond_output,
69    guidance_scale_text,
70    guidance_scale_lyric,
71):
72    return (1 - guidance_scale_text) * uncond_output + (guidance_scale_text - guidance_scale_lyric) * only_text_cond_output + guidance_scale_lyric * cond_output
def optimized_scale(positive_flat, negative_flat):
75def optimized_scale(positive_flat, negative_flat):
76    # Calculate dot production
77    dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
78
79    # Squared norm of uncondition
80    squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
81
82    # st_star = v_cond^T * v_uncond / ||v_uncond||^2
83    st_star = dot_product / squared_norm
84
85    return st_star
def cfg_zero_star( noise_pred_with_cond, noise_pred_uncond, guidance_scale, i, zero_steps=1, use_zero_init=True):
 88def cfg_zero_star(
 89    noise_pred_with_cond,
 90    noise_pred_uncond,
 91    guidance_scale,
 92    i,
 93    zero_steps=1,
 94    use_zero_init=True,
 95):
 96    bsz = noise_pred_with_cond.shape[0]
 97    positive_flat = noise_pred_with_cond.view(bsz, -1)
 98    negative_flat = noise_pred_uncond.view(bsz, -1)
 99    alpha = optimized_scale(positive_flat, negative_flat)
100    alpha = alpha.view(bsz, 1, 1, 1)
101    if (i <= zero_steps) and use_zero_init:
102        noise_pred = noise_pred_with_cond * 0.0
103    else:
104        noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_with_cond - noise_pred_uncond * alpha)
105    return noise_pred