divisor.noise

Noise generation functions for Flux models.

  1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
  2# <!-- // /*  d a r k s h a p e s */ -->
  3
  4"""Noise generation functions for Flux models."""
  5
  6import math
  7from typing import Any, Optional
  8
  9import torch
 10from torch import Tensor
 11
 12from divisor.controller import rng
 13from divisor.registry import gfx_dtype
 14
 15
 16def get_noise(
 17    num_samples: int,
 18    height: int,
 19    width: int,
 20    seed: int,
 21    dtype: torch.dtype = torch.bfloat16,
 22    device: torch.device | None = None,
 23    version_2: bool = False,
 24) -> Tensor:
 25    """Generate noise tensor for Flux models.\n
 26    :param num_samples: Number of samples to generate
 27    :param height: Height of the image
 28    :param width: Width of the image
 29    :param dtype: Data type of the noise
 30    :param seed: Seed for the random number generator
 31    :param device: Device to generate the noise on
 32    :param model_type: Model type - "flux1" or "flux2" (default: "flux1")
 33    :returns: Noise tensor with shape appropriate for the model type
 34
 35    Flux1 shape: (num_samples, 16, 2 * ceil(height/16), 2 * ceil(width/16))\n
 36    Flux2 shape: (num_samples, 128, height // 16, width // 16)\n"""
 37
 38    generator_device = rng._torch_generator.device if rng._torch_generator is not None else torch.device("cpu")
 39    rng._torch_generator.manual_seed(seed)  # type: ignore # reset seed
 40    if version_2:
 41        shape = (num_samples, 128, height // 16, width // 16)
 42    else:
 43        shape = (num_samples, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16))
 44    noise = torch.randn(shape, dtype=dtype, generator=rng._torch_generator, device=generator_device)
 45
 46    if device is not None and generator_device != device:
 47        noise = noise.to(device)
 48    return noise
 49
 50
 51def prepare_4d_noise_for_3d_model(
 52    height: int,
 53    width: int,
 54    seed: int,
 55    t5: Optional[Any] = None,
 56    clip: Optional[Any] = None,
 57    prompt: Optional[str] = None,
 58    num_samples: int = 1,
 59    dtype: torch.dtype = torch.bfloat16,
 60    device: torch.device | None = None,
 61    version_2: bool = False,
 62) -> Tensor:
 63    """Generate noise and convert to 3D format for model input.\n
 64    Generates 4D noise tensor and converts it from (batch, channels, height, width) format to\n
 65    (batch, sequence_length, features) format based on model type.\n
 66    :param height: Height of the image
 67    :param width: Width of the image
 68    :param seed: Seed for random number generation
 69    :param t5: Optional T5 embedder instance (required for Flux1/XFlux1)
 70    :param clip: Optional CLIP embedder instance (required for Flux1/XFlux1)
 71    :param prompt: Optional prompt string (required for Flux1/XFlux1)
 72    :param num_samples: Number of samples to generate (default: 1)
 73    :param dtype: Data type of the noise (default: torch.bfloat16)
 74    :param device: Device to generate the noise on (default: None)
 75    :param version_2: Whether to use Flux2 format (default: False)
 76    :returns: 3D tensor with shape (batch, sequence_length, features)"""
 77
 78    noise_4d = get_noise(
 79        num_samples=num_samples,
 80        height=height,
 81        width=width,
 82        seed=seed,
 83        dtype=dtype,
 84        device=device,
 85        version_2=version_2,
 86    )
 87
 88    if t5 is not None and clip is not None and prompt is not None:
 89        from divisor.flux1.sampling import prepare
 90
 91        inp = prepare(t5, clip, noise_4d, prompt=prompt)
 92        return inp["img"]  # 3D format: (batch, sequence_length, features)
 93    else:
 94        from divisor.flux2.sampling import batched_prc_img
 95
 96        noise_3d, _ = batched_prc_img(noise_4d)  # 4D -> 3D: Ignore x_ids as controller doesn't need them
 97        return noise_3d
 98
 99
