divisor.state
Divisor class definitions and configuration dataclasses.
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3 4"""Divisor class definitions and configuration dataclasses.""" 5 6from dataclasses import dataclass, field, replace 7from typing import Any, Dict, List, Optional 8 9import torch 10from torch import Tensor 11 12 13@dataclass 14class StepState: 15 """Runtime state that changes at each denoising step.""" 16 17 current_timestep: float 18 current_sample: torch.Tensor 19 timestep_index: int 20 total_timesteps: int 21 previous_timestep: float | None = None 22 23 def with_runtime_state( 24 self, 25 current_timestep: float | None = None, 26 current_sample: torch.Tensor | None = None, 27 timestep_index: int | None = None, 28 total_timesteps: int | None = None, 29 previous_timestep: float | None = None, 30 ) -> "StepState": 31 """Create a new class with updated runtime fields. 32 33 :param current_timestep: Current timestep value (if None, keeps existing) 34 :param current_sample: Current sample tensor (if None, keeps existing) 35 :param timestep_index: Current timestep index (if None, keeps existing) 36 :param total_timesteps: Total number of timesteps (if None, keeps existing) 37 :param previous_timestep: Previous timestep value (if None, keeps existing) 38 :returns: New class with updated fields""" 39 return replace( 40 self, 41 current_timestep=current_timestep if current_timestep is not None else self.current_timestep, 42 current_sample=current_sample if current_sample is not None else self.current_sample, 43 timestep_index=timestep_index if timestep_index is not None else self.timestep_index, 44 total_timesteps=total_timesteps if total_timesteps is not None else self.total_timesteps, 45 previous_timestep=previous_timestep if previous_timestep is not None else self.previous_timestep, 46 ) 47 48 49@dataclass 50class MenuState: 51 """State of the input at a given timestep.\n 52 This is the single source of truth for denoising configuration and runtime state. 53 Use from_cli_args() to create from command-line arguments, and with_runtime_state() 54 to update runtime fields during denoising.""" 55 56 step_state: StepState # Runtime state (changes every step) 57 guidance: float # Configuration (set at start, may change via controller) 58 num_steps: int 59 deterministic: bool = False 60 height: int = 1024 61 layer_dropout: List[int] | None = None 62 neg_prompt: str | None = None 63 prompt: str | None = None 64 seed: int = 0 65 use_previous_as_mask: bool = False 66 vae_scale_offset: float = 0.0 67 vae_shift_offset: float = 0.0 68 variation_seed: int = 0 69 variation_strength: float = 0.0 70 width: int = 1024 71 72 # Convenience properties for backward compatibility 73 @property 74 def current_timestep(self) -> float: 75 """Current timestep value.""" 76 return self.step_state.current_timestep 77 78 @property 79 def current_sample(self) -> torch.Tensor: 80 """Current sample tensor.""" 81 return self.step_state.current_sample 82 83 @property 84 def timestep_index(self) -> int: 85 """Current timestep index.""" 86 return self.step_state.timestep_index 87 88 @property 89 def total_timesteps(self) -> int: 90 """Total number of timesteps.""" 91 return self.step_state.total_timesteps 92 93 @property 94 def previous_timestep(self) -> float | None: 95 """Previous timestep value.""" 96 return self.step_state.previous_timestep 97 98 @classmethod 99 def from_cli_args( 100 cls, 101 prompt: str, 102 width: int, 103 height: int, 104 num_steps: int, 105 guidance: float, 106 seed: int = 0, 107 neg_prompt: str | None = None, 108 current_sample: torch.Tensor | None = None, 109 timesteps: List[float] | None = None, 110 **kwargs, 111 ) -> "MenuState": 112 """Create input state from CLI arguments. 113 114 :param prompt: Text prompt 115 :param width: Image width 116 :param height: Image height 117 :param num_steps: Number of denoising steps 118 :param guidance: Guidance scale 119 :param seed: Random seed 120 :param neg_prompt: Negative prompt 121 :param current_sample: Initial sample tensor (if available, otherwise empty tensor) 122 :param timesteps: List of timesteps (if available, used to determine total_timesteps) 123 :param kwargs: Additional fields to set (e.g., deterministic, vae_shift_offset, etc.) 124 :returns: Input state initialized from CLI args""" 125 total_timesteps = len(timesteps) if timesteps else num_steps 126 127 step_state = StepState( 128 current_timestep=0.0, 129 previous_timestep=None, 130 current_sample=current_sample if current_sample is not None else torch.empty(0), 131 timestep_index=0, 132 total_timesteps=total_timesteps, 133 ) 134 135 return cls( 136 step_state=step_state, 137 guidance=guidance, 138 width=width, 139 height=height, 140 seed=seed, 141 prompt=prompt, 142 num_steps=num_steps, 143 neg_prompt=neg_prompt, 144 deterministic=bool(torch.get_deterministic_debug_mode()) if "deterministic" not in kwargs else kwargs.pop("deterministic"), 145 **kwargs, 146 ) 147 148 def with_runtime_state( 149 self, 150 current_timestep: float, 151 current_sample: torch.Tensor, 152 timestep_index: int, 153 total_timesteps: int | None = None, 154 previous_timestep: float | None = None, 155 ) -> "MenuState": 156 """Create a new input state with updated runtime state and timestep state.\n 157 :param current_timestep: Current timestep value 158 :param current_sample: Current sample tensor 159 :param timestep_index: Current timestep index 160 :param total_timesteps: Total number of timesteps (if changed) 161 :param previous_timestep: Previous timestep value 162 :returns: New input state with updated runtime fields""" 163 new_timestep = self.step_state.with_runtime_state( 164 current_timestep=current_timestep, 165 current_sample=current_sample, 166 timestep_index=timestep_index, 167 total_timesteps=total_timesteps, 168 previous_timestep=previous_timestep, 169 ) 170 return replace(self, step_state=new_timestep) 171 172 173@dataclass 174class ImageEmbeddingState: 175 """Image-related configuration for prompt function creation.""" 176 177 img_ids: Tensor 178 img: Tensor 179 img_cond: Tensor | None = None 180 img_cond_seq: Tensor | None = None 181 img_cond_seq_ids: Tensor | None = None 182 image_proj: Tensor | None = None 183 neg_image_proj: Tensor | None = None 184 ip_scale: Tensor | None = None 185 neg_ip_scale: Tensor | None = None 186 187 188@dataclass 189class TextEmbeddingState: 190 """Base configuration class for text prompt function creation.""" 191 192 model_ref: List[Any] 193 state: Any 194 current_txt: List[Tensor] 195 current_txt_ids: List[Tensor] 196 current_vec: List[Tensor] 197 cached_prediction: List[Optional[Tensor]] = field(default_factory=lambda: [None]) 198 cached_prediction_state: List[Optional[Dict]] = field(default_factory=lambda: [None]) 199 neg_pred_enabled: bool = False 200 current_neg_txt: Tensor | str | None = "" 201 current_neg_txt_ids: Tensor | str | None = "" 202 current_neg_vec: Tensor | str | None = "" 203 true_gs: float | None = None 204 205 206@dataclass 207class StepStateXFlux1: 208 """Additional configuration for XFlux1-specific prediction settings.""" 209 210 timestep_to_start_cfg: int 211 current_timestep_index: List[int] 212 213 214@dataclass 215class InferenceState: 216 """Base configuration class for denoise function parameters.\n 217 Single source of truth for tensor-related operations""" 218 219 # Model and core inputs (required) 220 img: Tensor 221 img_ids: Tensor 222 txt: Tensor 223 txt_ids: Tensor 224 state: MenuState 225 ae: Any # AutoEncoder 226 timesteps: List[float] 227 vec: Tensor | None = None # CLIP embeddings (Flux1/XFlux1) 228 neg_pred_enabled: bool = False 229 neg_txt: Tensor | None = None 230 neg_txt_ids: Tensor | None = None 231 neg_vec: Tensor | None = None 232 true_gs: float = 1.0 233 234 # Text embedders for prompt changes 235 t5: Any | None = None # T5 embedder (Flux1/XFlux1) 236 clip: Any | None = None # CLIP embedder (Flux1/XFlux1) 237 text_embedder: Any | None = None # Mistral embedder (Flux2) 238 239 img_cond: Tensor | None = None # Channel-wise image conditioning (Flux1 only) 240 img_cond_seq: Tensor | None = None # Sequence-wise image conditioning 241 img_cond_seq_ids: Tensor | None = None 242 device: torch.device | None = None 243 initial_layer_dropout: List[int] | None = None 244 timestep_to_start_cfg: int = 0 245 image_proj: Tensor | None = None 246 neg_image_proj: Tensor | None = None 247 ip_scale: Tensor | None = None 248 neg_ip_scale: Tensor | None = None 249 250 251@dataclass 252class InferenceStateFlux2: 253 """Configuration for simple (non-interactive) Flux2 denoising tensor operations.""" 254 255 model: Any # Flux2 256 img: Tensor 257 img_ids: Tensor 258 txt: Tensor 259 txt_ids: Tensor 260 timesteps: List[float] 261 guidance: float = 4.0 262 img_cond_seq: Tensor | None = None 263 img_cond_seq_ids: Tensor | None = None
14@dataclass 15class StepState: 16 """Runtime state that changes at each denoising step.""" 17 18 current_timestep: float 19 current_sample: torch.Tensor 20 timestep_index: int 21 total_timesteps: int 22 previous_timestep: float | None = None 23 24 def with_runtime_state( 25 self, 26 current_timestep: float | None = None, 27 current_sample: torch.Tensor | None = None, 28 timestep_index: int | None = None, 29 total_timesteps: int | None = None, 30 previous_timestep: float | None = None, 31 ) -> "StepState": 32 """Create a new class with updated runtime fields. 33 34 :param current_timestep: Current timestep value (if None, keeps existing) 35 :param current_sample: Current sample tensor (if None, keeps existing) 36 :param timestep_index: Current timestep index (if None, keeps existing) 37 :param total_timesteps: Total number of timesteps (if None, keeps existing) 38 :param previous_timestep: Previous timestep value (if None, keeps existing) 39 :returns: New class with updated fields""" 40 return replace( 41 self, 42 current_timestep=current_timestep if current_timestep is not None else self.current_timestep, 43 current_sample=current_sample if current_sample is not None else self.current_sample, 44 timestep_index=timestep_index if timestep_index is not None else self.timestep_index, 45 total_timesteps=total_timesteps if total_timesteps is not None else self.total_timesteps, 46 previous_timestep=previous_timestep if previous_timestep is not None else self.previous_timestep, 47 )
Runtime state that changes at each denoising step.
24 def with_runtime_state( 25 self, 26 current_timestep: float | None = None, 27 current_sample: torch.Tensor | None = None, 28 timestep_index: int | None = None, 29 total_timesteps: int | None = None, 30 previous_timestep: float | None = None, 31 ) -> "StepState": 32 """Create a new class with updated runtime fields. 33 34 :param current_timestep: Current timestep value (if None, keeps existing) 35 :param current_sample: Current sample tensor (if None, keeps existing) 36 :param timestep_index: Current timestep index (if None, keeps existing) 37 :param total_timesteps: Total number of timesteps (if None, keeps existing) 38 :param previous_timestep: Previous timestep value (if None, keeps existing) 39 :returns: New class with updated fields""" 40 return replace( 41 self, 42 current_timestep=current_timestep if current_timestep is not None else self.current_timestep, 43 current_sample=current_sample if current_sample is not None else self.current_sample, 44 timestep_index=timestep_index if timestep_index is not None else self.timestep_index, 45 total_timesteps=total_timesteps if total_timesteps is not None else self.total_timesteps, 46 previous_timestep=previous_timestep if previous_timestep is not None else self.previous_timestep, 47 )
Create a new class with updated runtime fields.
Parameters
- current_timestep: Current timestep value (if None, keeps existing)
- current_sample: Current sample tensor (if None, keeps existing)
- timestep_index: Current timestep index (if None, keeps existing)
- total_timesteps: Total number of timesteps (if None, keeps existing)
- previous_timestep: Previous timestep value (if None, keeps existing) :returns: New class with updated fields
50@dataclass 51class MenuState: 52 """State of the input at a given timestep.\n 53 This is the single source of truth for denoising configuration and runtime state. 54 Use from_cli_args() to create from command-line arguments, and with_runtime_state() 55 to update runtime fields during denoising.""" 56 57 step_state: StepState # Runtime state (changes every step) 58 guidance: float # Configuration (set at start, may change via controller) 59 num_steps: int 60 deterministic: bool = False 61 height: int = 1024 62 layer_dropout: List[int] | None = None 63 neg_prompt: str | None = None 64 prompt: str | None = None 65 seed: int = 0 66 use_previous_as_mask: bool = False 67 vae_scale_offset: float = 0.0 68 vae_shift_offset: float = 0.0 69 variation_seed: int = 0 70 variation_strength: float = 0.0 71 width: int = 1024 72 73 # Convenience properties for backward compatibility 74 @property 75 def current_timestep(self) -> float: 76 """Current timestep value.""" 77 return self.step_state.current_timestep 78 79 @property 80 def current_sample(self) -> torch.Tensor: 81 """Current sample tensor.""" 82 return self.step_state.current_sample 83 84 @property 85 def timestep_index(self) -> int: 86 """Current timestep index.""" 87 return self.step_state.timestep_index 88 89 @property 90 def total_timesteps(self) -> int: 91 """Total number of timesteps.""" 92 return self.step_state.total_timesteps 93 94 @property 95 def previous_timestep(self) -> float | None: 96 """Previous timestep value.""" 97 return self.step_state.previous_timestep 98 99 @classmethod 100 def from_cli_args( 101 cls, 102 prompt: str, 103 width: int, 104 height: int, 105 num_steps: int, 106 guidance: float, 107 seed: int = 0, 108 neg_prompt: str | None = None, 109 current_sample: torch.Tensor | None = None, 110 timesteps: List[float] | None = None, 111 **kwargs, 112 ) -> "MenuState": 113 """Create input state from CLI arguments. 114 115 :param prompt: Text prompt 116 :param width: Image width 117 :param height: Image height 118 :param num_steps: Number of denoising steps 119 :param guidance: Guidance scale 120 :param seed: Random seed 121 :param neg_prompt: Negative prompt 122 :param current_sample: Initial sample tensor (if available, otherwise empty tensor) 123 :param timesteps: List of timesteps (if available, used to determine total_timesteps) 124 :param kwargs: Additional fields to set (e.g., deterministic, vae_shift_offset, etc.) 125 :returns: Input state initialized from CLI args""" 126 total_timesteps = len(timesteps) if timesteps else num_steps 127 128 step_state = StepState( 129 current_timestep=0.0, 130 previous_timestep=None, 131 current_sample=current_sample if current_sample is not None else torch.empty(0), 132 timestep_index=0, 133 total_timesteps=total_timesteps, 134 ) 135 136 return cls( 137 step_state=step_state, 138 guidance=guidance, 139 width=width, 140 height=height, 141 seed=seed, 142 prompt=prompt, 143 num_steps=num_steps, 144 neg_prompt=neg_prompt, 145 deterministic=bool(torch.get_deterministic_debug_mode()) if "deterministic" not in kwargs else kwargs.pop("deterministic"), 146 **kwargs, 147 ) 148 149 def with_runtime_state( 150 self, 151 current_timestep: float, 152 current_sample: torch.Tensor, 153 timestep_index: int, 154 total_timesteps: int | None = None, 155 previous_timestep: float | None = None, 156 ) -> "MenuState": 157 """Create a new input state with updated runtime state and timestep state.\n 158 :param current_timestep: Current timestep value 159 :param current_sample: Current sample tensor 160 :param timestep_index: Current timestep index 161 :param total_timesteps: Total number of timesteps (if changed) 162 :param previous_timestep: Previous timestep value 163 :returns: New input state with updated runtime fields""" 164 new_timestep = self.step_state.with_runtime_state( 165 current_timestep=current_timestep, 166 current_sample=current_sample, 167 timestep_index=timestep_index, 168 total_timesteps=total_timesteps, 169 previous_timestep=previous_timestep, 170 ) 171 return replace(self, step_state=new_timestep)
State of the input at a given timestep.
This is the single source of truth for denoising configuration and runtime state. Use from_cli_args() to create from command-line arguments, and with_runtime_state() to update runtime fields during denoising.
74 @property 75 def current_timestep(self) -> float: 76 """Current timestep value.""" 77 return self.step_state.current_timestep
Current timestep value.
79 @property 80 def current_sample(self) -> torch.Tensor: 81 """Current sample tensor.""" 82 return self.step_state.current_sample
Current sample tensor.
84 @property 85 def timestep_index(self) -> int: 86 """Current timestep index.""" 87 return self.step_state.timestep_index
Current timestep index.
89 @property 90 def total_timesteps(self) -> int: 91 """Total number of timesteps.""" 92 return self.step_state.total_timesteps
Total number of timesteps.
94 @property 95 def previous_timestep(self) -> float | None: 96 """Previous timestep value.""" 97 return self.step_state.previous_timestep
Previous timestep value.
99 @classmethod 100 def from_cli_args( 101 cls, 102 prompt: str, 103 width: int, 104 height: int, 105 num_steps: int, 106 guidance: float, 107 seed: int = 0, 108 neg_prompt: str | None = None, 109 current_sample: torch.Tensor | None = None, 110 timesteps: List[float] | None = None, 111 **kwargs, 112 ) -> "MenuState": 113 """Create input state from CLI arguments. 114 115 :param prompt: Text prompt 116 :param width: Image width 117 :param height: Image height 118 :param num_steps: Number of denoising steps 119 :param guidance: Guidance scale 120 :param seed: Random seed 121 :param neg_prompt: Negative prompt 122 :param current_sample: Initial sample tensor (if available, otherwise empty tensor) 123 :param timesteps: List of timesteps (if available, used to determine total_timesteps) 124 :param kwargs: Additional fields to set (e.g., deterministic, vae_shift_offset, etc.) 125 :returns: Input state initialized from CLI args""" 126 total_timesteps = len(timesteps) if timesteps else num_steps 127 128 step_state = StepState( 129 current_timestep=0.0, 130 previous_timestep=None, 131 current_sample=current_sample if current_sample is not None else torch.empty(0), 132 timestep_index=0, 133 total_timesteps=total_timesteps, 134 ) 135 136 return cls( 137 step_state=step_state, 138 guidance=guidance, 139 width=width, 140 height=height, 141 seed=seed, 142 prompt=prompt, 143 num_steps=num_steps, 144 neg_prompt=neg_prompt, 145 deterministic=bool(torch.get_deterministic_debug_mode()) if "deterministic" not in kwargs else kwargs.pop("deterministic"), 146 **kwargs, 147 )
Create input state from CLI arguments.
Parameters
- prompt: Text prompt
- width: Image width
- height: Image height
- num_steps: Number of denoising steps
- guidance: Guidance scale
- seed: Random seed
- neg_prompt: Negative prompt
- current_sample: Initial sample tensor (if available, otherwise empty tensor)
- timesteps: List of timesteps (if available, used to determine total_timesteps)
- kwargs: Additional fields to set (e.g., deterministic, vae_shift_offset, etc.) :returns: Input state initialized from CLI args
149 def with_runtime_state( 150 self, 151 current_timestep: float, 152 current_sample: torch.Tensor, 153 timestep_index: int, 154 total_timesteps: int | None = None, 155 previous_timestep: float | None = None, 156 ) -> "MenuState": 157 """Create a new input state with updated runtime state and timestep state.\n 158 :param current_timestep: Current timestep value 159 :param current_sample: Current sample tensor 160 :param timestep_index: Current timestep index 161 :param total_timesteps: Total number of timesteps (if changed) 162 :param previous_timestep: Previous timestep value 163 :returns: New input state with updated runtime fields""" 164 new_timestep = self.step_state.with_runtime_state( 165 current_timestep=current_timestep, 166 current_sample=current_sample, 167 timestep_index=timestep_index, 168 total_timesteps=total_timesteps, 169 previous_timestep=previous_timestep, 170 ) 171 return replace(self, step_state=new_timestep)
Create a new input state with updated runtime state and timestep state.
Parameters
- current_timestep: Current timestep value
- current_sample: Current sample tensor
- timestep_index: Current timestep index
- total_timesteps: Total number of timesteps (if changed)
- previous_timestep: Previous timestep value :returns: New input state with updated runtime fields
174@dataclass 175class ImageEmbeddingState: 176 """Image-related configuration for prompt function creation.""" 177 178 img_ids: Tensor 179 img: Tensor 180 img_cond: Tensor | None = None 181 img_cond_seq: Tensor | None = None 182 img_cond_seq_ids: Tensor | None = None 183 image_proj: Tensor | None = None 184 neg_image_proj: Tensor | None = None 185 ip_scale: Tensor | None = None 186 neg_ip_scale: Tensor | None = None
Image-related configuration for prompt function creation.
189@dataclass 190class TextEmbeddingState: 191 """Base configuration class for text prompt function creation.""" 192 193 model_ref: List[Any] 194 state: Any 195 current_txt: List[Tensor] 196 current_txt_ids: List[Tensor] 197 current_vec: List[Tensor] 198 cached_prediction: List[Optional[Tensor]] = field(default_factory=lambda: [None]) 199 cached_prediction_state: List[Optional[Dict]] = field(default_factory=lambda: [None]) 200 neg_pred_enabled: bool = False 201 current_neg_txt: Tensor | str | None = "" 202 current_neg_txt_ids: Tensor | str | None = "" 203 current_neg_vec: Tensor | str | None = "" 204 true_gs: float | None = None
Base configuration class for text prompt function creation.
207@dataclass 208class StepStateXFlux1: 209 """Additional configuration for XFlux1-specific prediction settings.""" 210 211 timestep_to_start_cfg: int 212 current_timestep_index: List[int]
Additional configuration for XFlux1-specific prediction settings.
215@dataclass 216class InferenceState: 217 """Base configuration class for denoise function parameters.\n 218 Single source of truth for tensor-related operations""" 219 220 # Model and core inputs (required) 221 img: Tensor 222 img_ids: Tensor 223 txt: Tensor 224 txt_ids: Tensor 225 state: MenuState 226 ae: Any # AutoEncoder 227 timesteps: List[float] 228 vec: Tensor | None = None # CLIP embeddings (Flux1/XFlux1) 229 neg_pred_enabled: bool = False 230 neg_txt: Tensor | None = None 231 neg_txt_ids: Tensor | None = None 232 neg_vec: Tensor | None = None 233 true_gs: float = 1.0 234 235 # Text embedders for prompt changes 236 t5: Any | None = None # T5 embedder (Flux1/XFlux1) 237 clip: Any | None = None # CLIP embedder (Flux1/XFlux1) 238 text_embedder: Any | None = None # Mistral embedder (Flux2) 239 240 img_cond: Tensor | None = None # Channel-wise image conditioning (Flux1 only) 241 img_cond_seq: Tensor | None = None # Sequence-wise image conditioning 242 img_cond_seq_ids: Tensor | None = None 243 device: torch.device | None = None 244 initial_layer_dropout: List[int] | None = None 245 timestep_to_start_cfg: int = 0 246 image_proj: Tensor | None = None 247 neg_image_proj: Tensor | None = None 248 ip_scale: Tensor | None = None 249 neg_ip_scale: Tensor | None = None
Base configuration class for denoise function parameters.
Single source of truth for tensor-related operations
252@dataclass 253class InferenceStateFlux2: 254 """Configuration for simple (non-interactive) Flux2 denoising tensor operations.""" 255 256 model: Any # Flux2 257 img: Tensor 258 img_ids: Tensor 259 txt: Tensor 260 txt_ids: Tensor 261 timesteps: List[float] 262 guidance: float = 4.0 263 img_cond_seq: Tensor | None = None 264 img_cond_seq_ids: Tensor | None = None
Configuration for simple (non-interactive) Flux2 denoising tensor operations.