divisor.flux1.loading
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3# adapted BFL Flux code from https://github.com/black-forest-labs/flux 4 5 6import os 7from pathlib import Path 8 9import torch 10from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny 11from huggingface_hub import snapshot_download 12from nnll.console import nfo 13from safetensors.torch import load_file as load_sft 14 15from divisor.flux1.autoencoder import AutoEncoder as AutoEncoder1 16from divisor.flux1.autoencoder import AutoEncoderParams 17from divisor.flux1.model import Flux, FluxLoraWrapper 18from divisor.flux1.text_embedder import HFEmbedder 19from divisor.flux2.autoencoder import ( 20 AutoEncoder as AutoEncoder2, 21) 22from divisor.flux2.autoencoder import ( 23 AutoEncoderParams as AutoEncoder2Params, 24) 25from divisor.flux2.model import Flux2, Flux2Params 26from divisor.flux2.text_encoder import Mistral3SmallEmbedder 27from divisor.mmada.modeling_mmada import MMadaConfig as MMaDAParams 28from divisor.mmada.modeling_mmada import MMadaModelLM as MMaDAModelLM 29from divisor.registry import gfx_device, gfx_dtype 30from divisor.spec import ModelSpec, optionally_expand_state_dict 31from divisor.xflux1.model import XFlux, XFluxParams 32 33 34def print_load_warning(missing: list[str], unexpected: list[str]) -> None: 35 if len(missing) > 0 and len(unexpected) > 0: 36 nfo(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 37 nfo("\n" + "-" * 79 + "\n") 38 nfo(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 39 elif len(missing) > 0: 40 nfo(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 41 elif len(unexpected) > 0: 42 nfo(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 43 44 45def retrieve_model(repo_id: str, file_name: str) -> Path: 46 """Get the local path for a checkpoint file, downloading if necessary. 47 :param repo_id: Repository ID for the checkpoint 48 :param file_name: Name of the checkpoint file 49 :returns: Path to the checkpoint file 50 """ 51 52 model_dir = snapshot_download(repo_id=repo_id, allow_patterns=[file_name]) 53 return Path(model_dir) / file_name 54 55 56def convert_fp8_to_bf16(model: torch.nn.Module, verbose: bool = True) -> bool: 57 """Detect and convert fp8 tensors in model parameters and buffers to bf16. 58 :param model: The model to convert fp8 tensors in 59 :param verbose: Whether to print conversion messages 60 """ 61 # Define fp8 dtypes to detect 62 fp8_dtypes = [ 63 getattr(torch, "float8_e4m3fn", None), 64 getattr(torch, "float8_e5m2", None), 65 getattr(torch, "float8_e4m3fnuz", None), 66 getattr(torch, "float8_e5m2fnuz", None), 67 ] 68 69 fp8_dtypes = [dtype for dtype in fp8_dtypes if dtype is not None] 70 if not fp8_dtypes: 71 return False 72 73 converted_count = 0 74 75 for name, param in model.named_parameters(): 76 if param.dtype in fp8_dtypes: 77 with torch.no_grad(): 78 param.data = param.data.float() 79 converted_count += 1 80 if verbose: 81 nfo(f"Converted fp8 parameter '{name}' to bf16") 82 83 for name, buffer in model.named_buffers(): 84 if buffer.dtype in fp8_dtypes: 85 buffer.data = buffer.data.float() 86 converted_count += 1 87 if verbose: 88 nfo(f"Converted fp8 buffer '{name}' to bf16") 89 90 if converted_count > 0: 91 nfo(f"Converted {converted_count} fp8 tensor(s) to bf16") 92 return True 93 return False 94 95 96def load_state_dict_into_model( 97 model: torch.nn.Module, 98 state_dict: dict[str, torch.Tensor], 99 verbose: bool = True, 100) -> tuple[list[str], list[str]]: 101 """Load a state dict into a model with optional expansion.\n 102 :param model: The model to load weights into 103 :param state_dict: The state dictionary to load 104 :param verbose: Whether to print warnings 105 :returns: Tuple of (missing_keys, unexpected_keys) 106 """ 107 expanded_sd = optionally_expand_state_dict(model, state_dict) 108 missing, unexpected = model.load_state_dict(expanded_sd, strict=False, assign=True) 109 if verbose: 110 print_load_warning(missing, unexpected) 111 return missing, unexpected 112 113 114def load_lora_weights( 115 model: FluxLoraWrapper, 116 lora_repo_id: str, 117 lora_filename: str, 118 device: str | torch.device = gfx_device, 119 verbose: bool = True, 120) -> None: 121 """Load LoRA weights into a FluxLoraWrapper model.\n 122 :param model: The FluxLoraWrapper model 123 :param lora_repo_id: Repository ID for the LoRA checkpoint 124 :param lora_filename: Filename of the LoRA checkpoint 125 :param device: Device to load weights on 126 :param verbose: Whether to print warnings 127 """ 128 nfo("Loading LoRA") 129 lora_path = str(retrieve_model(lora_repo_id, lora_filename)) 130 lora_sd = load_sft(lora_path, device=str(device)) 131 # loading the lora params + overwriting scale values in the norms 132 load_state_dict_into_model(model, lora_sd, verbose=verbose) 133 134 135def load_flow_model( 136 model_spec: ModelSpec, 137 device: torch.device = gfx_device, 138 verbose: bool = True, 139 lora_repo_id: str | None = None, 140 lora_filename: str | None = None, 141) -> Flux: 142 """Load a flow model (DiT model).\n 143 :param mir_id: Model ID (e.g., "model.dit.flux1-dev") 144 :param device: Device to load the model on 145 :param verbose: Whether to print loading warnings 146 :param lora_repo_id: Optional LoRA repository ID (if not in config) 147 :param lora_filename: Optional LoRA filename (if not in config) 148 :param compatibility_key: Optional compatibility key (e.g., "fp8-sai") to override repo_id and file_name 149 :returns: Loaded Flux model""" 150 151 with torch.device("meta"): 152 if model_spec.params is Flux2Params: 153 model = Flux2(model_spec.params).to(torch.bfloat16) 154 elif lora_repo_id and lora_filename: 155 model = FluxLoraWrapper(params=model_spec.params).to(torch.bfloat16) 156 elif model_spec.params is XFluxParams: 157 model = XFlux(model_spec.params).to(torch.bfloat16) 158 else: 159 model = Flux(model_spec.params).to(torch.bfloat16) # type: ignore 160 161 ckpt_path = str(retrieve_model(model_spec.repo_id, model_spec.file_name)) 162 nfo(f": {os.path.basename(ckpt_path)}") 163 sd = load_sft(ckpt_path, device=device.type) 164 if model_spec.params is Flux2Params: 165 model.load_state_dict(sd, strict=False, assign=True) 166 else: 167 load_state_dict_into_model(model, sd, verbose=verbose) # type: ignore 168 if device.type == "mps": 169 convert_fp8_to_bf16(model, verbose=verbose) # type: ignore 170 if lora_repo_id and lora_filename: 171 if not isinstance(model, FluxLoraWrapper): # type: ignore 172 raise ValueError("LoRA weights can only be loaded into FluxLoraWrapper models") 173 load_lora_weights(model, lora_repo_id, lora_filename, device, verbose) 174 return model # type: ignore 175 176 177def load_ae( 178 model_spec: ModelSpec, 179 device: torch.device = gfx_device, 180) -> AutoEncoder1 | AutoEncoder2 | AutoencoderTiny: 181 """Load the autoencoder model.\n 182 :param mir_id: Model ID (e.g., "model.vae.flux1-dev" or "model.taesd.flux1-dev") 183 :param device: Device to load the model on 184 :returns: Loaded AutoEncoder instance 185 """ 186 ckpt_path = str(retrieve_model(model_spec.repo_id, model_spec.file_name)) 187 188 with torch.device("meta"): 189 if isinstance(model_spec.params, AutoEncoderParams): 190 ae = AutoEncoder1(model_spec.params) 191 elif isinstance(model_spec.params, AutoEncoder2Params): 192 ae = AutoEncoder2(model_spec.params) 193 elif model_spec.params is AutoencoderTiny: 194 raise NotImplementedError("AutoencoderTiny loading not yet implemented. Use model.vae.flux1-dev instead.") 195 else: 196 raise ValueError(f"Config {model_spec.repo_id} is not an autoencoder (expected AutoEncoder1Params or AutoEncoder2Params, got {type(model_spec.params).__name__})") 197 198 nfo(f": {os.path.basename(ckpt_path)}") 199 sd = load_sft(ckpt_path, device=device.type) 200 if isinstance(ae, AutoEncoder2): 201 ae.load_state_dict(sd, strict=True, assign=True) 202 else: 203 load_state_dict_into_model(ae, sd, verbose=True) 204 return ae # type: ignore # to device flux2 205 206 207def load_mmada_model( 208 model_spec: ModelSpec, 209 device: torch.device = gfx_device, 210) -> MMaDAModelLM: 211 """Load a MMaDA model\n 212 :param model_spec: ModelSpec object containing model details 213 :param target_device: Device to load the model on 214 :param compatibility_key: Optional compatibility key (e.g., "mixcot") to override repo_id and file_name 215 :param force_reload: If True, bypass cache and reload the model 216 :returns: Loaded MMaDA model 217 :raises: TypeError if model_spec.params is not a MMaDAParams 218 """ 219 precision = gfx_dtype 220 if isinstance(model_spec.params, MMaDAParams): 221 model_spec.params.llm_model_path = model_spec.repo_id 222 model = MMaDAModelLM.from_pretrained(model_spec.repo_id, dtype=precision) # type: ignore 223 224 model = model.to(device).eval() 225 226 return model 227 raise TypeError(f"MMaDA params not found for: {model_spec.repo_id} with params type {type(model_spec.params).__name__}") 228 229 230def load_t5(device: str | torch.device = gfx_device, max_length: int = 512) -> HFEmbedder: 231 # max length 64, 128, 256 and 512 should work (if your sequence is short enough) 232 return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, dtype=torch.bfloat16).to(device) 233 234 235def load_clip(device: str | torch.device = gfx_device) -> HFEmbedder: 236 return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, dtype=torch.bfloat16).to(device) 237 238 239def load_mistral_small_embedder(device: str | torch.device = gfx_device) -> Mistral3SmallEmbedder: 240 return Mistral3SmallEmbedder().to(device)
def
print_load_warning(missing: list[str], unexpected: list[str]) -> None:
35def print_load_warning(missing: list[str], unexpected: list[str]) -> None: 36 if len(missing) > 0 and len(unexpected) > 0: 37 nfo(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 38 nfo("\n" + "-" * 79 + "\n") 39 nfo(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) 40 elif len(missing) > 0: 41 nfo(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) 42 elif len(unexpected) > 0: 43 nfo(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
def
retrieve_model(repo_id: str, file_name: str) -> pathlib._local.Path:
46def retrieve_model(repo_id: str, file_name: str) -> Path: 47 """Get the local path for a checkpoint file, downloading if necessary. 48 :param repo_id: Repository ID for the checkpoint 49 :param file_name: Name of the checkpoint file 50 :returns: Path to the checkpoint file 51 """ 52 53 model_dir = snapshot_download(repo_id=repo_id, allow_patterns=[file_name]) 54 return Path(model_dir) / file_name
Get the local path for a checkpoint file, downloading if necessary.
Parameters
- repo_id: Repository ID for the checkpoint
- file_name: Name of the checkpoint file :returns: Path to the checkpoint file
def
convert_fp8_to_bf16(model: torch.nn.modules.module.Module, verbose: bool = True) -> bool:
57def convert_fp8_to_bf16(model: torch.nn.Module, verbose: bool = True) -> bool: 58 """Detect and convert fp8 tensors in model parameters and buffers to bf16. 59 :param model: The model to convert fp8 tensors in 60 :param verbose: Whether to print conversion messages 61 """ 62 # Define fp8 dtypes to detect 63 fp8_dtypes = [ 64 getattr(torch, "float8_e4m3fn", None), 65 getattr(torch, "float8_e5m2", None), 66 getattr(torch, "float8_e4m3fnuz", None), 67 getattr(torch, "float8_e5m2fnuz", None), 68 ] 69 70 fp8_dtypes = [dtype for dtype in fp8_dtypes if dtype is not None] 71 if not fp8_dtypes: 72 return False 73 74 converted_count = 0 75 76 for name, param in model.named_parameters(): 77 if param.dtype in fp8_dtypes: 78 with torch.no_grad(): 79 param.data = param.data.float() 80 converted_count += 1 81 if verbose: 82 nfo(f"Converted fp8 parameter '{name}' to bf16") 83 84 for name, buffer in model.named_buffers(): 85 if buffer.dtype in fp8_dtypes: 86 buffer.data = buffer.data.float() 87 converted_count += 1 88 if verbose: 89 nfo(f"Converted fp8 buffer '{name}' to bf16") 90 91 if converted_count > 0: 92 nfo(f"Converted {converted_count} fp8 tensor(s) to bf16") 93 return True 94 return False
Detect and convert fp8 tensors in model parameters and buffers to bf16.
Parameters
- model: The model to convert fp8 tensors in
- verbose: Whether to print conversion messages
def
load_state_dict_into_model( model: torch.nn.modules.module.Module, state_dict: dict[str, torch.Tensor], verbose: bool = True) -> tuple[list[str], list[str]]:
97def load_state_dict_into_model( 98 model: torch.nn.Module, 99 state_dict: dict[str, torch.Tensor], 100 verbose: bool = True, 101) -> tuple[list[str], list[str]]: 102 """Load a state dict into a model with optional expansion.\n 103 :param model: The model to load weights into 104 :param state_dict: The state dictionary to load 105 :param verbose: Whether to print warnings 106 :returns: Tuple of (missing_keys, unexpected_keys) 107 """ 108 expanded_sd = optionally_expand_state_dict(model, state_dict) 109 missing, unexpected = model.load_state_dict(expanded_sd, strict=False, assign=True) 110 if verbose: 111 print_load_warning(missing, unexpected) 112 return missing, unexpected
Load a state dict into a model with optional expansion.
Parameters
- model: The model to load weights into
- state_dict: The state dictionary to load
- verbose: Whether to print warnings :returns: Tuple of (missing_keys, unexpected_keys)
def
load_lora_weights( model: divisor.flux1.model.FluxLoraWrapper, lora_repo_id: str, lora_filename: str, device: str | torch.device = device(type='mps'), verbose: bool = True) -> None:
115def load_lora_weights( 116 model: FluxLoraWrapper, 117 lora_repo_id: str, 118 lora_filename: str, 119 device: str | torch.device = gfx_device, 120 verbose: bool = True, 121) -> None: 122 """Load LoRA weights into a FluxLoraWrapper model.\n 123 :param model: The FluxLoraWrapper model 124 :param lora_repo_id: Repository ID for the LoRA checkpoint 125 :param lora_filename: Filename of the LoRA checkpoint 126 :param device: Device to load weights on 127 :param verbose: Whether to print warnings 128 """ 129 nfo("Loading LoRA") 130 lora_path = str(retrieve_model(lora_repo_id, lora_filename)) 131 lora_sd = load_sft(lora_path, device=str(device)) 132 # loading the lora params + overwriting scale values in the norms 133 load_state_dict_into_model(model, lora_sd, verbose=verbose)
Load LoRA weights into a FluxLoraWrapper model.
Parameters
- model: The FluxLoraWrapper model
- lora_repo_id: Repository ID for the LoRA checkpoint
- lora_filename: Filename of the LoRA checkpoint
- device: Device to load weights on
- verbose: Whether to print warnings
def
load_flow_model( model_spec: divisor.spec.ModelSpec, device: torch.device = device(type='mps'), verbose: bool = True, lora_repo_id: str | None = None, lora_filename: str | None = None) -> divisor.flux1.model.Flux:
136def load_flow_model( 137 model_spec: ModelSpec, 138 device: torch.device = gfx_device, 139 verbose: bool = True, 140 lora_repo_id: str | None = None, 141 lora_filename: str | None = None, 142) -> Flux: 143 """Load a flow model (DiT model).\n 144 :param mir_id: Model ID (e.g., "model.dit.flux1-dev") 145 :param device: Device to load the model on 146 :param verbose: Whether to print loading warnings 147 :param lora_repo_id: Optional LoRA repository ID (if not in config) 148 :param lora_filename: Optional LoRA filename (if not in config) 149 :param compatibility_key: Optional compatibility key (e.g., "fp8-sai") to override repo_id and file_name 150 :returns: Loaded Flux model""" 151 152 with torch.device("meta"): 153 if model_spec.params is Flux2Params: 154 model = Flux2(model_spec.params).to(torch.bfloat16) 155 elif lora_repo_id and lora_filename: 156 model = FluxLoraWrapper(params=model_spec.params).to(torch.bfloat16) 157 elif model_spec.params is XFluxParams: 158 model = XFlux(model_spec.params).to(torch.bfloat16) 159 else: 160 model = Flux(model_spec.params).to(torch.bfloat16) # type: ignore 161 162 ckpt_path = str(retrieve_model(model_spec.repo_id, model_spec.file_name)) 163 nfo(f": {os.path.basename(ckpt_path)}") 164 sd = load_sft(ckpt_path, device=device.type) 165 if model_spec.params is Flux2Params: 166 model.load_state_dict(sd, strict=False, assign=True) 167 else: 168 load_state_dict_into_model(model, sd, verbose=verbose) # type: ignore 169 if device.type == "mps": 170 convert_fp8_to_bf16(model, verbose=verbose) # type: ignore 171 if lora_repo_id and lora_filename: 172 if not isinstance(model, FluxLoraWrapper): # type: ignore 173 raise ValueError("LoRA weights can only be loaded into FluxLoraWrapper models") 174 load_lora_weights(model, lora_repo_id, lora_filename, device, verbose) 175 return model # type: ignore
Load a flow model (DiT model).
Parameters
- mir_id: Model ID (e.g., "model.dit.flux1-dev")
- device: Device to load the model on
- verbose: Whether to print loading warnings
- lora_repo_id: Optional LoRA repository ID (if not in config)
- lora_filename: Optional LoRA filename (if not in config)
- compatibility_key: Optional compatibility key (e.g., "fp8-sai") to override repo_id and file_name :returns: Loaded Flux model
def
load_ae( model_spec: divisor.spec.ModelSpec, device: torch.device = device(type='mps')) -> divisor.flux1.autoencoder.AutoEncoder | divisor.flux2.autoencoder.AutoEncoder | diffusers.models.autoencoders.autoencoder_tiny.AutoencoderTiny:
178def load_ae( 179 model_spec: ModelSpec, 180 device: torch.device = gfx_device, 181) -> AutoEncoder1 | AutoEncoder2 | AutoencoderTiny: 182 """Load the autoencoder model.\n 183 :param mir_id: Model ID (e.g., "model.vae.flux1-dev" or "model.taesd.flux1-dev") 184 :param device: Device to load the model on 185 :returns: Loaded AutoEncoder instance 186 """ 187 ckpt_path = str(retrieve_model(model_spec.repo_id, model_spec.file_name)) 188 189 with torch.device("meta"): 190 if isinstance(model_spec.params, AutoEncoderParams): 191 ae = AutoEncoder1(model_spec.params) 192 elif isinstance(model_spec.params, AutoEncoder2Params): 193 ae = AutoEncoder2(model_spec.params) 194 elif model_spec.params is AutoencoderTiny: 195 raise NotImplementedError("AutoencoderTiny loading not yet implemented. Use model.vae.flux1-dev instead.") 196 else: 197 raise ValueError(f"Config {model_spec.repo_id} is not an autoencoder (expected AutoEncoder1Params or AutoEncoder2Params, got {type(model_spec.params).__name__})") 198 199 nfo(f": {os.path.basename(ckpt_path)}") 200 sd = load_sft(ckpt_path, device=device.type) 201 if isinstance(ae, AutoEncoder2): 202 ae.load_state_dict(sd, strict=True, assign=True) 203 else: 204 load_state_dict_into_model(ae, sd, verbose=True) 205 return ae # type: ignore # to device flux2
Load the autoencoder model.
Parameters
- mir_id: Model ID (e.g., "model.vae.flux1-dev" or "model.taesd.flux1-dev")
- device: Device to load the model on :returns: Loaded AutoEncoder instance
def
load_mmada_model( model_spec: divisor.spec.ModelSpec, device: torch.device = device(type='mps')) -> divisor.mmada.modeling_mmada.MMadaModelLM:
208def load_mmada_model( 209 model_spec: ModelSpec, 210 device: torch.device = gfx_device, 211) -> MMaDAModelLM: 212 """Load a MMaDA model\n 213 :param model_spec: ModelSpec object containing model details 214 :param target_device: Device to load the model on 215 :param compatibility_key: Optional compatibility key (e.g., "mixcot") to override repo_id and file_name 216 :param force_reload: If True, bypass cache and reload the model 217 :returns: Loaded MMaDA model 218 :raises: TypeError if model_spec.params is not a MMaDAParams 219 """ 220 precision = gfx_dtype 221 if isinstance(model_spec.params, MMaDAParams): 222 model_spec.params.llm_model_path = model_spec.repo_id 223 model = MMaDAModelLM.from_pretrained(model_spec.repo_id, dtype=precision) # type: ignore 224 225 model = model.to(device).eval() 226 227 return model 228 raise TypeError(f"MMaDA params not found for: {model_spec.repo_id} with params type {type(model_spec.params).__name__}")
Load a MMaDA model
Parameters
- model_spec: ModelSpec object containing model details
- target_device: Device to load the model on
- compatibility_key: Optional compatibility key (e.g., "mixcot") to override repo_id and file_name
- force_reload: If True, bypass cache and reload the model :returns: Loaded MMaDA model
Raises
- TypeError if model_spec.params is not a MMaDAParams
def
load_t5( device: str | torch.device = device(type='mps'), max_length: int = 512) -> divisor.flux1.text_embedder.HFEmbedder:
def
load_clip( device: str | torch.device = device(type='mps')) -> divisor.flux1.text_embedder.HFEmbedder:
def
load_mistral_small_embedder( device: str | torch.device = device(type='mps')) -> divisor.flux2.text_encoder.Mistral3SmallEmbedder: