divisor.flux2.sampling

  1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
  2# <!-- // /*  d a r k s h a p e s */ -->
  3# adapted BFL Flux code from https://github.com/black-forest-labs/flux2
  4
  5import math
  6import time
  7from typing import Optional
  8
  9from PIL import Image
 10from einops import rearrange
 11from nnll.console import nfo
 12import torch
 13from torch import Tensor
 14import torchvision
 15
 16from divisor.cli_menu import route_choices
 17from divisor.controller import ManualTimestepController, rng, variation_rng
 18from divisor.denoise_step import (
 19    create_clear_prediction_cache,
 20    create_denoise_step_fn,
 21    create_get_prediction,
 22    create_recompute_text_embeddings,
 23)
 24from divisor.registry import gfx_device, gfx_dtype, gfx_sync
 25from divisor.save import SaveFile
 26from divisor.flux2.model import Flux2
 27from divisor.state import (
 28    InferenceState,
 29    InferenceStateFlux2,
 30    ImageEmbeddingState,
 31    TextEmbeddingState,
 32)
 33from divisor.interaction_context import InteractionContext
 34
 35
 36def compress_time(t_ids: Tensor) -> Tensor:
 37    assert t_ids.ndim == 1
 38    t_ids_max = torch.max(t_ids)
 39    t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype)  # type: ignore
 40    t_unique_sorted_ids = torch.unique(t_ids, sorted=True)
 41    t_remap[t_unique_sorted_ids] = torch.arange(len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype)
 42    t_ids_compressed = t_remap[t_ids]
 43    return t_ids_compressed
 44
 45
 46def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]:
 47    """
 48    using position ids to scatter tokens into place
 49    """
 50    x_list = []
 51    t_coords = []
 52    for data, pos in zip(x, x_ids):
 53        _, ch = data.shape  # noqa: F841
 54        t_ids = pos[:, 0].to(torch.int64)
 55        h_ids = pos[:, 1].to(torch.int64)
 56        w_ids = pos[:, 2].to(torch.int64)
 57
 58        t_ids_cmpr = compress_time(t_ids)
 59
 60        t = torch.max(t_ids_cmpr) + 1
 61        h = torch.max(h_ids) + 1
 62        w = torch.max(w_ids) + 1
 63
 64        flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids
 65
 66        out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype)  # type: ignore
 67        out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
 68
 69        x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w))
 70        t_coords.append(torch.unique(t_ids, sorted=True))
 71    return x_list
 72
 73
 74def encode_image_refs(ae, img_ctx: list[Image.Image]):
 75    precision = gfx_dtype
 76    scale = 10
 77
 78    if len(img_ctx) > 1:
 79        limit_pixels = 1024**2
 80    elif len(img_ctx) == 1:
 81        limit_pixels = 2024**2
 82    else:
 83        limit_pixels = None
 84
 85    if not img_ctx:
 86        return None, None
 87
 88    img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels)
 89    if not isinstance(img_ctx_prep, list):
 90        img_ctx_prep = [img_ctx_prep]
 91
 92    # Encode each reference image
 93    encoded_refs = []
 94    torch_device = gfx_device
 95    for img in img_ctx_prep:
 96        encoded = ae.encode(img[None].to(torch_device))[0]
 97        encoded_refs.append(encoded)
 98
 99    # Create time offsets for each reference
