divisor.registry

 1# SPDX-License-Identifier: MPL-2.0 AND LicenseRef-Commons-Clause-License-Condition-1.0
 2# <!-- // /*  d a r k s h a p e s */ -->
 3
 4from typing import Any
 5
 6from nnll.init_gpu import Gfx
 7
 8
 9def _init_gfx() -> Gfx:
10    """Initialise (and cache) a single :class:`~nnll.init_gpu.Gfx` instance."""
11    return Gfx(full_precision=False)
12
13
14gfx: Gfx = _init_gfx()
15gfx_device = gfx.device
16gfx_dtype = gfx.dtype
17gfx_sync = gfx.sync
18empty_cache = gfx.empty_cache
19
20
21def get_lora_rank(checkpoint):
22    for k in checkpoint.keys():
23        if k.endswith(".down.weight"):
24            return checkpoint[k].shape[0]
25
26
27def populate_model_choices(configs: dict[str, Any]) -> list[str]:
28    """Generate model choices from all entries in configs\n
29    :returns: List of model choices"""
30    model_choices = []
31    for model_id, model_config in configs.items():
32        model_choices.append(model_id)
33        model_choices.extend([f"{model_id}:{n}" for n in model_config.keys() if n != "*"])
34    return model_choices
35
36
37def build_available_models(configs: dict[str, Any]) -> dict[str, str]:
38    """Build model arguments from configs.\n
39    :param configs: Configuration mapping containing model specs
40    :returns: List of model arguments
41    """
42
43    model_choices = populate_model_choices(configs)
44    model_args: dict = {}
45    filters = ["fp8-", ".vae."]
46    for model in model_choices:
47        if not any(filter in model for filter in filters):
48            if ":" in model:
49                key = model.split(":")[-1]
50            else:
51                key = model.split(".")[-1]
52            if key not in model_args:
53                model_args[key] = model
54
55    return model_args
56
57
58if __name__ == "__main__":
59    import typing as _t
60
61    def _debug_dump() -> dict[str, _t.Any]:
62        """
63        Return a dictionary with the most important registry values.
64        Useful when debugging import‑order issues.
65        """
66        return {
67            "gfx": gfx,
68            "device": gfx_device,
69            "device_repr": repr(gfx_device),
70            "device.type": gfx_device.type,
71        }
72
73    print(_debug_dump())
gfx: nnll.init_gpu.Gfx = <nnll.init_gpu.Gfx object>
gfx_device = device(type='mps')
gfx_dtype = torch.bfloat16
gfx_sync = None
empty_cache = None
def get_lora_rank(checkpoint):
22def get_lora_rank(checkpoint):
23    for k in checkpoint.keys():
24        if k.endswith(".down.weight"):
25            return checkpoint[k].shape[0]
def populate_model_choices(configs: dict[str, typing.Any]) -> list[str]:
28def populate_model_choices(configs: dict[str, Any]) -> list[str]:
29    """Generate model choices from all entries in configs\n
30    :returns: List of model choices"""
31    model_choices = []
32    for model_id, model_config in configs.items():
33        model_choices.append(model_id)
34        model_choices.extend([f"{model_id}:{n}" for n in model_config.keys() if n != "*"])
35    return model_choices

Generate model choices from all entries in configs

:returns: List of model choices

def build_available_models(configs: dict[str, typing.Any]) -> dict[str, str]:
38def build_available_models(configs: dict[str, Any]) -> dict[str, str]:
39    """Build model arguments from configs.\n
40    :param configs: Configuration mapping containing model specs
41    :returns: List of model arguments
42    """
43
44    model_choices = populate_model_choices(configs)
45    model_args: dict = {}
46    filters = ["fp8-", ".vae."]
47    for model in model_choices:
48        if not any(filter in model for filter in filters):
49            if ":" in model:
50                key = model.split(":")[-1]
51            else:
52                key = model.split(".")[-1]
53            if key not in model_args:
54                model_args[key] = model
55
56    return model_args

Build model arguments from configs.

Parameters
  • configs: Configuration mapping containing model specs :returns: List of model arguments