divisor.flux1.sampling

  1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
  2# <!-- // /*  d a r k s h a p e s */ -->
  3# adapted BFL Flux code from https://github.com/black-forest-labs/flux
  4
  5import math
  6import time
  7from typing import Callable, Optional
  8
  9import torch
 10from einops import rearrange, repeat
 11from nnll.console import nfo
 12
 13from torch import Tensor
 14
 15from divisor.cli_menu import route_choices
 16from divisor.controller import ManualTimestepController, rng, variation_rng
 17from divisor.denoise_step import (
 18    create_clear_prediction_cache,
 19    create_denoise_step_fn,
 20    create_get_prediction,
 21    create_recompute_text_embeddings,
 22)
 23from divisor.flux1.model import Flux
 24from divisor.flux1.text_embedder import HFEmbedder
 25from divisor.interaction_context import InteractionContext
 26from divisor.registry import gfx_sync
 27from divisor.save import SaveFile
 28from divisor.state import (
 29    ImageEmbeddingState,
 30    InferenceState,
 31    TextEmbeddingState,
 32)
 33
 34
 35def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
 36    """Prepare the text embeddings for the model.\n
 37    :param t5: T5 embedder
 38    :param clip: CLIP embedder
 39    :param img: Image tensor
 40    :param prompt: Prompt
 41    :returns: Dictionary of input tensors"""
 42    bs, c, h, w = img.shape
 43    if bs == 1 and not isinstance(prompt, str):
 44        bs = len(prompt)
 45
 46    img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
 47    if img.shape[0] == 1 and bs > 1:
 48        img = repeat(img, "1 ... -> bs ...", bs=bs)
 49
 50    img_ids = torch.zeros(h // 2, w // 2, 3)
 51    img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
 52    img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
 53    img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
 54
 55    if isinstance(prompt, str):
 56        prompt = [prompt]
 57    txt = t5(prompt)
 58    if txt.shape[0] == 1 and bs > 1:
 59        txt = repeat(txt, "1 ... -> bs ...", bs=bs)
 60    txt_ids = torch.zeros(bs, txt.shape[1], 3)
 61
 62    vec = clip(prompt)
 63    if vec.shape[0] == 1 and bs > 1:
 64        vec = repeat(vec, "1 ... -> bs ...", bs=bs)
 65
 66    return {
 67        "img": img,
 68        "img_ids": img_ids.to(img.device),
 69        "txt": txt.to(img.device),
 70        "txt_ids": txt_ids.to(img.device),
 71        "vec": vec.to(img.device),
 72    }
 73
 74
 75def time_shift(schedule_mu: float, schedule_sigma: float, original_timestep_tensor: Tensor) -> Tensor:
 76    """Adjustable noise schedule. Compress or stretch any schedule to match a dynamic step sequence length.\n
 77    :param schedule_mu: Original schedule parameter.
 78    :param schedule_sigma: Original schedule parameter.
 79    :param original_timestep_tensor: Tensor of original timesteps in [0,1].
 80    :returns: Adjusted timestep tensor."""
 81    return math.exp(schedule_mu) / (math.exp(schedule_mu) + (1 / original_timestep_tensor - 1) ** schedule_sigma)
 82
 83
 84def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
 85    m = (y2 - y1) / (x2 - x1)
 86    b = y1 - m * x1
 87    return lambda x: m * x + b
 88
 89
 90def get_schedule(
 91    num_steps: int,
 92    image_seq_len: int,
 93    base_shift: float = 0.5,
 94    max_shift: float = 1.15,
 95    shift: bool = True,
 96) -> list[float]:
 97    """Generate a schedule of timesteps.\n
 98    :param num_steps: Number of steps to generate
 99    :param image_seq_len: Length of the image sequence
100    :param base_shift: Base shift value
101    :param max_shift: Maximum shift value
102    :param shift: Whether to shift the schedule
103    :returns: List of timesteps"""
104    # extra step for zero
105    timesteps = torch.linspace(1, 0, num_steps + 1)
106
107    # shifting the schedule to favor high timesteps for higher signal images
108    if shift:
109        # estimate mu based on linear estimation between two points
110        mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
111        timesteps = time_shift(mu, 1.0, timesteps)
112
113    return timesteps.tolist()
114
115
116@torch.inference_mode()
117def denoise(
118    model: Flux,
119    settings: InferenceState,
120):
121    """Denoise using Flux model with optional ManualTimestepController.\n
122    :param model: Flux model instance
123    :param settings: InferenceState containing all denoising configuration parameters"""
124
125    # Extract settings for easier access
126    img = settings.img
127    img_ids = settings.img_ids
128    txt = settings.txt
129    txt_ids = settings.txt_ids
130    vec = settings.vec
131    state = settings.state
132    ae = settings.ae
133    timesteps = settings.timesteps
134    img_cond = settings.img_cond
135    img_cond_seq = settings.img_cond_seq
136    img_cond_seq_ids = settings.img_cond_seq_ids
137    from divisor.registry import gfx_device as default_device
138
139    denoise_device = settings.device if settings.device is not None else default_device
140    initial_layer_dropout = settings.initial_layer_dropout
141    t5 = settings.t5
142    clip = settings.clip
143    neg_pred_enabled = settings.neg_pred_enabled
144    neg_txt = settings.neg_txt
145    neg_txt_ids = settings.neg_txt_ids
146    neg_vec = settings.neg_vec
147    true_gs = settings.true_gs
148
149    current_layer_dropout = [initial_layer_dropout]
150    previous_step_tensor: list[Optional[Tensor]] = [None]  # Store previous step's tensor for masking
151    cached_prediction: list[Optional[Tensor]] = [None]  # Cache prediction to avoid duplicate model calls
152    cached_prediction_state: list[Optional[dict]] = [None]  # Cache state when prediction was generated
153    controller_ref: list[Optional["ManualTimestepController"]] = [None]  # Reference to controller for closure access
154
155    model_ref: list[Flux] = [model]
156    target_device = img.device
157    try:
158        model_device = next(model.parameters()).device
159    except (TypeError, StopIteration, AttributeError):
160        model_device = target_device  # Assume model is already on correct device if we can't determine it
161    if model_device != target_device:
162        model_ref[0] = model.to_empty(device=target_device)
163
164    current_txt: list[Tensor] = [txt]
165    current_txt_ids: list[Tensor] = [txt_ids]
166    assert vec is not None, "vec (CLIP embeddings) is required for Flux1"
167    current_vec: list[Tensor] = [vec]
168    if neg_pred_enabled and all([neg_txt, neg_txt_ids, neg_vec]):
169        current_neg_txt: list[Tensor] = [neg_txt]  # type: ignore
170        current_neg_txt_ids: list[Tensor] = [neg_txt_ids]  # type: ignore
171        current_neg_vec: list[Tensor] = [neg_vec]  # type: ignore
172        true_gs = true_gs
173    else:
174        current_neg_txt: list[Tensor] | None = None
175        current_neg_txt_ids: list[Tensor] | None = None
176        current_neg_vec: list[Tensor] | None = None
177        true_gs = 1
178    current_prompt: list[Optional[str]] = [state.prompt]  # Track current prompt to detect changes
179
180    clear_prediction_cache = create_clear_prediction_cache(cached_prediction, cached_prediction_state)
181
182    recompute_text_embeddings = create_recompute_text_embeddings(  # formatting
183        img, t5, clip, current_txt, current_txt_ids, current_vec, current_prompt, clear_prediction_cache, is_flux2=False
184    )
185
186    pred_set = TextEmbeddingState(
187        model_ref=model_ref,
188        state=state,
189        current_txt=current_txt,
190        current_txt_ids=current_txt_ids,
191        current_vec=current_vec,
192        cached_prediction=cached_prediction,
193        cached_prediction_state=cached_prediction_state,
194        neg_pred_enabled=neg_pred_enabled,
195        current_neg_txt=current_neg_txt,  # pyright: ignore[reportArgumentType]
196        current_neg_txt_ids=current_neg_txt_ids,  # pyright: ignore[reportArgumentType]
197        current_neg_vec=current_neg_vec,  # pyright: ignore[reportArgumentType]
198        true_gs=int(true_gs) if true_gs is not None else None,
199    )
200    img_set = ImageEmbeddingState(
201        img_ids=img_ids,
202        img=img,
203        img_cond=img_cond,
204        img_cond_seq=img_cond_seq,
205        img_cond_seq_ids=img_cond_seq_ids,
206    )
207    get_prediction = create_get_prediction(pred_set, img_set)
208
209    denoise_step_fn = create_denoise_step_fn(  # formatting
210        controller_ref, current_layer_dropout, previous_step_tensor, get_prediction
211    )
212
213    controller = ManualTimestepController(  # formatting
214        timesteps=timesteps, initial_sample=img, denoise_step_fn=denoise_step_fn, initial_guidance=state.guidance
215    )
216    controller_ref[0] = controller  # Store reference for closure access
217
218    # Use state.layer_dropout if available, otherwise fall back to initial_layer_dropout
219    layer_dropout_to_set = state.layer_dropout if state.layer_dropout is not None else initial_layer_dropout
220    controller.set_layer_dropout(layer_dropout_to_set)
221
222    if state.width is not None and state.height is not None:
223        controller.set_resolution(state.width, state.height)
224    if state.seed is not None:
225        controller.set_seed(state.seed)
226    if state.prompt is not None:
227        controller.set_prompt(state.prompt)
228    if state.num_steps is not None:
229        controller.set_num_steps(state.num_steps)
230    controller.set_vae_shift_offset(state.vae_shift_offset)
231    controller.set_vae_scale_offset(state.vae_scale_offset)
232    controller.set_use_previous_as_mask(state.use_previous_as_mask)
233
234    # Interactive loop
235    while not controller.is_complete:
236        state = controller.current_state
237
238        # Check if prompt changed and recompute embeddings if needed
239        if state.prompt is not None and state.prompt != current_prompt[0]:
240            if t5 is not None and clip is not None:
241                recompute_text_embeddings(state.prompt)
242            else:
243                # If embedders not available, update current_prompt to avoid repeated checks
244                current_prompt[0] = state.prompt
245
246        interaction_context = InteractionContext(
247            clear_prediction_cache=clear_prediction_cache,
248            rng=rng,
249            variation_rng=variation_rng,
250            ae=ae,
251            t5=t5,
252            clip=clip,
253            recompute_text_embeddings=recompute_text_embeddings,
254        )
255        state = route_choices(
256            controller,
257            state,
258            interaction_context,
259        )
260
261        # Generate preview
262        t0 = time.perf_counter()
263        if state.seed is not None:
264            rng.next_seed(state.seed)
265        else:
266            state.seed = rng.next_seed()
267        if ae is not None and state.width is not None and state.height is not None:
268            # Reuse cached prediction if available, otherwise generate it
269            # This will be cached and reused in denoise_step_fn when advancing
270            # Always use state.layer_dropout from controller to ensure consistency
271            pred_preview = get_prediction(
272                state.current_sample,
273                state.current_timestep,
274                state.guidance,
275                state.layer_dropout,
276            )
277
278            intermediate = state.current_sample - state.current_timestep * pred_preview
279            # Unpack requires float32, but we'll convert back to correct dtype after
280            intermediate = unpack(intermediate.float(), state.height, state.width)
281
282            gfx_sync
283            t1 = time.perf_counter()
284
285            nfo(f"Step time: {t1 - t0:.1f}s")
286
287            if denoise_device.type == "cuda":
288                context = torch.autocast(device_type=denoise_device.type, dtype=torch.bfloat16)
289            else:
290                from contextlib import nullcontext
291
292                context = nullcontext()
293            with context:
294                # When autocast is disabled (MPS), ensure intermediate is in correct dtype for VAE
295                if denoise_device.type != "cuda":
296                    # Get VAE encoder dtype to ensure intermediate matches (bfloat16)
297                    # Safely get encoder dtype, handling Mock objects in tests
298                    try:
299                        ae_dtype = next(ae.encoder.parameters()).dtype
300                    except (TypeError, StopIteration, AttributeError):
301                        # Fallback: use intermediate dtype if we can't get encoder dtype (for Mock objects in tests)
302                        ae_dtype = intermediate.dtype
303                    intermediate = intermediate.to(dtype=ae_dtype)
304
305                # Apply VAE shift/scale offset by manually adjusting the decode operation
306                if state.vae_shift_offset != 0.0 or state.vae_scale_offset != 0.0:
307                    # Decode with offset: z = z / (scale_factor + scale_offset) + (shift_factor + shift_offset)
308                    z_adjusted = intermediate / (ae.scale_factor + state.vae_scale_offset) + (ae.shift_factor + state.vae_shift_offset)
309                    intermediate_image = ae.decoder(z_adjusted)
310                else:
311                    intermediate_image = ae.decode(intermediate)
312                if state.seed is not None:
313                    controller.store_state_in_chain(current_seed=state.seed)
314                with SaveFile() as saver:
315                    saver.intermediate_image = intermediate_image  # set up image
316                    saver.hyperchain = (controller.hyperchain,)  # set up hyperchain
317                    saver.with_hyperchain()
318
319    return controller.current_sample
320
321
322def unpack(x: Tensor, height: int, width: int) -> Tensor:
323    return rearrange(
324        x,
325        "b (h w) (c ph pw) -> b c (h ph) (w pw)",
326        h=math.ceil(height / 16),
327        w=math.ceil(width / 16),
328        ph=2,
329        pw=2,
330    )
def prepare( t5: divisor.flux1.text_embedder.HFEmbedder, clip: divisor.flux1.text_embedder.HFEmbedder, img: torch.Tensor, prompt: str | list[str]) -> dict[str, torch.Tensor]:
36def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
37    """Prepare the text embeddings for the model.\n
38    :param t5: T5 embedder
39    :param clip: CLIP embedder
40    :param img: Image tensor
41    :param prompt: Prompt
42    :returns: Dictionary of input tensors"""
43    bs, c, h, w = img.shape
44    if bs == 1 and not isinstance(prompt, str):
45        bs = len(prompt)
46
47    img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
48    if img.shape[0] == 1 and bs > 1:
49        img = repeat(img, "1 ... -> bs ...", bs=bs)
50
51    img_ids = torch.zeros(h // 2, w // 2, 3)
52    img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
53    img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
54    img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
55
56    if isinstance(prompt, str):
57        prompt = [prompt]
58    txt = t5(prompt)
59    if txt.shape[0] == 1 and bs > 1:
60        txt = repeat(txt, "1 ... -> bs ...", bs=bs)
61    txt_ids = torch.zeros(bs, txt.shape[1], 3)
62
63    vec = clip(prompt)
64    if vec.shape[0] == 1 and bs > 1:
65        vec = repeat(vec, "1 ... -> bs ...", bs=bs)
66
67    return {
68        "img": img,
69        "img_ids": img_ids.to(img.device),
70        "txt": txt.to(img.device),
71        "txt_ids": txt_ids.to(img.device),
72        "vec": vec.to(img.device),
73    }

Prepare the text embeddings for the model.

Parameters
  • t5: T5 embedder
  • clip: CLIP embedder
  • img: Image tensor
  • prompt: Prompt :returns: Dictionary of input tensors
def time_shift( schedule_mu: float, schedule_sigma: float, original_timestep_tensor: torch.Tensor) -> torch.Tensor:
76def time_shift(schedule_mu: float, schedule_sigma: float, original_timestep_tensor: Tensor) -> Tensor:
77    """Adjustable noise schedule. Compress or stretch any schedule to match a dynamic step sequence length.\n
78    :param schedule_mu: Original schedule parameter.
79    :param schedule_sigma: Original schedule parameter.
80    :param original_timestep_tensor: Tensor of original timesteps in [0,1].
81    :returns: Adjusted timestep tensor."""
82    return math.exp(schedule_mu) / (math.exp(schedule_mu) + (1 / original_timestep_tensor - 1) ** schedule_sigma)

Adjustable noise schedule. Compress or stretch any schedule to match a dynamic step sequence length.

Parameters
  • schedule_mu: Original schedule parameter.
  • schedule_sigma: Original schedule parameter.
  • original_timestep_tensor: Tensor of original timesteps in [0,1]. :returns: Adjusted timestep tensor.
def get_lin_function( x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
85def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
86    m = (y2 - y1) / (x2 - x1)
87    b = y1 - m * x1
88    return lambda x: m * x + b
def get_schedule( num_steps: int, image_seq_len: int, base_shift: float = 0.5, max_shift: float = 1.15, shift: bool = True) -> list[float]:
 91def get_schedule(
 92    num_steps: int,
 93    image_seq_len: int,
 94    base_shift: float = 0.5,
 95    max_shift: float = 1.15,
 96    shift: bool = True,
 97) -> list[float]:
 98    """Generate a schedule of timesteps.\n
 99    :param num_steps: Number of steps to generate
100    :param image_seq_len: Length of the image sequence
101    :param base_shift: Base shift value
102    :param max_shift: Maximum shift value
103    :param shift: Whether to shift the schedule
104    :returns: List of timesteps"""
105    # extra step for zero
106    timesteps = torch.linspace(1, 0, num_steps + 1)
107
108    # shifting the schedule to favor high timesteps for higher signal images
109    if shift:
110        # estimate mu based on linear estimation between two points
111        mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
112        timesteps = time_shift(mu, 1.0, timesteps)
113
114    return timesteps.tolist()

Generate a schedule of timesteps.

Parameters
  • num_steps: Number of steps to generate
  • image_seq_len: Length of the image sequence
  • base_shift: Base shift value
  • max_shift: Maximum shift value
  • shift: Whether to shift the schedule :returns: List of timesteps
@torch.inference_mode()
def denoise( model: divisor.flux1.model.Flux, settings: divisor.state.InferenceState):
117@torch.inference_mode()
118def denoise(
119    model: Flux,
120    settings: InferenceState,
121):
122    """Denoise using Flux model with optional ManualTimestepController.\n
123    :param model: Flux model instance
124    :param settings: InferenceState containing all denoising configuration parameters"""
125
126    # Extract settings for easier access
127    img = settings.img
128    img_ids = settings.img_ids
129    txt = settings.txt
130    txt_ids = settings.txt_ids
131    vec = settings.vec
132    state = settings.state
133    ae = settings.ae
134    timesteps = settings.timesteps
135    img_cond = settings.img_cond
136    img_cond_seq = settings.img_cond_seq
137    img_cond_seq_ids = settings.img_cond_seq_ids
138    from divisor.registry import gfx_device as default_device
139
140    denoise_device = settings.device if settings.device is not None else default_device
141    initial_layer_dropout = settings.initial_layer_dropout
142    t5 = settings.t5
143    clip = settings.clip
144    neg_pred_enabled = settings.neg_pred_enabled
145    neg_txt = settings.neg_txt
146    neg_txt_ids = settings.neg_txt_ids
147    neg_vec = settings.neg_vec
148    true_gs = settings.true_gs
149
150    current_layer_dropout = [initial_layer_dropout]
151    previous_step_tensor: list[Optional[Tensor]] = [None]  # Store previous step's tensor for masking
152    cached_prediction: list[Optional[Tensor]] = [None]  # Cache prediction to avoid duplicate model calls
153    cached_prediction_state: list[Optional[dict]] = [None]  # Cache state when prediction was generated
154    controller_ref: list[Optional["ManualTimestepController"]] = [None]  # Reference to controller for closure access
155
156    model_ref: list[Flux] = [model]
157    target_device = img.device
158    try:
159        model_device = next(model.parameters()).device
160    except (TypeError, StopIteration, AttributeError):
161        model_device = target_device  # Assume model is already on correct device if we can't determine it
162    if model_device != target_device:
163        model_ref[0] = model.to_empty(device=target_device)
164
165    current_txt: list[Tensor] = [txt]
166    current_txt_ids: list[Tensor] = [txt_ids]
167    assert vec is not None, "vec (CLIP embeddings) is required for Flux1"
168    current_vec: list[Tensor] = [vec]
169    if neg_pred_enabled and all([neg_txt, neg_txt_ids, neg_vec]):
170        current_neg_txt: list[Tensor] = [neg_txt]  # type: ignore
171        current_neg_txt_ids: list[Tensor] = [neg_txt_ids]  # type: ignore
172        current_neg_vec: list[Tensor] = [neg_vec]  # type: ignore
173        true_gs = true_gs
174    else:
175        current_neg_txt: list[Tensor] | None = None
176        current_neg_txt_ids: list[Tensor] | None = None
177        current_neg_vec: list[Tensor] | None = None
178        true_gs = 1
179    current_prompt: list[Optional[str]] = [state.prompt]  # Track current prompt to detect changes
180
181    clear_prediction_cache = create_clear_prediction_cache(cached_prediction, cached_prediction_state)
182
183    recompute_text_embeddings = create_recompute_text_embeddings(  # formatting
184        img, t5, clip, current_txt, current_txt_ids, current_vec, current_prompt, clear_prediction_cache, is_flux2=False
185    )
186
187    pred_set = TextEmbeddingState(
188        model_ref=model_ref,
189        state=state,
190        current_txt=current_txt,
191        current_txt_ids=current_txt_ids,
192        current_vec=current_vec,
193        cached_prediction=cached_prediction,
194        cached_prediction_state=cached_prediction_state,
195        neg_pred_enabled=neg_pred_enabled,
196        current_neg_txt=current_neg_txt,  # pyright: ignore[reportArgumentType]
197        current_neg_txt_ids=current_neg_txt_ids,  # pyright: ignore[reportArgumentType]
198        current_neg_vec=current_neg_vec,  # pyright: ignore[reportArgumentType]
199        true_gs=int(true_gs) if true_gs is not None else None,
200    )
201    img_set = ImageEmbeddingState(
202        img_ids=img_ids,
203        img=img,
204        img_cond=img_cond,
205        img_cond_seq=img_cond_seq,
206        img_cond_seq_ids=img_cond_seq_ids,
207    )
208    get_prediction = create_get_prediction(pred_set, img_set)
209
210    denoise_step_fn = create_denoise_step_fn(  # formatting
211        controller_ref, current_layer_dropout, previous_step_tensor, get_prediction
212    )
213
214    controller = ManualTimestepController(  # formatting
215        timesteps=timesteps, initial_sample=img, denoise_step_fn=denoise_step_fn, initial_guidance=state.guidance
216    )
217    controller_ref[0] = controller  # Store reference for closure access
218
219    # Use state.layer_dropout if available, otherwise fall back to initial_layer_dropout
220    layer_dropout_to_set = state.layer_dropout if state.layer_dropout is not None else initial_layer_dropout
221    controller.set_layer_dropout(layer_dropout_to_set)
222
223    if state.width is not None and state.height is not None:
224        controller.set_resolution(state.width, state.height)
225    if state.seed is not None:
226        controller.set_seed(state.seed)
227    if state.prompt is not None:
228        controller.set_prompt(state.prompt)
229    if state.num_steps is not None:
230        controller.set_num_steps(state.num_steps)
231    controller.set_vae_shift_offset(state.vae_shift_offset)
232    controller.set_vae_scale_offset(state.vae_scale_offset)
233    controller.set_use_previous_as_mask(state.use_previous_as_mask)
234
235    # Interactive loop
236    while not controller.is_complete:
237        state = controller.current_state
238
239        # Check if prompt changed and recompute embeddings if needed
240        if state.prompt is not None and state.prompt != current_prompt[0]:
241            if t5 is not None and clip is not None:
242                recompute_text_embeddings(state.prompt)
243            else:
244                # If embedders not available, update current_prompt to avoid repeated checks
245                current_prompt[0] = state.prompt
246
247        interaction_context = InteractionContext(
248            clear_prediction_cache=clear_prediction_cache,
249            rng=rng,
250            variation_rng=variation_rng,
251            ae=ae,
252            t5=t5,
253            clip=clip,
254            recompute_text_embeddings=recompute_text_embeddings,
255        )
256        state = route_choices(
257            controller,
258            state,
259            interaction_context,
260        )
261
262        # Generate preview
263        t0 = time.perf_counter()
264        if state.seed is not None:
265            rng.next_seed(state.seed)
266        else:
267            state.seed = rng.next_seed()
268        if ae is not None and state.width is not None and state.height is not None:
269            # Reuse cached prediction if available, otherwise generate it
270            # This will be cached and reused in denoise_step_fn when advancing
271            # Always use state.layer_dropout from controller to ensure consistency
272            pred_preview = get_prediction(
273                state.current_sample,
274                state.current_timestep,
275                state.guidance,
276                state.layer_dropout,
277            )
278
279            intermediate = state.current_sample - state.current_timestep * pred_preview
280            # Unpack requires float32, but we'll convert back to correct dtype after
281            intermediate = unpack(intermediate.float(), state.height, state.width)
282
283            gfx_sync
284            t1 = time.perf_counter()
285
286            nfo(f"Step time: {t1 - t0:.1f}s")
287
288            if denoise_device.type == "cuda":
289                context = torch.autocast(device_type=denoise_device.type, dtype=torch.bfloat16)
290            else:
291                from contextlib import nullcontext
292
293                context = nullcontext()
294            with context:
295                # When autocast is disabled (MPS), ensure intermediate is in correct dtype for VAE
296                if denoise_device.type != "cuda":
297                    # Get VAE encoder dtype to ensure intermediate matches (bfloat16)
298                    # Safely get encoder dtype, handling Mock objects in tests
299                    try:
300                        ae_dtype = next(ae.encoder.parameters()).dtype
301                    except (TypeError, StopIteration, AttributeError):
302                        # Fallback: use intermediate dtype if we can't get encoder dtype (for Mock objects in tests)
303                        ae_dtype = intermediate.dtype
304                    intermediate = intermediate.to(dtype=ae_dtype)
305
306                # Apply VAE shift/scale offset by manually adjusting the decode operation
307                if state.vae_shift_offset != 0.0 or state.vae_scale_offset != 0.0:
308                    # Decode with offset: z = z / (scale_factor + scale_offset) + (shift_factor + shift_offset)
309                    z_adjusted = intermediate / (ae.scale_factor + state.vae_scale_offset) + (ae.shift_factor + state.vae_shift_offset)
310                    intermediate_image = ae.decoder(z_adjusted)
311                else:
312                    intermediate_image = ae.decode(intermediate)
313                if state.seed is not None:
314                    controller.store_state_in_chain(current_seed=state.seed)
315                with SaveFile() as saver:
316                    saver.intermediate_image = intermediate_image  # set up image
317                    saver.hyperchain = (controller.hyperchain,)  # set up hyperchain
318                    saver.with_hyperchain()
319
320    return controller.current_sample

Denoise using Flux model with optional ManualTimestepController.

Parameters
  • model: Flux model instance
  • settings: InferenceState containing all denoising configuration parameters
def unpack(x: torch.Tensor, height: int, width: int) -> torch.Tensor:
323def unpack(x: Tensor, height: int, width: int) -> Tensor:
324    return rearrange(
325        x,
326        "b (h w) (c ph pw) -> b c (h ph) (w pw)",
327        h=math.ceil(height / 16),
328        w=math.ceil(width / 16),
329        ph=2,
330        pw=2,
331    )