divisor.state

Divisor class definitions and configuration dataclasses.

  1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
  2# <!-- // /*  d a r k s h a p e s */ -->
  3
  4"""Divisor class definitions and configuration dataclasses."""
  5
  6from dataclasses import dataclass, field, replace
  7from typing import Any, Dict, List, Optional
  8
  9import torch
 10from torch import Tensor
 11
 12
 13@dataclass
 14class StepState:
 15    """Runtime state that changes at each denoising step."""
 16
 17    current_timestep: float
 18    current_sample: torch.Tensor
 19    timestep_index: int
 20    total_timesteps: int
 21    previous_timestep: float | None = None
 22
 23    def with_runtime_state(
 24        self,
 25        current_timestep: float | None = None,
 26        current_sample: torch.Tensor | None = None,
 27        timestep_index: int | None = None,
 28        total_timesteps: int | None = None,
 29        previous_timestep: float | None = None,
 30    ) -> "StepState":
 31        """Create a new class with updated runtime fields.
 32
 33        :param current_timestep: Current timestep value (if None, keeps existing)
 34        :param current_sample: Current sample tensor (if None, keeps existing)
 35        :param timestep_index: Current timestep index (if None, keeps existing)
 36        :param total_timesteps: Total number of timesteps (if None, keeps existing)
 37        :param previous_timestep: Previous timestep value (if None, keeps existing)
 38        :returns: New class with updated fields"""
 39        return replace(
 40            self,
 41            current_timestep=current_timestep if current_timestep is not None else self.current_timestep,
 42            current_sample=current_sample if current_sample is not None else self.current_sample,
 43            timestep_index=timestep_index if timestep_index is not None else self.timestep_index,
 44            total_timesteps=total_timesteps if total_timesteps is not None else self.total_timesteps,
 45            previous_timestep=previous_timestep if previous_timestep is not None else self.previous_timestep,
 46        )
 47
 48
 49@dataclass
 50class MenuState:
 51    """State of the input at a given timestep.\n
 52    This is the single source of truth for denoising configuration and runtime state.
 53    Use from_cli_args() to create from command-line arguments, and with_runtime_state()
 54    to update runtime fields during denoising."""
 55
 56    step_state: StepState  # Runtime state (changes every step)
 57    guidance: float  # Configuration (set at start, may change via controller)
 58    num_steps: int
 59    deterministic: bool = False
 60    height: int = 1024
 61    layer_dropout: List[int] | None = None
 62    neg_prompt: str | None = None
 63    prompt: str | None = None
 64    seed: int = 0
 65    use_previous_as_mask: bool = False
 66    vae_scale_offset: float = 0.0
 67    vae_shift_offset: float = 0.0
 68    variation_seed: int = 0
 69    variation_strength: float = 0.0
 70    width: int = 1024
 71
 72    # Convenience properties for backward compatibility
 73    @property
 74    def current_timestep(self) -> float:
 75        """Current timestep value."""
 76        return self.step_state.current_timestep
 77
 78    @property
 79    def current_sample(self) -> torch.Tensor:
 80        """Current sample tensor."""
 81        return self.step_state.current_sample
 82
 83    @property
 84    def timestep_index(self) -> int:
 85        """Current timestep index."""
 86        return self.step_state.timestep_index
 87
 88    @property
 89    def total_timesteps(self) -> int:
 90        """Total number of timesteps."""
 91        return self.step_state.total_timesteps
 92
 93    @property
 94    def previous_timestep(self) -> float | None:
 95        """Previous timestep value."""
 96        return self.step_state.previous_timestep
 97
 98    @classmethod
 99    def from_cli_args(
100        cls,
101        prompt: str,
102        width: int,
103        height: int,
104        num_steps: int,
105        guidance: float,
106        seed: int = 0,
107        neg_prompt: str | None = None,
108        current_sample: torch.Tensor | None = None,
109        timesteps: List[float] | None = None,
110        **kwargs,
111    ) -> "MenuState":
112        """Create input state from CLI arguments.
113
114        :param prompt: Text prompt
115        :param width: Image width
116        :param height: Image height
117        :param num_steps: Number of denoising steps
118        :param guidance: Guidance scale
119        :param seed: Random seed
120        :param neg_prompt: Negative prompt
121        :param current_sample: Initial sample tensor (if available, otherwise empty tensor)
122        :param timesteps: List of timesteps (if available, used to determine total_timesteps)
123        :param kwargs: Additional fields to set (e.g., deterministic, vae_shift_offset, etc.)
124        :returns: Input state initialized from CLI args"""
125        total_timesteps = len(timesteps) if timesteps else num_steps
126
127        step_state = StepState(
128            current_timestep=0.0,
129            previous_timestep=None,
130            current_sample=current_sample if current_sample is not None else torch.empty(0),
131            timestep_index=0,
132            total_timesteps=total_timesteps,
133        )
134
135        return cls(
136            step_state=step_state,
137            guidance=guidance,
138            width=width,
139            height=height,
140            seed=seed,
141            prompt=prompt,
142            num_steps=num_steps,
143            neg_prompt=neg_prompt,
144            deterministic=bool(torch.get_deterministic_debug_mode()) if "deterministic" not in kwargs else kwargs.pop("deterministic"),
145            **kwargs,
146        )
147
148    def with_runtime_state(
149        self,
150        current_timestep: float,
151        current_sample: torch.Tensor,
152        timestep_index: int,
153        total_timesteps: int | None = None,
154        previous_timestep: float | None = None,
155    ) -> "MenuState":
156        """Create a new input state with updated runtime state and timestep state.\n
157        :param current_timestep: Current timestep value
158        :param current_sample: Current sample tensor
159        :param timestep_index: Current timestep index
160        :param total_timesteps: Total number of timesteps (if changed)
161        :param previous_timestep: Previous timestep value
162        :returns: New input state with updated runtime fields"""
163        new_timestep = self.step_state.with_runtime_state(
164            current_timestep=current_timestep,
165            current_sample=current_sample,
166            timestep_index=timestep_index,
167            total_timesteps=total_timesteps,
168            previous_timestep=previous_timestep,
169        )
170        return replace(self, step_state=new_timestep)
171
172
173@dataclass
174class ImageEmbeddingState:
175    """Image-related configuration for prompt function creation."""
176
177    img_ids: Tensor
178    img: Tensor
179    img_cond: Tensor | None = None
180    img_cond_seq: Tensor | None = None
181    img_cond_seq_ids: Tensor | None = None
182    image_proj: Tensor | None = None
183    neg_image_proj: Tensor | None = None
184    ip_scale: Tensor | None = None
185    neg_ip_scale: Tensor | None = None
186
187
188@dataclass
189class TextEmbeddingState:
190    """Base configuration class for text prompt function creation."""
191
192    model_ref: List[Any]
193    state: Any
194    current_txt: List[Tensor]
195    current_txt_ids: List[Tensor]
196    current_vec: List[Tensor]
197    cached_prediction: List[Optional[Tensor]] = field(default_factory=lambda: [None])
198    cached_prediction_state: List[Optional[Dict]] = field(default_factory=lambda: [None])
199    neg_pred_enabled: bool = False
200    current_neg_txt: Tensor | str | None = ""
201    current_neg_txt_ids: Tensor | str | None = ""
202    current_neg_vec: Tensor | str | None = ""
203    true_gs: float | None = None
204
205
206@dataclass
207class StepStateXFlux1:
208    """Additional configuration for XFlux1-specific prediction settings."""
209
210    timestep_to_start_cfg: int
211    current_timestep_index: List[int]
212
213
214@dataclass
215class InferenceState:
216    """Base configuration class for denoise function parameters.\n
217    Single source of truth for tensor-related operations"""
218
219    # Model and core inputs (required)
220    img: Tensor
221    img_ids: Tensor
222    txt: Tensor
223    txt_ids: Tensor
224    state: MenuState
225    ae: Any  # AutoEncoder
226    timesteps: List[float]
227    vec: Tensor | None = None  # CLIP embeddings (Flux1/XFlux1)
228    neg_pred_enabled: bool = False
229    neg_txt: Tensor | None = None
230    neg_txt_ids: Tensor | None = None
231    neg_vec: Tensor | None = None
232    true_gs: float = 1.0
233
234    # Text embedders for prompt changes
235    t5: Any | None = None  # T5 embedder (Flux1/XFlux1)
236    clip: Any | None = None  # CLIP embedder (Flux1/XFlux1)
237    text_embedder: Any | None = None  # Mistral embedder (Flux2)
238
239    img_cond: Tensor | None = None  # Channel-wise image conditioning (Flux1 only)
240    img_cond_seq: Tensor | None = None  # Sequence-wise image conditioning
241    img_cond_seq_ids: Tensor | None = None
242    device: torch.device | None = None
243    initial_layer_dropout: List[int] | None = None
244    timestep_to_start_cfg: int = 0
245    image_proj: Tensor | None = None
246    neg_image_proj: Tensor | None = None
247    ip_scale: Tensor | None = None
248    neg_ip_scale: Tensor | None = None
249
250
251@dataclass
252class InferenceStateFlux2:
253    """Configuration for simple (non-interactive) Flux2 denoising tensor operations."""
254
255    model: Any  # Flux2
256    img: Tensor
257    img_ids: Tensor
258    txt: Tensor
259    txt_ids: Tensor
260    timesteps: List[float]
261    guidance: float = 4.0
262    img_cond_seq: Tensor | None = None
263    img_cond_seq_ids: Tensor | None = None
@dataclass
class StepState:
14@dataclass
15class StepState:
16    """Runtime state that changes at each denoising step."""
17
18    current_timestep: float
19    current_sample: torch.Tensor
20    timestep_index: int
21    total_timesteps: int
22    previous_timestep: float | None = None
23
24    def with_runtime_state(
25        self,
26        current_timestep: float | None = None,
27        current_sample: torch.Tensor | None = None,
28        timestep_index: int | None = None,
29        total_timesteps: int | None = None,
30        previous_timestep: float | None = None,
31    ) -> "StepState":
32        """Create a new class with updated runtime fields.
33
34        :param current_timestep: Current timestep value (if None, keeps existing)
35        :param current_sample: Current sample tensor (if None, keeps existing)
36        :param timestep_index: Current timestep index (if None, keeps existing)
37        :param total_timesteps: Total number of timesteps (if None, keeps existing)
38        :param previous_timestep: Previous timestep value (if None, keeps existing)
39        :returns: New class with updated fields"""
40        return replace(
41            self,
42            current_timestep=current_timestep if current_timestep is not None else self.current_timestep,
43            current_sample=current_sample if current_sample is not None else self.current_sample,
44            timestep_index=timestep_index if timestep_index is not None else self.timestep_index,
45            total_timesteps=total_timesteps if total_timesteps is not None else self.total_timesteps,
46            previous_timestep=previous_timestep if previous_timestep is not None else self.previous_timestep,
47        )

