divisor.noise
Noise generation functions for Flux models.
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3 4"""Noise generation functions for Flux models.""" 5 6import math 7from typing import Any, Optional 8 9import torch 10from torch import Tensor 11 12from divisor.controller import rng 13from divisor.registry import gfx_dtype 14 15 16def get_noise( 17 num_samples: int, 18 height: int, 19 width: int, 20 seed: int, 21 dtype: torch.dtype = torch.bfloat16, 22 device: torch.device | None = None, 23 version_2: bool = False, 24) -> Tensor: 25 """Generate noise tensor for Flux models.\n 26 :param num_samples: Number of samples to generate 27 :param height: Height of the image 28 :param width: Width of the image 29 :param dtype: Data type of the noise 30 :param seed: Seed for the random number generator 31 :param device: Device to generate the noise on 32 :param model_type: Model type - "flux1" or "flux2" (default: "flux1") 33 :returns: Noise tensor with shape appropriate for the model type 34 35 Flux1 shape: (num_samples, 16, 2 * ceil(height/16), 2 * ceil(width/16))\n 36 Flux2 shape: (num_samples, 128, height // 16, width // 16)\n""" 37 38 generator_device = rng._torch_generator.device if rng._torch_generator is not None else torch.device("cpu") 39 rng._torch_generator.manual_seed(seed) # type: ignore # reset seed 40 if version_2: 41 shape = (num_samples, 128, height // 16, width // 16) 42 else: 43 shape = (num_samples, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16)) 44 noise = torch.randn(shape, dtype=dtype, generator=rng._torch_generator, device=generator_device) 45 46 if device is not None and generator_device != device: 47 noise = noise.to(device) 48 return noise 49 50 51def prepare_4d_noise_for_3d_model( 52 height: int, 53 width: int, 54 seed: int, 55 t5: Optional[Any] = None, 56 clip: Optional[Any] = None, 57 prompt: Optional[str] = None, 58 num_samples: int = 1, 59 dtype: torch.dtype = torch.bfloat16, 60 device: torch.device | None = None, 61 version_2: bool = False, 62) -> Tensor: 63 """Generate noise and convert to 3D format for model input.\n 64 Generates 4D noise tensor and converts it from (batch, channels, height, width) format to\n 65 (batch, sequence_length, features) format based on model type.\n 66 :param height: Height of the image 67 :param width: Width of the image 68 :param seed: Seed for random number generation 69 :param t5: Optional T5 embedder instance (required for Flux1/XFlux1) 70 :param clip: Optional CLIP embedder instance (required for Flux1/XFlux1) 71 :param prompt: Optional prompt string (required for Flux1/XFlux1) 72 :param num_samples: Number of samples to generate (default: 1) 73 :param dtype: Data type of the noise (default: torch.bfloat16) 74 :param device: Device to generate the noise on (default: None) 75 :param version_2: Whether to use Flux2 format (default: False) 76 :returns: 3D tensor with shape (batch, sequence_length, features)""" 77 78 noise_4d = get_noise( 79 num_samples=num_samples, 80 height=height, 81 width=width, 82 seed=seed, 83 dtype=dtype, 84 device=device, 85 version_2=version_2, 86 ) 87 88 if t5 is not None and clip is not None and prompt is not None: 89 from divisor.flux1.sampling import prepare 90 91 inp = prepare(t5, clip, noise_4d, prompt=prompt) 92 return inp["img"] # 3D format: (batch, sequence_length, features) 93 else: 94 from divisor.flux2.sampling import batched_prc_img 95 96 noise_3d, _ = batched_prc_img(noise_4d) # 4D -> 3D: Ignore x_ids as controller doesn't need them 97 return noise_3d 98 99 100def log(t, eps=1e-20): 101 return torch.log(t.clamp(min=eps)) 102 103 104def gumbel_noise(t, generator=None): 105 noise = torch.zeros_like(t).uniform_(0, 1, generator=generator) 106 return -log(-log(noise)) 107 108 109def gumbel_sample(t, temperature=1.0, dim=-1, generator=None): 110 return ((t / max(temperature, 1e-10)) + gumbel_noise(t, generator=generator)).argmax(dim=dim) 111 112 113def add_gumbel_noise(logits, temperature): 114 """Adds Gumbel noise to logits for stochastic sampling.\n 115 Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1).\n 116 This version is more numerically stable than a version involving exp() and division.""" 117 if abs(temperature) < 1e-9: # Effectively zero temperature 118 return logits 119 120 max_device_precision = gfx_dtype 121 logits = logits.to(max_device_precision) 122 noise = torch.rand_like(logits, dtype=max_device_precision) 123 # Standard Gumbel noise: -log(-log(U)), U ~ Uniform(0,1) Add small epsilon for numerical stability inside logs 124 125 standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20) 126 return logits + temperature * standard_gumbel_noise 127 128 129def add_unstable_gumbel_noise(logits, temperature): 130 """The Gumbel max is a method for sampling categorical distributions.\n 131 arXiv:2409.02908 low-precision Gumbel Max improves MDM perplexity score but reduces generation quality.\n 132 Thus, we use float64... unless mps, in which case we must use float32""" 133 precision = gfx_dtype 134 if temperature == 0: 135 return logits 136 logits = logits.to(precision) 137 noise = torch.rand_like(logits, dtype=precision) 138 gumbel_noise = (-torch.log(noise)) ** temperature 139 return logits.exp() / gumbel_noise
17def get_noise( 18 num_samples: int, 19 height: int, 20 width: int, 21 seed: int, 22 dtype: torch.dtype = torch.bfloat16, 23 device: torch.device | None = None, 24 version_2: bool = False, 25) -> Tensor: 26 """Generate noise tensor for Flux models.\n 27 :param num_samples: Number of samples to generate 28 :param height: Height of the image 29 :param width: Width of the image 30 :param dtype: Data type of the noise 31 :param seed: Seed for the random number generator 32 :param device: Device to generate the noise on 33 :param model_type: Model type - "flux1" or "flux2" (default: "flux1") 34 :returns: Noise tensor with shape appropriate for the model type 35 36 Flux1 shape: (num_samples, 16, 2 * ceil(height/16), 2 * ceil(width/16))\n 37 Flux2 shape: (num_samples, 128, height // 16, width // 16)\n""" 38 39 generator_device = rng._torch_generator.device if rng._torch_generator is not None else torch.device("cpu") 40 rng._torch_generator.manual_seed(seed) # type: ignore # reset seed 41 if version_2: 42 shape = (num_samples, 128, height // 16, width // 16) 43 else: 44 shape = (num_samples, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16)) 45 noise = torch.randn(shape, dtype=dtype, generator=rng._torch_generator, device=generator_device) 46 47 if device is not None and generator_device != device: 48 noise = noise.to(device) 49 return noise
Generate noise tensor for Flux models.
Parameters
- num_samples: Number of samples to generate
- height: Height of the image
- width: Width of the image
- dtype: Data type of the noise
- seed: Seed for the random number generator
- device: Device to generate the noise on
- model_type: Model type - "flux1" or "flux2" (default: "flux1") :returns: Noise tensor with shape appropriate for the model type
Flux1 shape: (num_samples, 16, 2 * ceil(height/16), 2 * ceil(width/16))
Flux2 shape: (num_samples, 128, height // 16, width // 16)
52def prepare_4d_noise_for_3d_model( 53 height: int, 54 width: int, 55 seed: int, 56 t5: Optional[Any] = None, 57 clip: Optional[Any] = None, 58 prompt: Optional[str] = None, 59 num_samples: int = 1, 60 dtype: torch.dtype = torch.bfloat16, 61 device: torch.device | None = None, 62 version_2: bool = False, 63) -> Tensor: 64 """Generate noise and convert to 3D format for model input.\n 65 Generates 4D noise tensor and converts it from (batch, channels, height, width) format to\n 66 (batch, sequence_length, features) format based on model type.\n 67 :param height: Height of the image 68 :param width: Width of the image 69 :param seed: Seed for random number generation 70 :param t5: Optional T5 embedder instance (required for Flux1/XFlux1) 71 :param clip: Optional CLIP embedder instance (required for Flux1/XFlux1) 72 :param prompt: Optional prompt string (required for Flux1/XFlux1) 73 :param num_samples: Number of samples to generate (default: 1) 74 :param dtype: Data type of the noise (default: torch.bfloat16) 75 :param device: Device to generate the noise on (default: None) 76 :param version_2: Whether to use Flux2 format (default: False) 77 :returns: 3D tensor with shape (batch, sequence_length, features)""" 78 79 noise_4d = get_noise( 80 num_samples=num_samples, 81 height=height, 82 width=width, 83 seed=seed, 84 dtype=dtype, 85 device=device, 86 version_2=version_2, 87 ) 88 89 if t5 is not None and clip is not None and prompt is not None: 90 from divisor.flux1.sampling import prepare 91 92 inp = prepare(t5, clip, noise_4d, prompt=prompt) 93 return inp["img"] # 3D format: (batch, sequence_length, features) 94 else: 95 from divisor.flux2.sampling import batched_prc_img 96 97 noise_3d, _ = batched_prc_img(noise_4d) # 4D -> 3D: Ignore x_ids as controller doesn't need them 98 return noise_3d
Generate noise and convert to 3D format for model input.
Generates 4D noise tensor and converts it from (batch, channels, height, width) format to
(batch, sequence_length, features) format based on model type.
Parameters
- height: Height of the image
- width: Width of the image
- seed: Seed for random number generation
- t5: Optional T5 embedder instance (required for Flux1/XFlux1)
- clip: Optional CLIP embedder instance (required for Flux1/XFlux1)
- prompt: Optional prompt string (required for Flux1/XFlux1)
- num_samples: Number of samples to generate (default: 1)
- dtype: Data type of the noise (default: torch.bfloat16)
- device: Device to generate the noise on (default: None)
- version_2: Whether to use Flux2 format (default: False) :returns: 3D tensor with shape (batch, sequence_length, features)
114def add_gumbel_noise(logits, temperature): 115 """Adds Gumbel noise to logits for stochastic sampling.\n 116 Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1).\n 117 This version is more numerically stable than a version involving exp() and division.""" 118 if abs(temperature) < 1e-9: # Effectively zero temperature 119 return logits 120 121 max_device_precision = gfx_dtype 122 logits = logits.to(max_device_precision) 123 noise = torch.rand_like(logits, dtype=max_device_precision) 124 # Standard Gumbel noise: -log(-log(U)), U ~ Uniform(0,1) Add small epsilon for numerical stability inside logs 125 126 standard_gumbel_noise = -torch.log(-torch.log(noise + 1e-20) + 1e-20) 127 return logits + temperature * standard_gumbel_noise
Adds Gumbel noise to logits for stochastic sampling.
Equivalent to argmax(logits + temperature * G) where G ~ Gumbel(0,1).
This version is more numerically stable than a version involving exp() and division.
130def add_unstable_gumbel_noise(logits, temperature): 131 """The Gumbel max is a method for sampling categorical distributions.\n 132 arXiv:2409.02908 low-precision Gumbel Max improves MDM perplexity score but reduces generation quality.\n 133 Thus, we use float64... unless mps, in which case we must use float32""" 134 precision = gfx_dtype 135 if temperature == 0: 136 return logits 137 logits = logits.to(precision) 138 noise = torch.rand_like(logits, dtype=precision) 139 gumbel_noise = (-torch.log(noise)) ** temperature 140 return logits.exp() / gumbel_noise
The Gumbel max is a method for sampling categorical distributions.
arXiv:2409.02908 low-precision Gumbel Max improves MDM perplexity score but reduces generation quality.
Thus, we use float64... unless mps, in which case we must use float32