divisor.denoise_step

Denoising step functions for interactive/manual controller-based denoising.

  1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
  2# <!-- // /*  d a r k s h a p e s */ -->
  3
  4"""Denoising step functions for interactive/manual controller-based denoising."""
  5
  6from typing import Any, Callable, Optional
  7
  8from einops import repeat
  9from divisor.registry import gfx_device
 10import torch
 11from torch import Tensor
 12
 13from divisor.state import ImageEmbeddingState, TextEmbeddingState
 14from divisor.variant import apply_variation_noise
 15
 16
 17def create_clear_prediction_cache(
 18    cached_prediction: list[Optional[Tensor]],
 19    cached_prediction_state: list[Optional[dict]],
 20) -> Callable[[], None]:
 21    """Create a function to clear the prediction cache.\n
 22    :param cached_prediction: Mutable list containing cached prediction
 23    :param cached_prediction_state: Mutable list containing cached prediction state
 24    :return: Function that clears the cache"""
 25
 26    def clear_prediction_cache():
 27        """Empty the prediction cache.\n"""
 28        cached_prediction[0] = None
 29        cached_prediction_state[0] = None
 30
 31    return clear_prediction_cache
 32
 33
 34def create_recompute_text_embeddings(
 35    img: Tensor,
 36    t5: Optional[Any],
 37    clip: Optional[Any],
 38    current_txt: list[Tensor],
 39    current_txt_ids: list[Tensor],
 40    current_vec: list[Tensor],
 41    current_prompt: list[Optional[str]],
 42    clear_prediction_cache: Callable[[], None],
 43    is_flux2: bool = False,
 44    text_embedder: Optional[Any] = None,
 45) -> Callable[[str], None]:
 46    """Create a function to recompute text embeddings when prompt changes.\n
 47    Supports both Flux1 (T5+CLIP) and Flux2 (Mistral) architectures.
 48    :param img: Image tensor for batch size reference
 49    :param t5: T5 embedder (Flux1 only, optional)
 50    :param clip: CLIP embedder (Flux1 only, optional)
 51    :param current_txt: Mutable list containing current text embeddings
 52    :param current_txt_ids: Mutable list containing current text IDs
 53    :param current_vec: Mutable list containing current CLIP embeddings (Flux1) or None (Flux2)
 54    :param current_prompt: Mutable list containing current prompt
 55    :param clear_prediction_cache: Function to clear prediction cache
 56    :param is_flux2: Whether this is for Flux2 model (uses different embedder)
 57    :param text_embedder: Text embedder for Flux2 (Mistral, optional)
 58    :return: Function that recomputes text embeddings"""
 59
 60    def recompute_text_embeddings(prompt: str) -> None:
 61        """Recompute text embeddings when prompt changes.\n
 62        :param prompt: New prompt text"""
 63        bs = img.shape[0]
 64        prompt_list = [prompt] if isinstance(prompt, str) else prompt
 65
 66        if is_flux2:
 67            # Flux2 uses Mistral embedder
 68            if text_embedder is None:
 69                return
 70            new_txt = text_embedder(prompt_list).to(img.device)
 71            if new_txt.shape[0] == 1 and bs > 1:
 72                new_txt = repeat(new_txt, "1 ... -> bs ...", bs=bs)
 73            # Flux2 uses 4D position IDs (t, h, w, l)
 74            # Generate IDs using the same approach as flux2/sampling.py
 75            try:
 76                from divisor.flux2.sampling import batched_prc_txt
 77
 78                new_txt, new_txt_ids = batched_prc_txt(new_txt)
 79            except ImportError:
 80                # Fallback: create simple IDs if import fails
 81                # This matches the structure expected by Flux2
 82                _l = new_txt.shape[1]
 83                coords = {
 84                    "t": torch.arange(1),
 85                    "h": torch.arange(1),  # dummy dimension
 86                    "w": torch.arange(1),  # dummy dimension
 87                    "l": torch.arange(_l),
 88                }
 89
 90                new_txt_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"])
 91                if bs > 1:
 92                    new_txt_ids = new_txt_ids.unsqueeze(0).repeat(bs, 1, 1)
 93                new_txt_ids = new_txt_ids.to(new_txt.device)
 94
 95            current_txt[0] = new_txt.to(img.device)
 96            current_txt_ids[0] = new_txt_ids.to(img.device)
 97            # Flux2 doesn't use separate CLIP embeddings
 98            if current_vec:
 99                current_vec[0] = None  # type: ignore
