divisor.controller
Interactive denoising with manual timestep control. Allows users to manually increment through timesteps one at a time.
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3 4""" 5Interactive denoising with manual timestep control. 6Allows users to manually increment through timesteps one at a time. 7""" 8 9import json 10from dataclasses import asdict 11from typing import Any, Callable, Optional 12 13import torch 14from nnll.console import nfo 15from nnll.hyperchain import HyperChain 16from nnll.random import RNGState 17 18from divisor.interaction_context import InteractionContext 19from divisor.registry import gfx_device 20from divisor.state import MenuState, StepState 21 22rng = RNGState(device=gfx_device.type) 23variation_rng = RNGState(device=gfx_device.type) 24 25 26def time_shift( 27 schedule_mu: float, 28 schedule_sigma: float, 29 original_timestep_tensor: torch.Tensor, 30 desired_step_count: int, 31 compression_factor: float = 1.0, 32) -> torch.Tensor: 33 """Adjustable noise schedule. Compress or stretch any schedule to match a dynamic step sequence length.\n 34 :param schedule_mu: Original schedule parameter. 35 :param schedule_sigma: Original schedule parameter. 36 :param original_timestep_tensor: Tensor of original timesteps in [0,1]. 37 :param desired_step_count: Desired number of timesteps. 38 :param compression_factor: >1 compresses (fewer steps), <1 stretches (more steps). 39 :returns: Adjusted timestep tensor.""" 40 41 if desired_step_count == 1: 42 return original_timestep_tensor 43 44 new_index = (original_timestep_tensor * (desired_step_count - 1) * compression_factor).clamp(0, desired_step_count - 1) 45 scaled_timestep = new_index / (desired_step_count - 1) 46 scaled_timestep = torch.clamp(scaled_timestep, min=1e-8, max=1.0 - 1e-8) 47 exponential_mu = torch.exp(torch.tensor(schedule_mu)) 48 return exponential_mu / (exponential_mu + (1 / scaled_timestep - 1) ** torch.tensor(schedule_sigma)) 49 50 51def serialize_state_for_chain(state: "MenuState", current_seed: int) -> str: 52 """Serialize MenuState for HyperChain storage, excluding current_sample and adding current_seed.\n 53 :param state: The MenuState to serialize 54 :param current_seed: The current seed value to include instead of current_sample 55 :returns: JSON string representation of the state""" 56 57 state_dict = asdict(state) 58 if "timestep" in state_dict and isinstance(state_dict["timestep"], dict): 59 state_dict["timestep"].pop("current_sample", None) 60 if "timestep" in state_dict and isinstance(state_dict["timestep"], dict): 61 timestep_dict = state_dict.pop("timestep") 62 state_dict.update(timestep_dict) 63 state_dict["current_seed"] = current_seed 64 return json.dumps(state_dict, default=str) 65 66 67def reconstruct_state_from_dict(state_dict: dict, current_sample: torch.Tensor) -> "MenuState": 68 """Reconstruct MenuState from dictionary and current sample tensor. 69 70 :param state_dict: Dictionary containing state fields 71 :param current_sample: The current sample tensor to include in the state 72 :returns: Reconstructed MenuState object""" 73 step_state = StepState( 74 current_timestep=state_dict["current_timestep"], 75 previous_timestep=state_dict["previous_timestep"], 76 current_sample=current_sample, 77 timestep_index=state_dict["timestep_index"], 78 total_timesteps=state_dict["total_timesteps"], 79 ) 80 return MenuState( 81 step_state=step_state, 82 guidance=state_dict["guidance"], 83 layer_dropout=state_dict["layer_dropout"], 84 width=state_dict["width"], 85 height=state_dict["height"], 86 seed=state_dict["seed"], 87 prompt=state_dict["prompt"], 88 num_steps=state_dict["num_steps"], 89 vae_shift_offset=state_dict["vae_shift_offset"], 90 vae_scale_offset=state_dict["vae_scale_offset"], 91 use_previous_as_mask=state_dict["use_previous_as_mask"], 92 variation_seed=state_dict["variation_seed"], 93 variation_strength=state_dict["variation_strength"], 94 deterministic=state_dict["deterministic"], 95 ) 96 97 98class ManualTimestepController: 99 """ 100 Controller for manually stepping through denoising timesteps.\n 101 Instead of automatically processing all timesteps, this allows 102 the user to increment timesteps one at a time, with the ability 103 to intervene between steps.""" 104 105 hyperchain: HyperChain = HyperChain() 106 107 def __init__( 108 self, 109 timesteps: list[float], 110 initial_sample: Any, 111 denoise_step_fn: Callable[[Any, float, float, float], Any], 112 schedule_mu: float = 0.0, 113 schedule_sigma: float = 1.0, 114 initial_guidance: float = 7.5, 115 ) -> None: 116 """Manipulate denoising process.\n 117 :param timesteps: List of timestep values to process (typically from 1.0 to 0.0) 118 :param initial_sample: The initial noisy sample to start denoising from 119 :param denoise_step_fn: Function that performs one denoising step. Signature: (sample, t_curr, t_prev, guidance) -> new_sample 120 :param schedule_mu: Schedule parameter for time_shift (default: 0.0) 121 :param schedule_sigma: Schedule parameter for time_shift (default: 1.0) 122 :param initial_guidance: Initial guidance (CFG) value (default: 7.5) 123 :param hyperchain: Optional HyperChain instance for storing state history (default: None)""" 124 self.timesteps = timesteps 125 self.original_timesteps = timesteps.copy() 126 self.denoise_step_fn = denoise_step_fn 127 self.current_index = 0 128 self.current_sample = initial_sample 129 self.state_history: list[MenuState] = [] 130 self.schedule_mu = schedule_mu 131 self.schedule_sigma = schedule_sigma 132 self.guidance = initial_guidance 133 self.guidance_history: list[float] = [initial_guidance] 134 self.layer_dropout: Optional[list[int]] = None 135 self.layer_dropout_history: list[Optional[list[int]]] = [None] 136 self.width: Optional[int] = None 137 self.height: Optional[int] = None 138 self.seed: Optional[int] = None 139 self.prompt: Optional[str] = None 140 self.num_steps: Optional[int] = None 141 self.vae_shift_offset: float = 0.0 142 self.vae_scale_offset: float = 0.0 143 self.use_previous_as_mask: bool = False 144 self.variation_seed: Optional[int] = None 145 self.variation_strength: float = 0.0 146 self.deterministic: bool = False 147 self.rewind_steps: int = 0 148 149 @property 150 def is_complete(self) -> bool: 151 """Check if all timesteps have been processed.""" 152 return self.current_index >= len(self.timesteps) - 1 153 154 @property 155 def current_state(self) -> MenuState: 156 """Get the current state of the denoising process.""" 157 t_curr = self.timesteps[self.current_index] 158 t_prev = self.timesteps[self.current_index + 1] if self.current_index + 1 < len(self.timesteps) else None 159 160 step_state = StepState( 161 current_timestep=t_curr, 162 previous_timestep=t_prev, 163 current_sample=self.current_sample, 164 timestep_index=self.current_index, 165 total_timesteps=len(self.timesteps), 166 ) 167 168 return MenuState( 169 step_state=step_state, 170 guidance=self.guidance, 171 layer_dropout=self.layer_dropout, 172 width=self.width, 173 height=self.height, 174 seed=self.seed, 175 prompt=self.prompt, 176 num_steps=self.num_steps, 177 vae_shift_offset=self.vae_shift_offset, 178 vae_scale_offset=self.vae_scale_offset, 179 use_previous_as_mask=self.use_previous_as_mask, 180 variation_seed=self.variation_seed, 181 variation_strength=self.variation_strength, 182 deterministic=self.deterministic, 183 ) 184 185 def step(self) -> MenuState: 186 """Manually increment to the next timestep and perform one denoising step. Uses the current guidance value.\n 187 :returns: The new state after the step. 188 :raises ValueError: If all timesteps have already been processed.""" 189 if self.is_complete: 190 raise ValueError("All timesteps have been processed. Cannot step further.") 191 192 t_curr = self.timesteps[self.current_index] 193 t_prev = self.timesteps[self.current_index + 1] 194 195 self.current_sample = self.denoise_step_fn(self.current_sample, t_curr, t_prev, self.guidance) 196 self.current_index += 1 197 198 state = self.current_state 199 self.state_history.append(state) 200 self.guidance_history.append(self.guidance) 201 self.layer_dropout_history.append(self.layer_dropout) 202 203 return state 204 205 def rewind(self, num_steps: int = 1) -> None: 206 """Move the controller back `n` timesteps (if possible).\n 207 :param n: Number of timesteps to rewind""" 208 209 if num_steps < 0: 210 nfo("Rewind count must be non‑negative") 211 return 212 213 self.current_index = max(0, self.current_index - num_steps) 214 self.rewind_steps += num_steps 215 216 def set_guidance(self, guidance: float) -> None: 217 """Set the guidance value for the next denoising step.\n 218 :param guidance: New guidance (CFG) value to use. Typically ranges from 1.0 to 20.0. Higher values provide stronger adherence to the conditioning.""" 219 self.guidance = guidance 220 221 def set_layer_dropout(self, layer_dropout: Optional[list[int]]): 222 """Set the layer dropout configuration for the next denoising step.\n 223 :param layer_dropout: List of block indices to skip during inference, or None to skip none. Blocks are indexed starting from 0.""" 224 self.layer_dropout = layer_dropout 225 226 def set_resolution(self, width: int, height: int) -> None: 227 """Set the width and height resolution for the denoising process.\n 228 :param width: Width in pixels 229 :param height: Height in pixels""" 230 self.width = width 231 self.height = height 232 233 def set_seed(self, seed: int) -> None: 234 """Set the seed value for the denoising process.\n 235 :param seed: Seed value for random number generation""" 236 self.seed = seed 237 238 def set_prompt(self, prompt: str) -> None: 239 """Set the prompt text for the denoising process.\n 240 :param prompt: Prompt text""" 241 self.prompt = prompt 242 243 def set_num_steps(self, num_steps: int) -> None: 244 """Set the number of steps for the denoising process.\n 245 :param num_steps: Number of denoising steps""" 246 self.num_steps = num_steps 247 248 def set_vae_shift_offset(self, offset: float) -> None: 249 """Set the VAE shift offset for autoencoder decode.\n 250 :param offset: Offset to add to shift_factor in autoencoder decode""" 251 self.vae_shift_offset = offset 252 253 def set_vae_scale_offset(self, offset: float) -> None: 254 """Set the VAE scale offset for autoencoder decode.\n 255 :param offset: Offset to add to scale_factor in autoencoder decode""" 256 self.vae_scale_offset = offset 257 258 def set_use_previous_as_mask(self, use_mask: bool) -> None: 259 """Set whether to use previous step tensor as mask.\n 260 :param use_mask: Whether to use previous step tensor as mask for next step""" 261 self.use_previous_as_mask = use_mask 262 263 def set_variation_seed(self, seed: int | None = None) -> None: 264 """Set the variation seed for adding variation noise.\n 265 :param seed: Variation seed value, or None to disable""" 266 self.variation_seed = seed 267 268 def set_variation_strength(self, strength: float) -> None: 269 """Set the variation strength (0.0 to 1.0).\n 270 :param strength: Variation strength, where 0.0 = no variation, 1.0 = full variation""" 271 self.variation_strength = max(0.0, min(1.0, strength)) 272 273 def set_deterministic(self, deterministic: bool) -> None: 274 """Set deterministic mode for PyTorch operations.\n 275 :param deterministic: Whether to use deterministic algorithms (False = non-deterministic, True = deterministic)""" 276 277 self.deterministic = deterministic 278 279 def store_state_in_chain(self, current_seed: int | None = None, serialized_state_int: int | None = None) -> Optional[Any]: 280 """Store the current MenuState in HyperChain, excluding current_sample and adding current_seed. 281 282 :param current_seed: The current seed value to include instead of current_sample. Required if serialized_state_int is None. 283 :param serialized_state_int: Optional pre-serialized state as integer. If provided, current_seed is ignored. 284 :returns: The created Block if hyperchain is configured, None otherwise 285 """ 286 287 if serialized_state_int is not None: 288 json_bytes = serialized_state_int.to_bytes((serialized_state_int.bit_length() + 7) // 8, byteorder="big", signed=False) 289 serialized = json_bytes.decode("utf-8") 290 else: 291 if current_seed is None: 292 raise ValueError("Either current_seed or serialized_state_int must be provided") 293 state = self.current_state 294 serialized = serialize_state_for_chain(state, current_seed) 295 296 return self.hyperchain.add_block(serialized) 297 298 299def update_state_and_cache( 300 controller: ManualTimestepController, 301 setter_func: Callable, 302 value: Any, 303 interaction_context: InteractionContext, 304 success_message: str, 305) -> MenuState: 306 """Generic state update helper that sets value, clears cache, and refreshes state.\n 307 :param controller: ManualTimestepController instance 308 :param setter_func: Controller setter method to call 309 :param value: Value to set 310 :param clear_prediction_cache: Function to clear prediction cache 311 :param success_message: Message to display on success 312 :returns: Updated MenuState""" 313 setter_func(value) 314 interaction_context.clear_prediction_cache() 315 state = controller.current_state 316 nfo(success_message) 317 return state
27def time_shift( 28 schedule_mu: float, 29 schedule_sigma: float, 30 original_timestep_tensor: torch.Tensor, 31 desired_step_count: int, 32 compression_factor: float = 1.0, 33) -> torch.Tensor: 34 """Adjustable noise schedule. Compress or stretch any schedule to match a dynamic step sequence length.\n 35 :param schedule_mu: Original schedule parameter. 36 :param schedule_sigma: Original schedule parameter. 37 :param original_timestep_tensor: Tensor of original timesteps in [0,1]. 38 :param desired_step_count: Desired number of timesteps. 39 :param compression_factor: >1 compresses (fewer steps), <1 stretches (more steps). 40 :returns: Adjusted timestep tensor.""" 41 42 if desired_step_count == 1: 43 return original_timestep_tensor 44 45 new_index = (original_timestep_tensor * (desired_step_count - 1) * compression_factor).clamp(0, desired_step_count - 1) 46 scaled_timestep = new_index / (desired_step_count - 1) 47 scaled_timestep = torch.clamp(scaled_timestep, min=1e-8, max=1.0 - 1e-8) 48 exponential_mu = torch.exp(torch.tensor(schedule_mu)) 49 return exponential_mu / (exponential_mu + (1 / scaled_timestep - 1) ** torch.tensor(schedule_sigma))
Adjustable noise schedule. Compress or stretch any schedule to match a dynamic step sequence length.
Parameters
- schedule_mu: Original schedule parameter.
- schedule_sigma: Original schedule parameter.
- original_timestep_tensor: Tensor of original timesteps in [0,1].
- desired_step_count: Desired number of timesteps.
- compression_factor: >1 compresses (fewer steps), <1 stretches (more steps). :returns: Adjusted timestep tensor.
52def serialize_state_for_chain(state: "MenuState", current_seed: int) -> str: 53 """Serialize MenuState for HyperChain storage, excluding current_sample and adding current_seed.\n 54 :param state: The MenuState to serialize 55 :param current_seed: The current seed value to include instead of current_sample 56 :returns: JSON string representation of the state""" 57 58 state_dict = asdict(state) 59 if "timestep" in state_dict and isinstance(state_dict["timestep"], dict): 60 state_dict["timestep"].pop("current_sample", None) 61 if "timestep" in state_dict and isinstance(state_dict["timestep"], dict): 62 timestep_dict = state_dict.pop("timestep") 63 state_dict.update(timestep_dict) 64 state_dict["current_seed"] = current_seed 65 return json.dumps(state_dict, default=str)
Serialize MenuState for HyperChain storage, excluding current_sample and adding current_seed.
Parameters
- state: The MenuState to serialize
- current_seed: The current seed value to include instead of current_sample :returns: JSON string representation of the state
68def reconstruct_state_from_dict(state_dict: dict, current_sample: torch.Tensor) -> "MenuState": 69 """Reconstruct MenuState from dictionary and current sample tensor. 70 71 :param state_dict: Dictionary containing state fields 72 :param current_sample: The current sample tensor to include in the state 73 :returns: Reconstructed MenuState object""" 74 step_state = StepState( 75 current_timestep=state_dict["current_timestep"], 76 previous_timestep=state_dict["previous_timestep"], 77 current_sample=current_sample, 78 timestep_index=state_dict["timestep_index"], 79 total_timesteps=state_dict["total_timesteps"], 80 ) 81 return MenuState( 82 step_state=step_state, 83 guidance=state_dict["guidance"], 84 layer_dropout=state_dict["layer_dropout"], 85 width=state_dict["width"], 86 height=state_dict["height"], 87 seed=state_dict["seed"], 88 prompt=state_dict["prompt"], 89 num_steps=state_dict["num_steps"], 90 vae_shift_offset=state_dict["vae_shift_offset"], 91 vae_scale_offset=state_dict["vae_scale_offset"], 92 use_previous_as_mask=state_dict["use_previous_as_mask"], 93 variation_seed=state_dict["variation_seed"], 94 variation_strength=state_dict["variation_strength"], 95 deterministic=state_dict["deterministic"], 96 )
Reconstruct MenuState from dictionary and current sample tensor.
Parameters
- state_dict: Dictionary containing state fields
- current_sample: The current sample tensor to include in the state :returns: Reconstructed MenuState object
99class ManualTimestepController: 100 """ 101 Controller for manually stepping through denoising timesteps.\n 102 Instead of automatically processing all timesteps, this allows 103 the user to increment timesteps one at a time, with the ability 104 to intervene between steps.""" 105 106 hyperchain: HyperChain = HyperChain() 107 108 def __init__( 109 self, 110 timesteps: list[float], 111 initial_sample: Any, 112 denoise_step_fn: Callable[[Any, float, float, float], Any], 113 schedule_mu: float = 0.0, 114 schedule_sigma: float = 1.0, 115 initial_guidance: float = 7.5, 116 ) -> None: 117 """Manipulate denoising process.\n 118 :param timesteps: List of timestep values to process (typically from 1.0 to 0.0) 119 :param initial_sample: The initial noisy sample to start denoising from 120 :param denoise_step_fn: Function that performs one denoising step. Signature: (sample, t_curr, t_prev, guidance) -> new_sample 121 :param schedule_mu: Schedule parameter for time_shift (default: 0.0) 122 :param schedule_sigma: Schedule parameter for time_shift (default: 1.0) 123 :param initial_guidance: Initial guidance (CFG) value (default: 7.5) 124 :param hyperchain: Optional HyperChain instance for storing state history (default: None)""" 125 self.timesteps = timesteps 126 self.original_timesteps = timesteps.copy() 127 self.denoise_step_fn = denoise_step_fn 128 self.current_index = 0 129 self.current_sample = initial_sample 130 self.state_history: list[MenuState] = [] 131 self.schedule_mu = schedule_mu 132 self.schedule_sigma = schedule_sigma 133 self.guidance = initial_guidance 134 self.guidance_history: list[float] = [initial_guidance] 135 self.layer_dropout: Optional[list[int]] = None 136 self.layer_dropout_history: list[Optional[list[int]]] = [None] 137 self.width: Optional[int] = None 138 self.height: Optional[int] = None 139 self.seed: Optional[int] = None 140 self.prompt: Optional[str] = None 141 self.num_steps: Optional[int] = None 142 self.vae_shift_offset: float = 0.0 143 self.vae_scale_offset: float = 0.0 144 self.use_previous_as_mask: bool = False 145 self.variation_seed: Optional[int] = None 146 self.variation_strength: float = 0.0 147 self.deterministic: bool = False 148 self.rewind_steps: int = 0 149 150 @property 151 def is_complete(self) -> bool: 152 """Check if all timesteps have been processed.""" 153 return self.current_index >= len(self.timesteps) - 1 154 155 @property 156 def current_state(self) -> MenuState: 157 """Get the current state of the denoising process.""" 158 t_curr = self.timesteps[self.current_index] 159 t_prev = self.timesteps[self.current_index + 1] if self.current_index + 1 < len(self.timesteps) else None 160 161 step_state = StepState( 162 current_timestep=t_curr, 163 previous_timestep=t_prev, 164 current_sample=self.current_sample, 165 timestep_index=self.current_index, 166 total_timesteps=len(self.timesteps), 167 ) 168 169 return MenuState( 170 step_state=step_state, 171 guidance=self.guidance, 172 layer_dropout=self.layer_dropout, 173 width=self.width, 174 height=self.height, 175 seed=self.seed, 176 prompt=self.prompt, 177 num_steps=self.num_steps, 178 vae_shift_offset=self.vae_shift_offset, 179 vae_scale_offset=self.vae_scale_offset, 180 use_previous_as_mask=self.use_previous_as_mask, 181 variation_seed=self.variation_seed, 182 variation_strength=self.variation_strength, 183 deterministic=self.deterministic, 184 ) 185 186 def step(self) -> MenuState: 187 """Manually increment to the next timestep and perform one denoising step. Uses the current guidance value.\n 188 :returns: The new state after the step. 189 :raises ValueError: If all timesteps have already been processed.""" 190 if self.is_complete: 191 raise ValueError("All timesteps have been processed. Cannot step further.") 192 193 t_curr = self.timesteps[self.current_index] 194 t_prev = self.timesteps[self.current_index + 1] 195 196 self.current_sample = self.denoise_step_fn(self.current_sample, t_curr, t_prev, self.guidance) 197 self.current_index += 1 198 199 state = self.current_state 200 self.state_history.append(state) 201 self.guidance_history.append(self.guidance) 202 self.layer_dropout_history.append(self.layer_dropout) 203 204 return state 205 206 def rewind(self, num_steps: int = 1) -> None: 207 """Move the controller back `n` timesteps (if possible).\n 208 :param n: Number of timesteps to rewind""" 209 210 if num_steps < 0: 211 nfo("Rewind count must be non‑negative") 212 return 213 214 self.current_index = max(0, self.current_index - num_steps) 215 self.rewind_steps += num_steps 216 217 def set_guidance(self, guidance: float) -> None: 218 """Set the guidance value for the next denoising step.\n 219 :param guidance: New guidance (CFG) value to use. Typically ranges from 1.0 to 20.0. Higher values provide stronger adherence to the conditioning.""" 220 self.guidance = guidance 221 222 def set_layer_dropout(self, layer_dropout: Optional[list[int]]): 223 """Set the layer dropout configuration for the next denoising step.\n 224 :param layer_dropout: List of block indices to skip during inference, or None to skip none. Blocks are indexed starting from 0.""" 225 self.layer_dropout = layer_dropout 226 227 def set_resolution(self, width: int, height: int) -> None: 228 """Set the width and height resolution for the denoising process.\n 229 :param width: Width in pixels 230 :param height: Height in pixels""" 231 self.width = width 232 self.height = height 233 234 def set_seed(self, seed: int) -> None: 235 """Set the seed value for the denoising process.\n 236 :param seed: Seed value for random number generation""" 237 self.seed = seed 238 239 def set_prompt(self, prompt: str) -> None: 240 """Set the prompt text for the denoising process.\n 241 :param prompt: Prompt text""" 242 self.prompt = prompt 243 244 def set_num_steps(self, num_steps: int) -> None: 245 """Set the number of steps for the denoising process.\n 246 :param num_steps: Number of denoising steps""" 247 self.num_steps = num_steps 248 249 def set_vae_shift_offset(self, offset: float) -> None: 250 """Set the VAE shift offset for autoencoder decode.\n 251 :param offset: Offset to add to shift_factor in autoencoder decode""" 252 self.vae_shift_offset = offset 253 254 def set_vae_scale_offset(self, offset: float) -> None: 255 """Set the VAE scale offset for autoencoder decode.\n 256 :param offset: Offset to add to scale_factor in autoencoder decode""" 257 self.vae_scale_offset = offset 258 259 def set_use_previous_as_mask(self, use_mask: bool) -> None: 260 """Set whether to use previous step tensor as mask.\n 261 :param use_mask: Whether to use previous step tensor as mask for next step""" 262 self.use_previous_as_mask = use_mask 263 264 def set_variation_seed(self, seed: int | None = None) -> None: 265 """Set the variation seed for adding variation noise.\n 266 :param seed: Variation seed value, or None to disable""" 267 self.variation_seed = seed 268 269 def set_variation_strength(self, strength: float) -> None: 270 """Set the variation strength (0.0 to 1.0).\n 271 :param strength: Variation strength, where 0.0 = no variation, 1.0 = full variation""" 272 self.variation_strength = max(0.0, min(1.0, strength)) 273 274 def set_deterministic(self, deterministic: bool) -> None: 275 """Set deterministic mode for PyTorch operations.\n 276 :param deterministic: Whether to use deterministic algorithms (False = non-deterministic, True = deterministic)""" 277 278 self.deterministic = deterministic 279 280 def store_state_in_chain(self, current_seed: int | None = None, serialized_state_int: int | None = None) -> Optional[Any]: 281 """Store the current MenuState in HyperChain, excluding current_sample and adding current_seed. 282 283 :param current_seed: The current seed value to include instead of current_sample. Required if serialized_state_int is None. 284 :param serialized_state_int: Optional pre-serialized state as integer. If provided, current_seed is ignored. 285 :returns: The created Block if hyperchain is configured, None otherwise 286 """ 287 288 if serialized_state_int is not None: 289 json_bytes = serialized_state_int.to_bytes((serialized_state_int.bit_length() + 7) // 8, byteorder="big", signed=False) 290 serialized = json_bytes.decode("utf-8") 291 else: 292 if current_seed is None: 293 raise ValueError("Either current_seed or serialized_state_int must be provided") 294 state = self.current_state 295 serialized = serialize_state_for_chain(state, current_seed) 296 297 return self.hyperchain.add_block(serialized)
Controller for manually stepping through denoising timesteps.
Instead of automatically processing all timesteps, this allows the user to increment timesteps one at a time, with the ability to intervene between steps.
108 def __init__( 109 self, 110 timesteps: list[float], 111 initial_sample: Any, 112 denoise_step_fn: Callable[[Any, float, float, float], Any], 113 schedule_mu: float = 0.0, 114 schedule_sigma: float = 1.0, 115 initial_guidance: float = 7.5, 116 ) -> None: 117 """Manipulate denoising process.\n 118 :param timesteps: List of timestep values to process (typically from 1.0 to 0.0) 119 :param initial_sample: The initial noisy sample to start denoising from 120 :param denoise_step_fn: Function that performs one denoising step. Signature: (sample, t_curr, t_prev, guidance) -> new_sample 121 :param schedule_mu: Schedule parameter for time_shift (default: 0.0) 122 :param schedule_sigma: Schedule parameter for time_shift (default: 1.0) 123 :param initial_guidance: Initial guidance (CFG) value (default: 7.5) 124 :param hyperchain: Optional HyperChain instance for storing state history (default: None)""" 125 self.timesteps = timesteps 126 self.original_timesteps = timesteps.copy() 127 self.denoise_step_fn = denoise_step_fn 128 self.current_index = 0 129 self.current_sample = initial_sample 130 self.state_history: list[MenuState] = [] 131 self.schedule_mu = schedule_mu 132 self.schedule_sigma = schedule_sigma 133 self.guidance = initial_guidance 134 self.guidance_history: list[float] = [initial_guidance] 135 self.layer_dropout: Optional[list[int]] = None 136 self.layer_dropout_history: list[Optional[list[int]]] = [None] 137 self.width: Optional[int] = None 138 self.height: Optional[int] = None 139 self.seed: Optional[int] = None 140 self.prompt: Optional[str] = None 141 self.num_steps: Optional[int] = None 142 self.vae_shift_offset: float = 0.0 143 self.vae_scale_offset: float = 0.0 144 self.use_previous_as_mask: bool = False 145 self.variation_seed: Optional[int] = None 146 self.variation_strength: float = 0.0 147 self.deterministic: bool = False 148 self.rewind_steps: int = 0
Manipulate denoising process.
Parameters
- timesteps: List of timestep values to process (typically from 1.0 to 0.0)
- initial_sample: The initial noisy sample to start denoising from
- denoise_step_fn: Function that performs one denoising step. Signature: (sample, t_curr, t_prev, guidance) -> new_sample
- schedule_mu: Schedule parameter for time_shift (default: 0.0)
- schedule_sigma: Schedule parameter for time_shift (default: 1.0)
- initial_guidance: Initial guidance (CFG) value (default: 7.5)
- hyperchain: Optional HyperChain instance for storing state history (default: None)
150 @property 151 def is_complete(self) -> bool: 152 """Check if all timesteps have been processed.""" 153 return self.current_index >= len(self.timesteps) - 1
Check if all timesteps have been processed.
155 @property 156 def current_state(self) -> MenuState: 157 """Get the current state of the denoising process.""" 158 t_curr = self.timesteps[self.current_index] 159 t_prev = self.timesteps[self.current_index + 1] if self.current_index + 1 < len(self.timesteps) else None 160 161 step_state = StepState( 162 current_timestep=t_curr, 163 previous_timestep=t_prev, 164 current_sample=self.current_sample, 165 timestep_index=self.current_index, 166 total_timesteps=len(self.timesteps), 167 ) 168 169 return MenuState( 170 step_state=step_state, 171 guidance=self.guidance, 172 layer_dropout=self.layer_dropout, 173 width=self.width, 174 height=self.height, 175 seed=self.seed, 176 prompt=self.prompt, 177 num_steps=self.num_steps, 178 vae_shift_offset=self.vae_shift_offset, 179 vae_scale_offset=self.vae_scale_offset, 180 use_previous_as_mask=self.use_previous_as_mask, 181 variation_seed=self.variation_seed, 182 variation_strength=self.variation_strength, 183 deterministic=self.deterministic, 184 )
Get the current state of the denoising process.
186 def step(self) -> MenuState: 187 """Manually increment to the next timestep and perform one denoising step. Uses the current guidance value.\n 188 :returns: The new state after the step. 189 :raises ValueError: If all timesteps have already been processed.""" 190 if self.is_complete: 191 raise ValueError("All timesteps have been processed. Cannot step further.") 192 193 t_curr = self.timesteps[self.current_index] 194 t_prev = self.timesteps[self.current_index + 1] 195 196 self.current_sample = self.denoise_step_fn(self.current_sample, t_curr, t_prev, self.guidance) 197 self.current_index += 1 198 199 state = self.current_state 200 self.state_history.append(state) 201 self.guidance_history.append(self.guidance) 202 self.layer_dropout_history.append(self.layer_dropout) 203 204 return state
Manually increment to the next timestep and perform one denoising step. Uses the current guidance value.
:returns: The new state after the step.
Raises
- ValueError: If all timesteps have already been processed.
206 def rewind(self, num_steps: int = 1) -> None: 207 """Move the controller back `n` timesteps (if possible).\n 208 :param n: Number of timesteps to rewind""" 209 210 if num_steps < 0: 211 nfo("Rewind count must be non‑negative") 212 return 213 214 self.current_index = max(0, self.current_index - num_steps) 215 self.rewind_steps += num_steps
Move the controller back n timesteps (if possible).
Parameters
- n: Number of timesteps to rewind
217 def set_guidance(self, guidance: float) -> None: 218 """Set the guidance value for the next denoising step.\n 219 :param guidance: New guidance (CFG) value to use. Typically ranges from 1.0 to 20.0. Higher values provide stronger adherence to the conditioning.""" 220 self.guidance = guidance
Set the guidance value for the next denoising step.
Parameters
- guidance: New guidance (CFG) value to use. Typically ranges from 1.0 to 20.0. Higher values provide stronger adherence to the conditioning.
222 def set_layer_dropout(self, layer_dropout: Optional[list[int]]): 223 """Set the layer dropout configuration for the next denoising step.\n 224 :param layer_dropout: List of block indices to skip during inference, or None to skip none. Blocks are indexed starting from 0.""" 225 self.layer_dropout = layer_dropout
Set the layer dropout configuration for the next denoising step.
Parameters
- layer_dropout: List of block indices to skip during inference, or None to skip none. Blocks are indexed starting from 0.
227 def set_resolution(self, width: int, height: int) -> None: 228 """Set the width and height resolution for the denoising process.\n 229 :param width: Width in pixels 230 :param height: Height in pixels""" 231 self.width = width 232 self.height = height
Set the width and height resolution for the denoising process.
Parameters
- width: Width in pixels
- height: Height in pixels
234 def set_seed(self, seed: int) -> None: 235 """Set the seed value for the denoising process.\n 236 :param seed: Seed value for random number generation""" 237 self.seed = seed
Set the seed value for the denoising process.
Parameters
- seed: Seed value for random number generation
239 def set_prompt(self, prompt: str) -> None: 240 """Set the prompt text for the denoising process.\n 241 :param prompt: Prompt text""" 242 self.prompt = prompt
Set the prompt text for the denoising process.
Parameters
- prompt: Prompt text
244 def set_num_steps(self, num_steps: int) -> None: 245 """Set the number of steps for the denoising process.\n 246 :param num_steps: Number of denoising steps""" 247 self.num_steps = num_steps
Set the number of steps for the denoising process.
Parameters
- num_steps: Number of denoising steps
249 def set_vae_shift_offset(self, offset: float) -> None: 250 """Set the VAE shift offset for autoencoder decode.\n 251 :param offset: Offset to add to shift_factor in autoencoder decode""" 252 self.vae_shift_offset = offset
Set the VAE shift offset for autoencoder decode.
Parameters
- offset: Offset to add to shift_factor in autoencoder decode
254 def set_vae_scale_offset(self, offset: float) -> None: 255 """Set the VAE scale offset for autoencoder decode.\n 256 :param offset: Offset to add to scale_factor in autoencoder decode""" 257 self.vae_scale_offset = offset
Set the VAE scale offset for autoencoder decode.
Parameters
- offset: Offset to add to scale_factor in autoencoder decode
259 def set_use_previous_as_mask(self, use_mask: bool) -> None: 260 """Set whether to use previous step tensor as mask.\n 261 :param use_mask: Whether to use previous step tensor as mask for next step""" 262 self.use_previous_as_mask = use_mask
Set whether to use previous step tensor as mask.
Parameters
- use_mask: Whether to use previous step tensor as mask for next step
264 def set_variation_seed(self, seed: int | None = None) -> None: 265 """Set the variation seed for adding variation noise.\n 266 :param seed: Variation seed value, or None to disable""" 267 self.variation_seed = seed
Set the variation seed for adding variation noise.
Parameters
- seed: Variation seed value, or None to disable
269 def set_variation_strength(self, strength: float) -> None: 270 """Set the variation strength (0.0 to 1.0).\n 271 :param strength: Variation strength, where 0.0 = no variation, 1.0 = full variation""" 272 self.variation_strength = max(0.0, min(1.0, strength))
Set the variation strength (0.0 to 1.0).
Parameters
- strength: Variation strength, where 0.0 = no variation, 1.0 = full variation
274 def set_deterministic(self, deterministic: bool) -> None: 275 """Set deterministic mode for PyTorch operations.\n 276 :param deterministic: Whether to use deterministic algorithms (False = non-deterministic, True = deterministic)""" 277 278 self.deterministic = deterministic
Set deterministic mode for PyTorch operations.
Parameters
- deterministic: Whether to use deterministic algorithms (False = non-deterministic, True = deterministic)
280 def store_state_in_chain(self, current_seed: int | None = None, serialized_state_int: int | None = None) -> Optional[Any]: 281 """Store the current MenuState in HyperChain, excluding current_sample and adding current_seed. 282 283 :param current_seed: The current seed value to include instead of current_sample. Required if serialized_state_int is None. 284 :param serialized_state_int: Optional pre-serialized state as integer. If provided, current_seed is ignored. 285 :returns: The created Block if hyperchain is configured, None otherwise 286 """ 287 288 if serialized_state_int is not None: 289 json_bytes = serialized_state_int.to_bytes((serialized_state_int.bit_length() + 7) // 8, byteorder="big", signed=False) 290 serialized = json_bytes.decode("utf-8") 291 else: 292 if current_seed is None: 293 raise ValueError("Either current_seed or serialized_state_int must be provided") 294 state = self.current_state 295 serialized = serialize_state_for_chain(state, current_seed) 296 297 return self.hyperchain.add_block(serialized)
Store the current MenuState in HyperChain, excluding current_sample and adding current_seed.
Parameters
- current_seed: The current seed value to include instead of current_sample. Required if serialized_state_int is None.
- serialized_state_int: Optional pre-serialized state as integer. If provided, current_seed is ignored. :returns: The created Block if hyperchain is configured, None otherwise
300def update_state_and_cache( 301 controller: ManualTimestepController, 302 setter_func: Callable, 303 value: Any, 304 interaction_context: InteractionContext, 305 success_message: str, 306) -> MenuState: 307 """Generic state update helper that sets value, clears cache, and refreshes state.\n 308 :param controller: ManualTimestepController instance 309 :param setter_func: Controller setter method to call 310 :param value: Value to set 311 :param clear_prediction_cache: Function to clear prediction cache 312 :param success_message: Message to display on success 313 :returns: Updated MenuState""" 314 setter_func(value) 315 interaction_context.clear_prediction_cache() 316 state = controller.current_state 317 nfo(success_message) 318 return state
Generic state update helper that sets value, clears cache, and refreshes state.
Parameters
- controller: ManualTimestepController instance
- setter_func: Controller setter method to call
- value: Value to set
- clear_prediction_cache: Function to clear prediction cache
- success_message: Message to display on success :returns: Updated MenuState