divisor.xflux1.sampling

  1# SPDX-License-Identifier:Apache-2.0
  2# original XFlux code from https://github.com/TencentARC/FluxKits
  3
  4import time
  5from typing import Callable
  6
  7import torch
  8from nnll.console import nfo
  9from torch import Tensor
 10
 11from divisor.cli_menu import route_choices
 12from divisor.controller import ManualTimestepController, rng, variation_rng
 13from divisor.denoise_step import (
 14    create_clear_prediction_cache,
 15    create_denoise_step_fn,
 16    create_recompute_text_embeddings,
 17)
 18from divisor.flux1.sampling import prepare, unpack
 19from divisor.interaction_context import InteractionContext
 20from divisor.registry import gfx_device, gfx_sync
 21from divisor.save import SaveFile
 22from divisor.state import (
 23    ImageEmbeddingState,
 24    InferenceState,
 25    StepStateXFlux1,
 26    TextEmbeddingState,
 27)
 28from divisor.xflux1.model import XFlux
 29
 30
 31def create_get_prediction_xflux1(
 32    pred_set: TextEmbeddingState,
 33    img_set: ImageEmbeddingState,
 34    add_set: StepStateXFlux1,
 35) -> Callable[[Tensor, float, float, list[int] | None], Tensor]:
 36    """Create a function to generate model prediction with caching for XFlux1.\n
 37    :param config: GetPredictionSettings containing all configuration parameters
 38    :return: Function that generates predictions with caching"""
 39
 40    def get_prediction(
 41        sample: Tensor,
 42        t_curr: float,
 43        guidance_val: float,
 44        layer_dropouts_val: list[int] | None,
 45    ) -> Tensor:
 46        """Generate model prediction, reusing cached prediction if state hasn't changed.\n
 47        :param sample: Current sample tensor
 48        :param t_curr: Current timestep
 49        :param guidance_val: Guidance value
 50        :param layer_dropouts_val: Layer dropout configuration
 51        :returns: Model prediction"""
 52        # Create a simple hash of the sample tensor for cache key
 53        if sample.numel() > 0:
 54            first_val = float(sample.flatten()[0].item())
 55        else:
 56            first_val = 0.0
 57        sample_hash = hash((sample.shape, first_val))
 58
 59        # Check if we can reuse cached prediction
 60        current_state = {
 61            "sample_hash": sample_hash,
 62            "t_curr": t_curr,
 63            "guidance": guidance_val,
 64            "layer_dropout": layer_dropouts_val,
 65        }
 66
 67        if (
 68            pred_set.cached_prediction[0] is not None
 69            and pred_set.cached_prediction_state[0] is not None
 70            and pred_set.cached_prediction_state[0]["sample_hash"] == current_state["sample_hash"]
 71            and pred_set.cached_prediction_state[0]["t_curr"] == current_state["t_curr"]
 72            and pred_set.cached_prediction_state[0]["guidance"] == current_state["guidance"]
 73            and pred_set.cached_prediction_state[0]["layer_dropout"] == current_state["layer_dropout"]
 74        ):
 75            return pred_set.cached_prediction[0]
 76
 77        # Generate new prediction
 78        try:
 79            model_dtype = next(pred_set.model_ref[0].parameters()).dtype
 80        except (TypeError, StopIteration, AttributeError):
 81            model_dtype = sample.dtype
 82        use_autocast = gfx_device.type == "cuda"
 83
 84        if not use_autocast:
 85            sample = sample.to(dtype=model_dtype)
 86
 87        t_vec = torch.full((sample.shape[0],), t_curr, dtype=sample.dtype, device=sample.device)
 88        guidance_vec = torch.full((img_set.img.shape[0],), guidance_val, device=img_set.img.device, dtype=img_set.img.dtype)
 89
 90        if not use_autocast:
 91            img_input = sample.to(dtype=model_dtype)
 92            txt_input = pred_set.current_txt[0].to(dtype=model_dtype)
 93            vec_input = pred_set.current_vec[0].to(dtype=model_dtype)
 94            t_vec = t_vec.to(dtype=model_dtype)
 95            guidance_vec = guidance_vec.to(dtype=model_dtype)
 96        else:
 97            img_input = sample
 98            txt_input = pred_set.current_txt[0]
 99            vec_input = pred_set.current_vec[0]
100
101        # Get current timestep index for CFG check
102        timestep_index = add_set.current_timestep_index[0]
103
104        kwargs = {}
105        if "image_proj" in pred_set.model_ref[0].__dict__:
106            kwargs = {"image_proj": img_set.image_proj}
107        if "ip_scale" in pred_set.model_ref[0].__dict__:
108            kwargs.setdefault("ip_scale", img_set.ip_scale)
109
110        # Generate positive prediction
111        pred = pred_set.model_ref[0](
112            img=img_input,
113            img_ids=img_set.img_ids,
114            txt=txt_input,
115            txt_ids=pred_set.current_txt_ids[0],
116            y=vec_input,
117            timesteps=t_vec,
118            guidance=guidance_vec,
119            **kwargs,
120        )
121
122        # Apply CFG if enabled and past start timestep
123        if (
124            pred_set.neg_pred_enabled
125            and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec])
126            and timestep_index >= add_set.timestep_to_start_cfg
127        ):
128            if not use_autocast:
129                neg_txt_input = pred_set.current_neg_txt[0].to(dtype=model_dtype)  # type: ignore
130                neg_vec_input = pred_set.current_neg_vec[0].to(dtype=model_dtype)  # type: ignore
131            else:
132                neg_txt_input = pred_set.current_neg_txt[0]  # type: ignore
133                neg_vec_input = pred_set.current_neg_vec[0]  # type: ignore
134
135            neg_pred = pred_set.model_ref[0](
136                img=img_input,
137                img_ids=img_set.img_ids,
138                txt=neg_txt_input,
139                txt_ids=pred_set.current_neg_txt_ids[0],  # type: ignore
140                y=neg_vec_input,
141                timesteps=t_vec,
142                guidance=guidance_vec,
143                image_proj=img_set.neg_image_proj,
144                ip_scale=img_set.neg_ip_scale,
145            )
146            pred = neg_pred + pred_set.true_gs * (pred - neg_pred)
147
148        # Cache the prediction
149        pred_set.cached_prediction[0] = pred
150        pred_set.cached_prediction_state[0] = current_state
151
152        return pred
153
154    return get_prediction
155
156
157@torch.inference_mode()
158def denoise(
159    model: XFlux,
160    settings: InferenceState,
161):
162    """Denoise using XFlux model with optional ManualTimestepController.\n
163    :param model: XFlux model instance
164    :param settings: Denoising state containing all denoising configuration parameters"""
165
166    # Extract settings for easier access
167    img = settings.img
168    img_ids = settings.img_ids
169    txt = settings.txt
170    txt_ids = settings.txt_ids
171    vec = settings.vec
172    neg_pred_enabled = settings.neg_pred_enabled
173    neg_txt = settings.neg_txt
174    neg_txt_ids = settings.neg_txt_ids
175    neg_vec = settings.neg_vec
176    state = settings.state
177    ae = settings.ae
178    timesteps = settings.timesteps
179    true_gs = settings.true_gs
180    timestep_to_start_cfg = settings.timestep_to_start_cfg
181    image_proj = settings.image_proj
182    neg_image_proj = settings.neg_image_proj
183    ip_scale = settings.ip_scale if settings.ip_scale is not None else torch.tensor(1.0)
184    neg_ip_scale = settings.neg_ip_scale if settings.neg_ip_scale is not None else torch.tensor(1.0)
185    initial_layer_dropout = settings.initial_layer_dropout
186    t5 = settings.t5
187    clip = settings.clip
188
189    # this is ignored for schnell
190    current_layer_dropout = [initial_layer_dropout]
191    previous_step_tensor: list[Tensor | None] = [None]  # Store previous step's tensor for masking
192    cached_prediction: list[Tensor | None] = [None]  # Cache prediction to avoid duplicate model calls
193    cached_prediction_state: list[dict | None] = [None]  # Cache state when prediction was generated
194    controller_ref: list[ManualTimestepController | None] = [None]  # Reference to controller for closure access
195    current_timestep_index: list[int] = [0]  # Track current timestep index for CFG
196
197    model_ref: list[XFlux] = [model]
198    # Ensure model is on the correct device (fixes meta device issue)
199    target_device = img.device
200    # Safely get model device, handling Mock objects in tests
201    try:
202        model_device = next(model.parameters()).device
203    except (TypeError, StopIteration, AttributeError):
204        # Fallback for Mock objects or models without parameters
205        # Assume model is already on correct device if we can't determine it
206        model_device = target_device
207    if model_device != target_device:
208        model_ref[0] = model.to_empty(device=target_device)
209
210    # Store embeddings in mutable containers so they can be updated when prompt changes
211    current_txt: list[Tensor] = [txt]
212    current_txt_ids: list[Tensor] = [txt_ids]
213    assert vec is not None, "vec (CLIP embeddings) is required for XFlux1"
214    current_vec: list[Tensor] = [vec]
215
216    # Handle negative prompts: XFlux1 supports negative prompts but they're optional
217    # If neg_pred_enabled is True, negative prompts must be provided
218    if neg_pred_enabled:
219        if neg_txt is None or neg_txt_ids is None or neg_vec is None:
220            # Generate empty negative prompts if not provided but neg_pred_enabled is True
221            if t5 is not None and clip is not None:
222                # Use empty string to generate embeddings
223                neg_inp = prepare(t5, clip, img, prompt="")
224                neg_txt = neg_inp["txt"]
225                neg_txt_ids = neg_inp["txt_ids"]
226                neg_vec = neg_inp["vec"]
227            else:
228                # If embedders not available, disable negative prompts
229                neg_pred_enabled = False
230
231    if neg_pred_enabled and neg_txt is not None and neg_txt_ids is not None and neg_vec is not None:
232        current_neg_txt: list[Tensor] = [neg_txt]  # type: ignore
233        current_neg_txt_ids: list[Tensor] = [neg_txt_ids]  # type: ignore
234        current_neg_vec: list[Tensor] = [neg_vec]  # type: ignore
235    else:
236        current_neg_txt: list[Tensor] | None = None
237        current_neg_txt_ids: list[Tensor] | None = None
238        current_neg_vec: list[Tensor] | None = None
239        neg_pred_enabled = False
240    current_prompt: list[str | None] = [state.prompt]  # Track current prompt to detect changes
241
242    clear_prediction_cache = create_clear_prediction_cache(cached_prediction, cached_prediction_state)
243
244    recompute_text_embeddings = create_recompute_text_embeddings(img, t5, clip, current_txt, current_txt_ids, current_vec, current_prompt, clear_prediction_cache, is_flux2=False)
245
246    pred_set = TextEmbeddingState(
247        model_ref=model_ref,
248        state=state,
249        current_txt=current_txt,
250        current_txt_ids=current_txt_ids,
251        current_vec=current_vec,
252        cached_prediction=cached_prediction,
253        cached_prediction_state=cached_prediction_state,
254        neg_pred_enabled=neg_pred_enabled,
255        current_neg_txt=current_neg_txt,  # type: ignore
256        current_neg_txt_ids=current_neg_txt_ids,  # type: ignore
257        current_neg_vec=current_neg_vec,  # type: ignore
258        true_gs=true_gs,
259    )
260    img_set = ImageEmbeddingState(
261        img_ids=img_ids,
262        img=img,
263        img_cond=None,
264        img_cond_seq=None,
265        img_cond_seq_ids=None,
266        image_proj=image_proj,
267        neg_image_proj=neg_image_proj,
268        ip_scale=ip_scale,
269        neg_ip_scale=neg_ip_scale,
270    )
271    add_set = StepStateXFlux1(
272        timestep_to_start_cfg=timestep_to_start_cfg,
273        current_timestep_index=current_timestep_index,
274    )
275    get_prediction = create_get_prediction_xflux1(pred_set, img_set, add_set)
276
277    denoise_step_fn = create_denoise_step_fn(controller_ref, current_layer_dropout, previous_step_tensor, get_prediction)
278
279    controller = ManualTimestepController(timesteps=timesteps, initial_sample=img, denoise_step_fn=denoise_step_fn, initial_guidance=state.guidance)
280    controller_ref[0] = controller  # Store reference for closure access
281
282    # Use state.layer_dropout if available, otherwise fall back to initial_layer_dropout
283    layer_dropout_to_set = state.layer_dropout if state.layer_dropout is not None else initial_layer_dropout
284    controller.set_layer_dropout(layer_dropout_to_set)
285
286    if state.width is not None and state.height is not None:
287        controller.set_resolution(state.width, state.height)
288    if state.seed is not None:
289        controller.set_seed(state.seed)
290    if state.prompt is not None:
291        controller.set_prompt(state.prompt)
292    if state.num_steps is not None:
293        controller.set_num_steps(state.num_steps)
294    controller.set_vae_shift_offset(state.vae_shift_offset)
295    controller.set_vae_scale_offset(state.vae_scale_offset)
296    controller.set_use_previous_as_mask(state.use_previous_as_mask)
297
298    # Interactive loop
299    while not controller.is_complete:
300        state = controller.current_state
301
302        # Update timestep index for CFG check
303        current_timestep_index[0] = state.timestep_index if hasattr(state, "timestep_index") else 0
304
305        # Check if prompt changed and recompute embeddings if needed
306        if state.prompt is not None and state.prompt != current_prompt[0]:
307            if t5 is not None and clip is not None:
308                recompute_text_embeddings(state.prompt)
309            else:
310                # If embedders not available, update current_prompt to avoid repeated checks
311                current_prompt[0] = state.prompt
312
313        interaction_context = InteractionContext(
314            clear_prediction_cache=clear_prediction_cache,
315            rng=rng,
316            variation_rng=variation_rng,
317            ae=ae,
318            t5=t5,
319            clip=clip,
320            recompute_text_embeddings=recompute_text_embeddings,
321        )
322        state = route_choices(
323            controller,
324            state,
325            interaction_context,
326        )
327
328        # Generate preview
329        t0 = time.perf_counter()
330        if state.seed is not None:
331            rng.next_seed(state.seed)
332        else:
333            state.seed = rng.next_seed()
334        if ae is not None and state.width is not None and state.height is not None:
335            # Reuse cached prediction if available, otherwise generate it
336            # This will be cached and reused in denoise_step_fn when advancing
337            # Always use state.layer_dropout from controller to ensure consistency
338            pred_preview = get_prediction(
339                state.current_sample,
340                state.current_timestep,
341                state.guidance,
342                state.layer_dropout,
343            )
344
345            intermediate = state.current_sample - state.current_timestep * pred_preview
346            # Unpack requires float32, but we'll convert back to correct dtype after
347            intermediate = unpack(intermediate.float(), state.height, state.width)
348
349            from divisor.registry import gfx_device as default_device
350
351            gfx_sync
352            t1 = time.perf_counter()
353
354            nfo(f"Step time: {t1 - t0:.1f}s")
355
356            if default_device.type == "cuda":
357                context = torch.autocast(device_type=default_device.type, dtype=torch.bfloat16)
358            else:
359                from contextlib import nullcontext
360
361                context = nullcontext()
362            with context:
363                # When autocast is disabled (MPS), ensure intermediate is in correct dtype for VAE
364                if default_device.type != "cuda":
365                    # Get VAE encoder dtype to ensure intermediate matches (bfloat16)
366                    # Safely get encoder dtype, handling Mock objects in tests
367                    try:
368                        ae_dtype = next(ae.encoder.parameters()).dtype
369                    except (TypeError, StopIteration, AttributeError):
370                        # Fallback: use intermediate dtype if we can't get encoder dtype (for Mock objects in tests)
371                        ae_dtype = intermediate.dtype
372                    intermediate = intermediate.to(dtype=ae_dtype)
373
374                # Apply VAE shift/scale offset by manually adjusting the decode operation
375                if state.vae_shift_offset != 0.0 or state.vae_scale_offset != 0.0:
376                    # Decode with offset: z = z / (scale_factor + scale_offset) + (shift_factor + shift_offset)
377                    z_adjusted = intermediate / (ae.scale_factor + state.vae_scale_offset) + (ae.shift_factor + state.vae_shift_offset)
378                    intermediate_image = ae.decoder(z_adjusted)
379                else:
380                    intermediate_image = ae.decode(intermediate)
381                if state.seed is not None:
382                    controller.store_state_in_chain(current_seed=state.seed)
383                with SaveFile() as saver:
384                    saver.intermediate_image = intermediate_image  # set up image
385                    saver.hyperchain = (controller.hyperchain,)  # set up hyperchain
386                    saver.with_hyperchain()
387
388    return controller.current_sample
def create_get_prediction_xflux1( pred_set: divisor.state.TextEmbeddingState, img_set: divisor.state.ImageEmbeddingState, add_set: divisor.state.StepStateXFlux1) -> Callable[[torch.Tensor, float, float, Optional[list[int]]], torch.Tensor]:
 32def create_get_prediction_xflux1(
 33    pred_set: TextEmbeddingState,
 34    img_set: ImageEmbeddingState,
 35    add_set: StepStateXFlux1,
 36) -> Callable[[Tensor, float, float, list[int] | None], Tensor]:
 37    """Create a function to generate model prediction with caching for XFlux1.\n
 38    :param config: GetPredictionSettings containing all configuration parameters
 39    :return: Function that generates predictions with caching"""
 40
 41    def get_prediction(
 42        sample: Tensor,
 43        t_curr: float,
 44        guidance_val: float,
 45        layer_dropouts_val: list[int] | None,
 46    ) -> Tensor:
 47        """Generate model prediction, reusing cached prediction if state hasn't changed.\n
 48        :param sample: Current sample tensor
 49        :param t_curr: Current timestep
 50        :param guidance_val: Guidance value
 51        :param layer_dropouts_val: Layer dropout configuration
 52        :returns: Model prediction"""
 53        # Create a simple hash of the sample tensor for cache key
 54        if sample.numel() > 0:
 55            first_val = float(sample.flatten()[0].item())
 56        else:
 57            first_val = 0.0
 58        sample_hash = hash((sample.shape, first_val))
 59
 60        # Check if we can reuse cached prediction
 61        current_state = {
 62            "sample_hash": sample_hash,
 63            "t_curr": t_curr,
 64            "guidance": guidance_val,
 65            "layer_dropout": layer_dropouts_val,
 66        }
 67
 68        if (
 69            pred_set.cached_prediction[0] is not None
 70            and pred_set.cached_prediction_state[0] is not None
 71            and pred_set.cached_prediction_state[0]["sample_hash"] == current_state["sample_hash"]
 72            and pred_set.cached_prediction_state[0]["t_curr"] == current_state["t_curr"]
 73            and pred_set.cached_prediction_state[0]["guidance"] == current_state["guidance"]
 74            and pred_set.cached_prediction_state[0]["layer_dropout"] == current_state["layer_dropout"]
 75        ):
 76            return pred_set.cached_prediction[0]
 77
 78        # Generate new prediction
 79        try:
 80            model_dtype = next(pred_set.model_ref[0].parameters()).dtype
 81        except (TypeError, StopIteration, AttributeError):
 82            model_dtype = sample.dtype
 83        use_autocast = gfx_device.type == "cuda"
 84
 85        if not use_autocast:
 86            sample = sample.to(dtype=model_dtype)
 87
 88        t_vec = torch.full((sample.shape[0],), t_curr, dtype=sample.dtype, device=sample.device)
 89        guidance_vec = torch.full((img_set.img.shape[0],), guidance_val, device=img_set.img.device, dtype=img_set.img.dtype)
 90
 91        if not use_autocast:
 92            img_input = sample.to(dtype=model_dtype)
 93            txt_input = pred_set.current_txt[0].to(dtype=model_dtype)
 94            vec_input = pred_set.current_vec[0].to(dtype=model_dtype)
 95            t_vec = t_vec.to(dtype=model_dtype)
 96            guidance_vec = guidance_vec.to(dtype=model_dtype)
 97        else:
 98            img_input = sample
 99            txt_input = pred_set.current_txt[0]