100    t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))]
101    t_off = [t.view(-1) for t in t_off]
102
103    # Process with position IDs
104    ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off)
105
106    # Concatenate all references along sequence dimension
107    ref_tokens = torch.cat(ref_tokens, dim=0)  # (total_ref_tokens, C)
108    ref_ids = torch.cat(ref_ids, dim=0)  # (total_ref_tokens, 4)
109
110    # Add batch dimension
111    ref_tokens = ref_tokens.unsqueeze(0)  # (1, total_ref_tokens, C)
112    ref_ids = ref_ids.unsqueeze(0)  # (1, total_ref_tokens, 4)
113
114    return ref_tokens.to(precision), ref_ids
115
116
117def prc_txt(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
118    _l, _ = x.shape  # noqa: F841
119
120    coords = {
121        "t": torch.arange(1) if t_coord is None else t_coord,
122        "h": torch.arange(1),  # dummy dimension
123        "w": torch.arange(1),  # dummy dimension
124        "l": torch.arange(_l),
125    }
126    x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"])
127    return x, x_ids.to(x.device)
128
129
130def batched_wrapper(fn):
131    def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
132        results = []
133        for i in range(len(x)):
134            results.append(
135                fn(
136                    x[i],
137                    t_coord[i] if t_coord is not None else None,
138                )
139            )
140        x, x_ids = zip(*results)  # type: ignore
141        return torch.stack(x), torch.stack(x_ids)  # type: ignore
142
143    return batched_prc
144
145
146def listed_wrapper(fn):
147    def listed_prc(
148        x: list[Tensor],
149        t_coord: list[Tensor] | None = None,
150    ) -> tuple[list[Tensor], list[Tensor]]:
151        results = []
152        for i in range(len(x)):
153            results.append(
154                fn(
155                    x[i],
156                    t_coord[i] if t_coord is not None else None,
157                )
158            )
159        x, x_ids = zip(*results)  # type: ignore
160        return list(x), list(x_ids)
161
162    return listed_prc
163
164
165def prc_img(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
166    _, h, w = x.shape  # noqa: F841
167    x_coords = {
168        "t": torch.arange(1) if t_coord is None else t_coord,
169        "h": torch.arange(h),
170        "w": torch.arange(w),
171        "l": torch.arange(1),
172    }
173    x_ids = torch.cartesian_prod(x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"])
174    x = rearrange(x, "c h w -> (h w) c")
175    return x, x_ids.to(x.device)
176
177
178listed_prc_img = listed_wrapper(prc_img)
179batched_prc_img = batched_wrapper(prc_img)
180batched_prc_txt = batched_wrapper(prc_txt)
181
182
183def center_crop_to_multiple_of_x(img: Image.Image | list[Image.Image], x: int) -> Image.Image | list[Image.Image]:
184    if isinstance(img, list):
185        return [center_crop_to_multiple_of_x(_img, x) for _img in img]  # type: ignore
186
187    w, h = img.size
188    new_w = (w // x) * x
189    new_h = (h // x) * x
190
191    left = (w - new_w) // 2
192    top = (h - new_h) // 2
193    right = left + new_w
194    bottom = top + new_h
195
196    resized = img.crop((left, top, right, bottom))
197    return resized
198
199
200def cap_pixels(img: Image.Image | list[Image.Image], k):
201    if isinstance(img, list):
202        return [cap_pixels(_img, k) for _img in img]
203    w, h = img.size
204    pixel_count = w * h
205
206    if pixel_count <= k:
207        return img
208
209    # Scaling factor to reduce total pixels below K
210    scale = math.sqrt(k / pixel_count)
211    new_w = int(w * scale)
212    new_h = int(h * scale)
213
214    return img.resize((new_w, new_h), Image.Resampling.LANCZOS)
215
216
217def cap_min_pixels(img: Image.Image | list[Image.Image], max_ar=8, min_sidelength=64):
218    if isinstance(img, list):
219        return [cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) for _img in img]
220    w, h = img.size
221    if w < min_sidelength or h < min_sidelength:
222        raise ValueError(f"Skipping due to minimal sidelength underschritten h {h} w {w}")
223    if w / h > max_ar or h / w > max_ar:
224        raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}")
225    return img
226
227
228def to_rgb(img: Image.Image | list[Image.Image]):
229    if isinstance(img, list):
230        return [
231            to_rgb(
232                _img,
233            )
234            for _img in img
235        ]
236    return img.convert("RGB")
237
238
239def default_images_prep(
240    x: Image.Image | list[Image.Image],
241) -> torch.Tensor | list[torch.Tensor]:
242    if isinstance(x, list):
243        return [default_images_prep(e) for e in x]  # type: ignore
244    x_tensor = torchvision.transforms.ToTensor()(x)
245    return 2 * x_tensor - 1
246
247
248def default_prep(img: Image.Image | list[Image.Image], limit_pixels: int | None, ensure_multiple: int = 16) -> torch.Tensor | list[torch.Tensor]:
249    img_rgb = to_rgb(img)
250    img_min = cap_min_pixels(img_rgb)  # type: ignore
251    if limit_pixels is not None:
252        img_cap = cap_pixels(img_min, limit_pixels)  # type: ignore
253    else:
254        img_cap = img_min
255    img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple)  # type: ignore
256    img_tensor = default_images_prep(img_crop)
257    return img_tensor
258
259
260def generalized_time_snr_shift(t: Tensor, mu: float, sigma: float) -> Tensor:
261    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
262
263
264def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
265    mu = compute_empirical_mu(image_seq_len, num_steps)
266    timesteps = torch.linspace(1, 0, num_steps + 1)
267    timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
268    return timesteps.tolist()
269
270
271def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
272    a1, b1 = 8.73809524e-05, 1.89833333
273    a2, b2 = 0.00016927, 0.45666666
274
275    if image_seq_len > 4300:
276        mu = a2 * image_seq_len + b2
277        return float(mu)
278
279    m_200 = a2 * image_seq_len + b2
280    m_10 = a1 * image_seq_len + b1
281
282    a = (m_200 - m_10) / 190.0
283    b = m_200 - 200.0 * a
284    mu = a * num_steps + b
285
286    return float(mu)
287
288
289def denoise(settings: InferenceStateFlux2) -> Tensor:
290    """Simple non-interactive denoising function for Flux2.\n
291    :param settings: InferenceStateFlux2 containing all denoising configuration parameters
292    :returns: Denoised image tensor"""
293    model = settings.model
294    img = settings.img
295    img_ids = settings.img_ids
296    txt = settings.txt
297    txt_ids = settings.txt_ids
298    timesteps = settings.timesteps
299    guidance = settings.guidance
300    img_cond_seq = settings.img_cond_seq
301    img_cond_seq_ids = settings.img_cond_seq_ids
302
303    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
304    for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
305        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
306        img_input = img
307        img_input_ids = img_ids
308        if img_cond_seq is not None:
309            assert img_cond_seq_ids is not None, "You need to provide either both or neither of the sequence conditioning"
310            img_input = torch.cat((img_input, img_cond_seq), dim=1)
311            img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
312        pred = model(
313            x=img_input,
314            x_ids=img_input_ids,
315            timesteps=t_vec,
316            ctx=txt,
317            ctx_ids=txt_ids,
318            guidance=guidance_vec,
319        )
320        if img_input_ids is not None:
321            pred = pred[:, : img.shape[1]]
322
323        img = img + (t_prev - t_curr) * pred
324
325    return img
326
327
328@torch.inference_mode()
329def denoise_interactive(
330    model: Flux2,
331    settings: InferenceState,
332):
333    """Interactive denoising using Flux2 model with optional ManualTimestepController.\n
334    :param model: Flux2 model instance
335    :param settings: InferenceState containing all denoising configuration parameters"""
336
337    # Extract settings for easier access
338    img = settings.img
339    img_ids = settings.img_ids
340    txt = settings.txt
341    txt_ids = settings.txt_ids
342    state = settings.state
343    ae = settings.ae
344    timesteps = settings.timesteps
345    img_cond_seq = settings.img_cond_seq
346    img_cond_seq_ids = settings.img_cond_seq_ids
347    from divisor.registry import gfx_device as default_device
348
349    denoise_device = settings.device if settings.device is not None else default_device
350    initial_layer_dropout = settings.initial_layer_dropout
351    text_embedder = settings.text_embedder
352
353    # this is ignored for schnell
354    current_layer_dropout = [initial_layer_dropout]
355    previous_step_tensor: list[Optional[Tensor]] = [None]  # Store previous step's tensor for masking
356    cached_prediction: list[Optional[Tensor]] = [None]  # Cache prediction to avoid duplicate model calls
357    cached_prediction_state: list[Optional[dict]] = [None]  # Cache state when prediction was generated
358    controller_ref: list[Optional["ManualTimestepController"]] = [None]  # Reference to controller for closure access
359
360    model_ref: list[Flux2] = [model]
361    target_device = img.device
362    try:
363        model_device = next(model.parameters()).device
364    except (TypeError, StopIteration, AttributeError):
365        # Fallback for Mock objects or models without parameters
366        # Assume model is already on correct device if we can't determine it
367        model_device = target_device
368    if model_device != target_device:
369        model_ref[0] = model.to_empty(device=target_device)
370
371    # Store embeddings in mutable containers so they can be updated when prompt changes
372    # Flux2 uses ctx instead of txt, and doesn't have separate CLIP embeddings
373    current_txt: list[Tensor] = [txt]
374    current_txt_ids: list[Tensor] = [txt_ids]
375    current_vec: list[Optional[Tensor]] = [None]  # Flux2 doesn't use CLIP embeddings
376    current_prompt: list[Optional[str]] = [state.prompt]  # Track current prompt to detect changes
377
378    clear_prediction_cache = create_clear_prediction_cache(cached_prediction, cached_prediction_state)
379
380    recompute_text_embeddings = create_recompute_text_embeddings(
381        img,
382        None,  # t5 not used for Flux2
383        None,  # clip not used for Flux2
384        current_txt,
385        current_txt_ids,
386        current_vec,  # type: ignore
387        current_prompt,
388        clear_prediction_cache,
389        is_flux2=True,
390        text_embedder=text_embedder,
391    )
392
393    pred_set = TextEmbeddingState(
394        model_ref=model_ref,
395        state=state,
396        current_txt=current_txt,
397        current_txt_ids=current_txt_ids,
398        current_vec=current_vec,  # type: ignore
399        cached_prediction=cached_prediction,
400        cached_prediction_state=cached_prediction_state,
401    )
402    img_set = ImageEmbeddingState(
403        img_ids=img_ids,
404        img=img,
405        img_cond=None,  # img_cond not used in Flux2 (only img_cond_seq)
406        img_cond_seq=img_cond_seq,
407        img_cond_seq_ids=img_cond_seq_ids,
408    )
409    get_prediction = create_get_prediction(pred_set, img_set)
410
411    denoise_step_fn = create_denoise_step_fn(
412        controller_ref,
413        current_layer_dropout,
414        previous_step_tensor,
415        get_prediction,
416    )
417
418    controller = ManualTimestepController(
419        timesteps=timesteps,
420        initial_sample=img,
421        denoise_step_fn=denoise_step_fn,
422        initial_guidance=state.guidance,
423    )
424    controller_ref[0] = controller  # Store reference for closure access
425
426    # Use state.layer_dropout if available, otherwise fall back to initial_layer_dropout
427    layer_dropout_to_set = state.layer_dropout if state.layer_dropout is not None else initial_layer_dropout
428    controller.set_layer_dropout(layer_dropout_to_set)
429
430    if state.width is not None and state.height is not None:
431        controller.set_resolution(state.width, state.height)
432    if state.seed is not None:
433        controller.set_seed(state.seed)
434    if state.prompt is not None:
435        controller.set_prompt(state.prompt)
436    if state.num_steps is not None:
437        controller.set_num_steps(state.num_steps)
438    controller.set_vae_shift_offset(state.vae_shift_offset)
439    controller.set_vae_scale_offset(state.vae_scale_offset)
440    controller.set_use_previous_as_mask(state.use_previous_as_mask)
441
442    # Interactive loop
443    while not controller.is_complete:
444        state = controller.current_state
445
446        # Check if prompt changed and recompute embeddings if needed
447        if state.prompt is not None and state.prompt != current_prompt[0]:
448            if text_embedder is not None:
449                recompute_text_embeddings(state.prompt)
450            else:
451                # If embedder not available, update current_prompt to avoid repeated checks
452                current_prompt[0] = state.prompt
453
454        interaction_context = InteractionContext(
455            clear_prediction_cache=clear_prediction_cache,
456            rng=rng,
457            variation_rng=variation_rng,
458            ae=ae,
459            recompute_text_embeddings=recompute_text_embeddings,
460        )
461        state = route_choices(
462            controller,
463            state,
464            interaction_context,
465        )
466
467        # Generate preview
468        t0 = time.perf_counter()
469        if state.seed is not None:
470            rng.next_seed(state.seed)
471        else:
472            state.seed = rng.next_seed()
473        if ae is not None and state.width is not None and state.height is not None:
474            # Reuse cached prediction if available, otherwise generate it
475            # This will be cached and reused in denoise_step_fn when advancing
476            # Always use state.layer_dropout from controller to ensure consistency
477            pred_preview = get_prediction(
478                state.current_sample,
479                state.current_timestep,
480                state.guidance,
481                state.layer_dropout,
482            )
483
484            intermediate = state.current_sample - state.current_timestep * pred_preview
485            # Flux2 uses scatter_ids to convert back to spatial format
486            # The intermediate is already in the correct format (sequence of tokens)
487            # We need to scatter it back to spatial dimensions for VAE decoding
488            scattered = scatter_ids(intermediate, img_ids)
489            if len(scattered) > 0:
490                intermediate_list = torch.cat(scattered).squeeze(2)
491                # scatter_ids returns list of tensors with shape (1, C, T, H, W)
492                # We need (1, C, H, W) for VAE, so we take the first time slice or squeeze
493                intermediate = intermediate_list[0].squeeze(2)  # Remove time dimension if present
494                if intermediate.dim() == 5:
495                    intermediate = intermediate[:, :, 0, :, :]  # Take first time slice
496
497            gfx_sync
498            t1 = time.perf_counter()
499
500            nfo(f"Step time: {t1 - t0:.1f}s")
501
502            if denoise_device.type == "cuda":
503                context = torch.autocast(device_type=denoise_device.type, dtype=torch.bfloat16)
504            else:
505                from contextlib import nullcontext
506
507                context = nullcontext()
508            with context:
509                if denoise_device.type != "cuda":
510                    try:
511                        ae_dtype = next(ae.encoder.parameters()).dtype
512                    except (TypeError, StopIteration, AttributeError):
513                        ae_dtype = intermediate.dtype  # For Tests
514                    intermediate = intermediate.to(dtype=ae_dtype)
515
516                # Apply VAE shift/scale offset by manually adjusting the decode operation
517                if state.vae_shift_offset != 0.0 or state.vae_scale_offset != 0.0:
518                    # Decode with offset: z = z / (scale_factor + scale_offset) + (shift_factor + shift_offset)
519                    z_adjusted = intermediate / (ae.scale_factor + state.vae_scale_offset) + (ae.shift_factor + state.vae_shift_offset)  # type: ignore
520                    intermediate_image = ae.decode(z_adjusted).float()
521                else:
522                    intermediate_image = ae.decode(intermediate)
523                if state.seed is not None:
524                    controller.store_state_in_chain(current_seed=state.seed)
525                with SaveFile() as saver:
526                    saver.intermediate_image = intermediate_image  # set up image
527                    saver.hyperchain = (controller.hyperchain,)  # set up hyperchain
528                    saver.with_hyperchain()
529
530    return controller.current_sample
531
532
533def concatenate_images(
534    images: list[Image.Image],
535) -> Image.Image:
536    """
537    Concatenate a list of PIL images horizontally with center alignment and white background.
538    """
539
540    # If only one image, return a copy of it
541    if len(images) == 1:
542        return images[0].copy()
543
544    # Convert all images to RGB if not already
545    images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]
546
547    # Calculate dimensions for horizontal concatenation
548    total_width = sum(img.width for img in images)
549    max_height = max(img.height for img in images)
550
551    # Create new image with white background
552    background_color = (255, 255, 255)
553    new_img = Image.new("RGB", (total_width, max_height), background_color)
554
555    # Paste images with center alignment
556    x_offset = 0
557    for img in images:
558        y_offset = (max_height - img.height) // 2
559        new_img.paste(img, (x_offset, y_offset))
560        x_offset += img.width
561
562    return new_img
def compress_time(t_ids: torch.Tensor) -> torch.Tensor:
37def compress_time(t_ids: Tensor) -> Tensor:
38    assert t_ids.ndim == 1
39    t_ids_max = torch.max(t_ids)
40    t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype)  # type: ignore
41    t_unique_sorted_ids = torch.unique(t_ids, sorted=True)
42    t_remap[t_unique_sorted_ids] = torch.arange(len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype)
43    t_ids_compressed = t_remap[t_ids]
44    return t_ids_compressed
def scatter_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
47def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]:
48    """
49    using position ids to scatter tokens into place
50    """
51    x_list = []
52    t_coords = []
53    for data, pos in zip(x, x_ids):
54        _, ch = data.shape  # noqa: F841
55        t_ids = pos[:, 0].to(torch.int64)
56        h_ids = pos[:, 1].to(torch.int64)
57        w_ids = pos[:, 2].to(torch.int64)
58
59        t_ids_cmpr = compress_time(t_ids)
60
61        t = torch.max(t_ids_cmpr) + 1
62        h = torch.max(h_ids) + 1
63        w = torch.max(w_ids) + 1
64
65        flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids
66
67        out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype)  # type: ignore
68        out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
69
70        x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w))
71        t_coords.append(torch.unique(t_ids, sorted=True))
72    return x_list