Runtime state that changes at each denoising step.

StepState( current_timestep: float, current_sample: torch.Tensor, timestep_index: int, total_timesteps: int, previous_timestep: float | None = None)
current_timestep: float
current_sample: torch.Tensor
timestep_index: int
total_timesteps: int
previous_timestep: float | None = None
def with_runtime_state( self, current_timestep: float | None = None, current_sample: torch.Tensor | None = None, timestep_index: int | None = None, total_timesteps: int | None = None, previous_timestep: float | None = None) -> StepState:
24    def with_runtime_state(
25        self,
26        current_timestep: float | None = None,
27        current_sample: torch.Tensor | None = None,
28        timestep_index: int | None = None,
29        total_timesteps: int | None = None,
30        previous_timestep: float | None = None,
31    ) -> "StepState":
32        """Create a new class with updated runtime fields.
33
34        :param current_timestep: Current timestep value (if None, keeps existing)
35        :param current_sample: Current sample tensor (if None, keeps existing)
36        :param timestep_index: Current timestep index (if None, keeps existing)
37        :param total_timesteps: Total number of timesteps (if None, keeps existing)
38        :param previous_timestep: Previous timestep value (if None, keeps existing)
39        :returns: New class with updated fields"""
40        return replace(
41            self,
42            current_timestep=current_timestep if current_timestep is not None else self.current_timestep,
43            current_sample=current_sample if current_sample is not None else self.current_sample,
44            timestep_index=timestep_index if timestep_index is not None else self.timestep_index,
45            total_timesteps=total_timesteps if total_timesteps is not None else self.total_timesteps,
46            previous_timestep=previous_timestep if previous_timestep is not None else self.previous_timestep,
47        )

