divisor.spec

  1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
  2# <!-- // /*  d a r k s h a p e s */ -->
  3# adapted BFL Flux code from https://github.com/black-forest-labs/flux and  https://github.com/Gen-Verse/MMaDA
  4
  5
  6from dataclasses import replace
  7from dataclasses import dataclass
  8from typing import Any
  9
 10from nnll.console import nfo
 11import torch
 12
 13from divisor.registry import build_available_models
 14from divisor.flux1.autoencoder import AutoEncoderParams as AutoEncoder1Params
 15from divisor.flux1.model import FluxLoraWrapper, FluxParams
 16from divisor.flux2.autoencoder import AutoEncoderParams as AutoEncoder2Params
 17from divisor.flux2.model import Flux2Params
 18from divisor.xflux1.model import XFluxParams
 19from divisor.mmada.modeling_mmada import MMadaConfig as MMaDAParams
 20
 21
 22@dataclass
 23class CompatibilitySpec:
 24    repo_id: str
 25    file_name: str
 26
 27
 28@dataclass
 29class InitialParamsFlux:
 30    num_steps: int
 31    max_length: int
 32    guidance: float
 33    shift: bool
 34    width: int = 1360
 35    height: int = 768
 36
 37
 38@dataclass
 39class InitialParamsMMaDA:
 40    """Default initialization parameters for MMaDA models."""
 41
 42    steps: int
 43    gen_length: int
 44    block_length: int
 45    temperature: float
 46    cfg_scale: float
 47    remasking_strategy: str
 48    mask_id: int
 49    max_position_embeddings: int
 50    max_text_len: int
 51
 52
 53@dataclass
 54class InitialParamsAceStep:
 55    infer_steps: int
 56    guidance_scale: float
 57    scheduler_type: str
 58    cfg_type: str
 59
 60
 61@dataclass
 62class AutoencoderTinyParams:
 63    """"""
 64
 65
 66@dataclass
 67class AceStepParams:
 68    """"""
 69
 70    attention_head_dim: int
 71    in_channels: int
 72    inner_dim: int
 73    max_position: int
 74    mlp_ratio: float
 75    num_attention_heads: int
 76    num_layers: int
 77    rope_theta: float
 78    speaker_embedding_dim: int
 79    text_embedding_dim: int
 80
 81
 82@dataclass
 83class ModelSpec:
 84    repo_id: str
 85    params: FluxParams | AutoEncoder1Params | XFluxParams | Flux2Params | MMaDAParams | AutoEncoder2Params | AutoencoderTinyParams | FluxLoraWrapper
 86    file_name: str
 87    init: InitialParamsFlux | InitialParamsMMaDA | None = None
 88
 89
 90flux_configs: dict[str, dict[str, ModelSpec | CompatibilitySpec]] = {
 91    "model.dit.flux1-dev": {
 92        "*": ModelSpec(
 93            repo_id="black-forest-labs/FLUX.1-dev",
 94            file_name="flux1-dev.safetensors",
 95            init=InitialParamsFlux(
 96                num_steps=28,
 97                max_length=512,
 98                guidance=4.0,
 99                shift=True,
100            ),
101            params=FluxParams(
102                in_channels=64,
103                vec_in_dim=768,
104                context_in_dim=4096,
105                hidden_size=3072,
106                mlp_ratio=4.0,
107                num_heads=24,
108                depth=19,
109                depth_single_blocks=38,
110                axes_dim=[16, 56, 56],
111                theta=10_000,
112                qkv_bias=True,
113                guidance_embed=True,
114            ),
115        ),
116        "@@fp8-e5m2-sai": CompatibilitySpec(
117            repo_id="Kijai/flux-fp8",
118            file_name="flux1-dev-fp8-e5m2.safetensors",
119        ),
120        "*@fp8-e4m3fn-sai": CompatibilitySpec(
121            repo_id="Kijai/flux-fp8",
122            file_name="flux1-dev-fp8-e4m3fn.safetensors",
123        ),
124        "*@fp8-sai": CompatibilitySpec(
125            repo_id="XLabs-AI/flux-dev-fp8",
126            file_name="flux-dev-fp8.safetensors",
127        ),
128        "mini": ModelSpec(
129            repo_id="TencentARC/flux-mini",
130            file_name="flux-mini.safetensors",
131            init=InitialParamsFlux(
132                num_steps=25,
133                max_length=512,
134                guidance=3.5,
135                shift=True,
136            ),
137            params=XFluxParams(
138                in_channels=64,
139                vec_in_dim=768,
140                context_in_dim=4096,
141                hidden_size=3072,
142                mlp_ratio=4.0,
143                num_heads=24,
144                depth=5,
145                depth_single_blocks=10,
146                axes_dim=[16, 56, 56],
147                theta=10_000,
148                qkv_bias=True,
149                guidance_embed=True,
150            ),
151        ),
152    },
153    "model.vae.flux1-dev": {
154        "*": ModelSpec(
155            repo_id="black-forest-labs/FLUX.1-dev",
156            file_name="ae.safetensors",
157            params=AutoEncoder1Params(
158                resolution=256,
159                in_channels=3,
160                ch=128,
161                out_ch=3,
162                ch_mult=[1, 2, 4, 4],
163                num_res_blocks=2,
164                z_channels=16,
165                scale_factor=0.3611,
166                shift_factor=0.1159,
167            ),
168        ),
169    },
170    "model.taesd.flux1-dev": {
171        "*": ModelSpec(repo_id="madebyollin/taef1", file_name="diffusion_pytorch_model.safetensors", params=AutoencoderTinyParams()),
172    },
173    "model.dit.flux1-schnell": {
174        "*": ModelSpec(
175            repo_id="black-forest-labs/FLUX.1-schnell",
176            file_name="flux1-schnell.safetensors",
177            init=InitialParamsFlux(
178                num_steps=4,
179                max_length=256,
180                guidance=2.5,
181                shift=False,
182            ),
183            params=FluxParams(
184                in_channels=64,
185                vec_in_dim=768,
186                context_in_dim=4096,
187                hidden_size=3072,
188                mlp_ratio=4.0,
189                num_heads=24,
190                depth=19,
191                depth_single_blocks=38,
192                axes_dim=[16, 56, 56],
193                theta=10_000,
194                qkv_bias=True,
195                guidance_embed=False,
196            ),
197        ),
198        "*@fp8-sai": CompatibilitySpec(
199            repo_id="Comfy-Org/flux1-schnell",
200            file_name="flux1-schnell-fp8.safetensors",
201        ),
202        "*@fp8-e4m3fn-sai": CompatibilitySpec(
203            repo_id="Kijai/flux-fp8",
204            file_name="flux1-schnell-fp8-e4m3fn.safetensors",
205        ),
206    },
207    "model.dit.flux2-dev": {
208        "*": ModelSpec(
209            repo_id="black-forest-labs/FLUX.2-dev",
210            file_name="flux2-dev.safetensors",
211            params=Flux2Params(),
212        ),
213        "*@fp8-sai": CompatibilitySpec(
214            repo_id="Comfy-Org/flux2-dev",
215            file_name="split_files/diffusion_models/flux2_dev_fp8mixed.safetensors",
216        ),
217    },
218    "model.vae.flux2-dev": {
219        "*": ModelSpec(
220            repo_id="black-forest-labs/FLUX.2-dev",
221            file_name="ae.safetensors",
222            params=AutoEncoder2Params(),
223        )
224    },
225}
226
227mmada_configs = {
228    "model.mldm.mmada": {
229        "*": ModelSpec(
230            repo_id="Gen-Verse/MMaDA-8B-Base",
231            file_name="model.safetensors",
232            init=InitialParamsMMaDA(
233                steps=256,
234                gen_length=512,
235                block_length=128,
236                temperature=1.0,
237                cfg_scale=0.0,
238                remasking_strategy="low_confidence",
239                mask_id=126336,
240                max_position_embeddings=2048,
241                max_text_len=512,
242            ),
243            params=MMaDAParams(
244                vocab_size=50257,
245                llm_vocab_size=50257,
246                llm_model_path="",
247                codebook_size=8192,
248                num_vq_tokens=1024,
249                num_new_special_tokens=0,
250            ),
251        ),
252        "mixcot": CompatibilitySpec(
253            repo_id="Gen-Verse/MMaDA-8B-MixCoT",
254            file_name="model.safetensors",
255        ),
256        "lumina-dimoo": ModelSpec(
257            repo_id="Alpha-VLLM/Lumina-DiMOO",
258            file_name="model.safetensors",
259            init=InitialParamsMMaDA(
260                steps=128,
261                gen_length=1024,
262                block_length=256,
263                temperature=0.0,
264                cfg_scale=0.0,
265                remasking_strategy="low_confidence",
266                mask_id=126336,
267                max_position_embeddings=2048,
268                max_text_len=512,
269            ),
270            params=MMaDAParams(
271                vocab_size=50257,
272                llm_vocab_size=50257,
273                llm_model_path="",
274                codebook_size=8192,
275                num_vq_tokens=1024,
276                num_new_special_tokens=0,
277            ),
278        ),
279    },
280}
281
282acestep_configs = {
283    "model.dit.acestep": {
284        "*": ModelSpec(
285            repo_id="ACE-Step/ACE-Step-v1-3.5B",
286            file_name="ace_step_transformer/diffusion_pytest_model.safetensors",
287            init=InitialParamsAceStep(
288                infer_steps=60,
289                guidance_scale=15.0,
290                scheduler_type="euler",
291                cfg_type="apg",
292            ),
293            params=AceStepParams(
294                attention_head_dim=128,
295                in_channels=8,
296                inner_dim=2560,
297                max_position=32768,
298                mlp_ratio=2.5,
299                num_attention_heads=20,
300                num_layers=24,
301                rope_theta=1000000.0,
302                speaker_embedding_dim=512,
303                text_embedding_dim=768,
304            ),
305        )
306    }
307}
308
309
310def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict:
311    """Optionally expand the state dict to match the model's parameters shapes.\n
312    :param model: The model to match parameters against
313    :param state_dict: The state dictionary to expand
314    :returns: The expanded state dictionary
315    """
316    for name, param in model.named_parameters():
317        if name in state_dict:
318            if state_dict[name].shape != param.shape:
319                nfo(f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}.")
320                # expand with zeros:
321                expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device)
322                slices = tuple(slice(0, dim) for dim in state_dict[name].shape)
323                expanded_state_dict_weight[slices] = state_dict[name]
324                state_dict[name] = expanded_state_dict_weight
325
326    return state_dict
327
328
329def merge_spec(base_spec: Any, subkey_spec: Any) -> ModelSpec:
330    """Merge two dataclass or nested dataclass specs with overlapping subkey values taking precedence over base values.\n
331    :param base_spec: Base specification dataclass
332    :param subkey_spec: Subkey specification dataclass (values take precedence)
333    :returns: Merged specification with subkey values overriding base values
334    """
335    if not hasattr(subkey_spec, "__dataclass_fields__"):
336        return base_spec
337
338    merge_kwargs = {}
339    for field_name in subkey_spec.__dataclass_fields__:
340        subkey_value = getattr(subkey_spec, field_name, None)
341        base_value = getattr(base_spec, field_name, None)
342
343        if subkey_value is not None:
344            if hasattr(subkey_value, "__dataclass_fields__") and base_value is not None and hasattr(base_value, "__dataclass_fields__"):
345                nested_merge_kwargs = {}
346                for nested_field in subkey_value.__dataclass_fields__:
347                    nested_subkey_val = getattr(subkey_value, nested_field, None)
348                    if nested_subkey_val is not None:
349                        nested_merge_kwargs[nested_field] = nested_subkey_val
350                if nested_merge_kwargs:
351                    merge_kwargs[field_name] = replace(base_value, **nested_merge_kwargs)
352                else:
353                    merge_kwargs[field_name] = subkey_value
354            else:
355                merge_kwargs[field_name] = subkey_value
356
357    if merge_kwargs:
358        return replace(base_spec, **merge_kwargs)
359    return base_spec
360
361
362def get_model_spec(mir_id: str, configs: dict[str, dict[str, ModelSpec | CompatibilitySpec]]) -> ModelSpec:
363    """Get a ModelSpec or CompatibilitySpec for a given model ID. Use to point to a known model spec.\n
364    :param mir_id: Model ID (e.g., "model.dit.flux1-dev")
365    :param configs: Configuration mapping containing model specs
366    :returns: ModelSpec if compatibility_key is None, CompatibilitySpec if provided and available, None if provided but not found
367    :raises ValueError: If model ID does not have a base ModelSpec
368    """
369
370    if ":" in mir_id:
371        series_key, compatibility_key = mir_id.split(":")
372        if base_spec := configs.get(series_key, {}).get("*", None):
373            if compatibility_spec := configs.get(series_key, {}).get(compatibility_key, None):
374                merged_spec = merge_spec(base_spec, compatibility_spec)
375                return merged_spec
376    else:
377        if model_spec := configs.get(mir_id, {}).get("*", None):
378            if isinstance(model_spec, ModelSpec):
379                return model_spec
380
381    raise ValueError(f"{mir_id} has no defined model spec")
382
383
384mmada_map = build_available_models(mmada_configs)
385
386flux_map = build_available_models(flux_configs)
387
388acestep_map = build_available_models(acestep_configs)
@dataclass
class CompatibilitySpec:
23@dataclass
24class CompatibilitySpec:
25    repo_id: str
26    file_name: str
CompatibilitySpec(repo_id: str, file_name: str)
repo_id: str
file_name: str
@dataclass
class InitialParamsFlux:
29@dataclass
30class InitialParamsFlux:
31    num_steps: int
32    max_length: int
33    guidance: float
34    shift: bool
35    width: int = 1360
36    height: int = 768
InitialParamsFlux( num_steps: int, max_length: int, guidance: float, shift: bool, width: int = 1360, height: int = 768)
num_steps: int
max_length: int
guidance: float
shift: bool
width: int = 1360
height: int = 768
@dataclass
class InitialParamsMMaDA:
39@dataclass
40class InitialParamsMMaDA:
41    """Default initialization parameters for MMaDA models."""
42
43    steps: int
44    gen_length: int
45    block_length: int
46    temperature: float
47    cfg_scale: float
48    remasking_strategy: str
49    mask_id: int
50    max_position_embeddings: int
51    max_text_len: int