100def log(t, eps=1e-20):
101    return torch.log(t.clamp(min=eps))
102
103
104def gumbel_noise(t, generator=None):
105    noise = torch.zeros_like(t).uniform_(0, 1, generator=generator)
106    return -log(-log(noise))
107
108
109def gumbel_sample(t, temperature=1.0, dim=-1, generator=None):
110    return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).argmax(dim=dim)
111
112
113def add_gumbel_noise(logits, temperature):
114    """Adds Gumbel noise to logits for stochastic sampling.\n
115    Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1).\n
116    This version is more numerically stable than a version involving exp() and division."""
117    if abs(temperature) < 1e-9:  # Effectively zero temperature
118        return logits
119
120    max_device_precision = gfx_dtype
121    logits = logits.to(max_device_precision)
122    noise = torch.rand_like(logits, dtype=max_device_precision)
123    # Standard Gumbel noise: -log(-log(U)), U ~ Uniform(0,1) Add small epsilon for numerical stability inside logs
124
125    standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
126    return logits + temperature * standard_gumbel_noise
127
128
129def add_unstable_gumbel_noise(logits, temperature):
130    """The Gumbel max is a method for sampling categorical distributions.\n
131    arXiv:2409.02908 low-precision Gumbel Max improves MDM perplexity score but reduces generation quality.\n
132    Thus, we use float64... unless mps, in which case we must use float32"""
133    precision = gfx_dtype
134    if temperature == 0:
135        return logits
136    logits = logits.to(precision)
137    noise = torch.rand_like(logits, dtype=precision)
138    gumbel_noise = (-torch.log(noise)) ** temperature
139    return logits.exp() / gumbel_noise
def get_noise( num_samples: int, height: int, width: int, seed: int, dtype: torch.dtype = torch.bfloat16, device: torch.device | None = None, version_2: bool = False) -> torch.Tensor:
17def get_noise(
18    num_samples: int,
19    height: int,
20    width: int,
21    seed: int,
22    dtype: torch.dtype = torch.bfloat16,
23    device: torch.device | None = None,
24    version_2: bool = False,
25) -> Tensor:
26    """Generate noise tensor for Flux models.\n
27    :param num_samples: Number of samples to generate
28    :param height: Height of the image
29    :param width: Width of the image
30    :param dtype: Data type of the noise
31    :param seed: Seed for the random number generator
32    :param device: Device to generate the noise on
33    :param model_type: Model type - "flux1" or "flux2" (default: "flux1")
34    :returns: Noise tensor with shape appropriate for the model type
35
36    Flux1 shape: (num_samples, 16, 2 * ceil(height/16), 2 * ceil(width/16))\n
37    Flux2 shape: (num_samples, 128, height // 16, width // 16)\n"""
38
39    generator_device = rng._torch_generator.device if rng._torch_generator is not None else torch.device("cpu")
40    rng._torch_generator.manual_seed(seed)  # type: ignore # reset seed
41    if version_2:
42        shape = (num_samples, 128, height // 16, width // 16)
43    else:
44        shape = (num_samples, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16))
45    noise = torch.randn(shape, dtype=dtype, generator=rng._torch_generator, device=generator_device)
46
47    if device is not None and generator_device != device:
48        noise = noise.to(device)
49    return noise

Generate noise tensor for Flux models.

Parameters
  • num_samples: Number of samples to generate
  • height: Height of the image
  • width: Width of the image
  • dtype: Data type of the noise
  • seed: Seed for the random number generator
  • device: Device to generate the noise on
  • model_type: Model type - "flux1" or "flux2" (default: "flux1") :returns: Noise tensor with shape appropriate for the model type

Flux1 shape: (num_samples, 16, 2 * ceil(height/16), 2 * ceil(width/16))