Create a new class with updated runtime fields.

Parameters
  • current_timestep: Current timestep value (if None, keeps existing)
  • current_sample: Current sample tensor (if None, keeps existing)
  • timestep_index: Current timestep index (if None, keeps existing)
  • total_timesteps: Total number of timesteps (if None, keeps existing)
  • previous_timestep: Previous timestep value (if None, keeps existing) :returns: New class with updated fields
@dataclass
class ImageEmbeddingState:
174@dataclass
175class ImageEmbeddingState:
176    """Image-related configuration for prompt function creation."""
177
178    img_ids: Tensor
179    img: Tensor
180    img_cond: Tensor | None = None
181    img_cond_seq: Tensor | None = None
182    img_cond_seq_ids: Tensor | None = None
183    image_proj: Tensor | None = None
184    neg_image_proj: Tensor | None = None
185    ip_scale: Tensor | None = None
186    neg_ip_scale: Tensor | None = None

Image-related configuration for prompt function creation.

ImageEmbeddingState( img_ids: torch.Tensor, img: torch.Tensor, img_cond: torch.Tensor | None = None, img_cond_seq: torch.Tensor | None = None, img_cond_seq_ids: torch.Tensor | None = None, image_proj: torch.Tensor | None = None, neg_image_proj: torch.Tensor | None = None, ip_scale: torch.Tensor | None = None, neg_ip_scale: torch.Tensor | None = None)
img_ids: torch.Tensor
img: torch.Tensor
img_cond: torch.Tensor | None = None
img_cond_seq: torch.Tensor | None = None
img_cond_seq_ids: torch.Tensor | None = None
image_proj: torch.Tensor | None = None
neg_image_proj: torch.Tensor | None = None
ip_scale: torch.Tensor | None = None
neg_ip_scale: torch.Tensor | None = None
@dataclass
class TextEmbeddingState:
189@dataclass
190class TextEmbeddingState:
191    """Base configuration class for text prompt function creation."""
192
193    model_ref: List[Any]
194    state: Any
195    current_txt: List[Tensor]
196    current_txt_ids: List[Tensor]
197    current_vec: List[Tensor]
198    cached_prediction: List[Optional[Tensor]] = field(default_factory=lambda: [None])
199    cached_prediction_state: List[Optional[Dict]] = field(default_factory=lambda: [None])
200    neg_pred_enabled: bool = False
201    current_neg_txt: Tensor | str | None = ""
202    current_neg_txt_ids: Tensor | str | None = ""
203    current_neg_vec: Tensor | str | None = ""
204    true_gs: float | None = None