100        else:
101            # Flux1 uses T5 + CLIP
102            if t5 is None or clip is None:
103                return
104
105            # Compute new embeddings
106            new_txt = t5(prompt_list)
107            if new_txt.shape[0] == 1 and bs > 1:
108                new_txt = repeat(new_txt, "1 ... -> bs ...", bs=bs)
109            new_txt_ids = torch.zeros(bs, new_txt.shape[1], 3)
110
111            new_vec = clip(prompt_list)
112            if new_vec.shape[0] == 1 and bs > 1:
113                new_vec = repeat(new_vec, "1 ... -> bs ...", bs=bs)
114
115            # Update embeddings and move to correct device
116            current_txt[0] = new_txt.to(img.device)
117            current_txt_ids[0] = new_txt_ids.to(img.device)
118            current_vec[0] = new_vec.to(img.device)
119
120        current_prompt[0] = prompt
121
122        # Clear prediction cache since embeddings changed
123        clear_prediction_cache()
124
125    return recompute_text_embeddings
126
127
128def _is_flux2_model(model: Any) -> bool:
129    """Check if model is Flux2 type.\n
130    :param model: Model instance to check
131    :return: True if Flux2, False if Flux1"""
132    try:
133        from divisor.flux2.model import Flux2
134
135        return isinstance(model, Flux2)
136    except (ImportError, TypeError):
137        # Fallback: check by class name or signature
138        model_class_name = model.__class__.__name__
139        return "Flux2" in model_class_name or "flux2" in model_class_name.lower()
140
141
142def create_get_prediction(pred_set: TextEmbeddingState, img_set: ImageEmbeddingState) -> Callable[[Tensor, float, float, Optional[list[int]]], Tensor]:
143    """Create a function to generate model prediction with caching.\n
144    :param config: GetPredictionSettings containing all configuration parameters
145    :return: Function that generates predictions with caching"""
146
147    def get_prediction(
148        sample: Tensor,
149        t_curr: float,
150        guidance_val: float,
151        layer_dropouts_val: Optional[list[int]],
152    ) -> Tensor:
153        """Generate model prediction, reusing cached prediction if state hasn't changed.\n
154        :param sample: Current sample tensor
155        :param t_curr: Current timestep
156        :param guidance_val: Guidance value
157        :param layer_dropouts_val: Layer dropout configuration
158        :returns: Model prediction"""
159        # Create a simple hash of the sample tensor for cache key (using first few values)
160        # This is faster than hashing the entire tensor but should be sufficient for cache invalidation
161        # Handle different tensor shapes safely
162        if sample.numel() > 0:
163            # Flatten and get first element for hash
164            first_val = float(sample.flatten()[0].item())
165        else:
166            first_val = 0.0
167        sample_hash = hash((sample.shape, first_val))
168
169        # Check if we can reuse cached prediction
170        current_state = {
171            "sample_hash": sample_hash,
172            "t_curr": t_curr,
173            "guidance": guidance_val,
174            "layer_dropout": layer_dropouts_val,
175        }
176
177        if (
178            pred_set.cached_prediction[0] is not None
179            and pred_set.cached_prediction_state[0] is not None
180            and pred_set.cached_prediction_state[0]["sample_hash"] == current_state["sample_hash"]
181            and pred_set.cached_prediction_state[0]["t_curr"] == current_state["t_curr"]
182            and pred_set.cached_prediction_state[0]["guidance"] == current_state["guidance"]
183            and pred_set.cached_prediction_state[0]["layer_dropout"] == current_state["layer_dropout"]
184        ):
185            return pred_set.cached_prediction[0]
186
187        # Generate new prediction
188        # When autocast is disabled (MPS), ensure all inputs are in correct dtype (bfloat16)
189        # Get model dtype to ensure inputs match
190        # Safely get model dtype, handling Mock objects in tests
191        try:
192            model_dtype = next(pred_set.model_ref[0].parameters()).dtype
193        except (TypeError, StopIteration, AttributeError):
194            # Fallback: use sample dtype if we can't get model dtype (for Mock objects in tests)
195            model_dtype = sample.dtype
196        use_autocast = gfx_device.type == "cuda"
197
198        # Ensure sample is in correct dtype before any operations
199        if not use_autocast:
200            sample = sample.to(dtype=model_dtype)
201
202        t_vec = torch.full((sample.shape[0],), t_curr, dtype=sample.dtype, device=sample.device)
203        img_input = sample
204        img_input_ids = img_set.img_ids
205
206        if img_set.img_cond is not None:
207            # Ensure img_cond matches sample dtype before concatenation
208            img_cond_converted = img_set.img_cond.to(dtype=model_dtype) if not use_autocast else img_set.img_cond
209            img_input = torch.cat((sample, img_cond_converted), dim=-1)
210        if img_set.img_cond_seq is not None:
211            assert img_set.img_cond_seq_ids is not None, "You need to provide either both or neither of the sequence conditioning"
212            # Ensure img_cond_seq matches dtype before concatenation
213            img_cond_seq_converted = img_set.img_cond_seq.to(dtype=model_dtype) if not use_autocast else img_set.img_cond_seq
214            img_input = torch.cat((img_input, img_cond_seq_converted), dim=1)
215            img_input_ids = torch.cat((img_input_ids, img_set.img_cond_seq_ids), dim=1)
216
217        # Determine model type and prepare inputs accordingly
218        is_flux2 = _is_flux2_model(pred_set.model_ref[0])
219
220        if is_flux2:
221            # Flux2 model signature: model(x=..., x_ids=..., timesteps=..., ctx=..., ctx_ids=..., guidance=..., layer_dropouts=...)
222            guidance_vec = torch.full((img_set.img.shape[0],), pred_set.state.guidance, device=img_set.img.device, dtype=img_set.img.dtype)
223
224            if not use_autocast:
225                # MPS: Convert all inputs to model dtype (bfloat16) before processing
226                img_input = img_input.to(dtype=model_dtype)
227                ctx_input = pred_set.current_txt[0].to(dtype=model_dtype)
228                t_vec = t_vec.to(dtype=model_dtype)
229                guidance_vec = guidance_vec.to(dtype=model_dtype)
230            else:
231                ctx_input = pred_set.current_txt[0]
232
233            # Flux2 uses x, x_ids, ctx, ctx_ids instead of img, img_ids, txt, txt_ids, y
234            pred = pred_set.model_ref[0](
235                x=img_input,
236                x_ids=img_input_ids,
237                timesteps=t_vec,
238                ctx=ctx_input,
239                ctx_ids=pred_set.current_txt_ids[0],
240                guidance=guidance_vec,
241                layer_dropouts=layer_dropouts_val,
242            )
243        else:
244            # Flux1 model signature: model(img=..., img_ids=..., txt=..., txt_ids=..., y=..., timesteps=..., guidance=..., layer_dropouts=...)
245            guidance_vec = (torch.full((img_set.img.shape[0],), pred_set.state.guidance, device=img_set.img.device, dtype=img_set.img.dtype) * 0.0) * 0.0
246
247            if not use_autocast:
248                # MPS: Convert all inputs to model dtype (bfloat16) before processing
249                img_input = img_input.to(dtype=model_dtype)
250
251                if pred_set.neg_pred_enabled and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]):
252                    txt_input = pred_set.current_neg_txt[0].to(dtype=model_dtype)  # type: ignore
253                    vec_input = pred_set.current_neg_vec[0].to(dtype=model_dtype)  # type: ignore
254                else:
255                    txt_input = pred_set.current_txt[0].to(dtype=model_dtype)
256                    vec_input = pred_set.current_vec[0].to(dtype=model_dtype)
257                t_vec = t_vec.to(dtype=model_dtype)
258                guidance_vec = guidance_vec.to(dtype=model_dtype)
259            else:
260                if pred_set.neg_pred_enabled and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]):
261                    txt_input = pred_set.current_neg_txt[0]  # type: ignore
262                    vec_input = pred_set.current_neg_vec[0]  # type: ignore
263                else:
264                    txt_input = pred_set.current_txt[0]
265                    vec_input = pred_set.current_vec[0]
266
267            # Use current embeddings (which may have been updated if prompt changed)
268            pred = pred_set.model_ref[0](
269                img=img_input,
270                img_ids=img_input_ids,
271                txt=txt_input,
272                txt_ids=pred_set.current_txt_ids[0],
273                y=vec_input,
274                timesteps=t_vec,
275                guidance=guidance_vec,
276                layer_dropouts=layer_dropouts_val,
277            )
278            if pred_set.neg_pred_enabled and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]):
279                neg_pred = pred_set.model_ref[0](
280                    img=img_input,
281                    img_ids=img_input_ids,
282                    txt=pred_set.current_neg_txt[0],  # type: ignore
283                    txt_ids=pred_set.current_neg_txt_ids[0],  # type: ignore
284                    y=pred_set.current_neg_vec[0],  # type: ignore
285                    timesteps=t_vec,
286                    guidance=guidance_vec,
287                    layer_dropouts=layer_dropouts_val,
288                )
289                pred = neg_pred + pred_set.true_gs * (pred - neg_pred)
290
291        if img_input_ids is not None:
292            pred = pred[:, : sample.shape[1]]
293
294        # Cache the prediction
295        pred_set.cached_prediction[0] = pred
296        pred_set.cached_prediction_state[0] = current_state
297
298        return pred
299
300    return get_prediction
301
302
303def create_denoise_step_fn(
304    controller_ref: list[Optional[Any]],
305    current_layer_dropout: list[Optional[list[int]]],
306    previous_step_tensor: list[Optional[Tensor]],
307    get_prediction: Callable[[Tensor, float, float, Optional[list[int]]], Tensor],
308) -> Callable[[Tensor, float, float, float], Tensor]:
309    """Create a denoising step function for the controller.\n
310    :param controller_ref: Mutable list containing controller reference
311    :param current_layer_dropout: Mutable list containing current layer dropout
312    :param previous_step_tensor: Mutable list containing previous step tensor
313    :param get_prediction: Function to get model prediction
314    :return: Denoising step function"""
315
316    def denoise_step_fn(sample: Tensor, t_curr: float, t_prev: float, guidance_val: float) -> Tensor:
317        """Single denoising step function for the controller.\n
318        :param sample: Current sample tensor
319        :param t_curr: Current timestep
320        :param t_prev: Previous timestep
321        :param guidance_val: Guidance value
322        :returns: Updated sample tensor"""
323
324        if controller_ref[0] is not None:
325            use_mask = controller_ref[0].use_previous_as_mask
326            layer_dropouts = controller_ref[0].current_state.layer_dropout
327        else:
328            layer_dropouts = current_layer_dropout[0]
329            use_mask = False
330        pred = get_prediction(sample, t_curr, guidance_val, layer_dropouts)
331        if use_mask and previous_step_tensor[0] is not None:
332            prev_tensor = previous_step_tensor[0]
333            if prev_tensor.shape == pred.shape:
334                prev_min = prev_tensor.min()
335                prev_max = prev_tensor.max()
336                if prev_max > prev_min:
337                    mask = (prev_tensor - prev_min) / (prev_max - prev_min)
338                else:
339                    mask = torch.ones_like(prev_tensor)
340                # Apply mask to prediction: mask controls how much the prediction affects the result
341                # Higher mask values = more effect from prediction, lower = less effect
342                pred = pred * mask
343
344        dt = t_prev - t_curr
345
346        # Standard noise addition
347        result = sample + dt * pred
348
349        # Apply variation noise if enabled
350        if controller_ref[0] is not None:
351            state = controller_ref[0].current_state
352            if state.variation_seed is not None and state.variation_strength > 0.0:
353                result = apply_variation_noise(
354                    latent_sample=result,
355                    variation_seed=state.variation_seed,
356                    variation_strength=state.variation_strength,
357                    mask=None,
358                    variation_method="linear",
359                )
360
361        # Store current sample as previous for next step
362        previous_step_tensor[0] = sample.clone()
363
364        return result
365
366    return denoise_step_fn
def create_clear_prediction_cache( cached_prediction: list[typing.Optional[torch.Tensor]], cached_prediction_state: list[typing.Optional[dict]]) -> Callable[[], NoneType]:
18def create_clear_prediction_cache(
19    cached_prediction: list[Optional[Tensor]],
20    cached_prediction_state: list[Optional[dict]],
21) -> Callable[[], None]:
22    """Create a function to clear the prediction cache.\n
23    :param cached_prediction: Mutable list containing cached prediction
24    :param cached_prediction_state: Mutable list containing cached prediction state
25    :return: Function that clears the cache"""
26
27    def clear_prediction_cache():
28        """Empty the prediction cache.\n"""
29        cached_prediction[0] = None
30        cached_prediction_state[0] = None
31
32    return clear_prediction_cache

