divisor.spec
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3# adapted BFL Flux code from https://github.com/black-forest-labs/flux and https://github.com/Gen-Verse/MMaDA 4 5 6from dataclasses import replace 7from dataclasses import dataclass 8from typing import Any 9 10from nnll.console import nfo 11import torch 12 13from divisor.registry import build_available_models 14from divisor.flux1.autoencoder import AutoEncoderParams as AutoEncoder1Params 15from divisor.flux1.model import FluxLoraWrapper, FluxParams 16from divisor.flux2.autoencoder import AutoEncoderParams as AutoEncoder2Params 17from divisor.flux2.model import Flux2Params 18from divisor.xflux1.model import XFluxParams 19from divisor.mmada.modeling_mmada import MMadaConfig as MMaDAParams 20 21 22@dataclass 23class CompatibilitySpec: 24 repo_id: str 25 file_name: str 26 27 28@dataclass 29class InitialParamsFlux: 30 num_steps: int 31 max_length: int 32 guidance: float 33 shift: bool 34 width: int = 1360 35 height: int = 768 36 37 38@dataclass 39class InitialParamsMMaDA: 40 """Default initialization parameters for MMaDA models.""" 41 42 steps: int 43 gen_length: int 44 block_length: int 45 temperature: float 46 cfg_scale: float 47 remasking_strategy: str 48 mask_id: int 49 max_position_embeddings: int 50 max_text_len: int 51 52 53@dataclass 54class InitialParamsAceStep: 55 infer_steps: int 56 guidance_scale: float 57 scheduler_type: str 58 cfg_type: str 59 60 61@dataclass 62class AutoencoderTinyParams: 63 """""" 64 65 66@dataclass 67class AceStepParams: 68 """""" 69 70 attention_head_dim: int 71 in_channels: int 72 inner_dim: int 73 max_position: int 74 mlp_ratio: float 75 num_attention_heads: int 76 num_layers: int 77 rope_theta: float 78 speaker_embedding_dim: int 79 text_embedding_dim: int 80 81 82@dataclass 83class ModelSpec: 84 repo_id: str 85 params: FluxParams | AutoEncoder1Params | XFluxParams | Flux2Params | MMaDAParams | AutoEncoder2Params | AutoencoderTinyParams | FluxLoraWrapper 86 file_name: str 87 init: InitialParamsFlux | InitialParamsMMaDA | None = None 88 89 90flux_configs: dict[str, dict[str, ModelSpec | CompatibilitySpec]] = { 91 "model.dit.flux1-dev": { 92 "*": ModelSpec( 93 repo_id="black-forest-labs/FLUX.1-dev", 94 file_name="flux1-dev.safetensors", 95 init=InitialParamsFlux( 96 num_steps=28, 97 max_length=512, 98 guidance=4.0, 99 shift=True, 100 ), 101 params=FluxParams( 102 in_channels=64, 103 vec_in_dim=768, 104 context_in_dim=4096, 105 hidden_size=3072, 106 mlp_ratio=4.0, 107 num_heads=24, 108 depth=19, 109 depth_single_blocks=38, 110 axes_dim=[16, 56, 56], 111 theta=10_000, 112 qkv_bias=True, 113 guidance_embed=True, 114 ), 115 ), 116 "@@fp8-e5m2-sai": CompatibilitySpec( 117 repo_id="Kijai/flux-fp8", 118 file_name="flux1-dev-fp8-e5m2.safetensors", 119 ), 120 "*@fp8-e4m3fn-sai": CompatibilitySpec( 121 repo_id="Kijai/flux-fp8", 122 file_name="flux1-dev-fp8-e4m3fn.safetensors", 123 ), 124 "*@fp8-sai": CompatibilitySpec( 125 repo_id="XLabs-AI/flux-dev-fp8", 126 file_name="flux-dev-fp8.safetensors", 127 ), 128 "mini": ModelSpec( 129 repo_id="TencentARC/flux-mini", 130 file_name="flux-mini.safetensors", 131 init=InitialParamsFlux( 132 num_steps=25, 133 max_length=512, 134 guidance=3.5, 135 shift=True, 136 ), 137 params=XFluxParams( 138 in_channels=64, 139 vec_in_dim=768, 140 context_in_dim=4096, 141 hidden_size=3072, 142 mlp_ratio=4.0, 143 num_heads=24, 144 depth=5, 145 depth_single_blocks=10, 146 axes_dim=[16, 56, 56], 147 theta=10_000, 148 qkv_bias=True, 149 guidance_embed=True, 150 ), 151 ), 152 }, 153 "model.vae.flux1-dev": { 154 "*": ModelSpec( 155 repo_id="black-forest-labs/FLUX.1-dev", 156 file_name="ae.safetensors", 157 params=AutoEncoder1Params( 158 resolution=256, 159 in_channels=3, 160 ch=128, 161 out_ch=3, 162 ch_mult=[1, 2, 4, 4], 163 num_res_blocks=2, 164 z_channels=16, 165 scale_factor=0.3611, 166 shift_factor=0.1159, 167 ), 168 ), 169 }, 170 "model.taesd.flux1-dev": { 171 "*": ModelSpec(repo_id="madebyollin/taef1", file_name="diffusion_pytorch_model.safetensors", params=AutoencoderTinyParams()), 172 }, 173 "model.dit.flux1-schnell": { 174 "*": ModelSpec( 175 repo_id="black-forest-labs/FLUX.1-schnell", 176 file_name="flux1-schnell.safetensors", 177 init=InitialParamsFlux( 178 num_steps=4, 179 max_length=256, 180 guidance=2.5, 181 shift=False, 182 ), 183 params=FluxParams( 184 in_channels=64, 185 vec_in_dim=768, 186 context_in_dim=4096, 187 hidden_size=3072, 188 mlp_ratio=4.0, 189 num_heads=24, 190 depth=19, 191 depth_single_blocks=38, 192 axes_dim=[16, 56, 56], 193 theta=10_000, 194 qkv_bias=True, 195 guidance_embed=False, 196 ), 197 ), 198 "*@fp8-sai": CompatibilitySpec( 199 repo_id="Comfy-Org/flux1-schnell", 200 file_name="flux1-schnell-fp8.safetensors", 201 ), 202 "*@fp8-e4m3fn-sai": CompatibilitySpec( 203 repo_id="Kijai/flux-fp8", 204 file_name="flux1-schnell-fp8-e4m3fn.safetensors", 205 ), 206 }, 207 "model.dit.flux2-dev": { 208 "*": ModelSpec( 209 repo_id="black-forest-labs/FLUX.2-dev", 210 file_name="flux2-dev.safetensors", 211 params=Flux2Params(), 212 ), 213 "*@fp8-sai": CompatibilitySpec( 214 repo_id="Comfy-Org/flux2-dev", 215 file_name="split_files/diffusion_models/flux2_dev_fp8mixed.safetensors", 216 ), 217 }, 218 "model.vae.flux2-dev": { 219 "*": ModelSpec( 220 repo_id="black-forest-labs/FLUX.2-dev", 221 file_name="ae.safetensors", 222 params=AutoEncoder2Params(), 223 ) 224 }, 225} 226 227mmada_configs = { 228 "model.mldm.mmada": { 229 "*": ModelSpec( 230 repo_id="Gen-Verse/MMaDA-8B-Base", 231 file_name="model.safetensors", 232 init=InitialParamsMMaDA( 233 steps=256, 234 gen_length=512, 235 block_length=128, 236 temperature=1.0, 237 cfg_scale=0.0, 238 remasking_strategy="low_confidence", 239 mask_id=126336, 240 max_position_embeddings=2048, 241 max_text_len=512, 242 ), 243 params=MMaDAParams( 244 vocab_size=50257, 245 llm_vocab_size=50257, 246 llm_model_path="", 247 codebook_size=8192, 248 num_vq_tokens=1024, 249 num_new_special_tokens=0, 250 ), 251 ), 252 "mixcot": CompatibilitySpec( 253 repo_id="Gen-Verse/MMaDA-8B-MixCoT", 254 file_name="model.safetensors", 255 ), 256 "lumina-dimoo": ModelSpec( 257 repo_id="Alpha-VLLM/Lumina-DiMOO", 258 file_name="model.safetensors", 259 init=InitialParamsMMaDA( 260 steps=128, 261 gen_length=1024, 262 block_length=256, 263 temperature=0.0, 264 cfg_scale=0.0, 265 remasking_strategy="low_confidence", 266 mask_id=126336, 267 max_position_embeddings=2048, 268 max_text_len=512, 269 ), 270 params=MMaDAParams( 271 vocab_size=50257, 272 llm_vocab_size=50257, 273 llm_model_path="", 274 codebook_size=8192, 275 num_vq_tokens=1024, 276 num_new_special_tokens=0, 277 ), 278 ), 279 }, 280} 281 282acestep_configs = { 283 "model.dit.acestep": { 284 "*": ModelSpec( 285 repo_id="ACE-Step/ACE-Step-v1-3.5B", 286 file_name="ace_step_transformer/diffusion_pytest_model.safetensors", 287 init=InitialParamsAceStep( 288 infer_steps=60, 289 guidance_scale=15.0, 290 scheduler_type="euler", 291 cfg_type="apg", 292 ), 293 params=AceStepParams( 294 attention_head_dim=128, 295 in_channels=8, 296 inner_dim=2560, 297 max_position=32768, 298 mlp_ratio=2.5, 299 num_attention_heads=20, 300 num_layers=24, 301 rope_theta=1000000.0, 302 speaker_embedding_dim=512, 303 text_embedding_dim=768, 304 ), 305 ) 306 } 307} 308 309 310def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict: 311 """Optionally expand the state dict to match the model's parameters shapes.\n 312 :param model: The model to match parameters against 313 :param state_dict: The state dictionary to expand 314 :returns: The expanded state dictionary 315 """ 316 for name, param in model.named_parameters(): 317 if name in state_dict: 318 if state_dict[name].shape != param.shape: 319 nfo(f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}.") 320 # expand with zeros: 321 expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device) 322 slices = tuple(slice(0, dim) for dim in state_dict[name].shape) 323 expanded_state_dict_weight[slices] = state_dict[name] 324 state_dict[name] = expanded_state_dict_weight 325 326 return state_dict 327 328 329def merge_spec(base_spec: Any, subkey_spec: Any) -> ModelSpec: 330 """Merge two dataclass or nested dataclass specs with overlapping subkey values taking precedence over base values.\n 331 :param base_spec: Base specification dataclass 332 :param subkey_spec: Subkey specification dataclass (values take precedence) 333 :returns: Merged specification with subkey values overriding base values 334 """ 335 if not hasattr(subkey_spec, "__dataclass_fields__"): 336 return base_spec 337 338 merge_kwargs = {} 339 for field_name in subkey_spec.__dataclass_fields__: 340 subkey_value = getattr(subkey_spec, field_name, None) 341 base_value = getattr(base_spec, field_name, None) 342 343 if subkey_value is not None: 344 if hasattr(subkey_value, "__dataclass_fields__") and base_value is not None and hasattr(base_value, "__dataclass_fields__"): 345 nested_merge_kwargs = {} 346 for nested_field in subkey_value.__dataclass_fields__: 347 nested_subkey_val = getattr(subkey_value, nested_field, None) 348 if nested_subkey_val is not None: 349 nested_merge_kwargs[nested_field] = nested_subkey_val 350 if nested_merge_kwargs: 351 merge_kwargs[field_name] = replace(base_value, **nested_merge_kwargs) 352 else: 353 merge_kwargs[field_name] = subkey_value 354 else: 355 merge_kwargs[field_name] = subkey_value 356 357 if merge_kwargs: 358 return replace(base_spec, **merge_kwargs) 359 return base_spec 360 361 362def get_model_spec(mir_id: str, configs: dict[str, dict[str, ModelSpec | CompatibilitySpec]]) -> ModelSpec: 363 """Get a ModelSpec or CompatibilitySpec for a given model ID. Use to point to a known model spec.\n 364 :param mir_id: Model ID (e.g., "model.dit.flux1-dev") 365 :param configs: Configuration mapping containing model specs 366 :returns: ModelSpec if compatibility_key is None, CompatibilitySpec if provided and available, None if provided but not found 367 :raises ValueError: If model ID does not have a base ModelSpec 368 """ 369 370 if ":" in mir_id: 371 series_key, compatibility_key = mir_id.split(":") 372 if base_spec := configs.get(series_key, {}).get("*", None): 373 if compatibility_spec := configs.get(series_key, {}).get(compatibility_key, None): 374 merged_spec = merge_spec(base_spec, compatibility_spec) 375 return merged_spec 376 else: 377 if model_spec := configs.get(mir_id, {}).get("*", None): 378 if isinstance(model_spec, ModelSpec): 379 return model_spec 380 381 raise ValueError(f"{mir_id} has no defined model spec") 382 383 384mmada_map = build_available_models(mmada_configs) 385 386flux_map = build_available_models(flux_configs) 387 388acestep_map = build_available_models(acestep_configs)
@dataclass
class
CompatibilitySpec:
@dataclass
class
InitialParamsFlux:
29@dataclass 30class InitialParamsFlux: 31 num_steps: int 32 max_length: int 33 guidance: float 34 shift: bool 35 width: int = 1360 36 height: int = 768
@dataclass
class
InitialParamsMMaDA:
39@dataclass 40class InitialParamsMMaDA: 41 """Default initialization parameters for MMaDA models.""" 42 43 steps: int 44 gen_length: int 45 block_length: int 46 temperature: float 47 cfg_scale: float 48 remasking_strategy: str 49 mask_id: int 50 max_position_embeddings: int 51 max_text_len: int
Default initialization parameters for MMaDA models.
@dataclass
class
InitialParamsAceStep:
@dataclass
class
AutoencoderTinyParams:
@dataclass
class
AceStepParams:
67@dataclass 68class AceStepParams: 69 """""" 70 71 attention_head_dim: int 72 in_channels: int 73 inner_dim: int 74 max_position: int 75 mlp_ratio: float 76 num_attention_heads: int 77 num_layers: int 78 rope_theta: float 79 speaker_embedding_dim: int 80 text_embedding_dim: int
@dataclass
class
ModelSpec:
83@dataclass 84class ModelSpec: 85 repo_id: str 86 params: FluxParams | AutoEncoder1Params | XFluxParams | Flux2Params | MMaDAParams | AutoEncoder2Params | AutoencoderTinyParams | FluxLoraWrapper 87 file_name: str 88 init: InitialParamsFlux | InitialParamsMMaDA | None = None
ModelSpec( repo_id: str, params: divisor.flux1.model.FluxParams | divisor.flux1.autoencoder.AutoEncoderParams | divisor.xflux1.model.XFluxParams | divisor.flux2.model.Flux2Params | divisor.mmada.modeling_mmada.MMadaConfig | divisor.flux2.autoencoder.AutoEncoderParams | AutoencoderTinyParams | divisor.flux1.model.FluxLoraWrapper, file_name: str, init: InitialParamsFlux | InitialParamsMMaDA | None = None)
params: divisor.flux1.model.FluxParams | divisor.flux1.autoencoder.AutoEncoderParams | divisor.xflux1.model.XFluxParams | divisor.flux2.model.Flux2Params | divisor.mmada.modeling_mmada.MMadaConfig | divisor.flux2.autoencoder.AutoEncoderParams | AutoencoderTinyParams | divisor.flux1.model.FluxLoraWrapper
flux_configs: dict[str, dict[str, ModelSpec | CompatibilitySpec]] =
{'model.dit.flux1-dev': {'*': ModelSpec(repo_id='black-forest-labs/FLUX.1-dev', params=FluxParams(in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10000, qkv_bias=True, guidance_embed=True), file_name='flux1-dev.safetensors', init=InitialParamsFlux(num_steps=28, max_length=512, guidance=4.0, shift=True, width=1360, height=768)), '@@fp8-e5m2-sai': CompatibilitySpec(repo_id='Kijai/flux-fp8', file_name='flux1-dev-fp8-e5m2.safetensors'), '*@fp8-e4m3fn-sai': CompatibilitySpec(repo_id='Kijai/flux-fp8', file_name='flux1-dev-fp8-e4m3fn.safetensors'), '*@fp8-sai': CompatibilitySpec(repo_id='XLabs-AI/flux-dev-fp8', file_name='flux-dev-fp8.safetensors'), 'mini': ModelSpec(repo_id='TencentARC/flux-mini', params=XFluxParams(in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=5, depth_single_blocks=10, axes_dim=[16, 56, 56], theta=10000, qkv_bias=True, guidance_embed=True), file_name='flux-mini.safetensors', init=InitialParamsFlux(num_steps=25, max_length=512, guidance=3.5, shift=True, width=1360, height=768))}, 'model.vae.flux1-dev': {'*': ModelSpec(repo_id='black-forest-labs/FLUX.1-dev', params=AutoEncoderParams(resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159), file_name='ae.safetensors', init=None)}, 'model.taesd.flux1-dev': {'*': ModelSpec(repo_id='madebyollin/taef1', params=AutoencoderTinyParams(), file_name='diffusion_pytorch_model.safetensors', init=None)}, 'model.dit.flux1-schnell': {'*': ModelSpec(repo_id='black-forest-labs/FLUX.1-schnell', params=FluxParams(in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10000, qkv_bias=True, guidance_embed=False), file_name='flux1-schnell.safetensors', init=InitialParamsFlux(num_steps=4, max_length=256, guidance=2.5, shift=False, width=1360, height=768)), '*@fp8-sai': CompatibilitySpec(repo_id='Comfy-Org/flux1-schnell', file_name='flux1-schnell-fp8.safetensors'), '*@fp8-e4m3fn-sai': CompatibilitySpec(repo_id='Kijai/flux-fp8', file_name='flux1-schnell-fp8-e4m3fn.safetensors')}, 'model.dit.flux2-dev': {'*': ModelSpec(repo_id='black-forest-labs/FLUX.2-dev', params=Flux2Params(in_channels=128, context_in_dim=15360, hidden_size=6144, num_heads=48, depth=8, depth_single_blocks=48, axes_dim=[32, 32, 32, 32], theta=2000, mlp_ratio=3.0), file_name='flux2-dev.safetensors', init=None), '*@fp8-sai': CompatibilitySpec(repo_id='Comfy-Org/flux2-dev', file_name='split_files/diffusion_models/flux2_dev_fp8mixed.safetensors')}, 'model.vae.flux2-dev': {'*': ModelSpec(repo_id='black-forest-labs/FLUX.2-dev', params=AutoEncoderParams(resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=32), file_name='ae.safetensors', init=None)}}
mmada_configs =
{'model.mldm.mmada': {'*': ModelSpec(repo_id='Gen-Verse/MMaDA-8B-Base', params=MMadaConfig {
"codebook_size": 8192,
"llm_model_path": "",
"llm_vocab_size": 50257,
"model_type": "mmada",
"num_new_special_tokens": 0,
"num_vq_tokens": 1024,
"transformers_version": "4.57.3",
"vocab_size": 50257
}
, file_name='model.safetensors', init=InitialParamsMMaDA(steps=256, gen_length=512, block_length=128, temperature=1.0, cfg_scale=0.0, remasking_strategy='low_confidence', mask_id=126336, max_position_embeddings=2048, max_text_len=512)), 'mixcot': CompatibilitySpec(repo_id='Gen-Verse/MMaDA-8B-MixCoT', file_name='model.safetensors'), 'lumina-dimoo': ModelSpec(repo_id='Alpha-VLLM/Lumina-DiMOO', params=MMadaConfig {
"codebook_size": 8192,
"llm_model_path": "",
"llm_vocab_size": 50257,
"model_type": "mmada",
"num_new_special_tokens": 0,
"num_vq_tokens": 1024,
"transformers_version": "4.57.3",
"vocab_size": 50257
}
, file_name='model.safetensors', init=InitialParamsMMaDA(steps=128, gen_length=1024, block_length=256, temperature=0.0, cfg_scale=0.0, remasking_strategy='low_confidence', mask_id=126336, max_position_embeddings=2048, max_text_len=512))}}
acestep_configs =
{'model.dit.acestep': {'*': ModelSpec(repo_id='ACE-Step/ACE-Step-v1-3.5B', params=AceStepParams(attention_head_dim=128, in_channels=8, inner_dim=2560, max_position=32768, mlp_ratio=2.5, num_attention_heads=20, num_layers=24, rope_theta=1000000.0, speaker_embedding_dim=512, text_embedding_dim=768), file_name='ace_step_transformer/diffusion_pytest_model.safetensors', init=InitialParamsAceStep(infer_steps=60, guidance_scale=15.0, scheduler_type='euler', cfg_type='apg'))}}
def
optionally_expand_state_dict(model: torch.nn.modules.module.Module, state_dict: dict) -> dict:
311def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict: 312 """Optionally expand the state dict to match the model's parameters shapes.\n 313 :param model: The model to match parameters against 314 :param state_dict: The state dictionary to expand 315 :returns: The expanded state dictionary 316 """ 317 for name, param in model.named_parameters(): 318 if name in state_dict: 319 if state_dict[name].shape != param.shape: 320 nfo(f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}.") 321 # expand with zeros: 322 expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device) 323 slices = tuple(slice(0, dim) for dim in state_dict[name].shape) 324 expanded_state_dict_weight[slices] = state_dict[name] 325 state_dict[name] = expanded_state_dict_weight 326 327 return state_dict
Optionally expand the state dict to match the model's parameters shapes.
Parameters
- model: The model to match parameters against
- state_dict: The state dictionary to expand :returns: The expanded state dictionary
330def merge_spec(base_spec: Any, subkey_spec: Any) -> ModelSpec: 331 """Merge two dataclass or nested dataclass specs with overlapping subkey values taking precedence over base values.\n 332 :param base_spec: Base specification dataclass 333 :param subkey_spec: Subkey specification dataclass (values take precedence) 334 :returns: Merged specification with subkey values overriding base values 335 """ 336 if not hasattr(subkey_spec, "__dataclass_fields__"): 337 return base_spec 338 339 merge_kwargs = {} 340 for field_name in subkey_spec.__dataclass_fields__: 341 subkey_value = getattr(subkey_spec, field_name, None) 342 base_value = getattr(base_spec, field_name, None) 343 344 if subkey_value is not None: 345 if hasattr(subkey_value, "__dataclass_fields__") and base_value is not None and hasattr(base_value, "__dataclass_fields__"): 346 nested_merge_kwargs = {} 347 for nested_field in subkey_value.__dataclass_fields__: 348 nested_subkey_val = getattr(subkey_value, nested_field, None) 349 if nested_subkey_val is not None: 350 nested_merge_kwargs[nested_field] = nested_subkey_val 351 if nested_merge_kwargs: 352 merge_kwargs[field_name] = replace(base_value, **nested_merge_kwargs) 353 else: 354 merge_kwargs[field_name] = subkey_value 355 else: 356 merge_kwargs[field_name] = subkey_value 357 358 if merge_kwargs: 359 return replace(base_spec, **merge_kwargs) 360 return base_spec
Merge two dataclass or nested dataclass specs with overlapping subkey values taking precedence over base values.
Parameters
- base_spec: Base specification dataclass
- subkey_spec: Subkey specification dataclass (values take precedence) :returns: Merged specification with subkey values overriding base values
def
get_model_spec( mir_id: str, configs: dict[str, dict[str, ModelSpec | CompatibilitySpec]]) -> ModelSpec:
363def get_model_spec(mir_id: str, configs: dict[str, dict[str, ModelSpec | CompatibilitySpec]]) -> ModelSpec: 364 """Get a ModelSpec or CompatibilitySpec for a given model ID. Use to point to a known model spec.\n 365 :param mir_id: Model ID (e.g., "model.dit.flux1-dev") 366 :param configs: Configuration mapping containing model specs 367 :returns: ModelSpec if compatibility_key is None, CompatibilitySpec if provided and available, None if provided but not found 368 :raises ValueError: If model ID does not have a base ModelSpec 369 """ 370 371 if ":" in mir_id: 372 series_key, compatibility_key = mir_id.split(":") 373 if base_spec := configs.get(series_key, {}).get("*", None): 374 if compatibility_spec := configs.get(series_key, {}).get(compatibility_key, None): 375 merged_spec = merge_spec(base_spec, compatibility_spec) 376 return merged_spec 377 else: 378 if model_spec := configs.get(mir_id, {}).get("*", None): 379 if isinstance(model_spec, ModelSpec): 380 return model_spec 381 382 raise ValueError(f"{mir_id} has no defined model spec")
Get a ModelSpec or CompatibilitySpec for a given model ID. Use to point to a known model spec.
Parameters
- mir_id: Model ID (e.g., "model.dit.flux1-dev")
- configs: Configuration mapping containing model specs :returns: ModelSpec if compatibility_key is None, CompatibilitySpec if provided and available, None if provided but not found
Raises
- ValueError: If model ID does not have a base ModelSpec
mmada_map =
{'mmada': 'model.mldm.mmada', 'mixcot': 'model.mldm.mmada:mixcot', 'lumina-dimoo': 'model.mldm.mmada:lumina-dimoo'}
flux_map =
{'flux1-dev': 'model.dit.flux1-dev', 'mini': 'model.dit.flux1-dev:mini', 'flux1-schnell': 'model.dit.flux1-schnell', 'flux2-dev': 'model.dit.flux2-dev'}
acestep_map =
{'acestep': 'model.dit.acestep'}