divisor.registry
1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0 2# <!-- // /* d a r k s h a p e s */ --> 3 4from typing import Any 5 6from nnll.init_gpu import Gfx 7 8 9def _init_gfx() -> Gfx: 10 """Initialise (and cache) a single :class:`~nnll.init_gpu.Gfx` instance.""" 11 return Gfx(full_precision=False) 12 13 14gfx: Gfx = _init_gfx() 15gfx_device = gfx.device 16gfx_dtype = gfx.dtype 17gfx_sync = gfx.sync 18empty_cache = gfx.empty_cache 19 20 21def get_lora_rank(checkpoint): 22 for k in checkpoint.keys(): 23 if k.endswith(".down.weight"): 24 return checkpoint[k].shape[0] 25 26 27def populate_model_choices(configs: dict[str, Any]) -> list[str]: 28 """Generate model choices from all entries in configs\n 29 :returns: List of model choices""" 30 model_choices = [] 31 for model_id, model_config in configs.items(): 32 model_choices.append(model_id) 33 model_choices.extend([f"{model_id}:{n}" for n in model_config.keys() if n != "*"]) 34 return model_choices 35 36 37def build_available_models(configs: dict[str, Any]) -> dict[str, str]: 38 """Build model arguments from configs.\n 39 :param configs: Configuration mapping containing model specs 40 :returns: List of model arguments 41 """ 42 43 model_choices = populate_model_choices(configs) 44 model_args: dict = {} 45 filters = ["fp8-", ".vae."] 46 for model in model_choices: 47 if not any(filter in model for filter in filters): 48 if ":" in model: 49 key = model.split(":")[-1] 50 else: 51 key = model.split(".")[-1] 52 if key not in model_args: 53 model_args[key] = model 54 55 return model_args 56 57 58if __name__ == "__main__": 59 import typing as _t 60 61 def _debug_dump() -> dict[str, _t.Any]: 62 """ 63 Return a dictionary with the most important registry values. 64 Useful when debugging import‑order issues. 65 """ 66 return { 67 "gfx": gfx, 68 "device": gfx_device, 69 "device_repr": repr(gfx_device), 70 "device.type": gfx_device.type, 71 } 72 73 print(_debug_dump())
gfx: nnll.init_gpu.Gfx =
<nnll.init_gpu.Gfx object>
gfx_device =
device(type='mps')
gfx_dtype =
torch.bfloat16
gfx_sync =
None
empty_cache =
None
def
get_lora_rank(checkpoint):
def
populate_model_choices(configs: dict[str, typing.Any]) -> list[str]:
28def populate_model_choices(configs: dict[str, Any]) -> list[str]: 29 """Generate model choices from all entries in configs\n 30 :returns: List of model choices""" 31 model_choices = [] 32 for model_id, model_config in configs.items(): 33 model_choices.append(model_id) 34 model_choices.extend([f"{model_id}:{n}" for n in model_config.keys() if n != "*"]) 35 return model_choices
Generate model choices from all entries in configs
:returns: List of model choices
def
build_available_models(configs: dict[str, typing.Any]) -> dict[str, str]:
38def build_available_models(configs: dict[str, Any]) -> dict[str, str]: 39 """Build model arguments from configs.\n 40 :param configs: Configuration mapping containing model specs 41 :returns: List of model arguments 42 """ 43 44 model_choices = populate_model_choices(configs) 45 model_args: dict = {} 46 filters = ["fp8-", ".vae."] 47 for model in model_choices: 48 if not any(filter in model for filter in filters): 49 if ":" in model: 50 key = model.split(":")[-1] 51 else: 52 key = model.split(".")[-1] 53 if key not in model_args: 54 model_args[key] = model 55 56 return model_args
Build model arguments from configs.
Parameters
- configs: Configuration mapping containing model specs :returns: List of model arguments