divisor.variant

Variation noise functions for denoising process.

  1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
  2# <!-- // /*  d a r k s h a p e s */ -->
  3
  4"""Variation noise functions for denoising process."""
  5
  6import math
  7
  8import torch
  9from torch import Tensor
 10
 11from divisor.controller import variation_rng
 12
 13
 14def mix_noise(from_noise: Tensor, to_noise: Tensor, strength: float, variation_method: str = "linear") -> Tensor:
 15    """Mix two noise tensors using specified method.\n
 16    :param from_noise: Source noise tensor
 17    :param to_noise: Target noise tensor to mix towards
 18    :param strength: Mixing strength (0.0 to 1.0)
 19    :param variation_method: Mixing method ('linear' or 'slerp')
 20    :returns: Mixed noise tensor
 21    """
 22    to_noise = to_noise.to(from_noise.device)
 23
 24    if variation_method == "slerp":
 25        # Spherical linear interpolation
 26        # Flatten for norm calculation (works with any tensor shape)
 27        from_flat = from_noise.flatten(start_dim=1)
 28        to_flat = to_noise.flatten(start_dim=1)
 29
 30        from_norm = torch.norm(from_flat, dim=1, keepdim=True)
 31        to_norm = torch.norm(to_flat, dim=1, keepdim=True)
 32
 33        # Normalize
 34        from_unit = from_flat / (from_norm + 1e-8)
 35        to_unit = to_flat / (to_norm + 1e-8)
 36
 37        # Dot product for angle
 38        dot = (from_unit * to_unit).sum(dim=1, keepdim=True)
 39        dot = torch.clamp(dot, -1.0, 1.0)
 40        theta = torch.acos(dot)
 41
 42        # Slerp formula
 43        sin_theta = torch.sin(theta)
 44        w1 = torch.sin((1 - strength) * theta) / (sin_theta + 1e-8)
 45        w2 = torch.sin(strength * theta) / (sin_theta + 1e-8)
 46
 47        # Apply weights and reshape back
 48        mixed_flat = w1 * from_flat + w2 * to_flat
 49        mixed_noise = mixed_flat.reshape(from_noise.shape)
 50    else:
 51        # Linear interpolation
 52        mixed_noise = (1 - strength) * from_noise + strength * to_noise
 53        # Scale factor correction for variance preservation
 54        scale_factor = math.sqrt((1 - strength) ** 2 + strength**2)
 55        mixed_noise = mixed_noise / (scale_factor + 1e-8)
 56
 57    return mixed_noise
 58
 59
 60def apply_variation_noise(
 61    latent_sample: Tensor,
 62    variation_seed: int | None,
 63    variation_strength: float,
 64    mask: Tensor | None = None,
 65    variation_method: str = "linear",
 66) -> Tensor:
 67    """Apply variation noise to the latent sample.\n
 68    :param latent_sample: Current sample tensor in 3D sequence format [batch, sequence, features]
 69    :param variation_seed: Seed for variation noise generation, or None to disable
 70    :param variation_strength: Strength of variation (0.0 to 1.0)
 71    :param mask: Optional mask tensor for selective application
 72    :param variation_method: Mixing method ('linear' or 'slerp')
 73    :returns: Sample with variation noise applied
 74    """
 75    if variation_seed is None or variation_strength == 0.0:
 76        return latent_sample
 77
 78    # Set seed for variation noise generation
 79    if variation_seed is not None:
 80        variation_rng.next_seed(variation_seed)
 81    else:
 82        variation_seed = variation_rng.next_seed()
 83
 84    # Get generator and its device
 85    variation_generator = variation_rng._torch_generator
 86    generator_device = variation_generator.device if variation_generator is not None else torch.device("cpu")
 87
 88    # Generate variation noise matching the sample shape
 89    # Create on generator's device first (required for MPS compatibility)
 90    variation_noise = torch.randn(
 91        latent_sample.shape,
 92        dtype=latent_sample.dtype,
 93        layout=latent_sample.layout,
 94        generator=variation_generator,
 95        device=generator_device,
 96    )
 97
 98    # Move to sample's device if different
 99    if generator_device != latent_sample.device:
