divisor.acestep.apg_guidance
1# SPDX-License-Identifier:Apache-2.0 2# code from https://github.com/ace-step/ACE-Step 3 4import torch 5 6 7class MomentumBuffer: 8 def __init__(self, momentum: float = -0.75): 9 self.momentum = momentum 10 self.running_average = 0 11 12 def update(self, update_value: torch.Tensor): 13 new_average = self.momentum * self.running_average 14 self.running_average = update_value + new_average 15 16 17def project( 18 v0: torch.Tensor, # [B, C, H, W] 19 v1: torch.Tensor, # [B, C, H, W] 20 dims=[-1, -2], 21): 22 dtype = v0.dtype 23 device_type = v0.device.type 24 if device_type == "mps": 25 v0, v1 = v0.cpu(), v1.cpu() 26 27 v0, v1 = v0.double(), v1.double() 28 v1 = torch.nn.functional.normalize(v1, dim=dims) 29 v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1 30 v0_orthogonal = v0 - v0_parallel 31 return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type) 32 33 34def apg_forward( 35 pred_cond: torch.Tensor, # [B, C, H, W] 36 pred_uncond: torch.Tensor, # [B, C, H, W] 37 guidance_scale: float, 38 momentum_buffer: MomentumBuffer = None, 39 eta: float = 0.0, 40 norm_threshold: float = 2.5, 41 dims=[-1, -2], 42): 43 diff = pred_cond - pred_uncond 44 if momentum_buffer is not None: 45 momentum_buffer.update(diff) 46 diff = momentum_buffer.running_average 47 48 if norm_threshold > 0: 49 ones = torch.ones_like(diff) 50 diff_norm = diff.norm(p=2, dim=dims, keepdim=True) 51 scale_factor = torch.minimum(ones, norm_threshold / diff_norm) 52 diff = diff * scale_factor 53 54 diff_parallel, diff_orthogonal = project(diff, pred_cond, dims) 55 normalized_update = diff_orthogonal + eta * diff_parallel 56 pred_guided = pred_cond + (guidance_scale - 1) * normalized_update 57 return pred_guided 58 59 60def cfg_forward(cond_output, uncond_output, cfg_strength): 61 return uncond_output + cfg_strength * (cond_output - uncond_output) 62 63 64def cfg_double_condition_forward( 65 cond_output, 66 uncond_output, 67 only_text_cond_output, 68 guidance_scale_text, 69 guidance_scale_lyric, 70): 71 return (1 - guidance_scale_text) * uncond_output + (guidance_scale_text - guidance_scale_lyric) * only_text_cond_output + guidance_scale_lyric * cond_output 72 73 74def optimized_scale(positive_flat, negative_flat): 75 # Calculate dot production 76 dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) 77 78 # Squared norm of uncondition 79 squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 80 81 # st_star = v_cond^T * v_uncond / ||v_uncond||^2 82 st_star = dot_product / squared_norm 83 84 return st_star 85 86 87def cfg_zero_star( 88 noise_pred_with_cond, 89 noise_pred_uncond, 90 guidance_scale, 91 i, 92 zero_steps=1, 93 use_zero_init=True, 94): 95 bsz = noise_pred_with_cond.shape[0] 96 positive_flat = noise_pred_with_cond.view(bsz, -1) 97 negative_flat = noise_pred_uncond.view(bsz, -1) 98 alpha = optimized_scale(positive_flat, negative_flat) 99 alpha = alpha.view(bsz, 1, 1, 1) 100 if (i <= zero_steps) and use_zero_init: 101 noise_pred = noise_pred_with_cond * 0.0 102 else: 103 noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_with_cond - noise_pred_uncond * alpha) 104 return noise_pred
class
MomentumBuffer:
8class MomentumBuffer: 9 def __init__(self, momentum: float = -0.75): 10 self.momentum = momentum 11 self.running_average = 0 12 13 def update(self, update_value: torch.Tensor): 14 new_average = self.momentum * self.running_average 15 self.running_average = update_value + new_average
def
project(v0: torch.Tensor, v1: torch.Tensor, dims=[-1, -2]):
18def project( 19 v0: torch.Tensor, # [B, C, H, W] 20 v1: torch.Tensor, # [B, C, H, W] 21 dims=[-1, -2], 22): 23 dtype = v0.dtype 24 device_type = v0.device.type 25 if device_type == "mps": 26 v0, v1 = v0.cpu(), v1.cpu() 27 28 v0, v1 = v0.double(), v1.double() 29 v1 = torch.nn.functional.normalize(v1, dim=dims) 30 v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1 31 v0_orthogonal = v0 - v0_parallel 32 return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type)
def
apg_forward( pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, momentum_buffer: MomentumBuffer = None, eta: float = 0.0, norm_threshold: float = 2.5, dims=[-1, -2]):
35def apg_forward( 36 pred_cond: torch.Tensor, # [B, C, H, W] 37 pred_uncond: torch.Tensor, # [B, C, H, W] 38 guidance_scale: float, 39 momentum_buffer: MomentumBuffer = None, 40 eta: float = 0.0, 41 norm_threshold: float = 2.5, 42 dims=[-1, -2], 43): 44 diff = pred_cond - pred_uncond 45 if momentum_buffer is not None: 46 momentum_buffer.update(diff) 47 diff = momentum_buffer.running_average 48 49 if norm_threshold > 0: 50 ones = torch.ones_like(diff) 51 diff_norm = diff.norm(p=2, dim=dims, keepdim=True) 52 scale_factor = torch.minimum(ones, norm_threshold / diff_norm) 53 diff = diff * scale_factor 54 55 diff_parallel, diff_orthogonal = project(diff, pred_cond, dims) 56 normalized_update = diff_orthogonal + eta * diff_parallel 57 pred_guided = pred_cond + (guidance_scale - 1) * normalized_update 58 return pred_guided
def
cfg_forward(cond_output, uncond_output, cfg_strength):
def
cfg_double_condition_forward( cond_output, uncond_output, only_text_cond_output, guidance_scale_text, guidance_scale_lyric):
65def cfg_double_condition_forward( 66 cond_output, 67 uncond_output, 68 only_text_cond_output, 69 guidance_scale_text, 70 guidance_scale_lyric, 71): 72 return (1 - guidance_scale_text) * uncond_output + (guidance_scale_text - guidance_scale_lyric) * only_text_cond_output + guidance_scale_lyric * cond_output
def
optimized_scale(positive_flat, negative_flat):
75def optimized_scale(positive_flat, negative_flat): 76 # Calculate dot production 77 dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) 78 79 # Squared norm of uncondition 80 squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 81 82 # st_star = v_cond^T * v_uncond / ||v_uncond||^2 83 st_star = dot_product / squared_norm 84 85 return st_star
def
cfg_zero_star( noise_pred_with_cond, noise_pred_uncond, guidance_scale, i, zero_steps=1, use_zero_init=True):
88def cfg_zero_star( 89 noise_pred_with_cond, 90 noise_pred_uncond, 91 guidance_scale, 92 i, 93 zero_steps=1, 94 use_zero_init=True, 95): 96 bsz = noise_pred_with_cond.shape[0] 97 positive_flat = noise_pred_with_cond.view(bsz, -1) 98 negative_flat = noise_pred_uncond.view(bsz, -1) 99 alpha = optimized_scale(positive_flat, negative_flat) 100 alpha = alpha.view(bsz, 1, 1, 1) 101 if (i <= zero_steps) and use_zero_init: 102 noise_pred = noise_pred_with_cond * 0.0 103 else: 104 noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_with_cond - noise_pred_uncond * alpha) 105 return noise_pred