divisor.variant
Variation noise functions for denoising process.
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3 4"""Variation noise functions for denoising process.""" 5 6import math 7 8import torch 9from torch import Tensor 10 11from divisor.controller import variation_rng 12 13 14def mix_noise(from_noise: Tensor, to_noise: Tensor, strength: float, variation_method: str = "linear") -> Tensor: 15 """Mix two noise tensors using specified method.\n 16 :param from_noise: Source noise tensor 17 :param to_noise: Target noise tensor to mix towards 18 :param strength: Mixing strength (0.0 to 1.0) 19 :param variation_method: Mixing method ('linear' or 'slerp') 20 :returns: Mixed noise tensor 21 """ 22 to_noise = to_noise.to(from_noise.device) 23 24 if variation_method == "slerp": 25 # Spherical linear interpolation 26 # Flatten for norm calculation (works with any tensor shape) 27 from_flat = from_noise.flatten(start_dim=1) 28 to_flat = to_noise.flatten(start_dim=1) 29 30 from_norm = torch.norm(from_flat, dim=1, keepdim=True) 31 to_norm = torch.norm(to_flat, dim=1, keepdim=True) 32 33 # Normalize 34 from_unit = from_flat / (from_norm + 1e-8) 35 to_unit = to_flat / (to_norm + 1e-8) 36 37 # Dot product for angle 38 dot = (from_unit * to_unit).sum(dim=1, keepdim=True) 39 dot = torch.clamp(dot, -1.0, 1.0) 40 theta = torch.acos(dot) 41 42 # Slerp formula 43 sin_theta = torch.sin(theta) 44 w1 = torch.sin((1 - strength) * theta) / (sin_theta + 1e-8) 45 w2 = torch.sin(strength * theta) / (sin_theta + 1e-8) 46 47 # Apply weights and reshape back 48 mixed_flat = w1 * from_flat + w2 * to_flat 49 mixed_noise = mixed_flat.reshape(from_noise.shape) 50 else: 51 # Linear interpolation 52 mixed_noise = (1 - strength) * from_noise + strength * to_noise 53 # Scale factor correction for variance preservation 54 scale_factor = math.sqrt((1 - strength) ** 2 + strength**2) 55 mixed_noise = mixed_noise / (scale_factor + 1e-8) 56 57 return mixed_noise 58 59 60def apply_variation_noise( 61 latent_sample: Tensor, 62 variation_seed: int | None, 63 variation_strength: float, 64 mask: Tensor | None = None, 65 variation_method: str = "linear", 66) -> Tensor: 67 """Apply variation noise to the latent sample.\n 68 :param latent_sample: Current sample tensor in 3D sequence format [batch, sequence, features] 69 :param variation_seed: Seed for variation noise generation, or None to disable 70 :param variation_strength: Strength of variation (0.0 to 1.0) 71 :param mask: Optional mask tensor for selective application 72 :param variation_method: Mixing method ('linear' or 'slerp') 73 :returns: Sample with variation noise applied 74 """ 75 if variation_seed is None or variation_strength == 0.0: 76 return latent_sample 77 78 # Set seed for variation noise generation 79 if variation_seed is not None: 80 variation_rng.next_seed(variation_seed) 81 else: 82 variation_seed = variation_rng.next_seed() 83 84 # Get generator and its device 85 variation_generator = variation_rng._torch_generator 86 generator_device = variation_generator.device if variation_generator is not None else torch.device("cpu") 87 88 # Generate variation noise matching the sample shape 89 # Create on generator's device first (required for MPS compatibility) 90 variation_noise = torch.randn( 91 latent_sample.shape, 92 dtype=latent_sample.dtype, 93 layout=latent_sample.layout, 94 generator=variation_generator, 95 device=generator_device, 96 ) 97 98 # Move to sample's device if different 99 if generator_device != latent_sample.device: 100 variation_noise = variation_noise.to(latent_sample.device) 101 102 if mask is None: 103 # Simple mixing without mask 104 result = mix_noise(latent_sample, variation_noise, variation_strength, variation_method) 105 else: 106 # Apply mask: mask=1 uses mixed noise, mask=0 uses original 107 mixed_noise_result = mix_noise(latent_sample, variation_noise, variation_strength, variation_method) 108 result = (mask == 1).float() * mixed_noise_result + (mask == 0).float() * latent_sample 109 110 return result
def
mix_noise( from_noise: torch.Tensor, to_noise: torch.Tensor, strength: float, variation_method: str = 'linear') -> torch.Tensor:
15def mix_noise(from_noise: Tensor, to_noise: Tensor, strength: float, variation_method: str = "linear") -> Tensor: 16 """Mix two noise tensors using specified method.\n 17 :param from_noise: Source noise tensor 18 :param to_noise: Target noise tensor to mix towards 19 :param strength: Mixing strength (0.0 to 1.0) 20 :param variation_method: Mixing method ('linear' or 'slerp') 21 :returns: Mixed noise tensor 22 """ 23 to_noise = to_noise.to(from_noise.device) 24 25 if variation_method == "slerp": 26 # Spherical linear interpolation 27 # Flatten for norm calculation (works with any tensor shape) 28 from_flat = from_noise.flatten(start_dim=1) 29 to_flat = to_noise.flatten(start_dim=1) 30 31 from_norm = torch.norm(from_flat, dim=1, keepdim=True) 32 to_norm = torch.norm(to_flat, dim=1, keepdim=True) 33 34 # Normalize 35 from_unit = from_flat / (from_norm + 1e-8) 36 to_unit = to_flat / (to_norm + 1e-8) 37 38 # Dot product for angle 39 dot = (from_unit * to_unit).sum(dim=1, keepdim=True) 40 dot = torch.clamp(dot, -1.0, 1.0) 41 theta = torch.acos(dot) 42 43 # Slerp formula 44 sin_theta = torch.sin(theta) 45 w1 = torch.sin((1 - strength) * theta) / (sin_theta + 1e-8) 46 w2 = torch.sin(strength * theta) / (sin_theta + 1e-8) 47 48 # Apply weights and reshape back 49 mixed_flat = w1 * from_flat + w2 * to_flat 50 mixed_noise = mixed_flat.reshape(from_noise.shape) 51 else: 52 # Linear interpolation 53 mixed_noise = (1 - strength) * from_noise + strength * to_noise 54 # Scale factor correction for variance preservation 55 scale_factor = math.sqrt((1 - strength) ** 2 + strength**2) 56 mixed_noise = mixed_noise / (scale_factor + 1e-8) 57 58 return mixed_noise
Mix two noise tensors using specified method.
Parameters
- from_noise: Source noise tensor
- to_noise: Target noise tensor to mix towards
- strength: Mixing strength (0.0 to 1.0)
- variation_method: Mixing method ('linear' or 'slerp') :returns: Mixed noise tensor
def
apply_variation_noise( latent_sample: torch.Tensor, variation_seed: int | None, variation_strength: float, mask: torch.Tensor | None = None, variation_method: str = 'linear') -> torch.Tensor:
61def apply_variation_noise( 62 latent_sample: Tensor, 63 variation_seed: int | None, 64 variation_strength: float, 65 mask: Tensor | None = None, 66 variation_method: str = "linear", 67) -> Tensor: 68 """Apply variation noise to the latent sample.\n 69 :param latent_sample: Current sample tensor in 3D sequence format [batch, sequence, features] 70 :param variation_seed: Seed for variation noise generation, or None to disable 71 :param variation_strength: Strength of variation (0.0 to 1.0) 72 :param mask: Optional mask tensor for selective application 73 :param variation_method: Mixing method ('linear' or 'slerp') 74 :returns: Sample with variation noise applied 75 """ 76 if variation_seed is None or variation_strength == 0.0: 77 return latent_sample 78 79 # Set seed for variation noise generation 80 if variation_seed is not None: 81 variation_rng.next_seed(variation_seed) 82 else: 83 variation_seed = variation_rng.next_seed() 84 85 # Get generator and its device 86 variation_generator = variation_rng._torch_generator 87 generator_device = variation_generator.device if variation_generator is not None else torch.device("cpu") 88 89 # Generate variation noise matching the sample shape 90 # Create on generator's device first (required for MPS compatibility) 91 variation_noise = torch.randn( 92 latent_sample.shape, 93 dtype=latent_sample.dtype, 94 layout=latent_sample.layout, 95 generator=variation_generator, 96 device=generator_device, 97 ) 98 99 # Move to sample's device if different 100 if generator_device != latent_sample.device: 101 variation_noise = variation_noise.to(latent_sample.device) 102 103 if mask is None: 104 # Simple mixing without mask 105 result = mix_noise(latent_sample, variation_noise, variation_strength, variation_method) 106 else: 107 # Apply mask: mask=1 uses mixed noise, mask=0 uses original 108 mixed_noise_result = mix_noise(latent_sample, variation_noise, variation_strength, variation_method) 109 result = (mask == 1).float() * mixed_noise_result + (mask == 0).float() * latent_sample 110 111 return result
Apply variation noise to the latent sample.
Parameters
- latent_sample: Current sample tensor in 3D sequence format [batch, sequence, features]
- variation_seed: Seed for variation noise generation, or None to disable
- variation_strength: Strength of variation (0.0 to 1.0)
- mask: Optional mask tensor for selective application
- variation_method: Mixing method ('linear' or 'slerp') :returns: Sample with variation noise applied