100        variation_noise = variation_noise.to(latent_sample.device)
101
102    if mask is None:
103        # Simple mixing without mask
104        result = mix_noise(latent_sample, variation_noise, variation_strength, variation_method)
105    else:
106        # Apply mask: mask=1 uses mixed noise, mask=0 uses original
107        mixed_noise_result = mix_noise(latent_sample, variation_noise, variation_strength, variation_method)
108        result = (mask == 1).float() * mixed_noise_result + (mask == 0).float() * latent_sample
109
110    return result
def mix_noise( from_noise: torch.Tensor, to_noise: torch.Tensor, strength: float, variation_method: str = 'linear') -> torch.Tensor:
15def mix_noise(from_noise: Tensor, to_noise: Tensor, strength: float, variation_method: str = "linear") -> Tensor:
16    """Mix two noise tensors using specified method.\n
17    :param from_noise: Source noise tensor
18    :param to_noise: Target noise tensor to mix towards
19    :param strength: Mixing strength (0.0 to 1.0)
20    :param variation_method: Mixing method ('linear' or 'slerp')
21    :returns: Mixed noise tensor
22    """
23    to_noise = to_noise.to(from_noise.device)
24
25    if variation_method == "slerp":
26        # Spherical linear interpolation
27        # Flatten for norm calculation (works with any tensor shape)
28        from_flat = from_noise.flatten(start_dim=1)
29        to_flat = to_noise.flatten(start_dim=1)
30
31        from_norm = torch.norm(from_flat, dim=1, keepdim=True)
32        to_norm = torch.norm(to_flat, dim=1, keepdim=True)
33
34        # Normalize
35        from_unit = from_flat / (from_norm + 1e-8)
36        to_unit = to_flat / (to_norm + 1e-8)
37
38        # Dot product for angle
39        dot = (from_unit * to_unit).sum(dim=1, keepdim=True)
40        dot = torch.clamp(dot, -1.0, 1.0)
41        theta = torch.acos(dot)
42
43        # Slerp formula
44        sin_theta = torch.sin(theta)
45        w1 = torch.sin((1 - strength) * theta) / (sin_theta + 1e-8)
46        w2 = torch.sin(strength * theta) / (sin_theta + 1e-8)
47
48        # Apply weights and reshape back
49        mixed_flat = w1 * from_flat + w2 * to_flat
50        mixed_noise = mixed_flat.reshape(from_noise.shape)
51    else:
52        # Linear interpolation
53        mixed_noise = (1 - strength) * from_noise + strength * to_noise
54        # Scale factor correction for variance preservation
55        scale_factor = math.sqrt((1 - strength) ** 2 + strength**2)
56        mixed_noise = mixed_noise / (scale_factor + 1e-8)
57
58    return mixed_noise

Mix two noise tensors using specified method.

Parameters
  • from_noise: Source noise tensor
  • to_noise: Target noise tensor to mix towards
  • strength: Mixing strength (0.0 to 1.0)
  • variation_method: Mixing method ('linear' or 'slerp') :returns: Mixed noise tensor
def apply_variation_noise( latent_sample: torch.Tensor, variation_seed: int | None, variation_strength: float, mask: torch.Tensor | None = None, variation_method: str = 'linear') -> torch.Tensor:
 61def apply_variation_noise(
 62    latent_sample: Tensor,
 63    variation_seed: int | None,
 64    variation_strength: float,
 65    mask: Tensor | None = None,
 66    variation_method: str = "linear",
 67) -> Tensor:
 68    """Apply variation noise to the latent sample.\n
 69    :param latent_sample: Current sample tensor in 3D sequence format [batch, sequence, features]
 70    :param variation_seed: Seed for variation noise generation, or None to disable
 71    :param variation_strength: Strength of variation (0.0 to 1.0)
 72    :param mask: Optional mask tensor for selective application
 73    :param variation_method: Mixing method ('linear' or 'slerp')
 74    :returns: Sample with variation noise applied
 75    """
 76    if variation_seed is None or variation_strength == 0.0:
 77        return latent_sample
 78
 79    # Set seed for variation noise generation
 80    if variation_seed is not None:
 81        variation_rng.next_seed(variation_seed)
 82    else:
 83        variation_seed = variation_rng.next_seed()
 84
 85    # Get generator and its device
 86    variation_generator = variation_rng._torch_generator
 87    generator_device = variation_generator.device if variation_generator is not None else torch.device("cpu")
 88
 89    # Generate variation noise matching the sample shape
 90    # Create on generator's device first (required for MPS compatibility)
 91    variation_noise = torch.randn(
 92        latent_sample.shape,
 93        dtype=latent_sample.dtype,
 94        layout=latent_sample.layout,
 95        generator=variation_generator,
 96        device=generator_device,
 97    )
 98
 99    # Move to sample's device if different
100    if generator_device != latent_sample.device:
101        variation_noise = variation_noise.to(latent_sample.device)
102
103    if mask is None:
104        # Simple mixing without mask
105        result = mix_noise(latent_sample, variation_noise, variation_strength, variation_method)
106    else:
107        # Apply mask: mask=1 uses mixed noise, mask=0 uses original
108        mixed_noise_result = mix_noise(latent_sample, variation_noise, variation_strength, variation_method)
109        result = (mask == 1).float() * mixed_noise_result + (mask == 0).float() * latent_sample
110
111    return result

Apply variation noise to the latent sample.

Parameters
  • latent_sample: Current sample tensor in 3D sequence format [batch, sequence, features]
  • variation_seed: Seed for variation noise generation, or None to disable
  • variation_strength: Strength of variation (0.0 to 1.0)
  • mask: Optional mask tensor for selective application
  • variation_method: Mixing method ('linear' or 'slerp') :returns: Sample with variation noise applied