Create a function to clear the prediction cache.

Parameters
  • cached_prediction: Mutable list containing cached prediction
  • cached_prediction_state: Mutable list containing cached prediction state
Returns

Function that clears the cache

def create_recompute_text_embeddings( img: torch.Tensor, t5: Optional[Any], clip: Optional[Any], current_txt: list[torch.Tensor], current_txt_ids: list[torch.Tensor], current_vec: list[torch.Tensor], current_prompt: list[typing.Optional[str]], clear_prediction_cache: Callable[[], NoneType], is_flux2: bool = False, text_embedder: Optional[Any] = None) -> Callable[[str], NoneType]:
 35def create_recompute_text_embeddings(
 36    img: Tensor,
 37    t5: Optional[Any],
 38    clip: Optional[Any],
 39    current_txt: list[Tensor],
 40    current_txt_ids: list[Tensor],
 41    current_vec: list[Tensor],
 42    current_prompt: list[Optional[str]],
 43    clear_prediction_cache: Callable[[], None],
 44    is_flux2: bool = False,
 45    text_embedder: Optional[Any] = None,
 46) -> Callable[[str], None]:
 47    """Create a function to recompute text embeddings when prompt changes.\n
 48    Supports both Flux1 (T5+CLIP) and Flux2 (Mistral) architectures.
 49    :param img: Image tensor for batch size reference
 50    :param t5: T5 embedder (Flux1 only, optional)
 51    :param clip: CLIP embedder (Flux1 only, optional)
 52    :param current_txt: Mutable list containing current text embeddings
 53    :param current_txt_ids: Mutable list containing current text IDs
 54    :param current_vec: Mutable list containing current CLIP embeddings (Flux1) or None (Flux2)
 55    :param current_prompt: Mutable list containing current prompt
 56    :param clear_prediction_cache: Function to clear prediction cache
 57    :param is_flux2: Whether this is for Flux2 model (uses different embedder)
 58    :param text_embedder: Text embedder for Flux2 (Mistral, optional)
 59    :return: Function that recomputes text embeddings"""
 60
 61    def recompute_text_embeddings(prompt: str) -> None:
 62        """Recompute text embeddings when prompt changes.\n
 63        :param prompt: New prompt text"""
 64        bs = img.shape[0]
 65        prompt_list = [prompt] if isinstance(prompt, str) else prompt
 66
 67        if is_flux2:
 68            # Flux2 uses Mistral embedder
 69            if text_embedder is None:
 70                return
 71            new_txt = text_embedder(prompt_list).to(img.device)
 72            if new_txt.shape[0] == 1 and bs > 1:
 73                new_txt = repeat(new_txt, "1 ... -> bs ...", bs=bs)
 74            # Flux2 uses 4D position IDs (t, h, w, l)
 75            # Generate IDs using the same approach as flux2/sampling.py
 76            try:
 77                from divisor.flux2.sampling import batched_prc_txt
 78
 79                new_txt, new_txt_ids = batched_prc_txt(new_txt)
 80            except ImportError:
 81                # Fallback: create simple IDs if import fails
 82                # This matches the structure expected by Flux2
 83                _l = new_txt.shape[1]
 84                coords = {
 85                    "t": torch.arange(1),
 86                    "h": torch.arange(1),  # dummy dimension
 87                    "w": torch.arange(1),  # dummy dimension
 88                    "l": torch.arange(_l),
 89                }
 90
 91                new_txt_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"])
 92                if bs > 1:
 93                    new_txt_ids = new_txt_ids.unsqueeze(0).repeat(bs, 1, 1)
 94                new_txt_ids = new_txt_ids.to(new_txt.device)
 95
 96            current_txt[0] = new_txt.to(img.device)
 97            current_txt_ids[0] = new_txt_ids.to(img.device)
 98            # Flux2 doesn't use separate CLIP embeddings
 99            if current_vec:
