divisor.flux1.loading

  1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
  2# <!-- // /*  d a r k s h a p e s */ -->
  3# adapted BFL Flux code from https://github.com/black-forest-labs/flux
  4
  5
  6import os
  7from pathlib import Path
  8
  9import torch
 10from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
 11from huggingface_hub import snapshot_download
 12from nnll.console import nfo
 13from safetensors.torch import load_file as load_sft
 14
 15from divisor.flux1.autoencoder import AutoEncoder as AutoEncoder1
 16from divisor.flux1.autoencoder import AutoEncoderParams
 17from divisor.flux1.model import Flux, FluxLoraWrapper
 18from divisor.flux1.text_embedder import HFEmbedder
 19from divisor.flux2.autoencoder import (
 20    AutoEncoder as AutoEncoder2,
 21)
 22from divisor.flux2.autoencoder import (
 23    AutoEncoderParams as AutoEncoder2Params,
 24)
 25from divisor.flux2.model import Flux2, Flux2Params
 26from divisor.flux2.text_encoder import Mistral3SmallEmbedder
 27from divisor.mmada.modeling_mmada import MMadaConfig as MMaDAParams
 28from divisor.mmada.modeling_mmada import MMadaModelLM as MMaDAModelLM
 29from divisor.registry import gfx_device, gfx_dtype
 30from divisor.spec import ModelSpec, optionally_expand_state_dict
 31from divisor.xflux1.model import XFlux, XFluxParams
 32
 33
 34def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
 35    if len(missing) > 0 and len(unexpected) > 0:
 36        nfo(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
 37        nfo("\n" + "-" * 79 + "\n")
 38        nfo(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
 39    elif len(missing) > 0:
 40        nfo(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
 41    elif len(unexpected) > 0:
 42        nfo(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
 43
 44
 45def retrieve_model(repo_id: str, file_name: str) -> Path:
 46    """Get the local path for a checkpoint file, downloading if necessary.
 47    :param repo_id: Repository ID for the checkpoint
 48    :param file_name: Name of the checkpoint file
 49    :returns: Path to the checkpoint file
 50    """
 51
 52    model_dir = snapshot_download(repo_id=repo_id, allow_patterns=[file_name])
 53    return Path(model_dir) / file_name
 54
 55
 56def convert_fp8_to_bf16(model: torch.nn.Module, verbose: bool = True) -> bool:
 57    """Detect and convert fp8 tensors in model parameters and buffers to bf16.
 58    :param model: The model to convert fp8 tensors in
 59    :param verbose: Whether to print conversion messages
 60    """
 61    # Define fp8 dtypes to detect
 62    fp8_dtypes = [
 63        getattr(torch, "float8_e4m3fn", None),
 64        getattr(torch, "float8_e5m2", None),
 65        getattr(torch, "float8_e4m3fnuz", None),
 66        getattr(torch, "float8_e5m2fnuz", None),
 67    ]
 68
 69    fp8_dtypes = [dtype for dtype in fp8_dtypes if dtype is not None]
 70    if not fp8_dtypes:
 71        return False
 72
 73    converted_count = 0
 74
 75    for name, param in model.named_parameters():
 76        if param.dtype in fp8_dtypes:
 77            with torch.no_grad():
 78                param.data = param.data.float()
 79            converted_count += 1
 80            if verbose:
 81                nfo(f"Converted fp8 parameter '{name}' to bf16")
 82
 83    for name, buffer in model.named_buffers():
 84        if buffer.dtype in fp8_dtypes:
 85            buffer.data = buffer.data.float()
 86            converted_count += 1
 87            if verbose:
 88                nfo(f"Converted fp8 buffer '{name}' to bf16")
 89
 90    if converted_count > 0:
 91        nfo(f"Converted {converted_count} fp8 tensor(s) to bf16")
 92        return True
 93    return False
 94
 95
 96def load_state_dict_into_model(
 97    model: torch.nn.Module,
 98    state_dict: dict[str, torch.Tensor],
 99    verbose: bool = True,
100) -> tuple[list[str], list[str]]:
101    """Load a state dict into a model with optional expansion.\n
102    :param model: The model to load weights into
103    :param state_dict: The state dictionary to load
104    :param verbose: Whether to print warnings
105    :returns: Tuple of (missing_keys, unexpected_keys)
106    """
107    expanded_sd = optionally_expand_state_dict(model, state_dict)
108    missing, unexpected = model.load_state_dict(expanded_sd, strict=False, assign=True)
109    if verbose:
110        print_load_warning(missing, unexpected)
111    return missing, unexpected
112
113
114def load_lora_weights(
115    model: FluxLoraWrapper,
116    lora_repo_id: str,
117    lora_filename: str,
118    device: str | torch.device = gfx_device,
119    verbose: bool = True,
120) -> None:
121    """Load LoRA weights into a FluxLoraWrapper model.\n
122    :param model: The FluxLoraWrapper model
123    :param lora_repo_id: Repository ID for the LoRA checkpoint
124    :param lora_filename: Filename of the LoRA checkpoint
125    :param device: Device to load weights on
126    :param verbose: Whether to print warnings
127    """
128    nfo("Loading LoRA")
129    lora_path = str(retrieve_model(lora_repo_id, lora_filename))
130    lora_sd = load_sft(lora_path, device=str(device))
131    # loading the lora params + overwriting scale values in the norms
132    load_state_dict_into_model(model, lora_sd, verbose=verbose)
133
134
135def load_flow_model(
136    model_spec: ModelSpec,
137    device: torch.device = gfx_device,
138    verbose: bool = True,
139    lora_repo_id: str | None = None,
140    lora_filename: str | None = None,
141) -> Flux:
142    """Load a flow model (DiT model).\n
143    :param mir_id: Model ID (e.g., "model.dit.flux1-dev")
144    :param device: Device to load the model on
145    :param verbose: Whether to print loading warnings
146    :param lora_repo_id: Optional LoRA repository ID (if not in config)
147    :param lora_filename: Optional LoRA filename (if not in config)
148    :param compatibility_key: Optional compatibility key (e.g., "fp8-sai") to override repo_id and file_name
149    :returns: Loaded Flux model"""
150
151    with torch.device("meta"):
152        if model_spec.params is Flux2Params:
153            model = Flux2(model_spec.params).to(torch.bfloat16)
154        elif lora_repo_id and lora_filename:
155            model = FluxLoraWrapper(params=model_spec.params).to(torch.bfloat16)
156        elif model_spec.params is XFluxParams:
157            model = XFlux(model_spec.params).to(torch.bfloat16)
158        else:
159            model = Flux(model_spec.params).to(torch.bfloat16)  # type: ignore
160
161    ckpt_path = str(retrieve_model(model_spec.repo_id, model_spec.file_name))
162    nfo(f": {os.path.basename(ckpt_path)}")
163    sd = load_sft(ckpt_path, device=device.type)
164    if model_spec.params is Flux2Params:
165        model.load_state_dict(sd, strict=False, assign=True)
166    else:
167        load_state_dict_into_model(model, sd, verbose=verbose)  # type: ignore
168    if device.type == "mps":
169        convert_fp8_to_bf16(model, verbose=verbose)  # type: ignore
170    if lora_repo_id and lora_filename:
171        if not isinstance(model, FluxLoraWrapper):  # type: ignore
172            raise ValueError("LoRA weights can only be loaded into FluxLoraWrapper models")
173        load_lora_weights(model, lora_repo_id, lora_filename, device, verbose)
174    return model  # type: ignore
175
176
177def load_ae(
178    model_spec: ModelSpec,
179    device: torch.device = gfx_device,
180) -> AutoEncoder1 | AutoEncoder2 | AutoencoderTiny:
181    """Load the autoencoder model.\n
182    :param mir_id: Model ID (e.g., "model.vae.flux1-dev" or "model.taesd.flux1-dev")
183    :param device: Device to load the model on
184    :returns: Loaded AutoEncoder instance
185    """
186    ckpt_path = str(retrieve_model(model_spec.repo_id, model_spec.file_name))
187
188    with torch.device("meta"):
189        if isinstance(model_spec.params, AutoEncoderParams):
190            ae = AutoEncoder1(model_spec.params)
191        elif isinstance(model_spec.params, AutoEncoder2Params):
192            ae = AutoEncoder2(model_spec.params)
193        elif model_spec.params is AutoencoderTiny:
194            raise NotImplementedError("AutoencoderTiny loading not yet implemented. Use model.vae.flux1-dev instead.")
195        else:
196            raise ValueError(f"Config {model_spec.repo_id} is not an autoencoder (expected AutoEncoder1Params or AutoEncoder2Params, got {type(model_spec.params).__name__})")
197
198    nfo(f": {os.path.basename(ckpt_path)}")
199    sd = load_sft(ckpt_path, device=device.type)
200    if isinstance(ae, AutoEncoder2):
201        ae.load_state_dict(sd, strict=True, assign=True)
202    else:
203        load_state_dict_into_model(ae, sd, verbose=True)
204    return ae  # type: ignore # to device flux2
205
206
207def load_mmada_model(
208    model_spec: ModelSpec,
209    device: torch.device = gfx_device,
210) -> MMaDAModelLM:
211    """Load a MMaDA model\n
212    :param model_spec: ModelSpec object containing model details
213    :param target_device: Device to load the model on
214    :param compatibility_key: Optional compatibility key (e.g., "mixcot") to override repo_id and file_name
215    :param force_reload: If True, bypass cache and reload the model
216    :returns: Loaded MMaDA model
217    :raises: TypeError if model_spec.params is not a MMaDAParams
218    """
219    precision = gfx_dtype
220    if isinstance(model_spec.params, MMaDAParams):
221        model_spec.params.llm_model_path = model_spec.repo_id
222        model = MMaDAModelLM.from_pretrained(model_spec.repo_id, dtype=precision)  # type: ignore
223
224        model = model.to(device).eval()
225
226        return model
227    raise TypeError(f"MMaDA params not found for: {model_spec.repo_id} with params type {type(model_spec.params).__name__}")
228
229
230def load_t5(device: str | torch.device = gfx_device, max_length: int = 512) -> HFEmbedder:
231    # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
232    return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, dtype=torch.bfloat16).to(device)
233
234
235def load_clip(device: str | torch.device = gfx_device) -> HFEmbedder:
236    return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, dtype=torch.bfloat16).to(device)
237
238
239def load_mistral_small_embedder(device: str | torch.device = gfx_device) -> Mistral3SmallEmbedder:
240    return Mistral3SmallEmbedder().to(device)
def retrieve_model(repo_id: str, file_name: str) -> pathlib._local.Path:
46def retrieve_model(repo_id: str, file_name: str) -> Path:
47    """Get the local path for a checkpoint file, downloading if necessary.
48    :param repo_id: Repository ID for the checkpoint
49    :param file_name: Name of the checkpoint file
50    :returns: Path to the checkpoint file
51    """
52
53    model_dir = snapshot_download(repo_id=repo_id, allow_patterns=[file_name])
54    return Path(model_dir) / file_name

Get the local path for a checkpoint file, downloading if necessary.

Parameters
  • repo_id: Repository ID for the checkpoint
  • file_name: Name of the checkpoint file :returns: Path to the checkpoint file
def convert_fp8_to_bf16(model: torch.nn.modules.module.Module, verbose: bool = True) -> bool:
57def convert_fp8_to_bf16(model: torch.nn.Module, verbose: bool = True) -> bool:
58    """Detect and convert fp8 tensors in model parameters and buffers to bf16.
59    :param model: The model to convert fp8 tensors in
60    :param verbose: Whether to print conversion messages
61    """
62    # Define fp8 dtypes to detect
63    fp8_dtypes = [
64        getattr(torch, "float8_e4m3fn", None),
65        getattr(torch, "float8_e5m2", None),
66        getattr(torch, "float8_e4m3fnuz", None),
67        getattr(torch, "float8_e5m2fnuz", None),
68    ]
69
70    fp8_dtypes = [dtype for dtype in fp8_dtypes if dtype is not None]
71    if not fp8_dtypes:
72        return False
73
74    converted_count = 0
75
76    for name, param in model.named_parameters():
77        if param.dtype in fp8_dtypes:
78            with torch.no_grad():
79                param.data = param.data.float()
80            converted_count += 1
81            if verbose:
82                nfo(f"Converted fp8 parameter '{name}' to bf16")
83
84    for name, buffer in model.named_buffers():
85        if buffer.dtype in fp8_dtypes:
86            buffer.data = buffer.data.float()
87            converted_count += 1
88            if verbose:
89                nfo(f"Converted fp8 buffer '{name}' to bf16")
90
91    if converted_count > 0:
92        nfo(f"Converted {converted_count} fp8 tensor(s) to bf16")
93        return True
94    return False

Detect and convert fp8 tensors in model parameters and buffers to bf16.

Parameters
  • model: The model to convert fp8 tensors in
  • verbose: Whether to print conversion messages
def load_state_dict_into_model( model: torch.nn.modules.module.Module, state_dict: dict[str, torch.Tensor], verbose: bool = True) -> tuple[list[str], list[str]]:
 97def load_state_dict_into_model(
 98    model: torch.nn.Module,
 99    state_dict: dict[str, torch.Tensor],
100    verbose: bool = True,
101) -> tuple[list[str], list[str]]:
102    """Load a state dict into a model with optional expansion.\n
103    :param model: The model to load weights into
104    :param state_dict: The state dictionary to load
105    :param verbose: Whether to print warnings
106    :returns: Tuple of (missing_keys, unexpected_keys)
107    """
108    expanded_sd = optionally_expand_state_dict(model, state_dict)
109    missing, unexpected = model.load_state_dict(expanded_sd, strict=False, assign=True)
110    if verbose:
111        print_load_warning(missing, unexpected)
112    return missing, unexpected

Load a state dict into a model with optional expansion.

Parameters
  • model: The model to load weights into
  • state_dict: The state dictionary to load
  • verbose: Whether to print warnings :returns: Tuple of (missing_keys, unexpected_keys)
def load_lora_weights( model: divisor.flux1.model.FluxLoraWrapper, lora_repo_id: str, lora_filename: str, device: str | torch.device = device(type='mps'), verbose: bool = True) -> None:
115def load_lora_weights(
116    model: FluxLoraWrapper,
117    lora_repo_id: str,
118    lora_filename: str,
119    device: str | torch.device = gfx_device,
120    verbose: bool = True,
121) -> None:
122    """Load LoRA weights into a FluxLoraWrapper model.\n
123    :param model: The FluxLoraWrapper model
124    :param lora_repo_id: Repository ID for the LoRA checkpoint
125    :param lora_filename: Filename of the LoRA checkpoint
126    :param device: Device to load weights on
127    :param verbose: Whether to print warnings
128    """
129    nfo("Loading LoRA")
130    lora_path = str(retrieve_model(lora_repo_id, lora_filename))
131    lora_sd = load_sft(lora_path, device=str(device))
132    # loading the lora params + overwriting scale values in the norms
133    load_state_dict_into_model(model, lora_sd, verbose=verbose)

Load LoRA weights into a FluxLoraWrapper model.

Parameters
  • model: The FluxLoraWrapper model
  • lora_repo_id: Repository ID for the LoRA checkpoint
  • lora_filename: Filename of the LoRA checkpoint
  • device: Device to load weights on
  • verbose: Whether to print warnings
def load_flow_model( model_spec: divisor.spec.ModelSpec, device: torch.device = device(type='mps'), verbose: bool = True, lora_repo_id: str | None = None, lora_filename: str | None = None) -> divisor.flux1.model.Flux:
136def load_flow_model(
137    model_spec: ModelSpec,
138    device: torch.device = gfx_device,
139    verbose: bool = True,
140    lora_repo_id: str | None = None,
141    lora_filename: str | None = None,
142) -> Flux:
143    """Load a flow model (DiT model).\n
144    :param mir_id: Model ID (e.g., "model.dit.flux1-dev")
145    :param device: Device to load the model on
146    :param verbose: Whether to print loading warnings
147    :param lora_repo_id: Optional LoRA repository ID (if not in config)
148    :param lora_filename: Optional LoRA filename (if not in config)
149    :param compatibility_key: Optional compatibility key (e.g., "fp8-sai") to override repo_id and file_name
150    :returns: Loaded Flux model"""
151
152    with torch.device("meta"):
153        if model_spec.params is Flux2Params:
154            model = Flux2(model_spec.params).to(torch.bfloat16)
155        elif lora_repo_id and lora_filename:
156            model = FluxLoraWrapper(params=model_spec.params).to(torch.bfloat16)
157        elif model_spec.params is XFluxParams:
158            model = XFlux(model_spec.params).to(torch.bfloat16)
159        else:
160            model = Flux(model_spec.params).to(torch.bfloat16)  # type: ignore
161
162    ckpt_path = str(retrieve_model(model_spec.repo_id, model_spec.file_name))
163    nfo(f": {os.path.basename(ckpt_path)}")
164    sd = load_sft(ckpt_path, device=device.type)
165    if model_spec.params is Flux2Params:
166        model.load_state_dict(sd, strict=False, assign=True)
167    else:
168        load_state_dict_into_model(model, sd, verbose=verbose)  # type: ignore
169    if device.type == "mps":
170        convert_fp8_to_bf16(model, verbose=verbose)  # type: ignore
171    if lora_repo_id and lora_filename:
172        if not isinstance(model, FluxLoraWrapper):  # type: ignore
173            raise ValueError("LoRA weights can only be loaded into FluxLoraWrapper models")
174        load_lora_weights(model, lora_repo_id, lora_filename, device, verbose)
175    return model  # type: ignore

Load a flow model (DiT model).

Parameters
  • mir_id: Model ID (e.g., "model.dit.flux1-dev")
  • device: Device to load the model on
  • verbose: Whether to print loading warnings
  • lora_repo_id: Optional LoRA repository ID (if not in config)
  • lora_filename: Optional LoRA filename (if not in config)
  • compatibility_key: Optional compatibility key (e.g., "fp8-sai") to override repo_id and file_name :returns: Loaded Flux model
def load_ae( model_spec: divisor.spec.ModelSpec, device: torch.device = device(type='mps')) -> divisor.flux1.autoencoder.AutoEncoder | divisor.flux2.autoencoder.AutoEncoder | diffusers.models.autoencoders.autoencoder_tiny.AutoencoderTiny:
178def load_ae(
179    model_spec: ModelSpec,
180    device: torch.device = gfx_device,
181) -> AutoEncoder1 | AutoEncoder2 | AutoencoderTiny:
182    """Load the autoencoder model.\n
183    :param mir_id: Model ID (e.g., "model.vae.flux1-dev" or "model.taesd.flux1-dev")
184    :param device: Device to load the model on
185    :returns: Loaded AutoEncoder instance
186    """
187    ckpt_path = str(retrieve_model(model_spec.repo_id, model_spec.file_name))
188
189    with torch.device("meta"):
190        if isinstance(model_spec.params, AutoEncoderParams):
191            ae = AutoEncoder1(model_spec.params)
192        elif isinstance(model_spec.params, AutoEncoder2Params):
193            ae = AutoEncoder2(model_spec.params)
194        elif model_spec.params is AutoencoderTiny:
195            raise NotImplementedError("AutoencoderTiny loading not yet implemented. Use model.vae.flux1-dev instead.")
196        else:
197            raise ValueError(f"Config {model_spec.repo_id} is not an autoencoder (expected AutoEncoder1Params or AutoEncoder2Params, got {type(model_spec.params).__name__})")
198
199    nfo(f": {os.path.basename(ckpt_path)}")
200    sd = load_sft(ckpt_path, device=device.type)
201    if isinstance(ae, AutoEncoder2):
202        ae.load_state_dict(sd, strict=True, assign=True)
203    else:
204        load_state_dict_into_model(ae, sd, verbose=True)
205    return ae  # type: ignore # to device flux2

Load the autoencoder model.

Parameters
  • mir_id: Model ID (e.g., "model.vae.flux1-dev" or "model.taesd.flux1-dev")
  • device: Device to load the model on :returns: Loaded AutoEncoder instance
def load_mmada_model( model_spec: divisor.spec.ModelSpec, device: torch.device = device(type='mps')) -> divisor.mmada.modeling_mmada.MMadaModelLM:
208def load_mmada_model(
209    model_spec: ModelSpec,
210    device: torch.device = gfx_device,
211) -> MMaDAModelLM:
212    """Load a MMaDA model\n
213    :param model_spec: ModelSpec object containing model details
214    :param target_device: Device to load the model on
215    :param compatibility_key: Optional compatibility key (e.g., "mixcot") to override repo_id and file_name
216    :param force_reload: If True, bypass cache and reload the model
217    :returns: Loaded MMaDA model
218    :raises: TypeError if model_spec.params is not a MMaDAParams
219    """
220    precision = gfx_dtype
221    if isinstance(model_spec.params, MMaDAParams):
222        model_spec.params.llm_model_path = model_spec.repo_id
223        model = MMaDAModelLM.from_pretrained(model_spec.repo_id, dtype=precision)  # type: ignore
224
225        model = model.to(device).eval()
226
227        return model
228    raise TypeError(f"MMaDA params not found for: {model_spec.repo_id} with params type {type(model_spec.params).__name__}")

Load a MMaDA model

Parameters
  • model_spec: ModelSpec object containing model details
  • target_device: Device to load the model on
  • compatibility_key: Optional compatibility key (e.g., "mixcot") to override repo_id and file_name
  • force_reload: If True, bypass cache and reload the model :returns: Loaded MMaDA model
Raises
  • TypeError if model_spec.params is not a MMaDAParams
def load_t5( device: str | torch.device = device(type='mps'), max_length: int = 512) -> divisor.flux1.text_embedder.HFEmbedder:
231def load_t5(device: str | torch.device = gfx_device, max_length: int = 512) -> HFEmbedder:
232    # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
233    return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, dtype=torch.bfloat16).to(device)
def load_clip( device: str | torch.device = device(type='mps')) -> divisor.flux1.text_embedder.HFEmbedder:
236def load_clip(device: str | torch.device = gfx_device) -> HFEmbedder:
237    return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, dtype=torch.bfloat16).to(device)
def load_mistral_small_embedder( device: str | torch.device = device(type='mps')) -> divisor.flux2.text_encoder.Mistral3SmallEmbedder:
240def load_mistral_small_embedder(device: str | torch.device = gfx_device) -> Mistral3SmallEmbedder:
241    return Mistral3SmallEmbedder().to(device)