Default initialization parameters for MMaDA models.

InitialParamsMMaDA( steps: int, gen_length: int, block_length: int, temperature: float, cfg_scale: float, remasking_strategy: str, mask_id: int, max_position_embeddings: int, max_text_len: int)
steps: int
gen_length: int
block_length: int
temperature: float
cfg_scale: float
remasking_strategy: str
mask_id: int
max_position_embeddings: int
max_text_len: int
@dataclass
class InitialParamsAceStep:
54@dataclass
55class InitialParamsAceStep:
56    infer_steps: int
57    guidance_scale: float
58    scheduler_type: str
59    cfg_type: str
InitialParamsAceStep( infer_steps: int, guidance_scale: float, scheduler_type: str, cfg_type: str)
infer_steps: int
guidance_scale: float
scheduler_type: str
cfg_type: str
@dataclass
class AutoencoderTinyParams:
62@dataclass
63class AutoencoderTinyParams:
64    """"""
@dataclass
class AceStepParams:
67@dataclass
68class AceStepParams:
69    """"""
70
71    attention_head_dim: int
72    in_channels: int
73    inner_dim: int
74    max_position: int
75    mlp_ratio: float
76    num_attention_heads: int
77    num_layers: int
78    rope_theta: float
79    speaker_embedding_dim: int
80    text_embedding_dim: int
AceStepParams( attention_head_dim: int, in_channels: int, inner_dim: int, max_position: int, mlp_ratio: float, num_attention_heads: int, num_layers: int, rope_theta: float, speaker_embedding_dim: int, text_embedding_dim: int)
attention_head_dim: int
in_channels: int
inner_dim: int
max_position: int
mlp_ratio: float
num_attention_heads: int
num_layers: int
rope_theta: float
speaker_embedding_dim: int
text_embedding_dim: int
@dataclass
class ModelSpec:
83@dataclass
84class ModelSpec:
85    repo_id: str
86    params: FluxParams | AutoEncoder1Params | XFluxParams | Flux2Params | MMaDAParams | AutoEncoder2Params | AutoencoderTinyParams | FluxLoraWrapper
87    file_name: str
88    init: InitialParamsFlux | InitialParamsMMaDA | None = None
repo_id: str
file_name: str
init: InitialParamsFlux | InitialParamsMMaDA | None = None
flux_configs: dict[str, dict[str, ModelSpec | CompatibilitySpec]] = {'model.dit.flux1-dev': {'*': ModelSpec(repo_id='black-forest-labs/FLUX.1-dev', params=FluxParams(in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10000, qkv_bias=True, guidance_embed=True), file_name='flux1-dev.safetensors', init=InitialParamsFlux(num_steps=28, max_length=512, guidance=4.0, shift=True, width=1360, height=768)), '@@fp8-e5m2-sai': CompatibilitySpec(repo_id='Kijai/flux-fp8', file_name='flux1-dev-fp8-e5m2.safetensors'), '*@fp8-e4m3fn-sai': CompatibilitySpec(repo_id='Kijai/flux-fp8', file_name='flux1-dev-fp8-e4m3fn.safetensors'), '*@fp8-sai': CompatibilitySpec(repo_id='XLabs-AI/flux-dev-fp8', file_name='flux-dev-fp8.safetensors'), 'mini': ModelSpec(repo_id='TencentARC/flux-mini', params=XFluxParams(in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=5, depth_single_blocks=10, axes_dim=[16, 56, 56], theta=10000, qkv_bias=True, guidance_embed=True), file_name='flux-mini.safetensors', init=InitialParamsFlux(num_steps=25, max_length=512, guidance=3.5, shift=True, width=1360, height=768))}, 'model.vae.flux1-dev': {'*': ModelSpec(repo_id='black-forest-labs/FLUX.1-dev', params=AutoEncoderParams(resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159), file_name='ae.safetensors', init=None)}, 'model.taesd.flux1-dev': {'*': ModelSpec(repo_id='madebyollin/taef1', params=AutoencoderTinyParams(), file_name='diffusion_pytorch_model.safetensors', init=None)}, 'model.dit.flux1-schnell': {'*': ModelSpec(repo_id='black-forest-labs/FLUX.1-schnell', params=FluxParams(in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10000, qkv_bias=True, guidance_embed=False), file_name='flux1-schnell.safetensors', init=InitialParamsFlux(num_steps=4, max_length=256, guidance=2.5, shift=False, width=1360, height=768)), '*@fp8-sai': CompatibilitySpec(repo_id='Comfy-Org/flux1-schnell', file_name='flux1-schnell-fp8.safetensors'), '*@fp8-e4m3fn-sai': CompatibilitySpec(repo_id='Kijai/flux-fp8', file_name='flux1-schnell-fp8-e4m3fn.safetensors')}, 'model.dit.flux2-dev': {'*': ModelSpec(repo_id='black-forest-labs/FLUX.2-dev', params=Flux2Params(in_channels=128, context_in_dim=15360, hidden_size=6144, num_heads=48, depth=8, depth_single_blocks=48, axes_dim=[32, 32, 32, 32], theta=2000, mlp_ratio=3.0), file_name='flux2-dev.safetensors', init=None), '*@fp8-sai': CompatibilitySpec(repo_id='Comfy-Org/flux2-dev', file_name='split_files/diffusion_models/flux2_dev_fp8mixed.safetensors')}, 'model.vae.flux2-dev': {'*': ModelSpec(repo_id='black-forest-labs/FLUX.2-dev', params=AutoEncoderParams(resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=32), file_name='ae.safetensors', init=None)}}
mmada_configs = {'model.mldm.mmada': {'*': ModelSpec(repo_id='Gen-Verse/MMaDA-8B-Base', params=MMadaConfig { "codebook_size": 8192, "llm_model_path": "", "llm_vocab_size": 50257, "model_type": "mmada", "num_new_special_tokens": 0, "num_vq_tokens": 1024, "transformers_version": "4.57.3", "vocab_size": 50257 } , file_name='model.safetensors', init=InitialParamsMMaDA(steps=256, gen_length=512, block_length=128, temperature=1.0, cfg_scale=0.0, remasking_strategy='low_confidence', mask_id=126336, max_position_embeddings=2048, max_text_len=512)), 'mixcot': CompatibilitySpec(repo_id='Gen-Verse/MMaDA-8B-MixCoT', file_name='model.safetensors'), 'lumina-dimoo': ModelSpec(repo_id='Alpha-VLLM/Lumina-DiMOO', params=MMadaConfig { "codebook_size": 8192, "llm_model_path": "", "llm_vocab_size": 50257, "model_type": "mmada", "num_new_special_tokens": 0, "num_vq_tokens": 1024, "transformers_version": "4.57.3", "vocab_size": 50257 } , file_name='model.safetensors', init=InitialParamsMMaDA(steps=128, gen_length=1024, block_length=256, temperature=0.0, cfg_scale=0.0, remasking_strategy='low_confidence', mask_id=126336, max_position_embeddings=2048, max_text_len=512))}}
acestep_configs = {'model.dit.acestep': {'*': ModelSpec(repo_id='ACE-Step/ACE-Step-v1-3.5B', params=AceStepParams(attention_head_dim=128, in_channels=8, inner_dim=2560, max_position=32768, mlp_ratio=2.5, num_attention_heads=20, num_layers=24, rope_theta=1000000.0, speaker_embedding_dim=512, text_embedding_dim=768), file_name='ace_step_transformer/diffusion_pytest_model.safetensors', init=InitialParamsAceStep(infer_steps=60, guidance_scale=15.0, scheduler_type='euler', cfg_type='apg'))}}
def optionally_expand_state_dict(model: torch.nn.modules.module.Module, state_dict: dict) -> dict:
311def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict:
312    """Optionally expand the state dict to match the model's parameters shapes.\n
313    :param model: The model to match parameters against
314    :param state_dict: The state dictionary to expand
315    :returns: The expanded state dictionary
316    """
317    for name, param in model.named_parameters():
318        if name in state_dict:
319            if state_dict[name].shape != param.shape:
320                nfo(f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}.")
321                # expand with zeros:
322                expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device)
323                slices = tuple(slice(0, dim) for dim in state_dict[name].shape)
324                expanded_state_dict_weight[slices] = state_dict[name]
325                state_dict[name] = expanded_state_dict_weight
326
327    return state_dict