100                current_vec[0] = None  # type: ignore
101        else:
102            # Flux1 uses T5 + CLIP
103            if t5 is None or clip is None:
104                return
105
106            # Compute new embeddings
107            new_txt = t5(prompt_list)
108            if new_txt.shape[0] == 1 and bs > 1:
109                new_txt = repeat(new_txt, "1 ... -> bs ...", bs=bs)
110            new_txt_ids = torch.zeros(bs, new_txt.shape[1], 3)
111
112            new_vec = clip(prompt_list)
113            if new_vec.shape[0] == 1 and bs > 1:
114                new_vec = repeat(new_vec, "1 ... -> bs ...", bs=bs)
115
116            # Update embeddings and move to correct device
117            current_txt[0] = new_txt.to(img.device)
118            current_txt_ids[0] = new_txt_ids.to(img.device)
119            current_vec[0] = new_vec.to(img.device)
120
121        current_prompt[0] = prompt
122
123        # Clear prediction cache since embeddings changed
124        clear_prediction_cache()
125
126    return recompute_text_embeddings

Create a function to recompute text embeddings when prompt changes.

Supports both Flux1 (T5+CLIP) and Flux2 (Mistral) architectures.

Parameters
  • img: Image tensor for batch size reference
  • t5: T5 embedder (Flux1 only, optional)
  • clip: CLIP embedder (Flux1 only, optional)
  • current_txt: Mutable list containing current text embeddings
  • current_txt_ids: Mutable list containing current text IDs
  • current_vec: Mutable list containing current CLIP embeddings (Flux1) or None (Flux2)
  • current_prompt: Mutable list containing current prompt
  • clear_prediction_cache: Function to clear prediction cache
  • is_flux2: Whether this is for Flux2 model (uses different embedder)
  • text_embedder: Text embedder for Flux2 (Mistral, optional)