Base configuration class for text prompt function creation.

TextEmbeddingState( model_ref: List[Any], state: Any, current_txt: List[torch.Tensor], current_txt_ids: List[torch.Tensor], current_vec: List[torch.Tensor], cached_prediction: List[Optional[torch.Tensor]] = <factory>, cached_prediction_state: List[Optional[Dict]] = <factory>, neg_pred_enabled: bool = False, current_neg_txt: torch.Tensor | str | None = '', current_neg_txt_ids: torch.Tensor | str | None = '', current_neg_vec: torch.Tensor | str | None = '', true_gs: float | None = None)
model_ref: List[Any]
state: Any
current_txt: List[torch.Tensor]
current_txt_ids: List[torch.Tensor]
current_vec: List[torch.Tensor]
cached_prediction: List[Optional[torch.Tensor]]
cached_prediction_state: List[Optional[Dict]]
neg_pred_enabled: bool = False
current_neg_txt: torch.Tensor | str | None = ''
current_neg_txt_ids: torch.Tensor | str | None = ''
current_neg_vec: torch.Tensor | str | None = ''
true_gs: float | None = None
@dataclass
class StepStateXFlux1:
207@dataclass
208class StepStateXFlux1:
209    """Additional configuration for XFlux1-specific prediction settings."""
210
211    timestep_to_start_cfg: int
212    current_timestep_index: List[int]

Additional configuration for XFlux1-specific prediction settings.

