divisor.flux1.model
1# SPDX-License-Identifier:Apache-2.0 2# original BFL Flux code from https://github.com/black-forest-labs/flux 3 4from dataclasses import dataclass 5 6import torch 7from torch import Tensor, nn 8 9from divisor.flux1.layers import ( 10 DoubleStreamBlock, 11 EmbedND, 12 LastLayer, 13 MLPEmbedder, 14 SingleStreamBlock, 15 timestep_embedding, 16) 17from divisor.flux1.lora import LinearLora, replace_linear_with_lora 18from divisor.layer_dropout import process_blocks_with_dropout 19 20 21@dataclass 22class FluxParams: 23 in_channels: int 24 vec_in_dim: int 25 context_in_dim: int 26 hidden_size: int 27 mlp_ratio: float 28 num_heads: int 29 depth: int 30 depth_single_blocks: int 31 axes_dim: list[int] 32 theta: int 33 qkv_bias: bool 34 guidance_embed: bool 35 36 37class Flux(nn.Module): 38 """Transformer model for flow matching on sequences.""" 39 40 def __init__(self, params: FluxParams): 41 super().__init__() 42 43 self.params = params 44 self.in_channels = params.in_channels 45 self.out_channels = params.in_channels 46 if params.hidden_size % params.num_heads != 0: 47 raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") 48 pe_dim = params.hidden_size // params.num_heads 49 if sum(params.axes_dim) != pe_dim: 50 raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 51 self.hidden_size = params.hidden_size 52 self.num_heads = params.num_heads 53 self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 54 self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 55 self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 56 self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 57 self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() 58 self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 59 60 self.double_blocks = nn.ModuleList( 61 [ 62 DoubleStreamBlock( 63 self.hidden_size, 64 self.num_heads, 65 mlp_ratio=params.mlp_ratio, 66 qkv_bias=params.qkv_bias, 67 ) 68 for _ in range(params.depth) 69 ] 70 ) 71 72 self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(params.depth_single_blocks)]) 73 74 self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 75 76 def forward( 77 self, 78 img: Tensor, 79 img_ids: Tensor, 80 txt: Tensor, 81 txt_ids: Tensor, 82 timesteps: Tensor, 83 y: Tensor, 84 guidance: Tensor | None = None, 85 layer_dropouts: list[int] | None = None, 86 ) -> Tensor: 87 if img.ndim != 3 or txt.ndim != 3: 88 raise ValueError("Input img and txt tensors must have 3 dimensions.") 89 90 # running on sequences img 91 img = self.img_in(img) 92 vec = self.time_in(timestep_embedding(timesteps, 256)) 93 if self.params.guidance_embed: 94 if guidance is None: 95 raise ValueError("Didn't get guidance strength for guidance distilled model.") 96 vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 97 vec = vec + self.vector_in(y) 98 txt = self.txt_in(txt) 99 100 ids = torch.cat((txt_ids, img_ids), dim=1) 101 pe = self.pe_embedder(ids) 102 103 img, txt = process_blocks_with_dropout(self.double_blocks, layer_dropouts, 0, "double", lambda block, state: block(img=state[0], txt=state[1], vec=vec, pe=pe), (img, txt)) 104 img = torch.cat((txt, img), 1) 105 106 img = process_blocks_with_dropout(self.single_blocks, layer_dropouts, len(self.double_blocks), "single", lambda block, state: block(state, vec=vec, pe=pe), img) 107 img = img[:, txt.shape[1] :, ...] 108 109 img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * self.out_channels) 110 return img 111 112 113class FluxLoraWrapper(Flux): 114 def __init__( 115 self, 116 lora_rank: int = 128, 117 lora_scale: float = 1.0, 118 *args, 119 **kwargs, 120 ) -> None: 121 super().__init__(*args, **kwargs) 122 123 self.lora_rank = lora_rank 124 125 replace_linear_with_lora( 126 self, 127 max_rank=lora_rank, 128 scale=lora_scale, 129 ) 130 131 def set_lora_scale(self, scale: float) -> None: 132 for module in self.modules(): 133 if isinstance(module, LinearLora): 134 module.set_scale(scale=scale)
@dataclass
class
FluxParams:
22@dataclass 23class FluxParams: 24 in_channels: int 25 vec_in_dim: int 26 context_in_dim: int 27 hidden_size: int 28 mlp_ratio: float 29 num_heads: int 30 depth: int 31 depth_single_blocks: int 32 axes_dim: list[int] 33 theta: int 34 qkv_bias: bool 35 guidance_embed: bool
class
Flux(torch.nn.modules.module.Module):
38class Flux(nn.Module): 39 """Transformer model for flow matching on sequences.""" 40 41 def __init__(self, params: FluxParams): 42 super().__init__() 43 44 self.params = params 45 self.in_channels = params.in_channels 46 self.out_channels = params.in_channels 47 if params.hidden_size % params.num_heads != 0: 48 raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") 49 pe_dim = params.hidden_size // params.num_heads 50 if sum(params.axes_dim) != pe_dim: 51 raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 52 self.hidden_size = params.hidden_size 53 self.num_heads = params.num_heads 54 self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 55 self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 56 self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 57 self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 58 self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() 59 self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 60 61 self.double_blocks = nn.ModuleList( 62 [ 63 DoubleStreamBlock( 64 self.hidden_size, 65 self.num_heads, 66 mlp_ratio=params.mlp_ratio, 67 qkv_bias=params.qkv_bias, 68 ) 69 for _ in range(params.depth) 70 ] 71 ) 72 73 self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(params.depth_single_blocks)]) 74 75 self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 76 77 def forward( 78 self, 79 img: Tensor, 80 img_ids: Tensor, 81 txt: Tensor, 82 txt_ids: Tensor, 83 timesteps: Tensor, 84 y: Tensor, 85 guidance: Tensor | None = None, 86 layer_dropouts: list[int] | None = None, 87 ) -> Tensor: 88 if img.ndim != 3 or txt.ndim != 3: 89 raise ValueError("Input img and txt tensors must have 3 dimensions.") 90 91 # running on sequences img 92 img = self.img_in(img) 93 vec = self.time_in(timestep_embedding(timesteps, 256)) 94 if self.params.guidance_embed: 95 if guidance is None: 96 raise ValueError("Didn't get guidance strength for guidance distilled model.") 97 vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 98 vec = vec + self.vector_in(y) 99 txt = self.txt_in(txt) 100 101 ids = torch.cat((txt_ids, img_ids), dim=1) 102 pe = self.pe_embedder(ids) 103 104 img, txt = process_blocks_with_dropout(self.double_blocks, layer_dropouts, 0, "double", lambda block, state: block(img=state[0], txt=state[1], vec=vec, pe=pe), (img, txt)) 105 img = torch.cat((txt, img), 1) 106 107 img = process_blocks_with_dropout(self.single_blocks, layer_dropouts, len(self.double_blocks), "single", lambda block, state: block(state, vec=vec, pe=pe), img) 108 img = img[:, txt.shape[1] :, ...] 109 110 img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * self.out_channels) 111 return img
Transformer model for flow matching on sequences.
Flux(params: FluxParams)
41 def __init__(self, params: FluxParams): 42 super().__init__() 43 44 self.params = params 45 self.in_channels = params.in_channels 46 self.out_channels = params.in_channels 47 if params.hidden_size % params.num_heads != 0: 48 raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") 49 pe_dim = params.hidden_size // params.num_heads 50 if sum(params.axes_dim) != pe_dim: 51 raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 52 self.hidden_size = params.hidden_size 53 self.num_heads = params.num_heads 54 self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 55 self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 56 self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 57 self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 58 self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() 59 self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 60 61 self.double_blocks = nn.ModuleList( 62 [ 63 DoubleStreamBlock( 64 self.hidden_size, 65 self.num_heads, 66 mlp_ratio=params.mlp_ratio, 67 qkv_bias=params.qkv_bias, 68 ) 69 for _ in range(params.depth) 70 ] 71 ) 72 73 self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(params.depth_single_blocks)]) 74 75 self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
def
forward( self, img: torch.Tensor, img_ids: torch.Tensor, txt: torch.Tensor, txt_ids: torch.Tensor, timesteps: torch.Tensor, y: torch.Tensor, guidance: torch.Tensor | None = None, layer_dropouts: list[int] | None = None) -> torch.Tensor:
77 def forward( 78 self, 79 img: Tensor, 80 img_ids: Tensor, 81 txt: Tensor, 82 txt_ids: Tensor, 83 timesteps: Tensor, 84 y: Tensor, 85 guidance: Tensor | None = None, 86 layer_dropouts: list[int] | None = None, 87 ) -> Tensor: 88 if img.ndim != 3 or txt.ndim != 3: 89 raise ValueError("Input img and txt tensors must have 3 dimensions.") 90 91 # running on sequences img 92 img = self.img_in(img) 93 vec = self.time_in(timestep_embedding(timesteps, 256)) 94 if self.params.guidance_embed: 95 if guidance is None: 96 raise ValueError("Didn't get guidance strength for guidance distilled model.") 97 vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 98 vec = vec + self.vector_in(y) 99 txt = self.txt_in(txt) 100 101 ids = torch.cat((txt_ids, img_ids), dim=1) 102 pe = self.pe_embedder(ids) 103 104 img, txt = process_blocks_with_dropout(self.double_blocks, layer_dropouts, 0, "double", lambda block, state: block(img=state[0], txt=state[1], vec=vec, pe=pe), (img, txt)) 105 img = torch.cat((txt, img), 1) 106 107 img = process_blocks_with_dropout(self.single_blocks, layer_dropouts, len(self.double_blocks), "single", lambda block, state: block(state, vec=vec, pe=pe), img) 108 img = img[:, txt.shape[1] :, ...] 109 110 img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * self.out_channels) 111 return img
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
114class FluxLoraWrapper(Flux): 115 def __init__( 116 self, 117 lora_rank: int = 128, 118 lora_scale: float = 1.0, 119 *args, 120 **kwargs, 121 ) -> None: 122 super().__init__(*args, **kwargs) 123 124 self.lora_rank = lora_rank 125 126 replace_linear_with_lora( 127 self, 128 max_rank=lora_rank, 129 scale=lora_scale, 130 ) 131 132 def set_lora_scale(self, scale: float) -> None: 133 for module in self.modules(): 134 if isinstance(module, LinearLora): 135 module.set_scale(scale=scale)
Transformer model for flow matching on sequences.
FluxLoraWrapper(lora_rank: int = 128, lora_scale: float = 1.0, *args, **kwargs)
115 def __init__( 116 self, 117 lora_rank: int = 128, 118 lora_scale: float = 1.0, 119 *args, 120 **kwargs, 121 ) -> None: 122 super().__init__(*args, **kwargs) 123 124 self.lora_rank = lora_rank 125 126 replace_linear_with_lora( 127 self, 128 max_rank=lora_rank, 129 scale=lora_scale, 130 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.