divisor.flux1.lora
1# SPDX-License-Identifier:Apache-2.0 2# original BFL Flux code from https://github.com/black-forest-labs/flux 3 4import torch 5from torch import nn 6 7 8def replace_linear_with_lora( 9 module: nn.Module, 10 max_rank: int, 11 scale: float = 1.0, 12) -> None: 13 for name, child in module.named_children(): 14 if isinstance(child, nn.Linear): 15 new_lora = LinearLora( 16 in_features=child.in_features, 17 out_features=child.out_features, 18 bias=child.bias, 19 rank=max_rank, 20 scale=scale, 21 dtype=child.weight.dtype, 22 device=child.weight.device, 23 ) 24 25 new_lora.weight = child.weight 26 new_lora.bias = child.bias if child.bias is not None else None 27 28 setattr(module, name, new_lora) 29 else: 30 replace_linear_with_lora( 31 module=child, 32 max_rank=max_rank, 33 scale=scale, 34 ) 35 36 37class LinearLora(nn.Linear): 38 def __init__( 39 self, 40 in_features: int, 41 out_features: int, 42 bias: bool, 43 rank: int, 44 dtype: torch.dtype, 45 device: torch.device, 46 lora_bias: bool = True, 47 scale: float = 1.0, 48 *args, 49 **kwargs, 50 ) -> None: 51 super().__init__( 52 in_features=in_features, 53 out_features=out_features, 54 bias=bias is not None, 55 device=device, 56 dtype=dtype, 57 *args, 58 **kwargs, 59 ) 60 61 assert isinstance(scale, float), "scale must be a float" 62 63 self.scale = scale 64 self.rank = rank 65 self.lora_bias = lora_bias 66 self.dtype = dtype 67 self.device = device 68 69 if rank > (new_rank := min(self.out_features, self.in_features)): 70 self.rank = new_rank 71 72 self.lora_A = nn.Linear( 73 in_features=in_features, 74 out_features=self.rank, 75 bias=False, 76 dtype=dtype, 77 device=device, 78 ) 79 self.lora_B = nn.Linear( 80 in_features=self.rank, 81 out_features=out_features, 82 bias=self.lora_bias, 83 dtype=dtype, 84 device=device, 85 ) 86 87 def set_scale(self, scale: float) -> None: 88 assert isinstance(scale, float), "scalar value must be a float" 89 self.scale = scale 90 91 def forward(self, input: torch.Tensor) -> torch.Tensor: 92 base_out = super().forward(input) 93 94 _lora_out_B = self.lora_B(self.lora_A(input)) 95 lora_update = _lora_out_B * self.scale 96 97 return base_out + lora_update
9def replace_linear_with_lora( 10 module: nn.Module, 11 max_rank: int, 12 scale: float = 1.0, 13) -> None: 14 for name, child in module.named_children(): 15 if isinstance(child, nn.Linear): 16 new_lora = LinearLora( 17 in_features=child.in_features, 18 out_features=child.out_features, 19 bias=child.bias, 20 rank=max_rank, 21 scale=scale, 22 dtype=child.weight.dtype, 23 device=child.weight.device, 24 ) 25 26 new_lora.weight = child.weight 27 new_lora.bias = child.bias if child.bias is not None else None 28 29 setattr(module, name, new_lora) 30 else: 31 replace_linear_with_lora( 32 module=child, 33 max_rank=max_rank, 34 scale=scale, 35 )
38class LinearLora(nn.Linear): 39 def __init__( 40 self, 41 in_features: int, 42 out_features: int, 43 bias: bool, 44 rank: int, 45 dtype: torch.dtype, 46 device: torch.device, 47 lora_bias: bool = True, 48 scale: float = 1.0, 49 *args, 50 **kwargs, 51 ) -> None: 52 super().__init__( 53 in_features=in_features, 54 out_features=out_features, 55 bias=bias is not None, 56 device=device, 57 dtype=dtype, 58 *args, 59 **kwargs, 60 ) 61 62 assert isinstance(scale, float), "scale must be a float" 63 64 self.scale = scale 65 self.rank = rank 66 self.lora_bias = lora_bias 67 self.dtype = dtype 68 self.device = device 69 70 if rank > (new_rank := min(self.out_features, self.in_features)): 71 self.rank = new_rank 72 73 self.lora_A = nn.Linear( 74 in_features=in_features, 75 out_features=self.rank, 76 bias=False, 77 dtype=dtype, 78 device=device, 79 ) 80 self.lora_B = nn.Linear( 81 in_features=self.rank, 82 out_features=out_features, 83 bias=self.lora_bias, 84 dtype=dtype, 85 device=device, 86 ) 87 88 def set_scale(self, scale: float) -> None: 89 assert isinstance(scale, float), "scalar value must be a float" 90 self.scale = scale 91 92 def forward(self, input: torch.Tensor) -> torch.Tensor: 93 base_out = super().forward(input) 94 95 _lora_out_B = self.lora_B(self.lora_A(input)) 96 lora_update = _lora_out_B * self.scale 97 98 return base_out + lora_update
Applies an affine linear transformation to the incoming data: \( y = xA^T + b \).
This module supports :ref:TensorFloat32<tf32_on_ampere>.
On certain ROCm devices, when using float16 inputs this module will use :ref:different precision<fp16_on_mi200> for backward.
Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to False, the layer will not learn an additive bias.
Default: True
Shape: - Input: \( (, H_\text{in}) \) where \( * \) means any number of dimensions including none and \( H_\text{in} = \text{in_features} \). - Output: \( (, H_\text{out}) \) where all but the last dimension are the same shape as the input and \( H_\text{out} = \text{out_features} \).
Attributes:
weight: the learnable weights of the module of shape
\( (\text{out_features}, \text{in_features}) \). The values are
initialized from \( \mathcal{U}(-\sqrt{k}, \sqrt{k}) \), where
\( k = \frac{1}{\text{in_features}} \)
bias: the learnable bias of the module of shape \( (\text{out_features}) \).
If bias is True, the values are initialized from
\( \mathcal{U}(-\sqrt{k}, \sqrt{k}) \) where
\( k = \frac{1}{\text{in_features}} \)
Examples::
>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
39 def __init__( 40 self, 41 in_features: int, 42 out_features: int, 43 bias: bool, 44 rank: int, 45 dtype: torch.dtype, 46 device: torch.device, 47 lora_bias: bool = True, 48 scale: float = 1.0, 49 *args, 50 **kwargs, 51 ) -> None: 52 super().__init__( 53 in_features=in_features, 54 out_features=out_features, 55 bias=bias is not None, 56 device=device, 57 dtype=dtype, 58 *args, 59 **kwargs, 60 ) 61 62 assert isinstance(scale, float), "scale must be a float" 63 64 self.scale = scale 65 self.rank = rank 66 self.lora_bias = lora_bias 67 self.dtype = dtype 68 self.device = device 69 70 if rank > (new_rank := min(self.out_features, self.in_features)): 71 self.rank = new_rank 72 73 self.lora_A = nn.Linear( 74 in_features=in_features, 75 out_features=self.rank, 76 bias=False, 77 dtype=dtype, 78 device=device, 79 ) 80 self.lora_B = nn.Linear( 81 in_features=self.rank, 82 out_features=out_features, 83 bias=self.lora_bias, 84 dtype=dtype, 85 device=device, 86 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
92 def forward(self, input: torch.Tensor) -> torch.Tensor: 93 base_out = super().forward(input) 94 95 _lora_out_B = self.lora_B(self.lora_A(input)) 96 lora_update = _lora_out_B * self.scale 97 98 return base_out + lora_update
Runs the forward pass.