divisor.denoise_step
Denoising step functions for interactive/manual controller-based denoising.
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3 4"""Denoising step functions for interactive/manual controller-based denoising.""" 5 6from typing import Any, Callable, Optional 7 8from einops import repeat 9from divisor.registry import gfx_device 10import torch 11from torch import Tensor 12 13from divisor.state import ImageEmbeddingState, TextEmbeddingState 14from divisor.variant import apply_variation_noise 15 16 17def create_clear_prediction_cache( 18 cached_prediction: list[Optional[Tensor]], 19 cached_prediction_state: list[Optional[dict]], 20) -> Callable[[], None]: 21 """Create a function to clear the prediction cache.\n 22 :param cached_prediction: Mutable list containing cached prediction 23 :param cached_prediction_state: Mutable list containing cached prediction state 24 :return: Function that clears the cache""" 25 26 def clear_prediction_cache(): 27 """Empty the prediction cache.\n""" 28 cached_prediction[0] = None 29 cached_prediction_state[0] = None 30 31 return clear_prediction_cache 32 33 34def create_recompute_text_embeddings( 35 img: Tensor, 36 t5: Optional[Any], 37 clip: Optional[Any], 38 current_txt: list[Tensor], 39 current_txt_ids: list[Tensor], 40 current_vec: list[Tensor], 41 current_prompt: list[Optional[str]], 42 clear_prediction_cache: Callable[[], None], 43 is_flux2: bool = False, 44 text_embedder: Optional[Any] = None, 45) -> Callable[[str], None]: 46 """Create a function to recompute text embeddings when prompt changes.\n 47 Supports both Flux1 (T5+CLIP) and Flux2 (Mistral) architectures. 48 :param img: Image tensor for batch size reference 49 :param t5: T5 embedder (Flux1 only, optional) 50 :param clip: CLIP embedder (Flux1 only, optional) 51 :param current_txt: Mutable list containing current text embeddings 52 :param current_txt_ids: Mutable list containing current text IDs 53 :param current_vec: Mutable list containing current CLIP embeddings (Flux1) or None (Flux2) 54 :param current_prompt: Mutable list containing current prompt 55 :param clear_prediction_cache: Function to clear prediction cache 56 :param is_flux2: Whether this is for Flux2 model (uses different embedder) 57 :param text_embedder: Text embedder for Flux2 (Mistral, optional) 58 :return: Function that recomputes text embeddings""" 59 60 def recompute_text_embeddings(prompt: str) -> None: 61 """Recompute text embeddings when prompt changes.\n 62 :param prompt: New prompt text""" 63 bs = img.shape[0] 64 prompt_list = [prompt] if isinstance(prompt, str) else prompt 65 66 if is_flux2: 67 # Flux2 uses Mistral embedder 68 if text_embedder is None: 69 return 70 new_txt = text_embedder(prompt_list).to(img.device) 71 if new_txt.shape[0] == 1 and bs > 1: 72 new_txt = repeat(new_txt, "1 ... -> bs ...", bs=bs) 73 # Flux2 uses 4D position IDs (t, h, w, l) 74 # Generate IDs using the same approach as flux2/sampling.py 75 try: 76 from divisor.flux2.sampling import batched_prc_txt 77 78 new_txt, new_txt_ids = batched_prc_txt(new_txt) 79 except ImportError: 80 # Fallback: create simple IDs if import fails 81 # This matches the structure expected by Flux2 82 _l = new_txt.shape[1] 83 coords = { 84 "t": torch.arange(1), 85 "h": torch.arange(1), # dummy dimension 86 "w": torch.arange(1), # dummy dimension 87 "l": torch.arange(_l), 88 } 89 90 new_txt_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"]) 91 if bs > 1: 92 new_txt_ids = new_txt_ids.unsqueeze(0).repeat(bs, 1, 1) 93 new_txt_ids = new_txt_ids.to(new_txt.device) 94 95 current_txt[0] = new_txt.to(img.device) 96 current_txt_ids[0] = new_txt_ids.to(img.device) 97 # Flux2 doesn't use separate CLIP embeddings 98 if current_vec: 99 current_vec[0] = None # type: ignore 100 else: 101 # Flux1 uses T5 + CLIP 102 if t5 is None or clip is None: 103 return 104 105 # Compute new embeddings 106 new_txt = t5(prompt_list) 107 if new_txt.shape[0] == 1 and bs > 1: 108 new_txt = repeat(new_txt, "1 ... -> bs ...", bs=bs) 109 new_txt_ids = torch.zeros(bs, new_txt.shape[1], 3) 110 111 new_vec = clip(prompt_list) 112 if new_vec.shape[0] == 1 and bs > 1: 113 new_vec = repeat(new_vec, "1 ... -> bs ...", bs=bs) 114 115 # Update embeddings and move to correct device 116 current_txt[0] = new_txt.to(img.device) 117 current_txt_ids[0] = new_txt_ids.to(img.device) 118 current_vec[0] = new_vec.to(img.device) 119 120 current_prompt[0] = prompt 121 122 # Clear prediction cache since embeddings changed 123 clear_prediction_cache() 124 125 return recompute_text_embeddings 126 127 128def _is_flux2_model(model: Any) -> bool: 129 """Check if model is Flux2 type.\n 130 :param model: Model instance to check 131 :return: True if Flux2, False if Flux1""" 132 try: 133 from divisor.flux2.model import Flux2 134 135 return isinstance(model, Flux2) 136 except (ImportError, TypeError): 137 # Fallback: check by class name or signature 138 model_class_name = model.__class__.__name__ 139 return "Flux2" in model_class_name or "flux2" in model_class_name.lower() 140 141 142def create_get_prediction(pred_set: TextEmbeddingState, img_set: ImageEmbeddingState) -> Callable[[Tensor, float, float, Optional[list[int]]], Tensor]: 143 """Create a function to generate model prediction with caching.\n 144 :param config: GetPredictionSettings containing all configuration parameters 145 :return: Function that generates predictions with caching""" 146 147 def get_prediction( 148 sample: Tensor, 149 t_curr: float, 150 guidance_val: float, 151 layer_dropouts_val: Optional[list[int]], 152 ) -> Tensor: 153 """Generate model prediction, reusing cached prediction if state hasn't changed.\n 154 :param sample: Current sample tensor 155 :param t_curr: Current timestep 156 :param guidance_val: Guidance value 157 :param layer_dropouts_val: Layer dropout configuration 158 :returns: Model prediction""" 159 # Create a simple hash of the sample tensor for cache key (using first few values) 160 # This is faster than hashing the entire tensor but should be sufficient for cache invalidation 161 # Handle different tensor shapes safely 162 if sample.numel() > 0: 163 # Flatten and get first element for hash 164 first_val = float(sample.flatten()[0].item()) 165 else: 166 first_val = 0.0 167 sample_hash = hash((sample.shape, first_val)) 168 169 # Check if we can reuse cached prediction 170 current_state = { 171 "sample_hash": sample_hash, 172 "t_curr": t_curr, 173 "guidance": guidance_val, 174 "layer_dropout": layer_dropouts_val, 175 } 176 177 if ( 178 pred_set.cached_prediction[0] is not None 179 and pred_set.cached_prediction_state[0] is not None 180 and pred_set.cached_prediction_state[0]["sample_hash"] == current_state["sample_hash"] 181 and pred_set.cached_prediction_state[0]["t_curr"] == current_state["t_curr"] 182 and pred_set.cached_prediction_state[0]["guidance"] == current_state["guidance"] 183 and pred_set.cached_prediction_state[0]["layer_dropout"] == current_state["layer_dropout"] 184 ): 185 return pred_set.cached_prediction[0] 186 187 # Generate new prediction 188 # When autocast is disabled (MPS), ensure all inputs are in correct dtype (bfloat16) 189 # Get model dtype to ensure inputs match 190 # Safely get model dtype, handling Mock objects in tests 191 try: 192 model_dtype = next(pred_set.model_ref[0].parameters()).dtype 193 except (TypeError, StopIteration, AttributeError): 194 # Fallback: use sample dtype if we can't get model dtype (for Mock objects in tests) 195 model_dtype = sample.dtype 196 use_autocast = gfx_device.type == "cuda" 197 198 # Ensure sample is in correct dtype before any operations 199 if not use_autocast: 200 sample = sample.to(dtype=model_dtype) 201 202 t_vec = torch.full((sample.shape[0],), t_curr, dtype=sample.dtype, device=sample.device) 203 img_input = sample 204 img_input_ids = img_set.img_ids 205 206 if img_set.img_cond is not None: 207 # Ensure img_cond matches sample dtype before concatenation 208 img_cond_converted = img_set.img_cond.to(dtype=model_dtype) if not use_autocast else img_set.img_cond 209 img_input = torch.cat((sample, img_cond_converted), dim=-1) 210 if img_set.img_cond_seq is not None: 211 assert img_set.img_cond_seq_ids is not None, "You need to provide either both or neither of the sequence conditioning" 212 # Ensure img_cond_seq matches dtype before concatenation 213 img_cond_seq_converted = img_set.img_cond_seq.to(dtype=model_dtype) if not use_autocast else img_set.img_cond_seq 214 img_input = torch.cat((img_input, img_cond_seq_converted), dim=1) 215 img_input_ids = torch.cat((img_input_ids, img_set.img_cond_seq_ids), dim=1) 216 217 # Determine model type and prepare inputs accordingly 218 is_flux2 = _is_flux2_model(pred_set.model_ref[0]) 219 220 if is_flux2: 221 # Flux2 model signature: model(x=..., x_ids=..., timesteps=..., ctx=..., ctx_ids=..., guidance=..., layer_dropouts=...) 222 guidance_vec = torch.full((img_set.img.shape[0],), pred_set.state.guidance, device=img_set.img.device, dtype=img_set.img.dtype) 223 224 if not use_autocast: 225 # MPS: Convert all inputs to model dtype (bfloat16) before processing 226 img_input = img_input.to(dtype=model_dtype) 227 ctx_input = pred_set.current_txt[0].to(dtype=model_dtype) 228 t_vec = t_vec.to(dtype=model_dtype) 229 guidance_vec = guidance_vec.to(dtype=model_dtype) 230 else: 231 ctx_input = pred_set.current_txt[0] 232 233 # Flux2 uses x, x_ids, ctx, ctx_ids instead of img, img_ids, txt, txt_ids, y 234 pred = pred_set.model_ref[0]( 235 x=img_input, 236 x_ids=img_input_ids, 237 timesteps=t_vec, 238 ctx=ctx_input, 239 ctx_ids=pred_set.current_txt_ids[0], 240 guidance=guidance_vec, 241 layer_dropouts=layer_dropouts_val, 242 ) 243 else: 244 # Flux1 model signature: model(img=..., img_ids=..., txt=..., txt_ids=..., y=..., timesteps=..., guidance=..., layer_dropouts=...) 245 guidance_vec = (torch.full((img_set.img.shape[0],), pred_set.state.guidance, device=img_set.img.device, dtype=img_set.img.dtype) * 0.0) * 0.0 246 247 if not use_autocast: 248 # MPS: Convert all inputs to model dtype (bfloat16) before processing 249 img_input = img_input.to(dtype=model_dtype) 250 251 if pred_set.neg_pred_enabled and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]): 252 txt_input = pred_set.current_neg_txt[0].to(dtype=model_dtype) # type: ignore 253 vec_input = pred_set.current_neg_vec[0].to(dtype=model_dtype) # type: ignore 254 else: 255 txt_input = pred_set.current_txt[0].to(dtype=model_dtype) 256 vec_input = pred_set.current_vec[0].to(dtype=model_dtype) 257 t_vec = t_vec.to(dtype=model_dtype) 258 guidance_vec = guidance_vec.to(dtype=model_dtype) 259 else: 260 if pred_set.neg_pred_enabled and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]): 261 txt_input = pred_set.current_neg_txt[0] # type: ignore 262 vec_input = pred_set.current_neg_vec[0] # type: ignore 263 else: 264 txt_input = pred_set.current_txt[0] 265 vec_input = pred_set.current_vec[0] 266 267 # Use current embeddings (which may have been updated if prompt changed) 268 pred = pred_set.model_ref[0]( 269 img=img_input, 270 img_ids=img_input_ids, 271 txt=txt_input, 272 txt_ids=pred_set.current_txt_ids[0], 273 y=vec_input, 274 timesteps=t_vec, 275 guidance=guidance_vec, 276 layer_dropouts=layer_dropouts_val, 277 ) 278 if pred_set.neg_pred_enabled and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]): 279 neg_pred = pred_set.model_ref[0]( 280 img=img_input, 281 img_ids=img_input_ids, 282 txt=pred_set.current_neg_txt[0], # type: ignore 283 txt_ids=pred_set.current_neg_txt_ids[0], # type: ignore 284 y=pred_set.current_neg_vec[0], # type: ignore 285 timesteps=t_vec, 286 guidance=guidance_vec, 287 layer_dropouts=layer_dropouts_val, 288 ) 289 pred = neg_pred + pred_set.true_gs * (pred - neg_pred) 290 291 if img_input_ids is not None: 292 pred = pred[:, : sample.shape[1]] 293 294 # Cache the prediction 295 pred_set.cached_prediction[0] = pred 296 pred_set.cached_prediction_state[0] = current_state 297 298 return pred 299 300 return get_prediction 301 302 303def create_denoise_step_fn( 304 controller_ref: list[Optional[Any]], 305 current_layer_dropout: list[Optional[list[int]]], 306 previous_step_tensor: list[Optional[Tensor]], 307 get_prediction: Callable[[Tensor, float, float, Optional[list[int]]], Tensor], 308) -> Callable[[Tensor, float, float, float], Tensor]: 309 """Create a denoising step function for the controller.\n 310 :param controller_ref: Mutable list containing controller reference 311 :param current_layer_dropout: Mutable list containing current layer dropout 312 :param previous_step_tensor: Mutable list containing previous step tensor 313 :param get_prediction: Function to get model prediction 314 :return: Denoising step function""" 315 316 def denoise_step_fn(sample: Tensor, t_curr: float, t_prev: float, guidance_val: float) -> Tensor: 317 """Single denoising step function for the controller.\n 318 :param sample: Current sample tensor 319 :param t_curr: Current timestep 320 :param t_prev: Previous timestep 321 :param guidance_val: Guidance value 322 :returns: Updated sample tensor""" 323 324 if controller_ref[0] is not None: 325 use_mask = controller_ref[0].use_previous_as_mask 326 layer_dropouts = controller_ref[0].current_state.layer_dropout 327 else: 328 layer_dropouts = current_layer_dropout[0] 329 use_mask = False 330 pred = get_prediction(sample, t_curr, guidance_val, layer_dropouts) 331 if use_mask and previous_step_tensor[0] is not None: 332 prev_tensor = previous_step_tensor[0] 333 if prev_tensor.shape == pred.shape: 334 prev_min = prev_tensor.min() 335 prev_max = prev_tensor.max() 336 if prev_max > prev_min: 337 mask = (prev_tensor - prev_min) / (prev_max - prev_min) 338 else: 339 mask = torch.ones_like(prev_tensor) 340 # Apply mask to prediction: mask controls how much the prediction affects the result 341 # Higher mask values = more effect from prediction, lower = less effect 342 pred = pred * mask 343 344 dt = t_prev - t_curr 345 346 # Standard noise addition 347 result = sample + dt * pred 348 349 # Apply variation noise if enabled 350 if controller_ref[0] is not None: 351 state = controller_ref[0].current_state 352 if state.variation_seed is not None and state.variation_strength > 0.0: 353 result = apply_variation_noise( 354 latent_sample=result, 355 variation_seed=state.variation_seed, 356 variation_strength=state.variation_strength, 357 mask=None, 358 variation_method="linear", 359 ) 360 361 # Store current sample as previous for next step 362 previous_step_tensor[0] = sample.clone() 363 364 return result 365 366 return denoise_step_fn
def
create_clear_prediction_cache( cached_prediction: list[typing.Optional[torch.Tensor]], cached_prediction_state: list[typing.Optional[dict]]) -> Callable[[], NoneType]:
18def create_clear_prediction_cache( 19 cached_prediction: list[Optional[Tensor]], 20 cached_prediction_state: list[Optional[dict]], 21) -> Callable[[], None]: 22 """Create a function to clear the prediction cache.\n 23 :param cached_prediction: Mutable list containing cached prediction 24 :param cached_prediction_state: Mutable list containing cached prediction state 25 :return: Function that clears the cache""" 26 27 def clear_prediction_cache(): 28 """Empty the prediction cache.\n""" 29 cached_prediction[0] = None 30 cached_prediction_state[0] = None 31 32 return clear_prediction_cache
Create a function to clear the prediction cache.
Parameters
- cached_prediction: Mutable list containing cached prediction
- cached_prediction_state: Mutable list containing cached prediction state
Returns
Function that clears the cache
def
create_recompute_text_embeddings( img: torch.Tensor, t5: Optional[Any], clip: Optional[Any], current_txt: list[torch.Tensor], current_txt_ids: list[torch.Tensor], current_vec: list[torch.Tensor], current_prompt: list[typing.Optional[str]], clear_prediction_cache: Callable[[], NoneType], is_flux2: bool = False, text_embedder: Optional[Any] = None) -> Callable[[str], NoneType]:
35def create_recompute_text_embeddings( 36 img: Tensor, 37 t5: Optional[Any], 38 clip: Optional[Any], 39 current_txt: list[Tensor], 40 current_txt_ids: list[Tensor], 41 current_vec: list[Tensor], 42 current_prompt: list[Optional[str]], 43 clear_prediction_cache: Callable[[], None], 44 is_flux2: bool = False, 45 text_embedder: Optional[Any] = None, 46) -> Callable[[str], None]: 47 """Create a function to recompute text embeddings when prompt changes.\n 48 Supports both Flux1 (T5+CLIP) and Flux2 (Mistral) architectures. 49 :param img: Image tensor for batch size reference 50 :param t5: T5 embedder (Flux1 only, optional) 51 :param clip: CLIP embedder (Flux1 only, optional) 52 :param current_txt: Mutable list containing current text embeddings 53 :param current_txt_ids: Mutable list containing current text IDs 54 :param current_vec: Mutable list containing current CLIP embeddings (Flux1) or None (Flux2) 55 :param current_prompt: Mutable list containing current prompt 56 :param clear_prediction_cache: Function to clear prediction cache 57 :param is_flux2: Whether this is for Flux2 model (uses different embedder) 58 :param text_embedder: Text embedder for Flux2 (Mistral, optional) 59 :return: Function that recomputes text embeddings""" 60 61 def recompute_text_embeddings(prompt: str) -> None: 62 """Recompute text embeddings when prompt changes.\n 63 :param prompt: New prompt text""" 64 bs = img.shape[0] 65 prompt_list = [prompt] if isinstance(prompt, str) else prompt 66 67 if is_flux2: 68 # Flux2 uses Mistral embedder 69 if text_embedder is None: 70 return 71 new_txt = text_embedder(prompt_list).to(img.device) 72 if new_txt.shape[0] == 1 and bs > 1: 73 new_txt = repeat(new_txt, "1 ... -> bs ...", bs=bs) 74 # Flux2 uses 4D position IDs (t, h, w, l) 75 # Generate IDs using the same approach as flux2/sampling.py 76 try: 77 from divisor.flux2.sampling import batched_prc_txt 78 79 new_txt, new_txt_ids = batched_prc_txt(new_txt) 80 except ImportError: 81 # Fallback: create simple IDs if import fails 82 # This matches the structure expected by Flux2 83 _l = new_txt.shape[1] 84 coords = { 85 "t": torch.arange(1), 86 "h": torch.arange(1), # dummy dimension 87 "w": torch.arange(1), # dummy dimension 88 "l": torch.arange(_l), 89 } 90 91 new_txt_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"]) 92 if bs > 1: 93 new_txt_ids = new_txt_ids.unsqueeze(0).repeat(bs, 1, 1) 94 new_txt_ids = new_txt_ids.to(new_txt.device) 95 96 current_txt[0] = new_txt.to(img.device) 97 current_txt_ids[0] = new_txt_ids.to(img.device) 98 # Flux2 doesn't use separate CLIP embeddings 99 if current_vec: 100 current_vec[0] = None # type: ignore 101 else: 102 # Flux1 uses T5 + CLIP 103 if t5 is None or clip is None: 104 return 105 106 # Compute new embeddings 107 new_txt = t5(prompt_list) 108 if new_txt.shape[0] == 1 and bs > 1: 109 new_txt = repeat(new_txt, "1 ... -> bs ...", bs=bs) 110 new_txt_ids = torch.zeros(bs, new_txt.shape[1], 3) 111 112 new_vec = clip(prompt_list) 113 if new_vec.shape[0] == 1 and bs > 1: 114 new_vec = repeat(new_vec, "1 ... -> bs ...", bs=bs) 115 116 # Update embeddings and move to correct device 117 current_txt[0] = new_txt.to(img.device) 118 current_txt_ids[0] = new_txt_ids.to(img.device) 119 current_vec[0] = new_vec.to(img.device) 120 121 current_prompt[0] = prompt 122 123 # Clear prediction cache since embeddings changed 124 clear_prediction_cache() 125 126 return recompute_text_embeddings
Create a function to recompute text embeddings when prompt changes.
Supports both Flux1 (T5+CLIP) and Flux2 (Mistral) architectures.
Parameters
- img: Image tensor for batch size reference
- t5: T5 embedder (Flux1 only, optional)
- clip: CLIP embedder (Flux1 only, optional)
- current_txt: Mutable list containing current text embeddings
- current_txt_ids: Mutable list containing current text IDs
- current_vec: Mutable list containing current CLIP embeddings (Flux1) or None (Flux2)
- current_prompt: Mutable list containing current prompt
- clear_prediction_cache: Function to clear prediction cache
- is_flux2: Whether this is for Flux2 model (uses different embedder)
- text_embedder: Text embedder for Flux2 (Mistral, optional)
Returns
Function that recomputes text embeddings
def
create_get_prediction( pred_set: divisor.state.TextEmbeddingState, img_set: divisor.state.ImageEmbeddingState) -> Callable[[torch.Tensor, float, float, Optional[list[int]]], torch.Tensor]:
143def create_get_prediction(pred_set: TextEmbeddingState, img_set: ImageEmbeddingState) -> Callable[[Tensor, float, float, Optional[list[int]]], Tensor]: 144 """Create a function to generate model prediction with caching.\n 145 :param config: GetPredictionSettings containing all configuration parameters 146 :return: Function that generates predictions with caching""" 147 148 def get_prediction( 149 sample: Tensor, 150 t_curr: float, 151 guidance_val: float, 152 layer_dropouts_val: Optional[list[int]], 153 ) -> Tensor: 154 """Generate model prediction, reusing cached prediction if state hasn't changed.\n 155 :param sample: Current sample tensor 156 :param t_curr: Current timestep 157 :param guidance_val: Guidance value 158 :param layer_dropouts_val: Layer dropout configuration 159 :returns: Model prediction""" 160 # Create a simple hash of the sample tensor for cache key (using first few values) 161 # This is faster than hashing the entire tensor but should be sufficient for cache invalidation 162 # Handle different tensor shapes safely 163 if sample.numel() > 0: 164 # Flatten and get first element for hash 165 first_val = float(sample.flatten()[0].item()) 166 else: 167 first_val = 0.0 168 sample_hash = hash((sample.shape, first_val)) 169 170 # Check if we can reuse cached prediction 171 current_state = { 172 "sample_hash": sample_hash, 173 "t_curr": t_curr, 174 "guidance": guidance_val, 175 "layer_dropout": layer_dropouts_val, 176 } 177 178 if ( 179 pred_set.cached_prediction[0] is not None 180 and pred_set.cached_prediction_state[0] is not None 181 and pred_set.cached_prediction_state[0]["sample_hash"] == current_state["sample_hash"] 182 and pred_set.cached_prediction_state[0]["t_curr"] == current_state["t_curr"] 183 and pred_set.cached_prediction_state[0]["guidance"] == current_state["guidance"] 184 and pred_set.cached_prediction_state[0]["layer_dropout"] == current_state["layer_dropout"] 185 ): 186 return pred_set.cached_prediction[0] 187 188 # Generate new prediction 189 # When autocast is disabled (MPS), ensure all inputs are in correct dtype (bfloat16) 190 # Get model dtype to ensure inputs match 191 # Safely get model dtype, handling Mock objects in tests 192 try: 193 model_dtype = next(pred_set.model_ref[0].parameters()).dtype 194 except (TypeError, StopIteration, AttributeError): 195 # Fallback: use sample dtype if we can't get model dtype (for Mock objects in tests) 196 model_dtype = sample.dtype 197 use_autocast = gfx_device.type == "cuda" 198 199 # Ensure sample is in correct dtype before any operations 200 if not use_autocast: 201 sample = sample.to(dtype=model_dtype) 202 203 t_vec = torch.full((sample.shape[0],), t_curr, dtype=sample.dtype, device=sample.device) 204 img_input = sample 205 img_input_ids = img_set.img_ids 206 207 if img_set.img_cond is not None: 208 # Ensure img_cond matches sample dtype before concatenation 209 img_cond_converted = img_set.img_cond.to(dtype=model_dtype) if not use_autocast else img_set.img_cond 210 img_input = torch.cat((sample, img_cond_converted), dim=-1) 211 if img_set.img_cond_seq is not None: 212 assert img_set.img_cond_seq_ids is not None, "You need to provide either both or neither of the sequence conditioning" 213 # Ensure img_cond_seq matches dtype before concatenation 214 img_cond_seq_converted = img_set.img_cond_seq.to(dtype=model_dtype) if not use_autocast else img_set.img_cond_seq 215 img_input = torch.cat((img_input, img_cond_seq_converted), dim=1) 216 img_input_ids = torch.cat((img_input_ids, img_set.img_cond_seq_ids), dim=1) 217 218 # Determine model type and prepare inputs accordingly 219 is_flux2 = _is_flux2_model(pred_set.model_ref[0]) 220 221 if is_flux2: 222 # Flux2 model signature: model(x=..., x_ids=..., timesteps=..., ctx=..., ctx_ids=..., guidance=..., layer_dropouts=...) 223 guidance_vec = torch.full((img_set.img.shape[0],), pred_set.state.guidance, device=img_set.img.device, dtype=img_set.img.dtype) 224 225 if not use_autocast: 226 # MPS: Convert all inputs to model dtype (bfloat16) before processing 227 img_input = img_input.to(dtype=model_dtype) 228 ctx_input = pred_set.current_txt[0].to(dtype=model_dtype) 229 t_vec = t_vec.to(dtype=model_dtype) 230 guidance_vec = guidance_vec.to(dtype=model_dtype) 231 else: 232 ctx_input = pred_set.current_txt[0] 233 234 # Flux2 uses x, x_ids, ctx, ctx_ids instead of img, img_ids, txt, txt_ids, y 235 pred = pred_set.model_ref[0]( 236 x=img_input, 237 x_ids=img_input_ids, 238 timesteps=t_vec, 239 ctx=ctx_input, 240 ctx_ids=pred_set.current_txt_ids[0], 241 guidance=guidance_vec, 242 layer_dropouts=layer_dropouts_val, 243 ) 244 else: 245 # Flux1 model signature: model(img=..., img_ids=..., txt=..., txt_ids=..., y=..., timesteps=..., guidance=..., layer_dropouts=...) 246 guidance_vec = (torch.full((img_set.img.shape[0],), pred_set.state.guidance, device=img_set.img.device, dtype=img_set.img.dtype) * 0.0) * 0.0 247 248 if not use_autocast: 249 # MPS: Convert all inputs to model dtype (bfloat16) before processing 250 img_input = img_input.to(dtype=model_dtype) 251 252 if pred_set.neg_pred_enabled and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]): 253 txt_input = pred_set.current_neg_txt[0].to(dtype=model_dtype) # type: ignore 254 vec_input = pred_set.current_neg_vec[0].to(dtype=model_dtype) # type: ignore 255 else: 256 txt_input = pred_set.current_txt[0].to(dtype=model_dtype) 257 vec_input = pred_set.current_vec[0].to(dtype=model_dtype) 258 t_vec = t_vec.to(dtype=model_dtype) 259 guidance_vec = guidance_vec.to(dtype=model_dtype) 260 else: 261 if pred_set.neg_pred_enabled and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]): 262 txt_input = pred_set.current_neg_txt[0] # type: ignore 263 vec_input = pred_set.current_neg_vec[0] # type: ignore 264 else: 265 txt_input = pred_set.current_txt[0] 266 vec_input = pred_set.current_vec[0] 267 268 # Use current embeddings (which may have been updated if prompt changed) 269 pred = pred_set.model_ref[0]( 270 img=img_input, 271 img_ids=img_input_ids, 272 txt=txt_input, 273 txt_ids=pred_set.current_txt_ids[0], 274 y=vec_input, 275 timesteps=t_vec, 276 guidance=guidance_vec, 277 layer_dropouts=layer_dropouts_val, 278 ) 279 if pred_set.neg_pred_enabled and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]): 280 neg_pred = pred_set.model_ref[0]( 281 img=img_input, 282 img_ids=img_input_ids, 283 txt=pred_set.current_neg_txt[0], # type: ignore 284 txt_ids=pred_set.current_neg_txt_ids[0], # type: ignore 285 y=pred_set.current_neg_vec[0], # type: ignore 286 timesteps=t_vec, 287 guidance=guidance_vec, 288 layer_dropouts=layer_dropouts_val, 289 ) 290 pred = neg_pred + pred_set.true_gs * (pred - neg_pred) 291 292 if img_input_ids is not None: 293 pred = pred[:, : sample.shape[1]] 294 295 # Cache the prediction 296 pred_set.cached_prediction[0] = pred 297 pred_set.cached_prediction_state[0] = current_state 298 299 return pred 300 301 return get_prediction
Create a function to generate model prediction with caching.
Parameters
- config: GetPredictionSettings containing all configuration parameters
Returns
Function that generates predictions with caching
def
create_denoise_step_fn( controller_ref: list[typing.Optional[typing.Any]], current_layer_dropout: list[typing.Optional[list[int]]], previous_step_tensor: list[typing.Optional[torch.Tensor]], get_prediction: Callable[[torch.Tensor, float, float, Optional[list[int]]], torch.Tensor]) -> Callable[[torch.Tensor, float, float, float], torch.Tensor]:
304def create_denoise_step_fn( 305 controller_ref: list[Optional[Any]], 306 current_layer_dropout: list[Optional[list[int]]], 307 previous_step_tensor: list[Optional[Tensor]], 308 get_prediction: Callable[[Tensor, float, float, Optional[list[int]]], Tensor], 309) -> Callable[[Tensor, float, float, float], Tensor]: 310 """Create a denoising step function for the controller.\n 311 :param controller_ref: Mutable list containing controller reference 312 :param current_layer_dropout: Mutable list containing current layer dropout 313 :param previous_step_tensor: Mutable list containing previous step tensor 314 :param get_prediction: Function to get model prediction 315 :return: Denoising step function""" 316 317 def denoise_step_fn(sample: Tensor, t_curr: float, t_prev: float, guidance_val: float) -> Tensor: 318 """Single denoising step function for the controller.\n 319 :param sample: Current sample tensor 320 :param t_curr: Current timestep 321 :param t_prev: Previous timestep 322 :param guidance_val: Guidance value 323 :returns: Updated sample tensor""" 324 325 if controller_ref[0] is not None: 326 use_mask = controller_ref[0].use_previous_as_mask 327 layer_dropouts = controller_ref[0].current_state.layer_dropout 328 else: 329 layer_dropouts = current_layer_dropout[0] 330 use_mask = False 331 pred = get_prediction(sample, t_curr, guidance_val, layer_dropouts) 332 if use_mask and previous_step_tensor[0] is not None: 333 prev_tensor = previous_step_tensor[0] 334 if prev_tensor.shape == pred.shape: 335 prev_min = prev_tensor.min() 336 prev_max = prev_tensor.max() 337 if prev_max > prev_min: 338 mask = (prev_tensor - prev_min) / (prev_max - prev_min) 339 else: 340 mask = torch.ones_like(prev_tensor) 341 # Apply mask to prediction: mask controls how much the prediction affects the result 342 # Higher mask values = more effect from prediction, lower = less effect 343 pred = pred * mask 344 345 dt = t_prev - t_curr 346 347 # Standard noise addition 348 result = sample + dt * pred 349 350 # Apply variation noise if enabled 351 if controller_ref[0] is not None: 352 state = controller_ref[0].current_state 353 if state.variation_seed is not None and state.variation_strength > 0.0: 354 result = apply_variation_noise( 355 latent_sample=result, 356 variation_seed=state.variation_seed, 357 variation_strength=state.variation_strength, 358 mask=None, 359 variation_method="linear", 360 ) 361 362 # Store current sample as previous for next step 363 previous_step_tensor[0] = sample.clone() 364 365 return result 366 367 return denoise_step_fn
Create a denoising step function for the controller.
Parameters
- controller_ref: Mutable list containing controller reference
- current_layer_dropout: Mutable list containing current layer dropout
- previous_step_tensor: Mutable list containing previous step tensor
- get_prediction: Function to get model prediction
Returns
Denoising step function