Optionally expand the state dict to match the model's parameters shapes.

Parameters
  • model: The model to match parameters against
  • state_dict: The state dictionary to expand :returns: The expanded state dictionary
def merge_spec(base_spec: Any, subkey_spec: Any) -> ModelSpec:
330def merge_spec(base_spec: Any, subkey_spec: Any) -> ModelSpec:
331    """Merge two dataclass or nested dataclass specs with overlapping subkey values taking precedence over base values.\n
332    :param base_spec: Base specification dataclass
333    :param subkey_spec: Subkey specification dataclass (values take precedence)
334    :returns: Merged specification with subkey values overriding base values
335    """
336    if not hasattr(subkey_spec, "__dataclass_fields__"):
337        return base_spec
338
339    merge_kwargs = {}
340    for field_name in subkey_spec.__dataclass_fields__:
341        subkey_value = getattr(subkey_spec, field_name, None)
342        base_value = getattr(base_spec, field_name, None)
343
344        if subkey_value is not None:
345            if hasattr(subkey_value, "__dataclass_fields__") and base_value is not None and hasattr(base_value, "__dataclass_fields__"):
346                nested_merge_kwargs = {}
347                for nested_field in subkey_value.__dataclass_fields__:
348                    nested_subkey_val = getattr(subkey_value, nested_field, None)
349                    if nested_subkey_val is not None:
350                        nested_merge_kwargs[nested_field] = nested_subkey_val
351                if nested_merge_kwargs:
352                    merge_kwargs[field_name] = replace(base_value, **nested_merge_kwargs)
353                else:
354                    merge_kwargs[field_name] = subkey_value
355            else:
356                merge_kwargs[field_name] = subkey_value
357
358    if merge_kwargs:
359        return replace(base_spec, **merge_kwargs)
360    return base_spec