Flux2 shape: (num_samples, 128, height // 16, width // 16)

def prepare_4d_noise_for_3d_model( height: int, width: int, seed: int, t5: Optional[Any] = None, clip: Optional[Any] = None, prompt: Optional[str] = None, num_samples: int = 1, dtype: torch.dtype = torch.bfloat16, device: torch.device | None = None, version_2: bool = False) -> torch.Tensor:
52def prepare_4d_noise_for_3d_model(
53    height: int,
54    width: int,
55    seed: int,
56    t5: Optional[Any] = None,
57    clip: Optional[Any] = None,
58    prompt: Optional[str] = None,
59    num_samples: int = 1,
60    dtype: torch.dtype = torch.bfloat16,
61    device: torch.device | None = None,
62    version_2: bool = False,
63) -> Tensor:
64    """Generate noise and convert to 3D format for model input.\n
65    Generates 4D noise tensor and converts it from (batch, channels, height, width) format to\n
66    (batch, sequence_length, features) format based on model type.\n
67    :param height: Height of the image
68    :param width: Width of the image
69    :param seed: Seed for random number generation
70    :param t5: Optional T5 embedder instance (required for Flux1/XFlux1)
71    :param clip: Optional CLIP embedder instance (required for Flux1/XFlux1)
72    :param prompt: Optional prompt string (required for Flux1/XFlux1)
73    :param num_samples: Number of samples to generate (default: 1)
74    :param dtype: Data type of the noise (default: torch.bfloat16)
75    :param device: Device to generate the noise on (default: None)
76    :param version_2: Whether to use Flux2 format (default: False)
77    :returns: 3D tensor with shape (batch, sequence_length, features)"""
78
79    noise_4d = get_noise(
80        num_samples=num_samples,
81        height=height,
82        width=width,
83        seed=seed,
84        dtype=dtype,
85        device=device,
86        version_2=version_2,
87    )
88
89    if t5 is not None and clip is not None and prompt is not None:
90        from divisor.flux1.sampling import prepare
91
92        inp = prepare(t5, clip, noise_4d, prompt=prompt)
93        return inp["img"]  # 3D format: (batch, sequence_length, features)
94    else:
95        from divisor.flux2.sampling import batched_prc_img
96
97        noise_3d, _ = batched_prc_img(noise_4d)  # 4D -> 3D: Ignore x_ids as controller doesn't need them
98        return noise_3d

Generate noise and convert to 3D format for model input.

Generates 4D noise tensor and converts it from (batch, channels, height, width) format to

(batch, sequence_length, features) format based on model type.

Parameters
  • height: Height of the image
  • width: Width of the image
  • seed: Seed for random number generation
  • t5: Optional T5 embedder instance (required for Flux1/XFlux1)
  • clip: Optional CLIP embedder instance (required for Flux1/XFlux1)
  • prompt: Optional prompt string (required for Flux1/XFlux1)
  • num_samples: Number of samples to generate (default: 1)
  • dtype: Data type of the noise (default: torch.bfloat16)
  • device: Device to generate the noise on (default: None)
  • version_2: Whether to use Flux2 format (default: False) :returns: 3D tensor with shape (batch, sequence_length, features)
def log(t, eps=1e-20):
101def log(t, eps=1e-20):
102    return torch.log(t.clamp(min=eps))
def gumbel_noise(t, generator=None):
105def gumbel_noise(t, generator=None):
106    noise = torch.zeros_like(t).uniform_(0, 1, generator=generator)
107    return -log(-log(noise))
def gumbel_sample(t, temperature=1.0, dim=-1, generator=None):
110def gumbel_sample(t, temperature=1.0, dim=-1, generator=None):
111    return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).argmax(dim=dim)
def add_gumbel_noise(logits, temperature):
114def add_gumbel_noise(logits, temperature):
115    """Adds Gumbel noise to logits for stochastic sampling.\n
116    Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1).\n
117    This version is more numerically stable than a version involving exp() and division."""
118    if abs(temperature) < 1e-9:  # Effectively zero temperature
119        return logits
120
121    max_device_precision = gfx_dtype
122    logits = logits.to(max_device_precision)
123    noise = torch.rand_like(logits, dtype=max_device_precision)
124    # Standard Gumbel noise: -log(-log(U)), U ~ Uniform(0,1) Add small epsilon for numerical stability inside logs
125
126    standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20)
127    return logits + temperature * standard_gumbel_noise

Adds Gumbel noise to logits for stochastic sampling.

Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1).

This version is more numerically stable than a version involving exp() and division.

def add_unstable_gumbel_noise(logits, temperature):
130def add_unstable_gumbel_noise(logits, temperature):
131    """The Gumbel max is a method for sampling categorical distributions.\n
132    arXiv:2409.02908 low-precision Gumbel Max improves MDM perplexity score but reduces generation quality.\n
133    Thus, we use float64... unless mps, in which case we must use float32"""
134    precision = gfx_dtype
135    if temperature == 0:
136        return logits
137    logits = logits.to(precision)
138    noise = torch.rand_like(logits, dtype=precision)
139    gumbel_noise = (-torch.log(noise)) ** temperature
140    return logits.exp() / gumbel_noise

The Gumbel max is a method for sampling categorical distributions.

arXiv:2409.02908 low-precision Gumbel Max improves MDM perplexity score but reduces generation quality.

Thus, we use float64... unless mps, in which case we must use float32