divisor.flux1.lora

 1# SPDX-License-Identifier:Apache-2.0
 2# original BFL Flux code from https://github.com/black-forest-labs/flux
 3
 4import torch
 5from torch import nn
 6
 7
 8def replace_linear_with_lora(
 9    module: nn.Module,
10    max_rank: int,
11    scale: float = 1.0,
12) -> None:
13    for name, child in module.named_children():
14        if isinstance(child, nn.Linear):
15            new_lora = LinearLora(
16                in_features=child.in_features,
17                out_features=child.out_features,
18                bias=child.bias,
19                rank=max_rank,
20                scale=scale,
21                dtype=child.weight.dtype,
22                device=child.weight.device,
23            )
24
25            new_lora.weight = child.weight
26            new_lora.bias = child.bias if child.bias is not None else None
27
28            setattr(module, name, new_lora)
29        else:
30            replace_linear_with_lora(
31                module=child,
32                max_rank=max_rank,
33                scale=scale,
34            )
35
36
37class LinearLora(nn.Linear):
38    def __init__(
39        self,
40        in_features: int,
41        out_features: int,
42        bias: bool,
43        rank: int,
44        dtype: torch.dtype,
45        device: torch.device,
46        lora_bias: bool = True,
47        scale: float = 1.0,
48        *args,
49        **kwargs,
50    ) -> None:
51        super().__init__(
52            in_features=in_features,
53            out_features=out_features,
54            bias=bias is not None,
55            device=device,
56            dtype=dtype,
57            *args,
58            **kwargs,
59        )
60
61        assert isinstance(scale, float), "scale must be a float"
62
63        self.scale = scale
64        self.rank = rank
65        self.lora_bias = lora_bias
66        self.dtype = dtype
67        self.device = device
68
69        if rank > (new_rank := min(self.out_features, self.in_features)):
70            self.rank = new_rank
71
72        self.lora_A = nn.Linear(
73            in_features=in_features,
74            out_features=self.rank,
75            bias=False,
76            dtype=dtype,
77            device=device,
78        )
79        self.lora_B = nn.Linear(
80            in_features=self.rank,
81            out_features=out_features,
82            bias=self.lora_bias,
83            dtype=dtype,
84            device=device,
85        )
86
87    def set_scale(self, scale: float) -> None:
88        assert isinstance(scale, float), "scalar value must be a float"
89        self.scale = scale
90
91    def forward(self, input: torch.Tensor) -> torch.Tensor:
92        base_out = super().forward(input)
93
94        _lora_out_B = self.lora_B(self.lora_A(input))
95        lora_update = _lora_out_B * self.scale
96
97        return base_out + lora_update
def replace_linear_with_lora( module: torch.nn.modules.module.Module, max_rank: int, scale: float = 1.0) -> None:
 9def replace_linear_with_lora(
10    module: nn.Module,
11    max_rank: int,
12    scale: float = 1.0,
13) -> None:
14    for name, child in module.named_children():
15        if isinstance(child, nn.Linear):
16            new_lora = LinearLora(
17                in_features=child.in_features,
18                out_features=child.out_features,
19                bias=child.bias,
20                rank=max_rank,
21                scale=scale,
22                dtype=child.weight.dtype,
23                device=child.weight.device,
24            )
25
26            new_lora.weight = child.weight
27            new_lora.bias = child.bias if child.bias is not None else None
28
29            setattr(module, name, new_lora)
30        else:
31            replace_linear_with_lora(
32                module=child,
33                max_rank=max_rank,
34                scale=scale,
35            )
class LinearLora(torch.nn.modules.linear.Linear):
38class LinearLora(nn.Linear):
39    def __init__(
40        self,
41        in_features: int,
42        out_features: int,
43        bias: bool,
44        rank: int,
45        dtype: torch.dtype,
46        device: torch.device,
47        lora_bias: bool = True,
48        scale: float = 1.0,
49        *args,
50        **kwargs,
51    ) -> None:
52        super().__init__(
53            in_features=in_features,
54            out_features=out_features,
55            bias=bias is not None,
56            device=device,
57            dtype=dtype,
58            *args,
59            **kwargs,
60        )
61
62        assert isinstance(scale, float), "scale must be a float"
63
64        self.scale = scale
65        self.rank = rank
66        self.lora_bias = lora_bias
67        self.dtype = dtype
68        self.device = device
69
70        if rank > (new_rank := min(self.out_features, self.in_features)):
71            self.rank = new_rank
72
73        self.lora_A = nn.Linear(
74            in_features=in_features,
75            out_features=self.rank,
76            bias=False,
77            dtype=dtype,
78            device=device,
79        )
80        self.lora_B = nn.Linear(
81            in_features=self.rank,
82            out_features=out_features,
83            bias=self.lora_bias,
84            dtype=dtype,
85            device=device,
86        )
87
88    def set_scale(self, scale: float) -> None:
89        assert isinstance(scale, float), "scalar value must be a float"
90        self.scale = scale
91
92    def forward(self, input: torch.Tensor) -> torch.Tensor:
93        base_out = super().forward(input)
94
95        _lora_out_B = self.lora_B(self.lora_A(input))
96        lora_update = _lora_out_B * self.scale
97
98        return base_out + lora_update