100            vec_input = pred_set.current_vec[0]
101
102        # Get current timestep index for CFG check
103        timestep_index = add_set.current_timestep_index[0]
104
105        kwargs = {}
106        if "image_proj" in pred_set.model_ref[0].__dict__:
107            kwargs = {"image_proj": img_set.image_proj}
108        if "ip_scale" in pred_set.model_ref[0].__dict__:
109            kwargs.setdefault("ip_scale", img_set.ip_scale)
110
111        # Generate positive prediction
112        pred = pred_set.model_ref[0](
113            img=img_input,
114            img_ids=img_set.img_ids,
115            txt=txt_input,
116            txt_ids=pred_set.current_txt_ids[0],
117            y=vec_input,
118            timesteps=t_vec,
119            guidance=guidance_vec,
120            **kwargs,
121        )
122
123        # Apply CFG if enabled and past start timestep
124        if (
125            pred_set.neg_pred_enabled
126            and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec])
127            and timestep_index >= add_set.timestep_to_start_cfg
128        ):
129            if not use_autocast:
130                neg_txt_input = pred_set.current_neg_txt[0].to(dtype=model_dtype)  # type: ignore
131                neg_vec_input = pred_set.current_neg_vec[0].to(dtype=model_dtype)  # type: ignore
132            else:
133                neg_txt_input = pred_set.current_neg_txt[0]  # type: ignore
134                neg_vec_input = pred_set.current_neg_vec[0]  # type: ignore
135
136            neg_pred = pred_set.model_ref[0](
137                img=img_input,
138                img_ids=img_set.img_ids,
139                txt=neg_txt_input,
140                txt_ids=pred_set.current_neg_txt_ids[0],  # type: ignore
141                y=neg_vec_input,
142                timesteps=t_vec,
143                guidance=guidance_vec,
144                image_proj=img_set.neg_image_proj,
145                ip_scale=img_set.neg_ip_scale,
146            )
147            pred = neg_pred + pred_set.true_gs * (pred - neg_pred)
148
149        # Cache the prediction
150        pred_set.cached_prediction[0] = pred
151        pred_set.cached_prediction_state[0] = current_state
152
153        return pred
154
155    return get_prediction

