divisor.xflux1.sampling
1# SPDX-License-Identifier:Apache-2.0 2# original XFlux code from https://github.com/TencentARC/FluxKits 3 4import time 5from typing import Callable 6 7import torch 8from nnll.console import nfo 9from torch import Tensor 10 11from divisor.cli_menu import route_choices 12from divisor.controller import ManualTimestepController, rng, variation_rng 13from divisor.denoise_step import ( 14 create_clear_prediction_cache, 15 create_denoise_step_fn, 16 create_recompute_text_embeddings, 17) 18from divisor.flux1.sampling import prepare, unpack 19from divisor.interaction_context import InteractionContext 20from divisor.registry import gfx_device, gfx_sync 21from divisor.save import SaveFile 22from divisor.state import ( 23 ImageEmbeddingState, 24 InferenceState, 25 StepStateXFlux1, 26 TextEmbeddingState, 27) 28from divisor.xflux1.model import XFlux 29 30 31def create_get_prediction_xflux1( 32 pred_set: TextEmbeddingState, 33 img_set: ImageEmbeddingState, 34 add_set: StepStateXFlux1, 35) -> Callable[[Tensor, float, float, list[int] | None], Tensor]: 36 """Create a function to generate model prediction with caching for XFlux1.\n 37 :param config: GetPredictionSettings containing all configuration parameters 38 :return: Function that generates predictions with caching""" 39 40 def get_prediction( 41 sample: Tensor, 42 t_curr: float, 43 guidance_val: float, 44 layer_dropouts_val: list[int] | None, 45 ) -> Tensor: 46 """Generate model prediction, reusing cached prediction if state hasn't changed.\n 47 :param sample: Current sample tensor 48 :param t_curr: Current timestep 49 :param guidance_val: Guidance value 50 :param layer_dropouts_val: Layer dropout configuration 51 :returns: Model prediction""" 52 # Create a simple hash of the sample tensor for cache key 53 if sample.numel() > 0: 54 first_val = float(sample.flatten()[0].item()) 55 else: 56 first_val = 0.0 57 sample_hash = hash((sample.shape, first_val)) 58 59 # Check if we can reuse cached prediction 60 current_state = { 61 "sample_hash": sample_hash, 62 "t_curr": t_curr, 63 "guidance": guidance_val, 64 "layer_dropout": layer_dropouts_val, 65 } 66 67 if ( 68 pred_set.cached_prediction[0] is not None 69 and pred_set.cached_prediction_state[0] is not None 70 and pred_set.cached_prediction_state[0]["sample_hash"] == current_state["sample_hash"] 71 and pred_set.cached_prediction_state[0]["t_curr"] == current_state["t_curr"] 72 and pred_set.cached_prediction_state[0]["guidance"] == current_state["guidance"] 73 and pred_set.cached_prediction_state[0]["layer_dropout"] == current_state["layer_dropout"] 74 ): 75 return pred_set.cached_prediction[0] 76 77 # Generate new prediction 78 try: 79 model_dtype = next(pred_set.model_ref[0].parameters()).dtype 80 except (TypeError, StopIteration, AttributeError): 81 model_dtype = sample.dtype 82 use_autocast = gfx_device.type == "cuda" 83 84 if not use_autocast: 85 sample = sample.to(dtype=model_dtype) 86 87 t_vec = torch.full((sample.shape[0],), t_curr, dtype=sample.dtype, device=sample.device) 88 guidance_vec = torch.full((img_set.img.shape[0],), guidance_val, device=img_set.img.device, dtype=img_set.img.dtype) 89 90 if not use_autocast: 91 img_input = sample.to(dtype=model_dtype) 92 txt_input = pred_set.current_txt[0].to(dtype=model_dtype) 93 vec_input = pred_set.current_vec[0].to(dtype=model_dtype) 94 t_vec = t_vec.to(dtype=model_dtype) 95 guidance_vec = guidance_vec.to(dtype=model_dtype) 96 else: 97 img_input = sample 98 txt_input = pred_set.current_txt[0] 99 vec_input = pred_set.current_vec[0] 100 101 # Get current timestep index for CFG check 102 timestep_index = add_set.current_timestep_index[0] 103 104 kwargs = {} 105 if "image_proj" in pred_set.model_ref[0].__dict__: 106 kwargs = {"image_proj": img_set.image_proj} 107 if "ip_scale" in pred_set.model_ref[0].__dict__: 108 kwargs.setdefault("ip_scale", img_set.ip_scale) 109 110 # Generate positive prediction 111 pred = pred_set.model_ref[0]( 112 img=img_input, 113 img_ids=img_set.img_ids, 114 txt=txt_input, 115 txt_ids=pred_set.current_txt_ids[0], 116 y=vec_input, 117 timesteps=t_vec, 118 guidance=guidance_vec, 119 **kwargs, 120 ) 121 122 # Apply CFG if enabled and past start timestep 123 if ( 124 pred_set.neg_pred_enabled 125 and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]) 126 and timestep_index >= add_set.timestep_to_start_cfg 127 ): 128 if not use_autocast: 129 neg_txt_input = pred_set.current_neg_txt[0].to(dtype=model_dtype) # type: ignore 130 neg_vec_input = pred_set.current_neg_vec[0].to(dtype=model_dtype) # type: ignore 131 else: 132 neg_txt_input = pred_set.current_neg_txt[0] # type: ignore 133 neg_vec_input = pred_set.current_neg_vec[0] # type: ignore 134 135 neg_pred = pred_set.model_ref[0]( 136 img=img_input, 137 img_ids=img_set.img_ids, 138 txt=neg_txt_input, 139 txt_ids=pred_set.current_neg_txt_ids[0], # type: ignore 140 y=neg_vec_input, 141 timesteps=t_vec, 142 guidance=guidance_vec, 143 image_proj=img_set.neg_image_proj, 144 ip_scale=img_set.neg_ip_scale, 145 ) 146 pred = neg_pred + pred_set.true_gs * (pred - neg_pred) 147 148 # Cache the prediction 149 pred_set.cached_prediction[0] = pred 150 pred_set.cached_prediction_state[0] = current_state 151 152 return pred 153 154 return get_prediction 155 156 157@torch.inference_mode() 158def denoise( 159 model: XFlux, 160 settings: InferenceState, 161): 162 """Denoise using XFlux model with optional ManualTimestepController.\n 163 :param model: XFlux model instance 164 :param settings: Denoising state containing all denoising configuration parameters""" 165 166 # Extract settings for easier access 167 img = settings.img 168 img_ids = settings.img_ids 169 txt = settings.txt 170 txt_ids = settings.txt_ids 171 vec = settings.vec 172 neg_pred_enabled = settings.neg_pred_enabled 173 neg_txt = settings.neg_txt 174 neg_txt_ids = settings.neg_txt_ids 175 neg_vec = settings.neg_vec 176 state = settings.state 177 ae = settings.ae 178 timesteps = settings.timesteps 179 true_gs = settings.true_gs 180 timestep_to_start_cfg = settings.timestep_to_start_cfg 181 image_proj = settings.image_proj 182 neg_image_proj = settings.neg_image_proj 183 ip_scale = settings.ip_scale if settings.ip_scale is not None else torch.tensor(1.0) 184 neg_ip_scale = settings.neg_ip_scale if settings.neg_ip_scale is not None else torch.tensor(1.0) 185 initial_layer_dropout = settings.initial_layer_dropout 186 t5 = settings.t5 187 clip = settings.clip 188 189 # this is ignored for schnell 190 current_layer_dropout = [initial_layer_dropout] 191 previous_step_tensor: list[Tensor | None] = [None] # Store previous step's tensor for masking 192 cached_prediction: list[Tensor | None] = [None] # Cache prediction to avoid duplicate model calls 193 cached_prediction_state: list[dict | None] = [None] # Cache state when prediction was generated 194 controller_ref: list[ManualTimestepController | None] = [None] # Reference to controller for closure access 195 current_timestep_index: list[int] = [0] # Track current timestep index for CFG 196 197 model_ref: list[XFlux] = [model] 198 # Ensure model is on the correct device (fixes meta device issue) 199 target_device = img.device 200 # Safely get model device, handling Mock objects in tests 201 try: 202 model_device = next(model.parameters()).device 203 except (TypeError, StopIteration, AttributeError): 204 # Fallback for Mock objects or models without parameters 205 # Assume model is already on correct device if we can't determine it 206 model_device = target_device 207 if model_device != target_device: 208 model_ref[0] = model.to_empty(device=target_device) 209 210 # Store embeddings in mutable containers so they can be updated when prompt changes 211 current_txt: list[Tensor] = [txt] 212 current_txt_ids: list[Tensor] = [txt_ids] 213 assert vec is not None, "vec (CLIP embeddings) is required for XFlux1" 214 current_vec: list[Tensor] = [vec] 215 216 # Handle negative prompts: XFlux1 supports negative prompts but they're optional 217 # If neg_pred_enabled is True, negative prompts must be provided 218 if neg_pred_enabled: 219 if neg_txt is None or neg_txt_ids is None or neg_vec is None: 220 # Generate empty negative prompts if not provided but neg_pred_enabled is True 221 if t5 is not None and clip is not None: 222 # Use empty string to generate embeddings 223 neg_inp = prepare(t5, clip, img, prompt="") 224 neg_txt = neg_inp["txt"] 225 neg_txt_ids = neg_inp["txt_ids"] 226 neg_vec = neg_inp["vec"] 227 else: 228 # If embedders not available, disable negative prompts 229 neg_pred_enabled = False 230 231 if neg_pred_enabled and neg_txt is not None and neg_txt_ids is not None and neg_vec is not None: 232 current_neg_txt: list[Tensor] = [neg_txt] # type: ignore 233 current_neg_txt_ids: list[Tensor] = [neg_txt_ids] # type: ignore 234 current_neg_vec: list[Tensor] = [neg_vec] # type: ignore 235 else: 236 current_neg_txt: list[Tensor] | None = None 237 current_neg_txt_ids: list[Tensor] | None = None 238 current_neg_vec: list[Tensor] | None = None 239 neg_pred_enabled = False 240 current_prompt: list[str | None] = [state.prompt] # Track current prompt to detect changes 241 242 clear_prediction_cache = create_clear_prediction_cache(cached_prediction, cached_prediction_state) 243 244 recompute_text_embeddings = create_recompute_text_embeddings(img, t5, clip, current_txt, current_txt_ids, current_vec, current_prompt, clear_prediction_cache, is_flux2=False) 245 246 pred_set = TextEmbeddingState( 247 model_ref=model_ref, 248 state=state, 249 current_txt=current_txt, 250 current_txt_ids=current_txt_ids, 251 current_vec=current_vec, 252 cached_prediction=cached_prediction, 253 cached_prediction_state=cached_prediction_state, 254 neg_pred_enabled=neg_pred_enabled, 255 current_neg_txt=current_neg_txt, # type: ignore 256 current_neg_txt_ids=current_neg_txt_ids, # type: ignore 257 current_neg_vec=current_neg_vec, # type: ignore 258 true_gs=true_gs, 259 ) 260 img_set = ImageEmbeddingState( 261 img_ids=img_ids, 262 img=img, 263 img_cond=None, 264 img_cond_seq=None, 265 img_cond_seq_ids=None, 266 image_proj=image_proj, 267 neg_image_proj=neg_image_proj, 268 ip_scale=ip_scale, 269 neg_ip_scale=neg_ip_scale, 270 ) 271 add_set = StepStateXFlux1( 272 timestep_to_start_cfg=timestep_to_start_cfg, 273 current_timestep_index=current_timestep_index, 274 ) 275 get_prediction = create_get_prediction_xflux1(pred_set, img_set, add_set) 276 277 denoise_step_fn = create_denoise_step_fn(controller_ref, current_layer_dropout, previous_step_tensor, get_prediction) 278 279 controller = ManualTimestepController(timesteps=timesteps, initial_sample=img, denoise_step_fn=denoise_step_fn, initial_guidance=state.guidance) 280 controller_ref[0] = controller # Store reference for closure access 281 282 # Use state.layer_dropout if available, otherwise fall back to initial_layer_dropout 283 layer_dropout_to_set = state.layer_dropout if state.layer_dropout is not None else initial_layer_dropout 284 controller.set_layer_dropout(layer_dropout_to_set) 285 286 if state.width is not None and state.height is not None: 287 controller.set_resolution(state.width, state.height) 288 if state.seed is not None: 289 controller.set_seed(state.seed) 290 if state.prompt is not None: 291 controller.set_prompt(state.prompt) 292 if state.num_steps is not None: 293 controller.set_num_steps(state.num_steps) 294 controller.set_vae_shift_offset(state.vae_shift_offset) 295 controller.set_vae_scale_offset(state.vae_scale_offset) 296 controller.set_use_previous_as_mask(state.use_previous_as_mask) 297 298 # Interactive loop 299 while not controller.is_complete: 300 state = controller.current_state 301 302 # Update timestep index for CFG check 303 current_timestep_index[0] = state.timestep_index if hasattr(state, "timestep_index") else 0 304 305 # Check if prompt changed and recompute embeddings if needed 306 if state.prompt is not None and state.prompt != current_prompt[0]: 307 if t5 is not None and clip is not None: 308 recompute_text_embeddings(state.prompt) 309 else: 310 # If embedders not available, update current_prompt to avoid repeated checks 311 current_prompt[0] = state.prompt 312 313 interaction_context = InteractionContext( 314 clear_prediction_cache=clear_prediction_cache, 315 rng=rng, 316 variation_rng=variation_rng, 317 ae=ae, 318 t5=t5, 319 clip=clip, 320 recompute_text_embeddings=recompute_text_embeddings, 321 ) 322 state = route_choices( 323 controller, 324 state, 325 interaction_context, 326 ) 327 328 # Generate preview 329 t0 = time.perf_counter() 330 if state.seed is not None: 331 rng.next_seed(state.seed) 332 else: 333 state.seed = rng.next_seed() 334 if ae is not None and state.width is not None and state.height is not None: 335 # Reuse cached prediction if available, otherwise generate it 336 # This will be cached and reused in denoise_step_fn when advancing 337 # Always use state.layer_dropout from controller to ensure consistency 338 pred_preview = get_prediction( 339 state.current_sample, 340 state.current_timestep, 341 state.guidance, 342 state.layer_dropout, 343 ) 344 345 intermediate = state.current_sample - state.current_timestep * pred_preview 346 # Unpack requires float32, but we'll convert back to correct dtype after 347 intermediate = unpack(intermediate.float(), state.height, state.width) 348 349 from divisor.registry import gfx_device as default_device 350 351 gfx_sync 352 t1 = time.perf_counter() 353 354 nfo(f"Step time: {t1 - t0:.1f}s") 355 356 if default_device.type == "cuda": 357 context = torch.autocast(device_type=default_device.type, dtype=torch.bfloat16) 358 else: 359 from contextlib import nullcontext 360 361 context = nullcontext() 362 with context: 363 # When autocast is disabled (MPS), ensure intermediate is in correct dtype for VAE 364 if default_device.type != "cuda": 365 # Get VAE encoder dtype to ensure intermediate matches (bfloat16) 366 # Safely get encoder dtype, handling Mock objects in tests 367 try: 368 ae_dtype = next(ae.encoder.parameters()).dtype 369 except (TypeError, StopIteration, AttributeError): 370 # Fallback: use intermediate dtype if we can't get encoder dtype (for Mock objects in tests) 371 ae_dtype = intermediate.dtype 372 intermediate = intermediate.to(dtype=ae_dtype) 373 374 # Apply VAE shift/scale offset by manually adjusting the decode operation 375 if state.vae_shift_offset != 0.0 or state.vae_scale_offset != 0.0: 376 # Decode with offset: z = z / (scale_factor + scale_offset) + (shift_factor + shift_offset) 377 z_adjusted = intermediate / (ae.scale_factor + state.vae_scale_offset) + (ae.shift_factor + state.vae_shift_offset) 378 intermediate_image = ae.decoder(z_adjusted) 379 else: 380 intermediate_image = ae.decode(intermediate) 381 if state.seed is not None: 382 controller.store_state_in_chain(current_seed=state.seed) 383 with SaveFile() as saver: 384 saver.intermediate_image = intermediate_image # set up image 385 saver.hyperchain = (controller.hyperchain,) # set up hyperchain 386 saver.with_hyperchain() 387 388 return controller.current_sample
def
create_get_prediction_xflux1( pred_set: divisor.state.TextEmbeddingState, img_set: divisor.state.ImageEmbeddingState, add_set: divisor.state.StepStateXFlux1) -> Callable[[torch.Tensor, float, float, Optional[list[int]]], torch.Tensor]:
32def create_get_prediction_xflux1( 33 pred_set: TextEmbeddingState, 34 img_set: ImageEmbeddingState, 35 add_set: StepStateXFlux1, 36) -> Callable[[Tensor, float, float, list[int] | None], Tensor]: 37 """Create a function to generate model prediction with caching for XFlux1.\n 38 :param config: GetPredictionSettings containing all configuration parameters 39 :return: Function that generates predictions with caching""" 40 41 def get_prediction( 42 sample: Tensor, 43 t_curr: float, 44 guidance_val: float, 45 layer_dropouts_val: list[int] | None, 46 ) -> Tensor: 47 """Generate model prediction, reusing cached prediction if state hasn't changed.\n 48 :param sample: Current sample tensor 49 :param t_curr: Current timestep 50 :param guidance_val: Guidance value 51 :param layer_dropouts_val: Layer dropout configuration 52 :returns: Model prediction""" 53 # Create a simple hash of the sample tensor for cache key 54 if sample.numel() > 0: 55 first_val = float(sample.flatten()[0].item()) 56 else: 57 first_val = 0.0 58 sample_hash = hash((sample.shape, first_val)) 59 60 # Check if we can reuse cached prediction 61 current_state = { 62 "sample_hash": sample_hash, 63 "t_curr": t_curr, 64 "guidance": guidance_val, 65 "layer_dropout": layer_dropouts_val, 66 } 67 68 if ( 69 pred_set.cached_prediction[0] is not None 70 and pred_set.cached_prediction_state[0] is not None 71 and pred_set.cached_prediction_state[0]["sample_hash"] == current_state["sample_hash"] 72 and pred_set.cached_prediction_state[0]["t_curr"] == current_state["t_curr"] 73 and pred_set.cached_prediction_state[0]["guidance"] == current_state["guidance"] 74 and pred_set.cached_prediction_state[0]["layer_dropout"] == current_state["layer_dropout"] 75 ): 76 return pred_set.cached_prediction[0] 77 78 # Generate new prediction 79 try: 80 model_dtype = next(pred_set.model_ref[0].parameters()).dtype 81 except (TypeError, StopIteration, AttributeError): 82 model_dtype = sample.dtype 83 use_autocast = gfx_device.type == "cuda" 84 85 if not use_autocast: 86 sample = sample.to(dtype=model_dtype) 87 88 t_vec = torch.full((sample.shape[0],), t_curr, dtype=sample.dtype, device=sample.device) 89 guidance_vec = torch.full((img_set.img.shape[0],), guidance_val, device=img_set.img.device, dtype=img_set.img.dtype) 90 91 if not use_autocast: 92 img_input = sample.to(dtype=model_dtype) 93 txt_input = pred_set.current_txt[0].to(dtype=model_dtype) 94 vec_input = pred_set.current_vec[0].to(dtype=model_dtype) 95 t_vec = t_vec.to(dtype=model_dtype) 96 guidance_vec = guidance_vec.to(dtype=model_dtype) 97 else: 98 img_input = sample 99 txt_input = pred_set.current_txt[0] 100 vec_input = pred_set.current_vec[0] 101 102 # Get current timestep index for CFG check 103 timestep_index = add_set.current_timestep_index[0] 104 105 kwargs = {} 106 if "image_proj" in pred_set.model_ref[0].__dict__: 107 kwargs = {"image_proj": img_set.image_proj} 108 if "ip_scale" in pred_set.model_ref[0].__dict__: 109 kwargs.setdefault("ip_scale", img_set.ip_scale) 110 111 # Generate positive prediction 112 pred = pred_set.model_ref[0]( 113 img=img_input, 114 img_ids=img_set.img_ids, 115 txt=txt_input, 116 txt_ids=pred_set.current_txt_ids[0], 117 y=vec_input, 118 timesteps=t_vec, 119 guidance=guidance_vec, 120 **kwargs, 121 ) 122 123 # Apply CFG if enabled and past start timestep 124 if ( 125 pred_set.neg_pred_enabled 126 and all([pred_set.current_neg_txt, pred_set.current_neg_txt_ids, pred_set.current_neg_vec]) 127 and timestep_index >= add_set.timestep_to_start_cfg 128 ): 129 if not use_autocast: 130 neg_txt_input = pred_set.current_neg_txt[0].to(dtype=model_dtype) # type: ignore 131 neg_vec_input = pred_set.current_neg_vec[0].to(dtype=model_dtype) # type: ignore 132 else: 133 neg_txt_input = pred_set.current_neg_txt[0] # type: ignore 134 neg_vec_input = pred_set.current_neg_vec[0] # type: ignore 135 136 neg_pred = pred_set.model_ref[0]( 137 img=img_input, 138 img_ids=img_set.img_ids, 139 txt=neg_txt_input, 140 txt_ids=pred_set.current_neg_txt_ids[0], # type: ignore 141 y=neg_vec_input, 142 timesteps=t_vec, 143 guidance=guidance_vec, 144 image_proj=img_set.neg_image_proj, 145 ip_scale=img_set.neg_ip_scale, 146 ) 147 pred = neg_pred + pred_set.true_gs * (pred - neg_pred) 148 149 # Cache the prediction 150 pred_set.cached_prediction[0] = pred 151 pred_set.cached_prediction_state[0] = current_state 152 153 return pred 154 155 return get_prediction
Create a function to generate model prediction with caching for XFlux1.
Parameters
- config: GetPredictionSettings containing all configuration parameters
Returns
Function that generates predictions with caching
@torch.inference_mode()
def
denoise( model: divisor.xflux1.model.XFlux, settings: divisor.state.InferenceState):
158@torch.inference_mode() 159def denoise( 160 model: XFlux, 161 settings: InferenceState, 162): 163 """Denoise using XFlux model with optional ManualTimestepController.\n 164 :param model: XFlux model instance 165 :param settings: Denoising state containing all denoising configuration parameters""" 166 167 # Extract settings for easier access 168 img = settings.img 169 img_ids = settings.img_ids 170 txt = settings.txt 171 txt_ids = settings.txt_ids 172 vec = settings.vec 173 neg_pred_enabled = settings.neg_pred_enabled 174 neg_txt = settings.neg_txt 175 neg_txt_ids = settings.neg_txt_ids 176 neg_vec = settings.neg_vec 177 state = settings.state 178 ae = settings.ae 179 timesteps = settings.timesteps 180 true_gs = settings.true_gs 181 timestep_to_start_cfg = settings.timestep_to_start_cfg 182 image_proj = settings.image_proj 183 neg_image_proj = settings.neg_image_proj 184 ip_scale = settings.ip_scale if settings.ip_scale is not None else torch.tensor(1.0) 185 neg_ip_scale = settings.neg_ip_scale if settings.neg_ip_scale is not None else torch.tensor(1.0) 186 initial_layer_dropout = settings.initial_layer_dropout 187 t5 = settings.t5 188 clip = settings.clip 189 190 # this is ignored for schnell 191 current_layer_dropout = [initial_layer_dropout] 192 previous_step_tensor: list[Tensor | None] = [None] # Store previous step's tensor for masking 193 cached_prediction: list[Tensor | None] = [None] # Cache prediction to avoid duplicate model calls 194 cached_prediction_state: list[dict | None] = [None] # Cache state when prediction was generated 195 controller_ref: list[ManualTimestepController | None] = [None] # Reference to controller for closure access 196 current_timestep_index: list[int] = [0] # Track current timestep index for CFG 197 198 model_ref: list[XFlux] = [model] 199 # Ensure model is on the correct device (fixes meta device issue) 200 target_device = img.device 201 # Safely get model device, handling Mock objects in tests 202 try: 203 model_device = next(model.parameters()).device 204 except (TypeError, StopIteration, AttributeError): 205 # Fallback for Mock objects or models without parameters 206 # Assume model is already on correct device if we can't determine it 207 model_device = target_device 208 if model_device != target_device: 209 model_ref[0] = model.to_empty(device=target_device) 210 211 # Store embeddings in mutable containers so they can be updated when prompt changes 212 current_txt: list[Tensor] = [txt] 213 current_txt_ids: list[Tensor] = [txt_ids] 214 assert vec is not None, "vec (CLIP embeddings) is required for XFlux1" 215 current_vec: list[Tensor] = [vec] 216 217 # Handle negative prompts: XFlux1 supports negative prompts but they're optional 218 # If neg_pred_enabled is True, negative prompts must be provided 219 if neg_pred_enabled: 220 if neg_txt is None or neg_txt_ids is None or neg_vec is None: 221 # Generate empty negative prompts if not provided but neg_pred_enabled is True 222 if t5 is not None and clip is not None: 223 # Use empty string to generate embeddings 224 neg_inp = prepare(t5, clip, img, prompt="") 225 neg_txt = neg_inp["txt"] 226 neg_txt_ids = neg_inp["txt_ids"] 227 neg_vec = neg_inp["vec"] 228 else: 229 # If embedders not available, disable negative prompts 230 neg_pred_enabled = False 231 232 if neg_pred_enabled and neg_txt is not None and neg_txt_ids is not None and neg_vec is not None: 233 current_neg_txt: list[Tensor] = [neg_txt] # type: ignore 234 current_neg_txt_ids: list[Tensor] = [neg_txt_ids] # type: ignore 235 current_neg_vec: list[Tensor] = [neg_vec] # type: ignore 236 else: 237 current_neg_txt: list[Tensor] | None = None 238 current_neg_txt_ids: list[Tensor] | None = None 239 current_neg_vec: list[Tensor] | None = None 240 neg_pred_enabled = False 241 current_prompt: list[str | None] = [state.prompt] # Track current prompt to detect changes 242 243 clear_prediction_cache = create_clear_prediction_cache(cached_prediction, cached_prediction_state) 244 245 recompute_text_embeddings = create_recompute_text_embeddings(img, t5, clip, current_txt, current_txt_ids, current_vec, current_prompt, clear_prediction_cache, is_flux2=False) 246 247 pred_set = TextEmbeddingState( 248 model_ref=model_ref, 249 state=state, 250 current_txt=current_txt, 251 current_txt_ids=current_txt_ids, 252 current_vec=current_vec, 253 cached_prediction=cached_prediction, 254 cached_prediction_state=cached_prediction_state, 255 neg_pred_enabled=neg_pred_enabled, 256 current_neg_txt=current_neg_txt, # type: ignore 257 current_neg_txt_ids=current_neg_txt_ids, # type: ignore 258 current_neg_vec=current_neg_vec, # type: ignore 259 true_gs=true_gs, 260 ) 261 img_set = ImageEmbeddingState( 262 img_ids=img_ids, 263 img=img, 264 img_cond=None, 265 img_cond_seq=None, 266 img_cond_seq_ids=None, 267 image_proj=image_proj, 268 neg_image_proj=neg_image_proj, 269 ip_scale=ip_scale, 270 neg_ip_scale=neg_ip_scale, 271 ) 272 add_set = StepStateXFlux1( 273 timestep_to_start_cfg=timestep_to_start_cfg, 274 current_timestep_index=current_timestep_index, 275 ) 276 get_prediction = create_get_prediction_xflux1(pred_set, img_set, add_set) 277 278 denoise_step_fn = create_denoise_step_fn(controller_ref, current_layer_dropout, previous_step_tensor, get_prediction) 279 280 controller = ManualTimestepController(timesteps=timesteps, initial_sample=img, denoise_step_fn=denoise_step_fn, initial_guidance=state.guidance) 281 controller_ref[0] = controller # Store reference for closure access 282 283 # Use state.layer_dropout if available, otherwise fall back to initial_layer_dropout 284 layer_dropout_to_set = state.layer_dropout if state.layer_dropout is not None else initial_layer_dropout 285 controller.set_layer_dropout(layer_dropout_to_set) 286 287 if state.width is not None and state.height is not None: 288 controller.set_resolution(state.width, state.height) 289 if state.seed is not None: 290 controller.set_seed(state.seed) 291 if state.prompt is not None: 292 controller.set_prompt(state.prompt) 293 if state.num_steps is not None: 294 controller.set_num_steps(state.num_steps) 295 controller.set_vae_shift_offset(state.vae_shift_offset) 296 controller.set_vae_scale_offset(state.vae_scale_offset) 297 controller.set_use_previous_as_mask(state.use_previous_as_mask) 298 299 # Interactive loop 300 while not controller.is_complete: 301 state = controller.current_state 302 303 # Update timestep index for CFG check 304 current_timestep_index[0] = state.timestep_index if hasattr(state, "timestep_index") else 0 305 306 # Check if prompt changed and recompute embeddings if needed 307 if state.prompt is not None and state.prompt != current_prompt[0]: 308 if t5 is not None and clip is not None: 309 recompute_text_embeddings(state.prompt) 310 else: 311 # If embedders not available, update current_prompt to avoid repeated checks 312 current_prompt[0] = state.prompt 313 314 interaction_context = InteractionContext( 315 clear_prediction_cache=clear_prediction_cache, 316 rng=rng, 317 variation_rng=variation_rng, 318 ae=ae, 319 t5=t5, 320 clip=clip, 321 recompute_text_embeddings=recompute_text_embeddings, 322 ) 323 state = route_choices( 324 controller, 325 state, 326 interaction_context, 327 ) 328 329 # Generate preview 330 t0 = time.perf_counter() 331 if state.seed is not None: 332 rng.next_seed(state.seed) 333 else: 334 state.seed = rng.next_seed() 335 if ae is not None and state.width is not None and state.height is not None: 336 # Reuse cached prediction if available, otherwise generate it 337 # This will be cached and reused in denoise_step_fn when advancing 338 # Always use state.layer_dropout from controller to ensure consistency 339 pred_preview = get_prediction( 340 state.current_sample, 341 state.current_timestep, 342 state.guidance, 343 state.layer_dropout, 344 ) 345 346 intermediate = state.current_sample - state.current_timestep * pred_preview 347 # Unpack requires float32, but we'll convert back to correct dtype after 348 intermediate = unpack(intermediate.float(), state.height, state.width) 349 350 from divisor.registry import gfx_device as default_device 351 352 gfx_sync 353 t1 = time.perf_counter() 354 355 nfo(f"Step time: {t1 - t0:.1f}s") 356 357 if default_device.type == "cuda": 358 context = torch.autocast(device_type=default_device.type, dtype=torch.bfloat16) 359 else: 360 from contextlib import nullcontext 361 362 context = nullcontext() 363 with context: 364 # When autocast is disabled (MPS), ensure intermediate is in correct dtype for VAE 365 if default_device.type != "cuda": 366 # Get VAE encoder dtype to ensure intermediate matches (bfloat16) 367 # Safely get encoder dtype, handling Mock objects in tests 368 try: 369 ae_dtype = next(ae.encoder.parameters()).dtype 370 except (TypeError, StopIteration, AttributeError): 371 # Fallback: use intermediate dtype if we can't get encoder dtype (for Mock objects in tests) 372 ae_dtype = intermediate.dtype 373 intermediate = intermediate.to(dtype=ae_dtype) 374 375 # Apply VAE shift/scale offset by manually adjusting the decode operation 376 if state.vae_shift_offset != 0.0 or state.vae_scale_offset != 0.0: 377 # Decode with offset: z = z / (scale_factor + scale_offset) + (shift_factor + shift_offset) 378 z_adjusted = intermediate / (ae.scale_factor + state.vae_scale_offset) + (ae.shift_factor + state.vae_shift_offset) 379 intermediate_image = ae.decoder(z_adjusted) 380 else: 381 intermediate_image = ae.decode(intermediate) 382 if state.seed is not None: 383 controller.store_state_in_chain(current_seed=state.seed) 384 with SaveFile() as saver: 385 saver.intermediate_image = intermediate_image # set up image 386 saver.hyperchain = (controller.hyperchain,) # set up hyperchain 387 saver.with_hyperchain() 388 389 return controller.current_sample
Denoise using XFlux model with optional ManualTimestepController.
Parameters
- model: XFlux model instance
- settings: Denoising state containing all denoising configuration parameters