Returns

Function that recomputes text embeddings

def create_get_prediction( pred_set: divisor.state.TextEmbeddingState, img_set: divisor.state.ImageEmbeddingState) -> Callable[[torch.Tensor, float, float, Optional[list[int]]], torch.Tensor]:
143def create_get_prediction(pred_set: TextEmbeddingState, img_set: ImageEmbeddingState) -> Callable[[Tensor, float, float, Optional[list[int]]], Tensor]:
144    """Create a function to generate model prediction with caching.\n
145    :param config: GetPredictionSettings containing all configuration parameters
146    :return: Function that generates predictions with caching"""
147
148    def get_prediction(
149        sample: Tensor,
150        t_curr: float,
151        guidance_val: float,
152        layer_dropouts_val: Optional[list[int]],
153    ) -> Tensor:
154        """Generate model prediction, reusing cached prediction if state hasn't changed.\n
155        :param sample: Current sample tensor
156        :param t_curr: Current timestep
157        :param guidance_val: Guidance value
158        :param layer_dropouts_val: Layer dropout configuration
159        :returns: Model prediction"""
160        # Create a simple hash of the sample tensor for cache key (using first few values)
161        # This is faster than hashing the entire tensor but should be sufficient for cache invalidation
162        # Handle different tensor shapes safely
163        if sample.numel() > 0:
164            # Flatten and get first element for hash
165            first_val = float(sample.flatten()[0].item())
166        else:
167            first_val = 0.0
168        sample_hash = hash((sample.shape, first_val))
169
170        # Check if we can reuse cached prediction
171        current_state = {
172            "sample_hash": sample_hash,
173            "t_curr": t_curr,
174            "guidance": guidance_val,
175            "layer_dropout": layer_dropouts_val,
176        }
177
178        if (
179            pred_set.cached_prediction[0] is not None
180            and pred_set.cached_prediction_state[0] is not None
181            and pred_set.cached_prediction_state[0]["sample_hash"] == current_state["sample_hash"]
182            and pred_set.cached_prediction_state[0]["t_curr"] == current_state["t_curr"]
183            and pred_set.cached_prediction_state[0]["guidance"] == current_state["guidance"]
184            and pred_set.cached_prediction_state[0]["layer_dropout"] == current_state["layer_dropout"]
185        ):
186            return pred_set.cached_prediction[0]
187
188        # Generate new prediction
189        # When autocast is disabled (MPS), ensure all inputs are in correct dtype (bfloat16)
190        # Get model dtype to ensure inputs match
191        # Safely get model dtype, handling Mock objects in tests
192        try:
193            model_dtype = next(pred_set.model_ref[0].parameters()).dtype
194        except (TypeError, StopIteration, AttributeError):
195            # Fallback: use sample dtype if we can't get model dtype (for Mock objects in tests)
196            model_dtype = sample.dtype
197        use_autocast = gfx_device.type == "cuda"
198
199        # Ensure sample is in correct dtype before any operations
200        if not use_autocast:
201            sample = sample.to(dtype=model_dtype)
202
203        t_vec = torch.full((sample.shape[0],), t_curr, dtype=sample.dtype, device=sample.device)
204        img_input = sample
205        img_input_ids = img_set.img_ids
206
207        if img_set.img_cond is not None:
208            # Ensure img_cond matches sample dtype before concatenation
209            img_cond_converted = img_set.img_cond.to(dtype=model_dtype) if not use_autocast else img_set.img_cond
210            img_input = torch.cat((sample, img_cond_converted), dim=-1)
211        if img_set.img_cond_seq is not None:
212            assert img_set.img_cond_seq_ids is not None, "You need to provide either both or neither of the sequence conditioning"
213            # Ensure img_cond_seq matches dtype before concatenation
214            img_cond_seq_converted = img_set.img_cond_seq.to(dtype=model_dtype) if not use_autocast else img_set.img_cond_seq
215            img_input = torch.cat((img_input, img_cond_seq_converted), dim=1)
216            img_input_ids = torch.cat((img_input_ids, img_set.img_cond_seq_ids), dim=1)
217
218        # Determine model type and prepare inputs accordingly
219        is_flux2 = _is_flux2_model(pred_set.model_ref[0])
220
221        if is_flux2:
222            # Flux2 model signature: model(x=..., x_ids=..., timesteps=..., ctx=..., ctx_ids=..., guidance=..., layer_dropouts=...)
223            guidance_vec = torch.full((img_set.img.shape[0],), pred_set.state.guidance, device=img_set.img.device, dtype=img_set.img.dtype)
224
225            if not use_autocast:
226                # MPS: Convert all inputs to model dtype (bfloat16) before processing
227                img_input = img_input.to(dtype=model_dtype)
228                ctx_input = pred_set.current_txt[0].to(dtype=model_dtype)
229                t_vec = t_vec.to(dtype=model_dtype)
230                guidance_vec = guidance_vec.to(dtype=model_dtype)
231            else:
232                ctx_input = pred_set.current_txt[0]
233
234            # Flux2 uses x, x_ids, ctx, ctx_ids instead of img, img_ids, txt, txt_ids, y
235            pred = pred_set.model_ref[0](
236                x=img_input,
237                x_ids=img_input_ids,
238                timesteps=t_vec,
239                ctx=ctx_input,
240                ctx_ids=pred_set.current_txt_ids[0],
241                guidance=guidance_vec,
242                layer_dropouts=layer_dropouts_val,
243            )
244        else:
245            # Flux1 model signature: model(img=..., img_ids=..., txt=..., txt_ids=..., y=..., timesteps=..., guidance=..., layer_dropouts=...)
246            guidance_vec = (torch.full((img_set.img.shape[0],), pred_set.state.guidance, device=img_set.img.device, dtype=img_set.img.dtype) * 0.0) * 0.0
247
248            if not use_autocast:
249                # MPS: Convert all inputs to model dtype (bfloat16) before processing
250                img_input = img_input.to(dtype=model_dtype)
251
252                if pred_set.neg_pred_enabled and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]):
253                    txt_input = pred_set.current_neg_txt[0].to(dtype=model_dtype)  # type: ignore
254                    vec_input = pred_set.current_neg_vec[0].to(dtype=model_dtype)  # type: ignore
255                else:
256                    txt_input = pred_set.current_txt[0].to(dtype=model_dtype)
257                    vec_input = pred_set.current_vec[0].to(dtype=model_dtype)
258                t_vec = t_vec.to(dtype=model_dtype)
259                guidance_vec = guidance_vec.to(dtype=model_dtype)
260            else:
261                if pred_set.neg_pred_enabled and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]):
262                    txt_input = pred_set.current_neg_txt[0]  # type: ignore
263                    vec_input = pred_set.current_neg_vec[0]  # type: ignore
264                else:
265                    txt_input = pred_set.current_txt[0]
266                    vec_input = pred_set.current_vec[0]
267
268            # Use current embeddings (which may have been updated if prompt changed)
269            pred = pred_set.model_ref[0](
270                img=img_input,
271                img_ids=img_input_ids,
272                txt=txt_input,
273                txt_ids=pred_set.current_txt_ids[0],
274                y=vec_input,
275                timesteps=t_vec,
276                guidance=guidance_vec,
277                layer_dropouts=layer_dropouts_val,
278            )
279            if pred_set.neg_pred_enabled and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]):
280                neg_pred = pred_set.model_ref[0](
281                    img=img_input,
282                    img_ids=img_input_ids,
283                    txt=pred_set.current_neg_txt[0],  # type: ignore
284                    txt_ids=pred_set.current_neg_txt_ids[0],  # type: ignore
285                    y=pred_set.current_neg_vec[0],  # type: ignore
286                    timesteps=t_vec,
287                    guidance=guidance_vec,
288                    layer_dropouts=layer_dropouts_val,
289                )
290                pred = neg_pred + pred_set.true_gs * (pred - neg_pred)
291
292        if img_input_ids is not None:
293            pred = pred[:, : sample.shape[1]]
294
295        # Cache the prediction
296        pred_set.cached_prediction[0] = pred
297        pred_set.cached_prediction_state[0] = current_state
298
299        return pred
300
301    return get_prediction