Applies an affine linear transformation to the incoming data: \( y = xA^T + b \).

This module supports :ref:TensorFloat32<tf32_on_ampere>.

On certain ROCm devices, when using float16 inputs this module will use :ref:different precision<fp16_on_mi200> for backward.

Args: in_features: size of each input sample out_features: size of each output sample bias: If set to False, the layer will not learn an additive bias. Default: True

Shape: - Input: \( (, H_\text{in}) \) where \( * \) means any number of dimensions including none and \( H_\text{in} = \text{in_features} \). - Output: \( (, H_\text{out}) \) where all but the last dimension are the same shape as the input and \( H_\text{out} = \text{out_features} \).

Attributes: weight: the learnable weights of the module of shape \( (\text{out_features}, \text{in_features}) \). The values are initialized from \( \mathcal{U}(-\sqrt{k}, \sqrt{k}) \), where \( k = \frac{1}{\text{in_features}} \) bias: the learnable bias of the module of shape \( (\text{out_features}) \). If bias is True, the values are initialized from \( \mathcal{U}(-\sqrt{k}, \sqrt{k}) \) where \( k = \frac{1}{\text{in_features}} \)

Examples::

>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
LinearLora( in_features: int, out_features: int, bias: bool, rank: int, dtype: torch.dtype, device: torch.device, lora_bias: bool = True, scale: float = 1.0, *args, **kwargs)
39    def __init__(
40        self,
41        in_features: int,
42        out_features: int,
43        bias: bool,
44        rank: int,
45        dtype: torch.dtype,
46        device: torch.device,
47        lora_bias: bool = True,
48        scale: float = 1.0,
49        *args,
50        **kwargs,
51    ) -> None:
52        super().__init__(
53            in_features=in_features,
54            out_features=out_features,
55            bias=bias is not None,
56            device=device,
57            dtype=dtype,
58            *args,
59            **kwargs,
60        )
61
62        assert isinstance(scale, float), "scale must be a float"
63
64        self.scale = scale
65        self.rank = rank
66        self.lora_bias = lora_bias
67        self.dtype = dtype
68        self.device = device
69
70        if rank > (new_rank := min(self.out_features, self.in_features)):
71            self.rank = new_rank
72
73        self.lora_A = nn.Linear(
74            in_features=in_features,
75            out_features=self.rank,
76            bias=False,
77            dtype=dtype,
78            device=device,
79        )
80        self.lora_B = nn.Linear(
81            in_features=self.rank,
82            out_features=out_features,
83            bias=self.lora_bias,
84            dtype=dtype,
85            device=device,
86        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

scale
rank
lora_bias
dtype
device
lora_A
lora_B
def set_scale(self, scale: float) -> None:
88    def set_scale(self, scale: float) -> None:
89        assert isinstance(scale, float), "scalar value must be a float"
90        self.scale = scale
def forward(self, input: torch.Tensor) -> torch.Tensor:
92    def forward(self, input: torch.Tensor) -> torch.Tensor:
93        base_out = super().forward(input)
94
95        _lora_out_B = self.lora_B(self.lora_A(input))
96        lora_update = _lora_out_B * self.scale
97
98        return base_out + lora_update

Runs the forward pass.