Create a function to generate model prediction with caching for XFlux1.

Parameters
  • config: GetPredictionSettings containing all configuration parameters
Returns

Function that generates predictions with caching

@torch.inference_mode()
def denoise( model: divisor.xflux1.model.XFlux, settings: divisor.state.InferenceState):
158@torch.inference_mode()
159def denoise(
160    model: XFlux,
161    settings: InferenceState,
162):
163    """Denoise using XFlux model with optional ManualTimestepController.\n
164    :param model: XFlux model instance
165    :param settings: Denoising state containing all denoising configuration parameters"""
166
167    # Extract settings for easier access
168    img = settings.img
169    img_ids = settings.img_ids
170    txt = settings.txt
171    txt_ids = settings.txt_ids
172    vec = settings.vec
173    neg_pred_enabled = settings.neg_pred_enabled
174    neg_txt = settings.neg_txt
175    neg_txt_ids = settings.neg_txt_ids
176    neg_vec = settings.neg_vec
177    state = settings.state
178    ae = settings.ae
179    timesteps = settings.timesteps
180    true_gs = settings.true_gs
181    timestep_to_start_cfg = settings.timestep_to_start_cfg
182    image_proj = settings.image_proj
183    neg_image_proj = settings.neg_image_proj
184    ip_scale = settings.ip_scale if settings.ip_scale is not None else torch.tensor(1.0)
185    neg_ip_scale = settings.neg_ip_scale if settings.neg_ip_scale is not None else torch.tensor(1.0)
186    initial_layer_dropout = settings.initial_layer_dropout
187    t5 = settings.t5
188    clip = settings.clip
189
190    # this is ignored for schnell
191    current_layer_dropout = [initial_layer_dropout]
192    previous_step_tensor: list[Tensor | None] = [None]  # Store previous step's tensor for masking
193    cached_prediction: list[Tensor | None] = [None]  # Cache prediction to avoid duplicate model calls
194    cached_prediction_state: list[dict | None] = [None]  # Cache state when prediction was generated
195    controller_ref: list[ManualTimestepController | None] = [None]  # Reference to controller for closure access
196    current_timestep_index: list[int] = [0]  # Track current timestep index for CFG
197
198    model_ref: list[XFlux] = [model]
199    # Ensure model is on the correct device (fixes meta device issue)
200    target_device = img.device
201    # Safely get model device, handling Mock objects in tests
202    try:
203        model_device = next(model.parameters()).device
204    except (TypeError, StopIteration, AttributeError):
205        # Fallback for Mock objects or models without parameters
206        # Assume model is already on correct device if we can't determine it
207        model_device = target_device
208    if model_device != target_device:
209        model_ref[0] = model.to_empty(device=target_device)
210
211    # Store embeddings in mutable containers so they can be updated when prompt changes
212    current_txt: list[Tensor] = [txt]
213    current_txt_ids: list[Tensor] = [txt_ids]
214    assert vec is not None, "vec (CLIP embeddings) is required for XFlux1"
215    current_vec: list[Tensor] = [vec]
216
217    # Handle negative prompts: XFlux1 supports negative prompts but they're optional
218    # If neg_pred_enabled is True, negative prompts must be provided
219    if neg_pred_enabled:
220        if neg_txt is None or neg_txt_ids is None or neg_vec is None:
221            # Generate empty negative prompts if not provided but neg_pred_enabled is True
222            if t5 is not None and clip is not None:
223                # Use empty string to generate embeddings
224                neg_inp = prepare(t5, clip, img, prompt="")
225                neg_txt = neg_inp["txt"]
226                neg_txt_ids = neg_inp["txt_ids"]
227                neg_vec = neg_inp["vec"]
228            else:
229                # If embedders not available, disable negative prompts
230                neg_pred_enabled = False
231
232    if neg_pred_enabled and neg_txt is not None and neg_txt_ids is not None and neg_vec is not None:
233        current_neg_txt: list[Tensor] = [neg_txt]  # type: ignore
234        current_neg_txt_ids: list[Tensor] = [neg_txt_ids]  # type: ignore
235        current_neg_vec: list[Tensor] = [neg_vec]  # type: ignore
236    else:
237        current_neg_txt: list[Tensor] | None = None
238        current_neg_txt_ids: list[Tensor] | None = None
239        current_neg_vec: list[Tensor] | None = None
240        neg_pred_enabled = False
241    current_prompt: list[str | None] = [state.prompt]  # Track current prompt to detect changes
242
243    clear_prediction_cache = create_clear_prediction_cache(cached_prediction, cached_prediction_state)
244
245    recompute_text_embeddings = create_recompute_text_embeddings(img, t5, clip, current_txt, current_txt_ids, current_vec, current_prompt, clear_prediction_cache, is_flux2=False)
246
247    pred_set = TextEmbeddingState(
248        model_ref=model_ref,
249        state=state,
250        current_txt=current_txt,
251        current_txt_ids=current_txt_ids,
252        current_vec=current_vec,
253        cached_prediction=cached_prediction,
254        cached_prediction_state=cached_prediction_state,
255        neg_pred_enabled=neg_pred_enabled,
256        current_neg_txt=current_neg_txt,  # type: ignore
257        current_neg_txt_ids=current_neg_txt_ids,  # type: ignore
258        current_neg_vec=current_neg_vec,  # type: ignore
259        true_gs=true_gs,
260    )
261    img_set = ImageEmbeddingState(
262        img_ids=img_ids,
263        img=img,
264        img_cond=None,
265        img_cond_seq=None,
266        img_cond_seq_ids=None,
267        image_proj=image_proj,
268        neg_image_proj=neg_image_proj,
269        ip_scale=ip_scale,
270        neg_ip_scale=neg_ip_scale,
271    )
272    add_set = StepStateXFlux1(
273        timestep_to_start_cfg=timestep_to_start_cfg,
274        current_timestep_index=current_timestep_index,
275    )
276    get_prediction = create_get_prediction_xflux1(pred_set, img_set, add_set)
277
278    denoise_step_fn = create_denoise_step_fn(controller_ref, current_layer_dropout, previous_step_tensor, get_prediction)
279
280    controller = ManualTimestepController(timesteps=timesteps, initial_sample=img, denoise_step_fn=denoise_step_fn, initial_guidance=state.guidance)
281    controller_ref[0] = controller  # Store reference for closure access
282
283    # Use state.layer_dropout if available, otherwise fall back to initial_layer_dropout
284    layer_dropout_to_set = state.layer_dropout if state.layer_dropout is not None else initial_layer_dropout
285    controller.set_layer_dropout(layer_dropout_to_set)
286
287    if state.width is not None and state.height is not None:
288        controller.set_resolution(state.width, state.height)
289    if state.seed is not None:
290        controller.set_seed(state.seed)
291    if state.prompt is not None:
292        controller.set_prompt(state.prompt)
293    if state.num_steps is not None:
294        controller.set_num_steps(state.num_steps)
295    controller.set_vae_shift_offset(state.vae_shift_offset)
296    controller.set_vae_scale_offset(state.vae_scale_offset)
297    controller.set_use_previous_as_mask(state.use_previous_as_mask)
298
299    # Interactive loop
300    while not controller.is_complete:
301        state = controller.current_state
302
303        # Update timestep index for CFG check
304        current_timestep_index[0] = state.timestep_index if hasattr(state, "timestep_index") else 0
305
306        # Check if prompt changed and recompute embeddings if needed
307        if state.prompt is not None and state.prompt != current_prompt[0]:
308            if t5 is not None and clip is not None:
309                recompute_text_embeddings(state.prompt)
310            else:
311                # If embedders not available, update current_prompt to avoid repeated checks
312                current_prompt[0] = state.prompt
313
314        interaction_context = InteractionContext(
315            clear_prediction_cache=clear_prediction_cache,
316            rng=rng,
317            variation_rng=variation_rng,
318            ae=ae,
319            t5=t5,
320            clip=clip,
321            recompute_text_embeddings=recompute_text_embeddings,
322        )
323        state = route_choices(
324            controller,
325            state,
326            interaction_context,
327        )
328
329        # Generate preview
330        t0 = time.perf_counter()
331        if state.seed is not None:
332            rng.next_seed(state.seed)
333        else:
334            state.seed = rng.next_seed()
335        if ae is not None and state.width is not None and state.height is not None:
336            # Reuse cached prediction if available, otherwise generate it
337            # This will be cached and reused in denoise_step_fn when advancing
338            # Always use state.layer_dropout from controller to ensure consistency
339            pred_preview = get_prediction(
340                state.current_sample,
341                state.current_timestep,
342                state.guidance,
343                state.layer_dropout,
344            )
345
346            intermediate = state.current_sample - state.current_timestep * pred_preview
347            # Unpack requires float32, but we'll convert back to correct dtype after
348            intermediate = unpack(intermediate.float(), state.height, state.width)
349
350            from divisor.registry import gfx_device as default_device
351
352            gfx_sync
353            t1 = time.perf_counter()
354
355            nfo(f"Step time: {t1 - t0:.1f}s")
356
357            if default_device.type == "cuda":
358                context = torch.autocast(device_type=default_device.type, dtype=torch.bfloat16)
359            else:
360                from contextlib import nullcontext
361
362                context = nullcontext()
363            with context:
364                # When autocast is disabled (MPS), ensure intermediate is in correct dtype for VAE
365                if default_device.type != "cuda":
366                    # Get VAE encoder dtype to ensure intermediate matches (bfloat16)
367                    # Safely get encoder dtype, handling Mock objects in tests
368                    try:
369                        ae_dtype = next(ae.encoder.parameters()).dtype
370                    except (TypeError, StopIteration, AttributeError):
371                        # Fallback: use intermediate dtype if we can't get encoder dtype (for Mock objects in tests)
372                        ae_dtype = intermediate.dtype
373                    intermediate = intermediate.to(dtype=ae_dtype)
374
375                # Apply VAE shift/scale offset by manually adjusting the decode operation
376                if state.vae_shift_offset != 0.0 or state.vae_scale_offset != 0.0:
377                    # Decode with offset: z = z / (scale_factor + scale_offset) + (shift_factor + shift_offset)
378                    z_adjusted = intermediate / (ae.scale_factor + state.vae_scale_offset) + (ae.shift_factor + state.vae_shift_offset)
379                    intermediate_image = ae.decoder(z_adjusted)
380                else:
381                    intermediate_image = ae.decode(intermediate)
382                if state.seed is not None:
383                    controller.store_state_in_chain(current_seed=state.seed)
384                with SaveFile() as saver:
385                    saver.intermediate_image = intermediate_image  # set up image
386                    saver.hyperchain = (controller.hyperchain,)  # set up hyperchain
387                    saver.with_hyperchain()
388
389    return controller.current_sample

Denoise using XFlux model with optional ManualTimestepController.

Parameters
  • model: XFlux model instance
  • settings: Denoising state containing all denoising configuration parameters