divisor.flux2.sampling
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3# adapted BFL Flux code from https://github.com/black-forest-labs/flux2 4 5import math 6import time 7from typing import Optional 8 9from PIL import Image 10from einops import rearrange 11from nnll.console import nfo 12import torch 13from torch import Tensor 14import torchvision 15 16from divisor.cli_menu import route_choices 17from divisor.controller import ManualTimestepController, rng, variation_rng 18from divisor.denoise_step import ( 19 create_clear_prediction_cache, 20 create_denoise_step_fn, 21 create_get_prediction, 22 create_recompute_text_embeddings, 23) 24from divisor.registry import gfx_device, gfx_dtype, gfx_sync 25from divisor.save import SaveFile 26from divisor.flux2.model import Flux2 27from divisor.state import ( 28 InferenceState, 29 InferenceStateFlux2, 30 ImageEmbeddingState, 31 TextEmbeddingState, 32) 33from divisor.interaction_context import InteractionContext 34 35 36def compress_time(t_ids: Tensor) -> Tensor: 37 assert t_ids.ndim == 1 38 t_ids_max = torch.max(t_ids) 39 t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype) # type: ignore 40 t_unique_sorted_ids = torch.unique(t_ids, sorted=True) 41 t_remap[t_unique_sorted_ids] = torch.arange(len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype) 42 t_ids_compressed = t_remap[t_ids] 43 return t_ids_compressed 44 45 46def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]: 47 """ 48 using position ids to scatter tokens into place 49 """ 50 x_list = [] 51 t_coords = [] 52 for data, pos in zip(x, x_ids): 53 _, ch = data.shape # noqa: F841 54 t_ids = pos[:, 0].to(torch.int64) 55 h_ids = pos[:, 1].to(torch.int64) 56 w_ids = pos[:, 2].to(torch.int64) 57 58 t_ids_cmpr = compress_time(t_ids) 59 60 t = torch.max(t_ids_cmpr) + 1 61 h = torch.max(h_ids) + 1 62 w = torch.max(w_ids) + 1 63 64 flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids 65 66 out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype) # type: ignore 67 out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) 68 69 x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w)) 70 t_coords.append(torch.unique(t_ids, sorted=True)) 71 return x_list 72 73 74def encode_image_refs(ae, img_ctx: list[Image.Image]): 75 precision = gfx_dtype 76 scale = 10 77 78 if len(img_ctx) > 1: 79 limit_pixels = 1024**2 80 elif len(img_ctx) == 1: 81 limit_pixels = 2024**2 82 else: 83 limit_pixels = None 84 85 if not img_ctx: 86 return None, None 87 88 img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels) 89 if not isinstance(img_ctx_prep, list): 90 img_ctx_prep = [img_ctx_prep] 91 92 # Encode each reference image 93 encoded_refs = [] 94 torch_device = gfx_device 95 for img in img_ctx_prep: 96 encoded = ae.encode(img[None].to(torch_device))[0] 97 encoded_refs.append(encoded) 98 99 # Create time offsets for each reference 100 t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))] 101 t_off = [t.view(-1) for t in t_off] 102 103 # Process with position IDs 104 ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off) 105 106 # Concatenate all references along sequence dimension 107 ref_tokens = torch.cat(ref_tokens, dim=0) # (total_ref_tokens, C) 108 ref_ids = torch.cat(ref_ids, dim=0) # (total_ref_tokens, 4) 109 110 # Add batch dimension 111 ref_tokens = ref_tokens.unsqueeze(0) # (1, total_ref_tokens, C) 112 ref_ids = ref_ids.unsqueeze(0) # (1, total_ref_tokens, 4) 113 114 return ref_tokens.to(precision), ref_ids 115 116 117def prc_txt(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]: 118 _l, _ = x.shape # noqa: F841 119 120 coords = { 121 "t": torch.arange(1) if t_coord is None else t_coord, 122 "h": torch.arange(1), # dummy dimension 123 "w": torch.arange(1), # dummy dimension 124 "l": torch.arange(_l), 125 } 126 x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"]) 127 return x, x_ids.to(x.device) 128 129 130def batched_wrapper(fn): 131 def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]: 132 results = [] 133 for i in range(len(x)): 134 results.append( 135 fn( 136 x[i], 137 t_coord[i] if t_coord is not None else None, 138 ) 139 ) 140 x, x_ids = zip(*results) # type: ignore 141 return torch.stack(x), torch.stack(x_ids) # type: ignore 142 143 return batched_prc 144 145 146def listed_wrapper(fn): 147 def listed_prc( 148 x: list[Tensor], 149 t_coord: list[Tensor] | None = None, 150 ) -> tuple[list[Tensor], list[Tensor]]: 151 results = [] 152 for i in range(len(x)): 153 results.append( 154 fn( 155 x[i], 156 t_coord[i] if t_coord is not None else None, 157 ) 158 ) 159 x, x_ids = zip(*results) # type: ignore 160 return list(x), list(x_ids) 161 162 return listed_prc 163 164 165def prc_img(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]: 166 _, h, w = x.shape # noqa: F841 167 x_coords = { 168 "t": torch.arange(1) if t_coord is None else t_coord, 169 "h": torch.arange(h), 170 "w": torch.arange(w), 171 "l": torch.arange(1), 172 } 173 x_ids = torch.cartesian_prod(x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"]) 174 x = rearrange(x, "c h w -> (h w) c") 175 return x, x_ids.to(x.device) 176 177 178listed_prc_img = listed_wrapper(prc_img) 179batched_prc_img = batched_wrapper(prc_img) 180batched_prc_txt = batched_wrapper(prc_txt) 181 182 183def center_crop_to_multiple_of_x(img: Image.Image | list[Image.Image], x: int) -> Image.Image | list[Image.Image]: 184 if isinstance(img, list): 185 return [center_crop_to_multiple_of_x(_img, x) for _img in img] # type: ignore 186 187 w, h = img.size 188 new_w = (w // x) * x 189 new_h = (h // x) * x 190 191 left = (w - new_w) // 2 192 top = (h - new_h) // 2 193 right = left + new_w 194 bottom = top + new_h 195 196 resized = img.crop((left, top, right, bottom)) 197 return resized 198 199 200def cap_pixels(img: Image.Image | list[Image.Image], k): 201 if isinstance(img, list): 202 return [cap_pixels(_img, k) for _img in img] 203 w, h = img.size 204 pixel_count = w * h 205 206 if pixel_count <= k: 207 return img 208 209 # Scaling factor to reduce total pixels below K 210 scale = math.sqrt(k / pixel_count) 211 new_w = int(w * scale) 212 new_h = int(h * scale) 213 214 return img.resize((new_w, new_h), Image.Resampling.LANCZOS) 215 216 217def cap_min_pixels(img: Image.Image | list[Image.Image], max_ar=8, min_sidelength=64): 218 if isinstance(img, list): 219 return [cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) for _img in img] 220 w, h = img.size 221 if w < min_sidelength or h < min_sidelength: 222 raise ValueError(f"Skipping due to minimal sidelength underschritten h {h} w {w}") 223 if w / h > max_ar or h / w > max_ar: 224 raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}") 225 return img 226 227 228def to_rgb(img: Image.Image | list[Image.Image]): 229 if isinstance(img, list): 230 return [ 231 to_rgb( 232 _img, 233 ) 234 for _img in img 235 ] 236 return img.convert("RGB") 237 238 239def default_images_prep( 240 x: Image.Image | list[Image.Image], 241) -> torch.Tensor | list[torch.Tensor]: 242 if isinstance(x, list): 243 return [default_images_prep(e) for e in x] # type: ignore 244 x_tensor = torchvision.transforms.ToTensor()(x) 245 return 2 * x_tensor - 1 246 247 248def default_prep(img: Image.Image | list[Image.Image], limit_pixels: int | None, ensure_multiple: int = 16) -> torch.Tensor | list[torch.Tensor]: 249 img_rgb = to_rgb(img) 250 img_min = cap_min_pixels(img_rgb) # type: ignore 251 if limit_pixels is not None: 252 img_cap = cap_pixels(img_min, limit_pixels) # type: ignore 253 else: 254 img_cap = img_min 255 img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple) # type: ignore 256 img_tensor = default_images_prep(img_crop) 257 return img_tensor 258 259 260def generalized_time_snr_shift(t: Tensor, mu: float, sigma: float) -> Tensor: 261 return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 262 263 264def get_schedule(num_steps: int, image_seq_len: int) -> list[float]: 265 mu = compute_empirical_mu(image_seq_len, num_steps) 266 timesteps = torch.linspace(1, 0, num_steps + 1) 267 timesteps = generalized_time_snr_shift(timesteps, mu, 1.0) 268 return timesteps.tolist() 269 270 271def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: 272 a1, b1 = 8.73809524e-05, 1.89833333 273 a2, b2 = 0.00016927, 0.45666666 274 275 if image_seq_len > 4300: 276 mu = a2 * image_seq_len + b2 277 return float(mu) 278 279 m_200 = a2 * image_seq_len + b2 280 m_10 = a1 * image_seq_len + b1 281 282 a = (m_200 - m_10) / 190.0 283 b = m_200 - 200.0 * a 284 mu = a * num_steps + b 285 286 return float(mu) 287 288 289def denoise(settings: InferenceStateFlux2) -> Tensor: 290 """Simple non-interactive denoising function for Flux2.\n 291 :param settings: InferenceStateFlux2 containing all denoising configuration parameters 292 :returns: Denoised image tensor""" 293 model = settings.model 294 img = settings.img 295 img_ids = settings.img_ids 296 txt = settings.txt 297 txt_ids = settings.txt_ids 298 timesteps = settings.timesteps 299 guidance = settings.guidance 300 img_cond_seq = settings.img_cond_seq 301 img_cond_seq_ids = settings.img_cond_seq_ids 302 303 guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) 304 for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): 305 t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) 306 img_input = img 307 img_input_ids = img_ids 308 if img_cond_seq is not None: 309 assert img_cond_seq_ids is not None, "You need to provide either both or neither of the sequence conditioning" 310 img_input = torch.cat((img_input, img_cond_seq), dim=1) 311 img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) 312 pred = model( 313 x=img_input, 314 x_ids=img_input_ids, 315 timesteps=t_vec, 316 ctx=txt, 317 ctx_ids=txt_ids, 318 guidance=guidance_vec, 319 ) 320 if img_input_ids is not None: 321 pred = pred[:, : img.shape[1]] 322 323 img = img + (t_prev - t_curr) * pred 324 325 return img 326 327 328@torch.inference_mode() 329def denoise_interactive( 330 model: Flux2, 331 settings: InferenceState, 332): 333 """Interactive denoising using Flux2 model with optional ManualTimestepController.\n 334 :param model: Flux2 model instance 335 :param settings: InferenceState containing all denoising configuration parameters""" 336 337 # Extract settings for easier access 338 img = settings.img 339 img_ids = settings.img_ids 340 txt = settings.txt 341 txt_ids = settings.txt_ids 342 state = settings.state 343 ae = settings.ae 344 timesteps = settings.timesteps 345 img_cond_seq = settings.img_cond_seq 346 img_cond_seq_ids = settings.img_cond_seq_ids 347 from divisor.registry import gfx_device as default_device 348 349 denoise_device = settings.device if settings.device is not None else default_device 350 initial_layer_dropout = settings.initial_layer_dropout 351 text_embedder = settings.text_embedder 352 353 # this is ignored for schnell 354 current_layer_dropout = [initial_layer_dropout] 355 previous_step_tensor: list[Optional[Tensor]] = [None] # Store previous step's tensor for masking 356 cached_prediction: list[Optional[Tensor]] = [None] # Cache prediction to avoid duplicate model calls 357 cached_prediction_state: list[Optional[dict]] = [None] # Cache state when prediction was generated 358 controller_ref: list[Optional["ManualTimestepController"]] = [None] # Reference to controller for closure access 359 360 model_ref: list[Flux2] = [model] 361 target_device = img.device 362 try: 363 model_device = next(model.parameters()).device 364 except (TypeError, StopIteration, AttributeError): 365 # Fallback for Mock objects or models without parameters 366 # Assume model is already on correct device if we can't determine it 367 model_device = target_device 368 if model_device != target_device: 369 model_ref[0] = model.to_empty(device=target_device) 370 371 # Store embeddings in mutable containers so they can be updated when prompt changes 372 # Flux2 uses ctx instead of txt, and doesn't have separate CLIP embeddings 373 current_txt: list[Tensor] = [txt] 374 current_txt_ids: list[Tensor] = [txt_ids] 375 current_vec: list[Optional[Tensor]] = [None] # Flux2 doesn't use CLIP embeddings 376 current_prompt: list[Optional[str]] = [state.prompt] # Track current prompt to detect changes 377 378 clear_prediction_cache = create_clear_prediction_cache(cached_prediction, cached_prediction_state) 379 380 recompute_text_embeddings = create_recompute_text_embeddings( 381 img, 382 None, # t5 not used for Flux2 383 None, # clip not used for Flux2 384 current_txt, 385 current_txt_ids, 386 current_vec, # type: ignore 387 current_prompt, 388 clear_prediction_cache, 389 is_flux2=True, 390 text_embedder=text_embedder, 391 ) 392 393 pred_set = TextEmbeddingState( 394 model_ref=model_ref, 395 state=state, 396 current_txt=current_txt, 397 current_txt_ids=current_txt_ids, 398 current_vec=current_vec, # type: ignore 399 cached_prediction=cached_prediction, 400 cached_prediction_state=cached_prediction_state, 401 ) 402 img_set = ImageEmbeddingState( 403 img_ids=img_ids, 404 img=img, 405 img_cond=None, # img_cond not used in Flux2 (only img_cond_seq) 406 img_cond_seq=img_cond_seq, 407 img_cond_seq_ids=img_cond_seq_ids, 408 ) 409 get_prediction = create_get_prediction(pred_set, img_set) 410 411 denoise_step_fn = create_denoise_step_fn( 412 controller_ref, 413 current_layer_dropout, 414 previous_step_tensor, 415 get_prediction, 416 ) 417 418 controller = ManualTimestepController( 419 timesteps=timesteps, 420 initial_sample=img, 421 denoise_step_fn=denoise_step_fn, 422 initial_guidance=state.guidance, 423 ) 424 controller_ref[0] = controller # Store reference for closure access 425 426 # Use state.layer_dropout if available, otherwise fall back to initial_layer_dropout 427 layer_dropout_to_set = state.layer_dropout if state.layer_dropout is not None else initial_layer_dropout 428 controller.set_layer_dropout(layer_dropout_to_set) 429 430 if state.width is not None and state.height is not None: 431 controller.set_resolution(state.width, state.height) 432 if state.seed is not None: 433 controller.set_seed(state.seed) 434 if state.prompt is not None: 435 controller.set_prompt(state.prompt) 436 if state.num_steps is not None: 437 controller.set_num_steps(state.num_steps) 438 controller.set_vae_shift_offset(state.vae_shift_offset) 439 controller.set_vae_scale_offset(state.vae_scale_offset) 440 controller.set_use_previous_as_mask(state.use_previous_as_mask) 441 442 # Interactive loop 443 while not controller.is_complete: 444 state = controller.current_state 445 446 # Check if prompt changed and recompute embeddings if needed 447 if state.prompt is not None and state.prompt != current_prompt[0]: 448 if text_embedder is not None: 449 recompute_text_embeddings(state.prompt) 450 else: 451 # If embedder not available, update current_prompt to avoid repeated checks 452 current_prompt[0] = state.prompt 453 454 interaction_context = InteractionContext( 455 clear_prediction_cache=clear_prediction_cache, 456 rng=rng, 457 variation_rng=variation_rng, 458 ae=ae, 459 recompute_text_embeddings=recompute_text_embeddings, 460 ) 461 state = route_choices( 462 controller, 463 state, 464 interaction_context, 465 ) 466 467 # Generate preview 468 t0 = time.perf_counter() 469 if state.seed is not None: 470 rng.next_seed(state.seed) 471 else: 472 state.seed = rng.next_seed() 473 if ae is not None and state.width is not None and state.height is not None: 474 # Reuse cached prediction if available, otherwise generate it 475 # This will be cached and reused in denoise_step_fn when advancing 476 # Always use state.layer_dropout from controller to ensure consistency 477 pred_preview = get_prediction( 478 state.current_sample, 479 state.current_timestep, 480 state.guidance, 481 state.layer_dropout, 482 ) 483 484 intermediate = state.current_sample - state.current_timestep * pred_preview 485 # Flux2 uses scatter_ids to convert back to spatial format 486 # The intermediate is already in the correct format (sequence of tokens) 487 # We need to scatter it back to spatial dimensions for VAE decoding 488 scattered = scatter_ids(intermediate, img_ids) 489 if len(scattered) > 0: 490 intermediate_list = torch.cat(scattered).squeeze(2) 491 # scatter_ids returns list of tensors with shape (1, C, T, H, W) 492 # We need (1, C, H, W) for VAE, so we take the first time slice or squeeze 493 intermediate = intermediate_list[0].squeeze(2) # Remove time dimension if present 494 if intermediate.dim() == 5: 495 intermediate = intermediate[:, :, 0, :, :] # Take first time slice 496 497 gfx_sync 498 t1 = time.perf_counter() 499 500 nfo(f"Step time: {t1 - t0:.1f}s") 501 502 if denoise_device.type == "cuda": 503 context = torch.autocast(device_type=denoise_device.type, dtype=torch.bfloat16) 504 else: 505 from contextlib import nullcontext 506 507 context = nullcontext() 508 with context: 509 if denoise_device.type != "cuda": 510 try: 511 ae_dtype = next(ae.encoder.parameters()).dtype 512 except (TypeError, StopIteration, AttributeError): 513 ae_dtype = intermediate.dtype # For Tests 514 intermediate = intermediate.to(dtype=ae_dtype) 515 516 # Apply VAE shift/scale offset by manually adjusting the decode operation 517 if state.vae_shift_offset != 0.0 or state.vae_scale_offset != 0.0: 518 # Decode with offset: z = z / (scale_factor + scale_offset) + (shift_factor + shift_offset) 519 z_adjusted = intermediate / (ae.scale_factor + state.vae_scale_offset) + (ae.shift_factor + state.vae_shift_offset) # type: ignore 520 intermediate_image = ae.decode(z_adjusted).float() 521 else: 522 intermediate_image = ae.decode(intermediate) 523 if state.seed is not None: 524 controller.store_state_in_chain(current_seed=state.seed) 525 with SaveFile() as saver: 526 saver.intermediate_image = intermediate_image # set up image 527 saver.hyperchain = (controller.hyperchain,) # set up hyperchain 528 saver.with_hyperchain() 529 530 return controller.current_sample 531 532 533def concatenate_images( 534 images: list[Image.Image], 535) -> Image.Image: 536 """ 537 Concatenate a list of PIL images horizontally with center alignment and white background. 538 """ 539 540 # If only one image, return a copy of it 541 if len(images) == 1: 542 return images[0].copy() 543 544 # Convert all images to RGB if not already 545 images = [img.convert("RGB") if img.mode != "RGB" else img for img in images] 546 547 # Calculate dimensions for horizontal concatenation 548 total_width = sum(img.width for img in images) 549 max_height = max(img.height for img in images) 550 551 # Create new image with white background 552 background_color = (255, 255, 255) 553 new_img = Image.new("RGB", (total_width, max_height), background_color) 554 555 # Paste images with center alignment 556 x_offset = 0 557 for img in images: 558 y_offset = (max_height - img.height) // 2 559 new_img.paste(img, (x_offset, y_offset)) 560 x_offset += img.width 561 562 return new_img
def
compress_time(t_ids: torch.Tensor) -> torch.Tensor:
37def compress_time(t_ids: Tensor) -> Tensor: 38 assert t_ids.ndim == 1 39 t_ids_max = torch.max(t_ids) 40 t_remap = torch.zeros((t_ids_max + 1,), device=t_ids.device, dtype=t_ids.dtype) # type: ignore 41 t_unique_sorted_ids = torch.unique(t_ids, sorted=True) 42 t_remap[t_unique_sorted_ids] = torch.arange(len(t_unique_sorted_ids), device=t_ids.device, dtype=t_ids.dtype) 43 t_ids_compressed = t_remap[t_ids] 44 return t_ids_compressed
def
scatter_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
47def scatter_ids(x: Tensor, x_ids: Tensor) -> list[Tensor]: 48 """ 49 using position ids to scatter tokens into place 50 """ 51 x_list = [] 52 t_coords = [] 53 for data, pos in zip(x, x_ids): 54 _, ch = data.shape # noqa: F841 55 t_ids = pos[:, 0].to(torch.int64) 56 h_ids = pos[:, 1].to(torch.int64) 57 w_ids = pos[:, 2].to(torch.int64) 58 59 t_ids_cmpr = compress_time(t_ids) 60 61 t = torch.max(t_ids_cmpr) + 1 62 h = torch.max(h_ids) + 1 63 w = torch.max(w_ids) + 1 64 65 flat_ids = t_ids_cmpr * w * h + h_ids * w + w_ids 66 67 out = torch.zeros((t * h * w, ch), device=data.device, dtype=data.dtype) # type: ignore 68 out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) 69 70 x_list.append(rearrange(out, "(t h w) c -> 1 c t h w", t=t, h=h, w=w)) 71 t_coords.append(torch.unique(t_ids, sorted=True)) 72 return x_list
using position ids to scatter tokens into place
def
encode_image_refs(ae, img_ctx: list[PIL.Image.Image]):
75def encode_image_refs(ae, img_ctx: list[Image.Image]): 76 precision = gfx_dtype 77 scale = 10 78 79 if len(img_ctx) > 1: 80 limit_pixels = 1024**2 81 elif len(img_ctx) == 1: 82 limit_pixels = 2024**2 83 else: 84 limit_pixels = None 85 86 if not img_ctx: 87 return None, None 88 89 img_ctx_prep = default_prep(img=img_ctx, limit_pixels=limit_pixels) 90 if not isinstance(img_ctx_prep, list): 91 img_ctx_prep = [img_ctx_prep] 92 93 # Encode each reference image 94 encoded_refs = [] 95 torch_device = gfx_device 96 for img in img_ctx_prep: 97 encoded = ae.encode(img[None].to(torch_device))[0] 98 encoded_refs.append(encoded) 99 100 # Create time offsets for each reference 101 t_off = [scale + scale * t for t in torch.arange(0, len(encoded_refs))] 102 t_off = [t.view(-1) for t in t_off] 103 104 # Process with position IDs 105 ref_tokens, ref_ids = listed_prc_img(encoded_refs, t_coord=t_off) 106 107 # Concatenate all references along sequence dimension 108 ref_tokens = torch.cat(ref_tokens, dim=0) # (total_ref_tokens, C) 109 ref_ids = torch.cat(ref_ids, dim=0) # (total_ref_tokens, 4) 110 111 # Add batch dimension 112 ref_tokens = ref_tokens.unsqueeze(0) # (1, total_ref_tokens, C) 113 ref_ids = ref_ids.unsqueeze(0) # (1, total_ref_tokens, 4) 114 115 return ref_tokens.to(precision), ref_ids
def
prc_txt( x: torch.Tensor, t_coord: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
118def prc_txt(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]: 119 _l, _ = x.shape # noqa: F841 120 121 coords = { 122 "t": torch.arange(1) if t_coord is None else t_coord, 123 "h": torch.arange(1), # dummy dimension 124 "w": torch.arange(1), # dummy dimension 125 "l": torch.arange(_l), 126 } 127 x_ids = torch.cartesian_prod(coords["t"], coords["h"], coords["w"], coords["l"]) 128 return x, x_ids.to(x.device)
def
batched_wrapper(fn):
131def batched_wrapper(fn): 132 def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]: 133 results = [] 134 for i in range(len(x)): 135 results.append( 136 fn( 137 x[i], 138 t_coord[i] if t_coord is not None else None, 139 ) 140 ) 141 x, x_ids = zip(*results) # type: ignore 142 return torch.stack(x), torch.stack(x_ids) # type: ignore 143 144 return batched_prc
def
listed_wrapper(fn):
147def listed_wrapper(fn): 148 def listed_prc( 149 x: list[Tensor], 150 t_coord: list[Tensor] | None = None, 151 ) -> tuple[list[Tensor], list[Tensor]]: 152 results = [] 153 for i in range(len(x)): 154 results.append( 155 fn( 156 x[i], 157 t_coord[i] if t_coord is not None else None, 158 ) 159 ) 160 x, x_ids = zip(*results) # type: ignore 161 return list(x), list(x_ids) 162 163 return listed_prc
def
prc_img( x: torch.Tensor, t_coord: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
166def prc_img(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]: 167 _, h, w = x.shape # noqa: F841 168 x_coords = { 169 "t": torch.arange(1) if t_coord is None else t_coord, 170 "h": torch.arange(h), 171 "w": torch.arange(w), 172 "l": torch.arange(1), 173 } 174 x_ids = torch.cartesian_prod(x_coords["t"], x_coords["h"], x_coords["w"], x_coords["l"]) 175 x = rearrange(x, "c h w -> (h w) c") 176 return x, x_ids.to(x.device)
def
listed_prc_img( x: list[torch.Tensor], t_coord: list[torch.Tensor] | None = None) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
148 def listed_prc( 149 x: list[Tensor], 150 t_coord: list[Tensor] | None = None, 151 ) -> tuple[list[Tensor], list[Tensor]]: 152 results = [] 153 for i in range(len(x)): 154 results.append( 155 fn( 156 x[i], 157 t_coord[i] if t_coord is not None else None, 158 ) 159 ) 160 x, x_ids = zip(*results) # type: ignore 161 return list(x), list(x_ids)
The type of the None singleton.
def
batched_prc_img( x: torch.Tensor, t_coord: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
132 def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]: 133 results = [] 134 for i in range(len(x)): 135 results.append( 136 fn( 137 x[i], 138 t_coord[i] if t_coord is not None else None, 139 ) 140 ) 141 x, x_ids = zip(*results) # type: ignore 142 return torch.stack(x), torch.stack(x_ids) # type: ignore
The type of the None singleton.
def
batched_prc_txt( x: torch.Tensor, t_coord: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]:
132 def batched_prc(x: Tensor, t_coord: Tensor | None = None) -> tuple[Tensor, Tensor]: 133 results = [] 134 for i in range(len(x)): 135 results.append( 136 fn( 137 x[i], 138 t_coord[i] if t_coord is not None else None, 139 ) 140 ) 141 x, x_ids = zip(*results) # type: ignore 142 return torch.stack(x), torch.stack(x_ids) # type: ignore
The type of the None singleton.
def
center_crop_to_multiple_of_x( img: PIL.Image.Image | list[PIL.Image.Image], x: int) -> PIL.Image.Image | list[PIL.Image.Image]:
184def center_crop_to_multiple_of_x(img: Image.Image | list[Image.Image], x: int) -> Image.Image | list[Image.Image]: 185 if isinstance(img, list): 186 return [center_crop_to_multiple_of_x(_img, x) for _img in img] # type: ignore 187 188 w, h = img.size 189 new_w = (w // x) * x 190 new_h = (h // x) * x 191 192 left = (w - new_w) // 2 193 top = (h - new_h) // 2 194 right = left + new_w 195 bottom = top + new_h 196 197 resized = img.crop((left, top, right, bottom)) 198 return resized
def
cap_pixels(img: PIL.Image.Image | list[PIL.Image.Image], k):
201def cap_pixels(img: Image.Image | list[Image.Image], k): 202 if isinstance(img, list): 203 return [cap_pixels(_img, k) for _img in img] 204 w, h = img.size 205 pixel_count = w * h 206 207 if pixel_count <= k: 208 return img 209 210 # Scaling factor to reduce total pixels below K 211 scale = math.sqrt(k / pixel_count) 212 new_w = int(w * scale) 213 new_h = int(h * scale) 214 215 return img.resize((new_w, new_h), Image.Resampling.LANCZOS)
def
cap_min_pixels( img: PIL.Image.Image | list[PIL.Image.Image], max_ar=8, min_sidelength=64):
218def cap_min_pixels(img: Image.Image | list[Image.Image], max_ar=8, min_sidelength=64): 219 if isinstance(img, list): 220 return [cap_min_pixels(_img, max_ar=max_ar, min_sidelength=min_sidelength) for _img in img] 221 w, h = img.size 222 if w < min_sidelength or h < min_sidelength: 223 raise ValueError(f"Skipping due to minimal sidelength underschritten h {h} w {w}") 224 if w / h > max_ar or h / w > max_ar: 225 raise ValueError(f"Skipping due to maximal ar overschritten h {h} w {w}") 226 return img
def
to_rgb(img: PIL.Image.Image | list[PIL.Image.Image]):
def
default_images_prep( x: PIL.Image.Image | list[PIL.Image.Image]) -> torch.Tensor | list[torch.Tensor]:
def
default_prep( img: PIL.Image.Image | list[PIL.Image.Image], limit_pixels: int | None, ensure_multiple: int = 16) -> torch.Tensor | list[torch.Tensor]:
249def default_prep(img: Image.Image | list[Image.Image], limit_pixels: int | None, ensure_multiple: int = 16) -> torch.Tensor | list[torch.Tensor]: 250 img_rgb = to_rgb(img) 251 img_min = cap_min_pixels(img_rgb) # type: ignore 252 if limit_pixels is not None: 253 img_cap = cap_pixels(img_min, limit_pixels) # type: ignore 254 else: 255 img_cap = img_min 256 img_crop = center_crop_to_multiple_of_x(img_cap, ensure_multiple) # type: ignore 257 img_tensor = default_images_prep(img_crop) 258 return img_tensor
def
generalized_time_snr_shift(t: torch.Tensor, mu: float, sigma: float) -> torch.Tensor:
def
get_schedule(num_steps: int, image_seq_len: int) -> list[float]:
def
compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
272def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: 273 a1, b1 = 8.73809524e-05, 1.89833333 274 a2, b2 = 0.00016927, 0.45666666 275 276 if image_seq_len > 4300: 277 mu = a2 * image_seq_len + b2 278 return float(mu) 279 280 m_200 = a2 * image_seq_len + b2 281 m_10 = a1 * image_seq_len + b1 282 283 a = (m_200 - m_10) / 190.0 284 b = m_200 - 200.0 * a 285 mu = a * num_steps + b 286 287 return float(mu)
290def denoise(settings: InferenceStateFlux2) -> Tensor: 291 """Simple non-interactive denoising function for Flux2.\n 292 :param settings: InferenceStateFlux2 containing all denoising configuration parameters 293 :returns: Denoised image tensor""" 294 model = settings.model 295 img = settings.img 296 img_ids = settings.img_ids 297 txt = settings.txt 298 txt_ids = settings.txt_ids 299 timesteps = settings.timesteps 300 guidance = settings.guidance 301 img_cond_seq = settings.img_cond_seq 302 img_cond_seq_ids = settings.img_cond_seq_ids 303 304 guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) 305 for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): 306 t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) 307 img_input = img 308 img_input_ids = img_ids 309 if img_cond_seq is not None: 310 assert img_cond_seq_ids is not None, "You need to provide either both or neither of the sequence conditioning" 311 img_input = torch.cat((img_input, img_cond_seq), dim=1) 312 img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) 313 pred = model( 314 x=img_input, 315 x_ids=img_input_ids, 316 timesteps=t_vec, 317 ctx=txt, 318 ctx_ids=txt_ids, 319 guidance=guidance_vec, 320 ) 321 if img_input_ids is not None: 322 pred = pred[:, : img.shape[1]] 323 324 img = img + (t_prev - t_curr) * pred 325 326 return img
Simple non-interactive denoising function for Flux2.
Parameters
- settings: InferenceStateFlux2 containing all denoising configuration parameters :returns: Denoised image tensor
@torch.inference_mode()
def
denoise_interactive( model: divisor.flux2.model.Flux2, settings: divisor.state.InferenceState):
329@torch.inference_mode() 330def denoise_interactive( 331 model: Flux2, 332 settings: InferenceState, 333): 334 """Interactive denoising using Flux2 model with optional ManualTimestepController.\n 335 :param model: Flux2 model instance 336 :param settings: InferenceState containing all denoising configuration parameters""" 337 338 # Extract settings for easier access 339 img = settings.img 340 img_ids = settings.img_ids 341 txt = settings.txt 342 txt_ids = settings.txt_ids 343 state = settings.state 344 ae = settings.ae 345 timesteps = settings.timesteps 346 img_cond_seq = settings.img_cond_seq 347 img_cond_seq_ids = settings.img_cond_seq_ids 348 from divisor.registry import gfx_device as default_device 349 350 denoise_device = settings.device if settings.device is not None else default_device 351 initial_layer_dropout = settings.initial_layer_dropout 352 text_embedder = settings.text_embedder 353 354 # this is ignored for schnell 355 current_layer_dropout = [initial_layer_dropout] 356 previous_step_tensor: list[Optional[Tensor]] = [None] # Store previous step's tensor for masking 357 cached_prediction: list[Optional[Tensor]] = [None] # Cache prediction to avoid duplicate model calls 358 cached_prediction_state: list[Optional[dict]] = [None] # Cache state when prediction was generated 359 controller_ref: list[Optional["ManualTimestepController"]] = [None] # Reference to controller for closure access 360 361 model_ref: list[Flux2] = [model] 362 target_device = img.device 363 try: 364 model_device = next(model.parameters()).device 365 except (TypeError, StopIteration, AttributeError): 366 # Fallback for Mock objects or models without parameters 367 # Assume model is already on correct device if we can't determine it 368 model_device = target_device 369 if model_device != target_device: 370 model_ref[0] = model.to_empty(device=target_device) 371 372 # Store embeddings in mutable containers so they can be updated when prompt changes 373 # Flux2 uses ctx instead of txt, and doesn't have separate CLIP embeddings 374 current_txt: list[Tensor] = [txt] 375 current_txt_ids: list[Tensor] = [txt_ids] 376 current_vec: list[Optional[Tensor]] = [None] # Flux2 doesn't use CLIP embeddings 377 current_prompt: list[Optional[str]] = [state.prompt] # Track current prompt to detect changes 378 379 clear_prediction_cache = create_clear_prediction_cache(cached_prediction, cached_prediction_state) 380 381 recompute_text_embeddings = create_recompute_text_embeddings( 382 img, 383 None, # t5 not used for Flux2 384 None, # clip not used for Flux2 385 current_txt, 386 current_txt_ids, 387 current_vec, # type: ignore 388 current_prompt, 389 clear_prediction_cache, 390 is_flux2=True, 391 text_embedder=text_embedder, 392 ) 393 394 pred_set = TextEmbeddingState( 395 model_ref=model_ref, 396 state=state, 397 current_txt=current_txt, 398 current_txt_ids=current_txt_ids, 399 current_vec=current_vec, # type: ignore 400 cached_prediction=cached_prediction, 401 cached_prediction_state=cached_prediction_state, 402 ) 403 img_set = ImageEmbeddingState( 404 img_ids=img_ids, 405 img=img, 406 img_cond=None, # img_cond not used in Flux2 (only img_cond_seq) 407 img_cond_seq=img_cond_seq, 408 img_cond_seq_ids=img_cond_seq_ids, 409 ) 410 get_prediction = create_get_prediction(pred_set, img_set) 411 412 denoise_step_fn = create_denoise_step_fn( 413 controller_ref, 414 current_layer_dropout, 415 previous_step_tensor, 416 get_prediction, 417 ) 418 419 controller = ManualTimestepController( 420 timesteps=timesteps, 421 initial_sample=img, 422 denoise_step_fn=denoise_step_fn, 423 initial_guidance=state.guidance, 424 ) 425 controller_ref[0] = controller # Store reference for closure access 426 427 # Use state.layer_dropout if available, otherwise fall back to initial_layer_dropout 428 layer_dropout_to_set = state.layer_dropout if state.layer_dropout is not None else initial_layer_dropout 429 controller.set_layer_dropout(layer_dropout_to_set) 430 431 if state.width is not None and state.height is not None: 432 controller.set_resolution(state.width, state.height) 433 if state.seed is not None: 434 controller.set_seed(state.seed) 435 if state.prompt is not None: 436 controller.set_prompt(state.prompt) 437 if state.num_steps is not None: 438 controller.set_num_steps(state.num_steps) 439 controller.set_vae_shift_offset(state.vae_shift_offset) 440 controller.set_vae_scale_offset(state.vae_scale_offset) 441 controller.set_use_previous_as_mask(state.use_previous_as_mask) 442 443 # Interactive loop 444 while not controller.is_complete: 445 state = controller.current_state 446 447 # Check if prompt changed and recompute embeddings if needed 448 if state.prompt is not None and state.prompt != current_prompt[0]: 449 if text_embedder is not None: 450 recompute_text_embeddings(state.prompt) 451 else: 452 # If embedder not available, update current_prompt to avoid repeated checks 453 current_prompt[0] = state.prompt 454 455 interaction_context = InteractionContext( 456 clear_prediction_cache=clear_prediction_cache, 457 rng=rng, 458 variation_rng=variation_rng, 459 ae=ae, 460 recompute_text_embeddings=recompute_text_embeddings, 461 ) 462 state = route_choices( 463 controller, 464 state, 465 interaction_context, 466 ) 467 468 # Generate preview 469 t0 = time.perf_counter() 470 if state.seed is not None: 471 rng.next_seed(state.seed) 472 else: 473 state.seed = rng.next_seed() 474 if ae is not None and state.width is not None and state.height is not None: 475 # Reuse cached prediction if available, otherwise generate it 476 # This will be cached and reused in denoise_step_fn when advancing 477 # Always use state.layer_dropout from controller to ensure consistency 478 pred_preview = get_prediction( 479 state.current_sample, 480 state.current_timestep, 481 state.guidance, 482 state.layer_dropout, 483 ) 484 485 intermediate = state.current_sample - state.current_timestep * pred_preview 486 # Flux2 uses scatter_ids to convert back to spatial format 487 # The intermediate is already in the correct format (sequence of tokens) 488 # We need to scatter it back to spatial dimensions for VAE decoding 489 scattered = scatter_ids(intermediate, img_ids) 490 if len(scattered) > 0: 491 intermediate_list = torch.cat(scattered).squeeze(2) 492 # scatter_ids returns list of tensors with shape (1, C, T, H, W) 493 # We need (1, C, H, W) for VAE, so we take the first time slice or squeeze 494 intermediate = intermediate_list[0].squeeze(2) # Remove time dimension if present 495 if intermediate.dim() == 5: 496 intermediate = intermediate[:, :, 0, :, :] # Take first time slice 497 498 gfx_sync 499 t1 = time.perf_counter() 500 501 nfo(f"Step time: {t1 - t0:.1f}s") 502 503 if denoise_device.type == "cuda": 504 context = torch.autocast(device_type=denoise_device.type, dtype=torch.bfloat16) 505 else: 506 from contextlib import nullcontext 507 508 context = nullcontext() 509 with context: 510 if denoise_device.type != "cuda": 511 try: 512 ae_dtype = next(ae.encoder.parameters()).dtype 513 except (TypeError, StopIteration, AttributeError): 514 ae_dtype = intermediate.dtype # For Tests 515 intermediate = intermediate.to(dtype=ae_dtype) 516 517 # Apply VAE shift/scale offset by manually adjusting the decode operation 518 if state.vae_shift_offset != 0.0 or state.vae_scale_offset != 0.0: 519 # Decode with offset: z = z / (scale_factor + scale_offset) + (shift_factor + shift_offset) 520 z_adjusted = intermediate / (ae.scale_factor + state.vae_scale_offset) + (ae.shift_factor + state.vae_shift_offset) # type: ignore 521 intermediate_image = ae.decode(z_adjusted).float() 522 else: 523 intermediate_image = ae.decode(intermediate) 524 if state.seed is not None: 525 controller.store_state_in_chain(current_seed=state.seed) 526 with SaveFile() as saver: 527 saver.intermediate_image = intermediate_image # set up image 528 saver.hyperchain = (controller.hyperchain,) # set up hyperchain 529 saver.with_hyperchain() 530 531 return controller.current_sample
Interactive denoising using Flux2 model with optional ManualTimestepController.
Parameters
- model: Flux2 model instance
- settings: InferenceState containing all denoising configuration parameters
def
concatenate_images(images: list[PIL.Image.Image]) -> PIL.Image.Image:
534def concatenate_images( 535 images: list[Image.Image], 536) -> Image.Image: 537 """ 538 Concatenate a list of PIL images horizontally with center alignment and white background. 539 """ 540 541 # If only one image, return a copy of it 542 if len(images) == 1: 543 return images[0].copy() 544 545 # Convert all images to RGB if not already 546 images = [img.convert("RGB") if img.mode != "RGB" else img for img in images] 547 548 # Calculate dimensions for horizontal concatenation 549 total_width = sum(img.width for img in images) 550 max_height = max(img.height for img in images) 551 552 # Create new image with white background 553 background_color = (255, 255, 255) 554 new_img = Image.new("RGB", (total_width, max_height), background_color) 555 556 # Paste images with center alignment 557 x_offset = 0 558 for img in images: 559 y_offset = (max_height - img.height) // 2 560 new_img.paste(img, (x_offset, y_offset)) 561 x_offset += img.width 562 563 return new_img
Concatenate a list of PIL images horizontally with center alignment and white background.