Merge two dataclass or nested dataclass specs with overlapping subkey values taking precedence over base values.

Parameters
  • base_spec: Base specification dataclass
  • subkey_spec: Subkey specification dataclass (values take precedence) :returns: Merged specification with subkey values overriding base values
def get_model_spec( mir_id: str, configs: dict[str, dict[str, ModelSpec | CompatibilitySpec]]) -> ModelSpec:
363def get_model_spec(mir_id: str, configs: dict[str, dict[str, ModelSpec | CompatibilitySpec]]) -> ModelSpec:
364    """Get a ModelSpec or CompatibilitySpec for a given model ID. Use to point to a known model spec.\n
365    :param mir_id: Model ID (e.g., "model.dit.flux1-dev")
366    :param configs: Configuration mapping containing model specs
367    :returns: ModelSpec if compatibility_key is None, CompatibilitySpec if provided and available, None if provided but not found
368    :raises ValueError: If model ID does not have a base ModelSpec
369    """
370
371    if ":" in mir_id:
372        series_key, compatibility_key = mir_id.split(":")
373        if base_spec := configs.get(series_key, {}).get("*", None):
374            if compatibility_spec := configs.get(series_key, {}).get(compatibility_key, None):
375                merged_spec = merge_spec(base_spec, compatibility_spec)
376                return merged_spec
377    else:
378        if model_spec := configs.get(mir_id, {}).get("*", None):
379            if isinstance(model_spec, ModelSpec):
380                return model_spec
381
382    raise ValueError(f"{mir_id} has no defined model spec")

Get a ModelSpec or CompatibilitySpec for a given model ID. Use to point to a known model spec.

Parameters
  • mir_id: Model ID (e.g., "model.dit.flux1-dev")
  • configs: Configuration mapping containing model specs :returns: ModelSpec if compatibility_key is None, CompatibilitySpec if provided and available, None if provided but not found
Raises
  • ValueError: If model ID does not have a base ModelSpec
mmada_map = {'mmada': 'model.mldm.mmada', 'mixcot': 'model.mldm.mmada:mixcot', 'lumina-dimoo': 'model.mldm.mmada:lumina-dimoo'}
flux_map = {'flux1-dev': 'model.dit.flux1-dev', 'mini': 'model.dit.flux1-dev:mini', 'flux1-schnell': 'model.dit.flux1-schnell', 'flux2-dev': 'model.dit.flux2-dev'}
acestep_map = {'acestep': 'model.dit.acestep'}