StepStateXFlux1(timestep_to_start_cfg: int, current_timestep_index: List[int])
timestep_to_start_cfg: int
current_timestep_index: List[int]
@dataclass
class InferenceState:
215@dataclass
216class InferenceState:
217    """Base configuration class for denoise function parameters.\n
218    Single source of truth for tensor-related operations"""
219
220    # Model and core inputs (required)
221    img: Tensor
222    img_ids: Tensor
223    txt: Tensor
224    txt_ids: Tensor
225    state: MenuState
226    ae: Any  # AutoEncoder
227    timesteps: List[float]
228    vec: Tensor | None = None  # CLIP embeddings (Flux1/XFlux1)
229    neg_pred_enabled: bool = False
230    neg_txt: Tensor | None = None
231    neg_txt_ids: Tensor | None = None
232    neg_vec: Tensor | None = None
233    true_gs: float = 1.0
234
235    # Text embedders for prompt changes
236    t5: Any | None = None  # T5 embedder (Flux1/XFlux1)
237    clip: Any | None = None  # CLIP embedder (Flux1/XFlux1)
238    text_embedder: Any | None = None  # Mistral embedder (Flux2)
239
240    img_cond: Tensor | None = None  # Channel-wise image conditioning (Flux1 only)
241    img_cond_seq: Tensor | None = None  # Sequence-wise image conditioning
242    img_cond_seq_ids: Tensor | None = None
243    device: torch.device | None = None
244    initial_layer_dropout: List[int] | None = None
245    timestep_to_start_cfg: int = 0
246    image_proj: Tensor | None = None
247    neg_image_proj: Tensor | None = None
248    ip_scale: Tensor | None = None
249    neg_ip_scale: Tensor | None = None

Base configuration class for denoise function parameters.

Single source of truth for tensor-related operations

InferenceState( img: torch.Tensor, img_ids: torch.Tensor, txt: torch.Tensor, txt_ids: torch.Tensor, state: MenuState, ae: Any, timesteps: List[float], vec: torch.Tensor | None = None, neg_pred_enabled: bool = False, neg_txt: torch.Tensor | None = None, neg_txt_ids: torch.Tensor | None = None, neg_vec: torch.Tensor | None = None, true_gs: float = 1.0, t5: typing.Any | None = None, clip: typing.Any | None = None, text_embedder: typing.Any | None = None, img_cond: torch.Tensor | None = None, img_cond_seq: torch.Tensor | None = None, img_cond_seq_ids: torch.Tensor | None = None, device: torch.device | None = None, initial_layer_dropout: Optional[List[int]] = None, timestep_to_start_cfg: int = 0, image_proj: torch.Tensor | None = None, neg_image_proj: torch.Tensor | None = None, ip_scale: torch.Tensor | None = None, neg_ip_scale: torch.Tensor | None = None)
img: torch.Tensor
img_ids: torch.Tensor
txt: torch.Tensor
txt_ids: torch.Tensor
state: MenuState
ae: Any
timesteps: List[float]
vec: torch.Tensor | None = None
neg_pred_enabled: bool = False
neg_txt: torch.Tensor | None = None
neg_txt_ids: torch.Tensor | None = None
neg_vec: torch.Tensor | None = None
true_gs: float = 1.0
t5: typing.Any | None = None
clip: typing.Any | None = None
text_embedder: typing.Any | None = None
img_cond: torch.Tensor | None = None
img_cond_seq: torch.Tensor | None = None
img_cond_seq_ids: torch.Tensor | None = None
device: torch.device | None = None
initial_layer_dropout: Optional[List[int]] = None
timestep_to_start_cfg: int = 0
image_proj: torch.Tensor | None = None
neg_image_proj: torch.Tensor | None = None
ip_scale: torch.Tensor | None = None
neg_ip_scale: torch.Tensor | None = None
@dataclass
class InferenceStateFlux2:
252@dataclass
253class InferenceStateFlux2:
254    """Configuration for simple (non-interactive) Flux2 denoising tensor operations."""
255
256    model: Any  # Flux2
257    img: Tensor
258    img_ids: Tensor
259    txt: Tensor
260    txt_ids: Tensor
261    timesteps: List[float]
262    guidance: float = 4.0
263    img_cond_seq: Tensor | None = None
264    img_cond_seq_ids: Tensor | None = None

Configuration for simple (non-interactive) Flux2 denoising tensor operations.

InferenceStateFlux2( model: Any, img: torch.Tensor, img_ids: torch.Tensor, txt: torch.Tensor, txt_ids: torch.Tensor, timesteps: List[float], guidance: float = 4.0, img_cond_seq: torch.Tensor | None = None, img_cond_seq_ids: torch.Tensor | None = None)
model: Any
img: torch.Tensor
img_ids: torch.Tensor
txt: torch.Tensor
txt_ids: torch.Tensor
timesteps: List[float]
guidance: float = 4.0
img_cond_seq: torch.Tensor | None = None
img_cond_seq_ids: torch.Tensor | None = None