Create a function to generate model prediction with caching.

Parameters
  • config: GetPredictionSettings containing all configuration parameters
Returns

Function that generates predictions with caching

def create_denoise_step_fn( controller_ref: list[typing.Optional[typing.Any]], current_layer_dropout: list[typing.Optional[list[int]]], previous_step_tensor: list[typing.Optional[torch.Tensor]], get_prediction: Callable[[torch.Tensor, float, float, Optional[list[int]]], torch.Tensor]) -> Callable[[torch.Tensor, float, float, float], torch.Tensor]:
304def create_denoise_step_fn(
305    controller_ref: list[Optional[Any]],
306    current_layer_dropout: list[Optional[list[int]]],
307    previous_step_tensor: list[Optional[Tensor]],
308    get_prediction: Callable[[Tensor, float, float, Optional[list[int]]], Tensor],
309) -> Callable[[Tensor, float, float, float], Tensor]:
310    """Create a denoising step function for the controller.\n
311    :param controller_ref: Mutable list containing controller reference
312    :param current_layer_dropout: Mutable list containing current layer dropout
313    :param previous_step_tensor: Mutable list containing previous step tensor
314    :param get_prediction: Function to get model prediction
315    :return: Denoising step function"""
316
317    def denoise_step_fn(sample: Tensor, t_curr: float, t_prev: float, guidance_val: float) -> Tensor:
318        """Single denoising step function for the controller.\n
319        :param sample: Current sample tensor
320        :param t_curr: Current timestep
321        :param t_prev: Previous timestep
322        :param guidance_val: Guidance value
323        :returns: Updated sample tensor"""
324
325        if controller_ref[0] is not None:
326            use_mask = controller_ref[0].use_previous_as_mask
327            layer_dropouts = controller_ref[0].current_state.layer_dropout
328        else:
329            layer_dropouts = current_layer_dropout[0]
330            use_mask = False
331        pred = get_prediction(sample, t_curr, guidance_val, layer_dropouts)
332        if use_mask and previous_step_tensor[0] is not None:
333            prev_tensor = previous_step_tensor[0]
334            if prev_tensor.shape == pred.shape:
335                prev_min = prev_tensor.min()
336                prev_max = prev_tensor.max()
337                if prev_max > prev_min:
338                    mask = (prev_tensor - prev_min) / (prev_max - prev_min)
339                else:
340                    mask = torch.ones_like(prev_tensor)
341                # Apply mask to prediction: mask controls how much the prediction affects the result
342                # Higher mask values = more effect from prediction, lower = less effect
343                pred = pred * mask
344
345        dt = t_prev - t_curr
346
347        # Standard noise addition
348        result = sample + dt * pred
349
350        # Apply variation noise if enabled
351        if controller_ref[0] is not None:
352            state = controller_ref[0].current_state
353            if state.variation_seed is not None and state.variation_strength > 0.0:
354                result = apply_variation_noise(
355                    latent_sample=result,
356                    variation_seed=state.variation_seed,
357                    variation_strength=state.variation_strength,
358                    mask=None,
359                    variation_method="linear",
360                )
361
362        # Store current sample as previous for next step
363        previous_step_tensor[0] = sample.clone()
364
365        return result
366
367    return denoise_step_fn

Create a denoising step function for the controller.

Parameters
  • controller_ref: Mutable list containing controller reference
  • current_layer_dropout: Mutable list containing current layer dropout
  • previous_step_tensor: Mutable list containing previous step tensor
  • get_prediction: Function to get model prediction
Returns

Denoising step function