using position ids to scatter tokens into place

def encode_image_refs(ae, img_ctx: list[PIL.Image.Image]):
 75def encode_image_refs(ae, img_ctx: list[Image.Image]):
 76    precision = gfx_dtype
 77    scale = 10
 78
 79    if len(img_ctx) > 1:
 80        limit_pixels = 1024**2
 81    elif len(img_ctx) == 1:
 82        limit_pixels = 2024**2
 83    else:
 84        limit_pixels = None
 85
 86    if not img_ctx:
 87        return None, None
 88
 89    img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels)
 90    if not isinstance(img_ctx_prep, list):
 91        img_ctx_prep = [img_ctx_prep]
 92
 93    # Encode each reference image
 94    encoded_refs = []
 95    torch_device = gfx_device
 96    for img in img_ctx_prep:
 97        encoded = ae.encode(img[None].to(torch_device))[0]
 98        encoded_refs.append(encoded)
 99
100    # Create time offsets for each reference
101    t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))]
102    t_off = [t.view(-1) for t in t_off]
103
104    # Process with position IDs
105    ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off)
106
107    # Concatenate all references along sequence dimension
108    ref_tokens = torch.cat(ref_tokens, dim=0)  # (total_ref_tokens, C)
109    ref_ids = torch.cat(ref_ids, dim=0)  # (total_ref_tokens, 4)
110
111    # Add batch dimension
112    ref_tokens = ref_tokens.unsqueeze(0)  # (1, total_ref_tokens, C)
113    ref_ids = ref_ids.unsqueeze(0)  # (1, total_ref_tokens, 4)
114
115    return ref_tokens.to(precision), ref_ids
def prc_txt( x: torch.Tensor, t_coord: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
118def prc_txt(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
119    _l, _ = x.shape  # noqa: F841
120
121    coords = {
122        "t": torch.arange(1) if t_coord is None else t_coord,
123        "h": torch.arange(1),  # dummy dimension
124        "w": torch.arange(1),  # dummy dimension
125        "l": torch.arange(_l),
126    }
127    x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"])
128    return x, x_ids.to(x.device)
def batched_wrapper(fn):
131def batched_wrapper(fn):
132    def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
133        results = []
134        for i in range(len(x)):
135            results.append(
136                fn(
137                    x[i],
138                    t_coord[i] if t_coord is not None else None,
139                )
140            )
141        x, x_ids = zip(*results)  # type: ignore
142        return torch.stack(x), torch.stack(x_ids)  # type: ignore
143
144    return batched_prc
def listed_wrapper(fn):
147def listed_wrapper(fn):
148    def listed_prc(
149        x: list[Tensor],
150        t_coord: list[Tensor] | None = None,
151    ) -> tuple[list[Tensor], list[Tensor]]:
152        results = []
153        for i in range(len(x)):
154            results.append(
155                fn(
156                    x[i],
157                    t_coord[i] if t_coord is not None else None,
158                )
159            )
160        x, x_ids = zip(*results)  # type: ignore
161        return list(x), list(x_ids)
162
163    return listed_prc
def prc_img( x: torch.Tensor, t_coord: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
166def prc_img(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
167    _, h, w = x.shape  # noqa: F841
168    x_coords = {
169        "t": torch.arange(1) if t_coord is None else t_coord,
170        "h": torch.arange(h),
171        "w": torch.arange(w),
172        "l": torch.arange(1),
173    }
174    x_ids = torch.cartesian_prod(x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"])
175    x = rearrange(x, "c h w -> (h w) c")
176    return x, x_ids.to(x.device)
def listed_prc_img( x: list[torch.Tensor], t_coord: list[torch.Tensor] | None = None) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
148    def listed_prc(
149        x: list[Tensor],
150        t_coord: list[Tensor] | None = None,
151    ) -> tuple[list[Tensor], list[Tensor]]:
152        results = []
153        for i in range(len(x)):
154            results.append(
155                fn(
156                    x[i],
157                    t_coord[i] if t_coord is not None else None,
158                )
159            )
160        x, x_ids = zip(*results)  # type: ignore
161        return list(x), list(x_ids)

The type of the None singleton.

def batched_prc_img( x: torch.Tensor, t_coord: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
132    def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
133        results = []
134        for i in range(len(x)):
135            results.append(
136                fn(
137                    x[i],
138                    t_coord[i] if t_coord is not None else None,
139                )
140            )
141        x, x_ids = zip(*results)  # type: ignore
142        return torch.stack(x), torch.stack(x_ids)  # type: ignore

The type of the None singleton.

def batched_prc_txt( x: torch.Tensor, t_coord: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
132    def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]:
133        results = []
134        for i in range(len(x)):
135            results.append(
136                fn(
137                    x[i],
138                    t_coord[i] if t_coord is not None else None,
139                )
140            )
141        x, x_ids = zip(*results)  # type: ignore
142        return torch.stack(x), torch.stack(x_ids)  # type: ignore

The type of the None singleton.

def center_crop_to_multiple_of_x( img: PIL.Image.Image | list[PIL.Image.Image], x: int) -> PIL.Image.Image | list[PIL.Image.Image]:
184def center_crop_to_multiple_of_x(img: Image.Image | list[Image.Image], x: int) -> Image.Image | list[Image.Image]:
185    if isinstance(img, list):
186        return [center_crop_to_multiple_of_x(_img, x) for _img in img]  # type: ignore
187
188    w, h = img.size
189    new_w = (w // x) * x
190    new_h = (h // x) * x
191
192    left = (w - new_w) // 2
193    top = (h - new_h) // 2
194    right = left + new_w
195    bottom = top + new_h
196
197    resized = img.crop((left, top, right, bottom))
198    return resized
def cap_pixels(img: PIL.Image.Image | list[PIL.Image.Image], k):
201def cap_pixels(img: Image.Image | list[Image.Image], k):
202    if isinstance(img, list):
203        return [cap_pixels(_img, k) for _img in img]
204    w, h = img.size
205    pixel_count = w * h
206
207    if pixel_count <= k:
208        return img
209
210    # Scaling factor to reduce total pixels below K
211    scale = math.sqrt(k / pixel_count)
212    new_w = int(w * scale)
213    new_h = int(h * scale)
214
215    return img.resize((new_w, new_h), Image.Resampling.LANCZOS)
def cap_min_pixels( img: PIL.Image.Image | list[PIL.Image.Image], max_ar=8, min_sidelength=64):
218def cap_min_pixels(img: Image.Image | list[Image.Image], max_ar=8, min_sidelength=64):
219    if isinstance(img, list):
220        return [cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) for _img in img]
221    w, h = img.size
222    if w < min_sidelength or h < min_sidelength:
223        raise ValueError(f"Skipping due to minimal sidelength underschritten h {h} w {w}")
224    if w / h > max_ar or h / w > max_ar:
225        raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}")
226    return img
def to_rgb(img: PIL.Image.Image | list[PIL.Image.Image]):
229def to_rgb(img: Image.Image | list[Image.Image]):
230    if isinstance(img, list):
231        return [
232            to_rgb(
233                _img,
234            )
235            for _img in img
236        ]
237    return img.convert("RGB")
def default_images_prep( x: PIL.Image.Image | list[PIL.Image.Image]) -> torch.Tensor | list[torch.Tensor]:
240def default_images_prep(
241    x: Image.Image | list[Image.Image],
242) -> torch.Tensor | list[torch.Tensor]:
243    if isinstance(x, list):
244        return [default_images_prep(e) for e in x]  # type: ignore
245    x_tensor = torchvision.transforms.ToTensor()(x)
246    return 2 * x_tensor - 1
def default_prep( img: PIL.Image.Image | list[PIL.Image.Image], limit_pixels: int | None, ensure_multiple: int = 16) -> torch.Tensor | list[torch.Tensor]:
249def default_prep(img: Image.Image | list[Image.Image], limit_pixels: int | None, ensure_multiple: int = 16) -> torch.Tensor | list[torch.Tensor]:
250    img_rgb = to_rgb(img)
251    img_min = cap_min_pixels(img_rgb)  # type: ignore
252    if limit_pixels is not None:
253        img_cap = cap_pixels(img_min, limit_pixels)  # type: ignore
254    else:
255        img_cap = img_min
256    img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple)  # type: ignore
257    img_tensor = default_images_prep(img_crop)
258    return img_tensor
def generalized_time_snr_shift(t: torch.Tensor, mu: float, sigma: float) -> torch.Tensor:
261def generalized_time_snr_shift(t: Tensor, mu: float, sigma: float) -> Tensor:
262    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
265def get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
266    mu = compute_empirical_mu(image_seq_len, num_steps)
267    timesteps = torch.linspace(1, 0, num_steps + 1)
268    timesteps = generalized_time_snr_shift(timesteps, mu, 1.0)
269    return timesteps.tolist()
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
272def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
273    a1, b1 = 8.73809524e-05, 1.89833333
274    a2, b2 = 0.00016927, 0.45666666
275
276    if image_seq_len > 4300:
277        mu = a2 * image_seq_len + b2
278        return float(mu)
279
280    m_200 = a2 * image_seq_len + b2
281    m_10 = a1 * image_seq_len + b1
282
283    a = (m_200 - m_10) / 190.0
284    b = m_200 - 200.0 * a
285    mu = a * num_steps + b
286
287    return float(mu)
def denoise(settings: divisor.state.InferenceStateFlux2) -> torch.Tensor:
290def denoise(settings: InferenceStateFlux2) -> Tensor:
291    """Simple non-interactive denoising function for Flux2.\n
292    :param settings: InferenceStateFlux2 containing all denoising configuration parameters
293    :returns: Denoised image tensor"""
294    model = settings.model
295    img = settings.img
296    img_ids = settings.img_ids
297    txt = settings.txt
298    txt_ids = settings.txt_ids
299    timesteps = settings.timesteps
300    guidance = settings.guidance
301    img_cond_seq = settings.img_cond_seq
302    img_cond_seq_ids = settings.img_cond_seq_ids
303
304    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
305    for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
306        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
307        img_input = img
308        img_input_ids = img_ids
309        if img_cond_seq is not None:
310            assert img_cond_seq_ids is not None, "You need to provide either both or neither of the sequence conditioning"
311            img_input = torch.cat((img_input, img_cond_seq), dim=1)
312            img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
313        pred = model(
314            x=img_input,
315            x_ids=img_input_ids,
316            timesteps=t_vec,
317            ctx=txt,
318            ctx_ids=txt_ids,
319            guidance=guidance_vec,
320        )
321        if img_input_ids is not None:
322            pred = pred[:, : img.shape[1]]
323
324        img = img + (t_prev - t_curr) * pred
325
326    return img

Simple non-interactive denoising function for Flux2.

Parameters
  • settings: InferenceStateFlux2 containing all denoising configuration parameters :returns: Denoised image tensor
@torch.inference_mode()
def denoise_interactive( model: divisor.flux2.model.Flux2, settings: divisor.state.InferenceState):
329@torch.inference_mode()
330def denoise_interactive(
331    model: Flux2,
332    settings: InferenceState,
333):
334    """Interactive denoising using Flux2 model with optional ManualTimestepController.\n
335    :param model: Flux2 model instance
336    :param settings: InferenceState containing all denoising configuration parameters"""
337
338    # Extract settings for easier access
339    img = settings.img
340    img_ids = settings.img_ids
341    txt = settings.txt
342    txt_ids = settings.txt_ids
343    state = settings.state
344    ae = settings.ae
345    timesteps = settings.timesteps
346    img_cond_seq = settings.img_cond_seq
347    img_cond_seq_ids = settings.img_cond_seq_ids
348    from divisor.registry import gfx_device as default_device
349
350    denoise_device = settings.device if settings.device is not None else default_device
351    initial_layer_dropout = settings.initial_layer_dropout
352    text_embedder = settings.text_embedder
353
354    # this is ignored for schnell
355    current_layer_dropout = [initial_layer_dropout]
356    previous_step_tensor: list[Optional[Tensor]] = [None]  # Store previous step's tensor for masking
357    cached_prediction: list[Optional[Tensor]] = [None]  # Cache prediction to avoid duplicate model calls
358    cached_prediction_state: list[Optional[dict]] = [None]  # Cache state when prediction was generated
359    controller_ref: list[Optional["ManualTimestepController"]] = [None]  # Reference to controller for closure access
360
361    model_ref: list[Flux2] = [model]
362    target_device = img.device
363    try:
364        model_device = next(model.parameters()).device
365    except (TypeError, StopIteration, AttributeError):
366        # Fallback for Mock objects or models without parameters
367        # Assume model is already on correct device if we can't determine it
368        model_device = target_device
369    if model_device != target_device:
370        model_ref[0] = model.to_empty(device=target_device)
371
372    # Store embeddings in mutable containers so they can be updated when prompt changes
373    # Flux2 uses ctx instead of txt, and doesn't have separate CLIP embeddings
374    current_txt: list[Tensor] = [txt]
375    current_txt_ids: list[Tensor] = [txt_ids]
376    current_vec: list[Optional[Tensor]] = [None]  # Flux2 doesn't use CLIP embeddings
377    current_prompt: list[Optional[str]] = [state.prompt]  # Track current prompt to detect changes
378
379    clear_prediction_cache = create_clear_prediction_cache(cached_prediction, cached_prediction_state)
380
381    recompute_text_embeddings = create_recompute_text_embeddings(
382        img,
383        None,  # t5 not used for Flux2
384        None,  # clip not used for Flux2
385        current_txt,
386        current_txt_ids,
387        current_vec,  # type: ignore
388        current_prompt,
389        clear_prediction_cache,
390        is_flux2=True,
391        text_embedder=text_embedder,
392    )
393
394    pred_set = TextEmbeddingState(
395        model_ref=model_ref,
396        state=state,
397        current_txt=current_txt,
398        current_txt_ids=current_txt_ids,
399        current_vec=current_vec,  # type: ignore
400        cached_prediction=cached_prediction,
401        cached_prediction_state=cached_prediction_state,
402    )
403    img_set = ImageEmbeddingState(
404        img_ids=img_ids,
405        img=img,
406        img_cond=None,  # img_cond not used in Flux2 (only img_cond_seq)
407        img_cond_seq=img_cond_seq,
408        img_cond_seq_ids=img_cond_seq_ids,
409    )
410    get_prediction = create_get_prediction(pred_set, img_set)
411
412    denoise_step_fn = create_denoise_step_fn(
413        controller_ref,
414        current_layer_dropout,
415        previous_step_tensor,
416        get_prediction,
417    )
418
419    controller = ManualTimestepController(
420        timesteps=timesteps,
421        initial_sample=img,
422        denoise_step_fn=denoise_step_fn,
423        initial_guidance=state.guidance,
424    )
425    controller_ref[0] = controller  # Store reference for closure access
426
427    # Use state.layer_dropout if available, otherwise fall back to initial_layer_dropout
428    layer_dropout_to_set = state.layer_dropout if state.layer_dropout is not None else initial_layer_dropout
429    controller.set_layer_dropout(layer_dropout_to_set)
430
431    if state.width is not None and state.height is not None:
432        controller.set_resolution(state.width, state.height)
433    if state.seed is not None:
434        controller.set_seed(state.seed)
435    if state.prompt is not None:
436        controller.set_prompt(state.prompt)
437    if state.num_steps is not None:
438        controller.set_num_steps(state.num_steps)
439    controller.set_vae_shift_offset(state.vae_shift_offset)
440    controller.set_vae_scale_offset(state.vae_scale_offset)
441    controller.set_use_previous_as_mask(state.use_previous_as_mask)
442
443    # Interactive loop
444    while not controller.is_complete:
445        state = controller.current_state
446
447        # Check if prompt changed and recompute embeddings if needed
448        if state.prompt is not None and state.prompt != current_prompt[0]:
449            if text_embedder is not None:
450                recompute_text_embeddings(state.prompt)
451            else:
452                # If embedder not available, update current_prompt to avoid repeated checks
453                current_prompt[0] = state.prompt
454
455        interaction_context = InteractionContext(
456            clear_prediction_cache=clear_prediction_cache,
457            rng=rng,
458            variation_rng=variation_rng,
459            ae=ae,
460            recompute_text_embeddings=recompute_text_embeddings,
461        )
462        state = route_choices(
463            controller,
464            state,
465            interaction_context,
466        )
467
468        # Generate preview
469        t0 = time.perf_counter()
470        if state.seed is not None:
471            rng.next_seed(state.seed)
472        else:
473            state.seed = rng.next_seed()
474        if ae is not None and state.width is not None and state.height is not None:
475            # Reuse cached prediction if available, otherwise generate it
476            # This will be cached and reused in denoise_step_fn when advancing
477            # Always use state.layer_dropout from controller to ensure consistency
478            pred_preview = get_prediction(
479                state.current_sample,
480                state.current_timestep,
481                state.guidance,
482                state.layer_dropout,
483            )
484
485            intermediate = state.current_sample - state.current_timestep * pred_preview
486            # Flux2 uses scatter_ids to convert back to spatial format
487            # The intermediate is already in the correct format (sequence of tokens)
488            # We need to scatter it back to spatial dimensions for VAE decoding
489            scattered = scatter_ids(intermediate, img_ids)
490            if len(scattered) > 0:
491                intermediate_list = torch.cat(scattered).squeeze(2)
492                # scatter_ids returns list of tensors with shape (1, C, T, H, W)
493                # We need (1, C, H, W) for VAE, so we take the first time slice or squeeze
494                intermediate = intermediate_list[0].squeeze(2)  # Remove time dimension if present
495                if intermediate.dim() == 5:
496                    intermediate = intermediate[:, :, 0, :, :]  # Take first time slice
497
498            gfx_sync
499            t1 = time.perf_counter()
500
501            nfo(f"Step time: {t1 - t0:.1f}s")
502
503            if denoise_device.type == "cuda":
504                context = torch.autocast(device_type=denoise_device.type, dtype=torch.bfloat16)
505            else:
506                from contextlib import nullcontext
507
508                context = nullcontext()
509            with context:
510                if denoise_device.type != "cuda":
511                    try:
512                        ae_dtype = next(ae.encoder.parameters()).dtype
513                    except (TypeError, StopIteration, AttributeError):
514                        ae_dtype = intermediate.dtype  # For Tests
515                    intermediate = intermediate.to(dtype=ae_dtype)
516
517                # Apply VAE shift/scale offset by manually adjusting the decode operation
518                if state.vae_shift_offset != 0.0 or state.vae_scale_offset != 0.0:
519                    # Decode with offset: z = z / (scale_factor + scale_offset) + (shift_factor + shift_offset)
520                    z_adjusted = intermediate / (ae.scale_factor + state.vae_scale_offset) + (ae.shift_factor + state.vae_shift_offset)  # type: ignore
521                    intermediate_image = ae.decode(z_adjusted).float()
522                else:
523                    intermediate_image = ae.decode(intermediate)
524                if state.seed is not None:
525                    controller.store_state_in_chain(current_seed=state.seed)
526                with SaveFile() as saver:
527                    saver.intermediate_image = intermediate_image  # set up image
528                    saver.hyperchain = (controller.hyperchain,)  # set up hyperchain
529                    saver.with_hyperchain()
530
531    return controller.current_sample

Interactive denoising using Flux2 model with optional ManualTimestepController.

Parameters
  • model: Flux2 model instance
  • settings: InferenceState containing all denoising configuration parameters
def concatenate_images(images: list[PIL.Image.Image]) -> PIL.Image.Image:
534def concatenate_images(
535    images: list[Image.Image],
536) -> Image.Image:
537    """
538    Concatenate a list of PIL images horizontally with center alignment and white background.
539    """
540
541    # If only one image, return a copy of it
542    if len(images) == 1:
543        return images[0].copy()
544
545    # Convert all images to RGB if not already
546    images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]
547
548    # Calculate dimensions for horizontal concatenation
549    total_width = sum(img.width for img in images)
550    max_height = max(img.height for img in images)
551
552    # Create new image with white background
553    background_color = (255, 255, 255)
554    new_img = Image.new("RGB", (total_width, max_height), background_color)
555
556    # Paste images with center alignment
557    x_offset = 0
558    for img in images:
559        y_offset = (max_height - img.height) // 2
560        new_img.paste(img, (x_offset, y_offset))
561        x_offset += img.width
562
563    return new_img

Concatenate a list of PIL images horizontally with center alignment and white background.