divisor.flux1.sampling
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3# adapted BFL Flux code from https://github.com/black-forest-labs/flux 4 5import math 6import time 7from typing import Callable, Optional 8 9import torch 10from einops import rearrange, repeat 11from nnll.console import nfo 12 13from torch import Tensor 14 15from divisor.cli_menu import route_choices 16from divisor.controller import ManualTimestepController, rng, variation_rng 17from divisor.denoise_step import ( 18 create_clear_prediction_cache, 19 create_denoise_step_fn, 20 create_get_prediction, 21 create_recompute_text_embeddings, 22) 23from divisor.flux1.model import Flux 24from divisor.flux1.text_embedder import HFEmbedder 25from divisor.interaction_context import InteractionContext 26from divisor.registry import gfx_sync 27from divisor.save import SaveFile 28from divisor.state import ( 29 ImageEmbeddingState, 30 InferenceState, 31 TextEmbeddingState, 32) 33 34 35def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: 36 """Prepare the text embeddings for the model.\n 37 :param t5: T5 embedder 38 :param clip: CLIP embedder 39 :param img: Image tensor 40 :param prompt: Prompt 41 :returns: Dictionary of input tensors""" 42 bs, c, h, w = img.shape 43 if bs == 1 and not isinstance(prompt, str): 44 bs = len(prompt) 45 46 img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 47 if img.shape[0] == 1 and bs > 1: 48 img = repeat(img, "1 ... -> bs ...", bs=bs) 49 50 img_ids = torch.zeros(h // 2, w // 2, 3) 51 img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 52 img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 53 img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 54 55 if isinstance(prompt, str): 56 prompt = [prompt] 57 txt = t5(prompt) 58 if txt.shape[0] == 1 and bs > 1: 59 txt = repeat(txt, "1 ... -> bs ...", bs=bs) 60 txt_ids = torch.zeros(bs, txt.shape[1], 3) 61 62 vec = clip(prompt) 63 if vec.shape[0] == 1 and bs > 1: 64 vec = repeat(vec, "1 ... -> bs ...", bs=bs) 65 66 return { 67 "img": img, 68 "img_ids": img_ids.to(img.device), 69 "txt": txt.to(img.device), 70 "txt_ids": txt_ids.to(img.device), 71 "vec": vec.to(img.device), 72 } 73 74 75def time_shift(schedule_mu: float, schedule_sigma: float, original_timestep_tensor: Tensor) -> Tensor: 76 """Adjustable noise schedule. Compress or stretch any schedule to match a dynamic step sequence length.\n 77 :param schedule_mu: Original schedule parameter. 78 :param schedule_sigma: Original schedule parameter. 79 :param original_timestep_tensor: Tensor of original timesteps in [0,1]. 80 :returns: Adjusted timestep tensor.""" 81 return math.exp(schedule_mu) / (math.exp(schedule_mu) + (1 / original_timestep_tensor - 1) ** schedule_sigma) 82 83 84def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]: 85 m = (y2 - y1) / (x2 - x1) 86 b = y1 - m * x1 87 return lambda x: m * x + b 88 89 90def get_schedule( 91 num_steps: int, 92 image_seq_len: int, 93 base_shift: float = 0.5, 94 max_shift: float = 1.15, 95 shift: bool = True, 96) -> list[float]: 97 """Generate a schedule of timesteps.\n 98 :param num_steps: Number of steps to generate 99 :param image_seq_len: Length of the image sequence 100 :param base_shift: Base shift value 101 :param max_shift: Maximum shift value 102 :param shift: Whether to shift the schedule 103 :returns: List of timesteps""" 104 # extra step for zero 105 timesteps = torch.linspace(1, 0, num_steps + 1) 106 107 # shifting the schedule to favor high timesteps for higher signal images 108 if shift: 109 # estimate mu based on linear estimation between two points 110 mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) 111 timesteps = time_shift(mu, 1.0, timesteps) 112 113 return timesteps.tolist() 114 115 116@torch.inference_mode() 117def denoise( 118 model: Flux, 119 settings: InferenceState, 120): 121 """Denoise using Flux model with optional ManualTimestepController.\n 122 :param model: Flux model instance 123 :param settings: InferenceState containing all denoising configuration parameters""" 124 125 # Extract settings for easier access 126 img = settings.img 127 img_ids = settings.img_ids 128 txt = settings.txt 129 txt_ids = settings.txt_ids 130 vec = settings.vec 131 state = settings.state 132 ae = settings.ae 133 timesteps = settings.timesteps 134 img_cond = settings.img_cond 135 img_cond_seq = settings.img_cond_seq 136 img_cond_seq_ids = settings.img_cond_seq_ids 137 from divisor.registry import gfx_device as default_device 138 139 denoise_device = settings.device if settings.device is not None else default_device 140 initial_layer_dropout = settings.initial_layer_dropout 141 t5 = settings.t5 142 clip = settings.clip 143 neg_pred_enabled = settings.neg_pred_enabled 144 neg_txt = settings.neg_txt 145 neg_txt_ids = settings.neg_txt_ids 146 neg_vec = settings.neg_vec 147 true_gs = settings.true_gs 148 149 current_layer_dropout = [initial_layer_dropout] 150 previous_step_tensor: list[Optional[Tensor]] = [None] # Store previous step's tensor for masking 151 cached_prediction: list[Optional[Tensor]] = [None] # Cache prediction to avoid duplicate model calls 152 cached_prediction_state: list[Optional[dict]] = [None] # Cache state when prediction was generated 153 controller_ref: list[Optional["ManualTimestepController"]] = [None] # Reference to controller for closure access 154 155 model_ref: list[Flux] = [model] 156 target_device = img.device 157 try: 158 model_device = next(model.parameters()).device 159 except (TypeError, StopIteration, AttributeError): 160 model_device = target_device # Assume model is already on correct device if we can't determine it 161 if model_device != target_device: 162 model_ref[0] = model.to_empty(device=target_device) 163 164 current_txt: list[Tensor] = [txt] 165 current_txt_ids: list[Tensor] = [txt_ids] 166 assert vec is not None, "vec (CLIP embeddings) is required for Flux1" 167 current_vec: list[Tensor] = [vec] 168 if neg_pred_enabled and all([neg_txt, neg_txt_ids, neg_vec]): 169 current_neg_txt: list[Tensor] = [neg_txt] # type: ignore 170 current_neg_txt_ids: list[Tensor] = [neg_txt_ids] # type: ignore 171 current_neg_vec: list[Tensor] = [neg_vec] # type: ignore 172 true_gs = true_gs 173 else: 174 current_neg_txt: list[Tensor] | None = None 175 current_neg_txt_ids: list[Tensor] | None = None 176 current_neg_vec: list[Tensor] | None = None 177 true_gs = 1 178 current_prompt: list[Optional[str]] = [state.prompt] # Track current prompt to detect changes 179 180 clear_prediction_cache = create_clear_prediction_cache(cached_prediction, cached_prediction_state) 181 182 recompute_text_embeddings = create_recompute_text_embeddings( # formatting 183 img, t5, clip, current_txt, current_txt_ids, current_vec, current_prompt, clear_prediction_cache, is_flux2=False 184 ) 185 186 pred_set = TextEmbeddingState( 187 model_ref=model_ref, 188 state=state, 189 current_txt=current_txt, 190 current_txt_ids=current_txt_ids, 191 current_vec=current_vec, 192 cached_prediction=cached_prediction, 193 cached_prediction_state=cached_prediction_state, 194 neg_pred_enabled=neg_pred_enabled, 195 current_neg_txt=current_neg_txt, # pyright: ignore[reportArgumentType] 196 current_neg_txt_ids=current_neg_txt_ids, # pyright: ignore[reportArgumentType] 197 current_neg_vec=current_neg_vec, # pyright: ignore[reportArgumentType] 198 true_gs=int(true_gs) if true_gs is not None else None, 199 ) 200 img_set = ImageEmbeddingState( 201 img_ids=img_ids, 202 img=img, 203 img_cond=img_cond, 204 img_cond_seq=img_cond_seq, 205 img_cond_seq_ids=img_cond_seq_ids, 206 ) 207 get_prediction = create_get_prediction(pred_set, img_set) 208 209 denoise_step_fn = create_denoise_step_fn( # formatting 210 controller_ref, current_layer_dropout, previous_step_tensor, get_prediction 211 ) 212 213 controller = ManualTimestepController( # formatting 214 timesteps=timesteps, initial_sample=img, denoise_step_fn=denoise_step_fn, initial_guidance=state.guidance 215 ) 216 controller_ref[0] = controller # Store reference for closure access 217 218 # Use state.layer_dropout if available, otherwise fall back to initial_layer_dropout 219 layer_dropout_to_set = state.layer_dropout if state.layer_dropout is not None else initial_layer_dropout 220 controller.set_layer_dropout(layer_dropout_to_set) 221 222 if state.width is not None and state.height is not None: 223 controller.set_resolution(state.width, state.height) 224 if state.seed is not None: 225 controller.set_seed(state.seed) 226 if state.prompt is not None: 227 controller.set_prompt(state.prompt) 228 if state.num_steps is not None: 229 controller.set_num_steps(state.num_steps) 230 controller.set_vae_shift_offset(state.vae_shift_offset) 231 controller.set_vae_scale_offset(state.vae_scale_offset) 232 controller.set_use_previous_as_mask(state.use_previous_as_mask) 233 234 # Interactive loop 235 while not controller.is_complete: 236 state = controller.current_state 237 238 # Check if prompt changed and recompute embeddings if needed 239 if state.prompt is not None and state.prompt != current_prompt[0]: 240 if t5 is not None and clip is not None: 241 recompute_text_embeddings(state.prompt) 242 else: 243 # If embedders not available, update current_prompt to avoid repeated checks 244 current_prompt[0] = state.prompt 245 246 interaction_context = InteractionContext( 247 clear_prediction_cache=clear_prediction_cache, 248 rng=rng, 249 variation_rng=variation_rng, 250 ae=ae, 251 t5=t5, 252 clip=clip, 253 recompute_text_embeddings=recompute_text_embeddings, 254 ) 255 state = route_choices( 256 controller, 257 state, 258 interaction_context, 259 ) 260 261 # Generate preview 262 t0 = time.perf_counter() 263 if state.seed is not None: 264 rng.next_seed(state.seed) 265 else: 266 state.seed = rng.next_seed() 267 if ae is not None and state.width is not None and state.height is not None: 268 # Reuse cached prediction if available, otherwise generate it 269 # This will be cached and reused in denoise_step_fn when advancing 270 # Always use state.layer_dropout from controller to ensure consistency 271 pred_preview = get_prediction( 272 state.current_sample, 273 state.current_timestep, 274 state.guidance, 275 state.layer_dropout, 276 ) 277 278 intermediate = state.current_sample - state.current_timestep * pred_preview 279 # Unpack requires float32, but we'll convert back to correct dtype after 280 intermediate = unpack(intermediate.float(), state.height, state.width) 281 282 gfx_sync 283 t1 = time.perf_counter() 284 285 nfo(f"Step time: {t1 - t0:.1f}s") 286 287 if denoise_device.type == "cuda": 288 context = torch.autocast(device_type=denoise_device.type, dtype=torch.bfloat16) 289 else: 290 from contextlib import nullcontext 291 292 context = nullcontext() 293 with context: 294 # When autocast is disabled (MPS), ensure intermediate is in correct dtype for VAE 295 if denoise_device.type != "cuda": 296 # Get VAE encoder dtype to ensure intermediate matches (bfloat16) 297 # Safely get encoder dtype, handling Mock objects in tests 298 try: 299 ae_dtype = next(ae.encoder.parameters()).dtype 300 except (TypeError, StopIteration, AttributeError): 301 # Fallback: use intermediate dtype if we can't get encoder dtype (for Mock objects in tests) 302 ae_dtype = intermediate.dtype 303 intermediate = intermediate.to(dtype=ae_dtype) 304 305 # Apply VAE shift/scale offset by manually adjusting the decode operation 306 if state.vae_shift_offset != 0.0 or state.vae_scale_offset != 0.0: 307 # Decode with offset: z = z / (scale_factor + scale_offset) + (shift_factor + shift_offset) 308 z_adjusted = intermediate / (ae.scale_factor + state.vae_scale_offset) + (ae.shift_factor + state.vae_shift_offset) 309 intermediate_image = ae.decoder(z_adjusted) 310 else: 311 intermediate_image = ae.decode(intermediate) 312 if state.seed is not None: 313 controller.store_state_in_chain(current_seed=state.seed) 314 with SaveFile() as saver: 315 saver.intermediate_image = intermediate_image # set up image 316 saver.hyperchain = (controller.hyperchain,) # set up hyperchain 317 saver.with_hyperchain() 318 319 return controller.current_sample 320 321 322def unpack(x: Tensor, height: int, width: int) -> Tensor: 323 return rearrange( 324 x, 325 "b (h w) (c ph pw) -> b c (h ph) (w pw)", 326 h=math.ceil(height / 16), 327 w=math.ceil(width / 16), 328 ph=2, 329 pw=2, 330 )
def
prepare( t5: divisor.flux1.text_embedder.HFEmbedder, clip: divisor.flux1.text_embedder.HFEmbedder, img: torch.Tensor, prompt: str | list[str]) -> dict[str, torch.Tensor]:
36def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: 37 """Prepare the text embeddings for the model.\n 38 :param t5: T5 embedder 39 :param clip: CLIP embedder 40 :param img: Image tensor 41 :param prompt: Prompt 42 :returns: Dictionary of input tensors""" 43 bs, c, h, w = img.shape 44 if bs == 1 and not isinstance(prompt, str): 45 bs = len(prompt) 46 47 img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) 48 if img.shape[0] == 1 and bs > 1: 49 img = repeat(img, "1 ... -> bs ...", bs=bs) 50 51 img_ids = torch.zeros(h // 2, w // 2, 3) 52 img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] 53 img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] 54 img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) 55 56 if isinstance(prompt, str): 57 prompt = [prompt] 58 txt = t5(prompt) 59 if txt.shape[0] == 1 and bs > 1: 60 txt = repeat(txt, "1 ... -> bs ...", bs=bs) 61 txt_ids = torch.zeros(bs, txt.shape[1], 3) 62 63 vec = clip(prompt) 64 if vec.shape[0] == 1 and bs > 1: 65 vec = repeat(vec, "1 ... -> bs ...", bs=bs) 66 67 return { 68 "img": img, 69 "img_ids": img_ids.to(img.device), 70 "txt": txt.to(img.device), 71 "txt_ids": txt_ids.to(img.device), 72 "vec": vec.to(img.device), 73 }
Prepare the text embeddings for the model.
Parameters
- t5: T5 embedder
- clip: CLIP embedder
- img: Image tensor
- prompt: Prompt :returns: Dictionary of input tensors
def
time_shift( schedule_mu: float, schedule_sigma: float, original_timestep_tensor: torch.Tensor) -> torch.Tensor:
76def time_shift(schedule_mu: float, schedule_sigma: float, original_timestep_tensor: Tensor) -> Tensor: 77 """Adjustable noise schedule. Compress or stretch any schedule to match a dynamic step sequence length.\n 78 :param schedule_mu: Original schedule parameter. 79 :param schedule_sigma: Original schedule parameter. 80 :param original_timestep_tensor: Tensor of original timesteps in [0,1]. 81 :returns: Adjusted timestep tensor.""" 82 return math.exp(schedule_mu) / (math.exp(schedule_mu) + (1 / original_timestep_tensor - 1) ** schedule_sigma)
Adjustable noise schedule. Compress or stretch any schedule to match a dynamic step sequence length.
Parameters
- schedule_mu: Original schedule parameter.
- schedule_sigma: Original schedule parameter.
- original_timestep_tensor: Tensor of original timesteps in [0,1]. :returns: Adjusted timestep tensor.
def
get_lin_function( x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
def
get_schedule( num_steps: int, image_seq_len: int, base_shift: float = 0.5, max_shift: float = 1.15, shift: bool = True) -> list[float]:
91def get_schedule( 92 num_steps: int, 93 image_seq_len: int, 94 base_shift: float = 0.5, 95 max_shift: float = 1.15, 96 shift: bool = True, 97) -> list[float]: 98 """Generate a schedule of timesteps.\n 99 :param num_steps: Number of steps to generate 100 :param image_seq_len: Length of the image sequence 101 :param base_shift: Base shift value 102 :param max_shift: Maximum shift value 103 :param shift: Whether to shift the schedule 104 :returns: List of timesteps""" 105 # extra step for zero 106 timesteps = torch.linspace(1, 0, num_steps + 1) 107 108 # shifting the schedule to favor high timesteps for higher signal images 109 if shift: 110 # estimate mu based on linear estimation between two points 111 mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) 112 timesteps = time_shift(mu, 1.0, timesteps) 113 114 return timesteps.tolist()
Generate a schedule of timesteps.
Parameters
- num_steps: Number of steps to generate
- image_seq_len: Length of the image sequence
- base_shift: Base shift value
- max_shift: Maximum shift value
- shift: Whether to shift the schedule :returns: List of timesteps
@torch.inference_mode()
def
denoise( model: divisor.flux1.model.Flux, settings: divisor.state.InferenceState):
117@torch.inference_mode() 118def denoise( 119 model: Flux, 120 settings: InferenceState, 121): 122 """Denoise using Flux model with optional ManualTimestepController.\n 123 :param model: Flux model instance 124 :param settings: InferenceState containing all denoising configuration parameters""" 125 126 # Extract settings for easier access 127 img = settings.img 128 img_ids = settings.img_ids 129 txt = settings.txt 130 txt_ids = settings.txt_ids 131 vec = settings.vec 132 state = settings.state 133 ae = settings.ae 134 timesteps = settings.timesteps 135 img_cond = settings.img_cond 136 img_cond_seq = settings.img_cond_seq 137 img_cond_seq_ids = settings.img_cond_seq_ids 138 from divisor.registry import gfx_device as default_device 139 140 denoise_device = settings.device if settings.device is not None else default_device 141 initial_layer_dropout = settings.initial_layer_dropout 142 t5 = settings.t5 143 clip = settings.clip 144 neg_pred_enabled = settings.neg_pred_enabled 145 neg_txt = settings.neg_txt 146 neg_txt_ids = settings.neg_txt_ids 147 neg_vec = settings.neg_vec 148 true_gs = settings.true_gs 149 150 current_layer_dropout = [initial_layer_dropout] 151 previous_step_tensor: list[Optional[Tensor]] = [None] # Store previous step's tensor for masking 152 cached_prediction: list[Optional[Tensor]] = [None] # Cache prediction to avoid duplicate model calls 153 cached_prediction_state: list[Optional[dict]] = [None] # Cache state when prediction was generated 154 controller_ref: list[Optional["ManualTimestepController"]] = [None] # Reference to controller for closure access 155 156 model_ref: list[Flux] = [model] 157 target_device = img.device 158 try: 159 model_device = next(model.parameters()).device 160 except (TypeError, StopIteration, AttributeError): 161 model_device = target_device # Assume model is already on correct device if we can't determine it 162 if model_device != target_device: 163 model_ref[0] = model.to_empty(device=target_device) 164 165 current_txt: list[Tensor] = [txt] 166 current_txt_ids: list[Tensor] = [txt_ids] 167 assert vec is not None, "vec (CLIP embeddings) is required for Flux1" 168 current_vec: list[Tensor] = [vec] 169 if neg_pred_enabled and all([neg_txt, neg_txt_ids, neg_vec]): 170 current_neg_txt: list[Tensor] = [neg_txt] # type: ignore 171 current_neg_txt_ids: list[Tensor] = [neg_txt_ids] # type: ignore 172 current_neg_vec: list[Tensor] = [neg_vec] # type: ignore 173 true_gs = true_gs 174 else: 175 current_neg_txt: list[Tensor] | None = None 176 current_neg_txt_ids: list[Tensor] | None = None 177 current_neg_vec: list[Tensor] | None = None 178 true_gs = 1 179 current_prompt: list[Optional[str]] = [state.prompt] # Track current prompt to detect changes 180 181 clear_prediction_cache = create_clear_prediction_cache(cached_prediction, cached_prediction_state) 182 183 recompute_text_embeddings = create_recompute_text_embeddings( # formatting 184 img, t5, clip, current_txt, current_txt_ids, current_vec, current_prompt, clear_prediction_cache, is_flux2=False 185 ) 186 187 pred_set = TextEmbeddingState( 188 model_ref=model_ref, 189 state=state, 190 current_txt=current_txt, 191 current_txt_ids=current_txt_ids, 192 current_vec=current_vec, 193 cached_prediction=cached_prediction, 194 cached_prediction_state=cached_prediction_state, 195 neg_pred_enabled=neg_pred_enabled, 196 current_neg_txt=current_neg_txt, # pyright: ignore[reportArgumentType] 197 current_neg_txt_ids=current_neg_txt_ids, # pyright: ignore[reportArgumentType] 198 current_neg_vec=current_neg_vec, # pyright: ignore[reportArgumentType] 199 true_gs=int(true_gs) if true_gs is not None else None, 200 ) 201 img_set = ImageEmbeddingState( 202 img_ids=img_ids, 203 img=img, 204 img_cond=img_cond, 205 img_cond_seq=img_cond_seq, 206 img_cond_seq_ids=img_cond_seq_ids, 207 ) 208 get_prediction = create_get_prediction(pred_set, img_set) 209 210 denoise_step_fn = create_denoise_step_fn( # formatting 211 controller_ref, current_layer_dropout, previous_step_tensor, get_prediction 212 ) 213 214 controller = ManualTimestepController( # formatting 215 timesteps=timesteps, initial_sample=img, denoise_step_fn=denoise_step_fn, initial_guidance=state.guidance 216 ) 217 controller_ref[0] = controller # Store reference for closure access 218 219 # Use state.layer_dropout if available, otherwise fall back to initial_layer_dropout 220 layer_dropout_to_set = state.layer_dropout if state.layer_dropout is not None else initial_layer_dropout 221 controller.set_layer_dropout(layer_dropout_to_set) 222 223 if state.width is not None and state.height is not None: 224 controller.set_resolution(state.width, state.height) 225 if state.seed is not None: 226 controller.set_seed(state.seed) 227 if state.prompt is not None: 228 controller.set_prompt(state.prompt) 229 if state.num_steps is not None: 230 controller.set_num_steps(state.num_steps) 231 controller.set_vae_shift_offset(state.vae_shift_offset) 232 controller.set_vae_scale_offset(state.vae_scale_offset) 233 controller.set_use_previous_as_mask(state.use_previous_as_mask) 234 235 # Interactive loop 236 while not controller.is_complete: 237 state = controller.current_state 238 239 # Check if prompt changed and recompute embeddings if needed 240 if state.prompt is not None and state.prompt != current_prompt[0]: 241 if t5 is not None and clip is not None: 242 recompute_text_embeddings(state.prompt) 243 else: 244 # If embedders not available, update current_prompt to avoid repeated checks 245 current_prompt[0] = state.prompt 246 247 interaction_context = InteractionContext( 248 clear_prediction_cache=clear_prediction_cache, 249 rng=rng, 250 variation_rng=variation_rng, 251 ae=ae, 252 t5=t5, 253 clip=clip, 254 recompute_text_embeddings=recompute_text_embeddings, 255 ) 256 state = route_choices( 257 controller, 258 state, 259 interaction_context, 260 ) 261 262 # Generate preview 263 t0 = time.perf_counter() 264 if state.seed is not None: 265 rng.next_seed(state.seed) 266 else: 267 state.seed = rng.next_seed() 268 if ae is not None and state.width is not None and state.height is not None: 269 # Reuse cached prediction if available, otherwise generate it 270 # This will be cached and reused in denoise_step_fn when advancing 271 # Always use state.layer_dropout from controller to ensure consistency 272 pred_preview = get_prediction( 273 state.current_sample, 274 state.current_timestep, 275 state.guidance, 276 state.layer_dropout, 277 ) 278 279 intermediate = state.current_sample - state.current_timestep * pred_preview 280 # Unpack requires float32, but we'll convert back to correct dtype after 281 intermediate = unpack(intermediate.float(), state.height, state.width) 282 283 gfx_sync 284 t1 = time.perf_counter() 285 286 nfo(f"Step time: {t1 - t0:.1f}s") 287 288 if denoise_device.type == "cuda": 289 context = torch.autocast(device_type=denoise_device.type, dtype=torch.bfloat16) 290 else: 291 from contextlib import nullcontext 292 293 context = nullcontext() 294 with context: 295 # When autocast is disabled (MPS), ensure intermediate is in correct dtype for VAE 296 if denoise_device.type != "cuda": 297 # Get VAE encoder dtype to ensure intermediate matches (bfloat16) 298 # Safely get encoder dtype, handling Mock objects in tests 299 try: 300 ae_dtype = next(ae.encoder.parameters()).dtype 301 except (TypeError, StopIteration, AttributeError): 302 # Fallback: use intermediate dtype if we can't get encoder dtype (for Mock objects in tests) 303 ae_dtype = intermediate.dtype 304 intermediate = intermediate.to(dtype=ae_dtype) 305 306 # Apply VAE shift/scale offset by manually adjusting the decode operation 307 if state.vae_shift_offset != 0.0 or state.vae_scale_offset != 0.0: 308 # Decode with offset: z = z / (scale_factor + scale_offset) + (shift_factor + shift_offset) 309 z_adjusted = intermediate / (ae.scale_factor + state.vae_scale_offset) + (ae.shift_factor + state.vae_shift_offset) 310 intermediate_image = ae.decoder(z_adjusted) 311 else: 312 intermediate_image = ae.decode(intermediate) 313 if state.seed is not None: 314 controller.store_state_in_chain(current_seed=state.seed) 315 with SaveFile() as saver: 316 saver.intermediate_image = intermediate_image # set up image 317 saver.hyperchain = (controller.hyperchain,) # set up hyperchain 318 saver.with_hyperchain() 319 320 return controller.current_sample
Denoise using Flux model with optional ManualTimestepController.
Parameters
- model: Flux model instance
- settings: InferenceState containing all denoising configuration parameters
def
unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor: