divisor.flux1.model

  1# SPDX-License-Identifier:Apache-2.0
  2# original BFL Flux code from https://github.com/black-forest-labs/flux
  3
  4from dataclasses import dataclass
  5
  6import torch
  7from torch import Tensor, nn
  8
  9from divisor.flux1.layers import (
 10    DoubleStreamBlock,
 11    EmbedND,
 12    LastLayer,
 13    MLPEmbedder,
 14    SingleStreamBlock,
 15    timestep_embedding,
 16)
 17from divisor.flux1.lora import LinearLora, replace_linear_with_lora
 18from divisor.layer_dropout import process_blocks_with_dropout
 19
 20
 21@dataclass
 22class FluxParams:
 23    in_channels: int
 24    vec_in_dim: int
 25    context_in_dim: int
 26    hidden_size: int
 27    mlp_ratio: float
 28    num_heads: int
 29    depth: int
 30    depth_single_blocks: int
 31    axes_dim: list[int]
 32    theta: int
 33    qkv_bias: bool
 34    guidance_embed: bool
 35
 36
 37class Flux(nn.Module):
 38    """Transformer model for flow matching on sequences."""
 39
 40    def __init__(self, params: FluxParams):
 41        super().__init__()
 42
 43        self.params = params
 44        self.in_channels = params.in_channels
 45        self.out_channels = params.in_channels
 46        if params.hidden_size % params.num_heads != 0:
 47            raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
 48        pe_dim = params.hidden_size // params.num_heads
 49        if sum(params.axes_dim) != pe_dim:
 50            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
 51        self.hidden_size = params.hidden_size
 52        self.num_heads = params.num_heads
 53        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
 54        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
 55        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
 56        self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
 57        self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
 58        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
 59
 60        self.double_blocks = nn.ModuleList(
 61            [
 62                DoubleStreamBlock(
 63                    self.hidden_size,
 64                    self.num_heads,
 65                    mlp_ratio=params.mlp_ratio,
 66                    qkv_bias=params.qkv_bias,
 67                )
 68                for _ in range(params.depth)
 69            ]
 70        )
 71
 72        self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(params.depth_single_blocks)])
 73
 74        self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
 75
 76    def forward(
 77        self,
 78        img: Tensor,
 79        img_ids: Tensor,
 80        txt: Tensor,
 81        txt_ids: Tensor,
 82        timesteps: Tensor,
 83        y: Tensor,
 84        guidance: Tensor | None = None,
 85        layer_dropouts: list[int] | None = None,
 86    ) -> Tensor:
 87        if img.ndim != 3 or txt.ndim != 3:
 88            raise ValueError("Input img and txt tensors must have 3 dimensions.")
 89
 90        # running on sequences img
 91        img = self.img_in(img)
 92        vec = self.time_in(timestep_embedding(timesteps, 256))
 93        if self.params.guidance_embed:
 94            if guidance is None:
 95                raise ValueError("Didn't get guidance strength for guidance distilled model.")
 96            vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
 97        vec = vec + self.vector_in(y)
 98        txt = self.txt_in(txt)
 99
100        ids = torch.cat((txt_ids, img_ids), dim=1)
101        pe = self.pe_embedder(ids)
102
103        img, txt = process_blocks_with_dropout(self.double_blocks, layer_dropouts, 0, "double", lambda block, state: block(img=state[0], txt=state[1], vec=vec, pe=pe), (img, txt))
104        img = torch.cat((txt, img), 1)
105
106        img = process_blocks_with_dropout(self.single_blocks, layer_dropouts, len(self.double_blocks), "single", lambda block, state: block(state, vec=vec, pe=pe), img)
107        img = img[:, txt.shape[1] :, ...]
108
109        img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * self.out_channels)
110        return img
111
112
113class FluxLoraWrapper(Flux):
114    def __init__(
115        self,
116        lora_rank: int = 128,
117        lora_scale: float = 1.0,
118        *args,
119        **kwargs,
120    ) -> None:
121        super().__init__(*args, **kwargs)
122
123        self.lora_rank = lora_rank
124
125        replace_linear_with_lora(
126            self,
127            max_rank=lora_rank,
128            scale=lora_scale,
129        )
130
131    def set_lora_scale(self, scale: float) -> None:
132        for module in self.modules():
133            if isinstance(module, LinearLora):
134                module.set_scale(scale=scale)
@dataclass
class FluxParams:
22@dataclass
23class FluxParams:
24    in_channels: int
25    vec_in_dim: int
26    context_in_dim: int
27    hidden_size: int
28    mlp_ratio: float
29    num_heads: int
30    depth: int
31    depth_single_blocks: int
32    axes_dim: list[int]
33    theta: int
34    qkv_bias: bool
35    guidance_embed: bool
FluxParams( in_channels: int, vec_in_dim: int, context_in_dim: int, hidden_size: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], theta: int, qkv_bias: bool, guidance_embed: bool)
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(torch.nn.modules.module.Module):
 38class Flux(nn.Module):
 39    """Transformer model for flow matching on sequences."""
 40
 41    def __init__(self, params: FluxParams):
 42        super().__init__()
 43
 44        self.params = params
 45        self.in_channels = params.in_channels
 46        self.out_channels = params.in_channels
 47        if params.hidden_size % params.num_heads != 0:
 48            raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
 49        pe_dim = params.hidden_size // params.num_heads
 50        if sum(params.axes_dim) != pe_dim:
 51            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
 52        self.hidden_size = params.hidden_size
 53        self.num_heads = params.num_heads
 54        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
 55        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
 56        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
 57        self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
 58        self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
 59        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
 60
 61        self.double_blocks = nn.ModuleList(
 62            [
 63                DoubleStreamBlock(
 64                    self.hidden_size,
 65                    self.num_heads,
 66                    mlp_ratio=params.mlp_ratio,
 67                    qkv_bias=params.qkv_bias,
 68                )
 69                for _ in range(params.depth)
 70            ]
 71        )
 72
 73        self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(params.depth_single_blocks)])
 74
 75        self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
 76
 77    def forward(
 78        self,
 79        img: Tensor,
 80        img_ids: Tensor,
 81        txt: Tensor,
 82        txt_ids: Tensor,
 83        timesteps: Tensor,
 84        y: Tensor,
 85        guidance: Tensor | None = None,
 86        layer_dropouts: list[int] | None = None,
 87    ) -> Tensor:
 88        if img.ndim != 3 or txt.ndim != 3:
 89            raise ValueError("Input img and txt tensors must have 3 dimensions.")
 90
 91        # running on sequences img
 92        img = self.img_in(img)
 93        vec = self.time_in(timestep_embedding(timesteps, 256))
 94        if self.params.guidance_embed:
 95            if guidance is None:
 96                raise ValueError("Didn't get guidance strength for guidance distilled model.")
 97            vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
 98        vec = vec + self.vector_in(y)
 99        txt = self.txt_in(txt)
100
101        ids = torch.cat((txt_ids, img_ids), dim=1)
102        pe = self.pe_embedder(ids)
103
104        img, txt = process_blocks_with_dropout(self.double_blocks, layer_dropouts, 0, "double", lambda block, state: block(img=state[0], txt=state[1], vec=vec, pe=pe), (img, txt))
105        img = torch.cat((txt, img), 1)
106
107        img = process_blocks_with_dropout(self.single_blocks, layer_dropouts, len(self.double_blocks), "single", lambda block, state: block(state, vec=vec, pe=pe), img)
108        img = img[:, txt.shape[1] :, ...]
109
110        img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * self.out_channels)
111        return img

Transformer model for flow matching on sequences.

Flux(params: FluxParams)
41    def __init__(self, params: FluxParams):
42        super().__init__()
43
44        self.params = params
45        self.in_channels = params.in_channels
46        self.out_channels = params.in_channels
47        if params.hidden_size % params.num_heads != 0:
48            raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
49        pe_dim = params.hidden_size // params.num_heads
50        if sum(params.axes_dim) != pe_dim:
51            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
52        self.hidden_size = params.hidden_size
53        self.num_heads = params.num_heads
54        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
55        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
56        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
57        self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
58        self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
59        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
60
61        self.double_blocks = nn.ModuleList(
62            [
63                DoubleStreamBlock(
64                    self.hidden_size,
65                    self.num_heads,
66                    mlp_ratio=params.mlp_ratio,
67                    qkv_bias=params.qkv_bias,
68                )
69                for _ in range(params.depth)
70            ]
71        )
72
73        self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(params.depth_single_blocks)])
74
75        self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

params
in_channels
out_channels
hidden_size
num_heads
pe_embedder
img_in
time_in
vector_in
guidance_in
txt_in
double_blocks
single_blocks
final_layer
def forward( self, img: torch.Tensor, img_ids: torch.Tensor, txt: torch.Tensor, txt_ids: torch.Tensor, timesteps: torch.Tensor, y: torch.Tensor, guidance: torch.Tensor | None = None, layer_dropouts: list[int] | None = None) -> torch.Tensor:
 77    def forward(
 78        self,
 79        img: Tensor,
 80        img_ids: Tensor,
 81        txt: Tensor,
 82        txt_ids: Tensor,
 83        timesteps: Tensor,
 84        y: Tensor,
 85        guidance: Tensor | None = None,
 86        layer_dropouts: list[int] | None = None,
 87    ) -> Tensor:
 88        if img.ndim != 3 or txt.ndim != 3:
 89            raise ValueError("Input img and txt tensors must have 3 dimensions.")
 90
 91        # running on sequences img
 92        img = self.img_in(img)
 93        vec = self.time_in(timestep_embedding(timesteps, 256))
 94        if self.params.guidance_embed:
 95            if guidance is None:
 96                raise ValueError("Didn't get guidance strength for guidance distilled model.")
 97            vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
 98        vec = vec + self.vector_in(y)
 99        txt = self.txt_in(txt)
100
101        ids = torch.cat((txt_ids, img_ids), dim=1)
102        pe = self.pe_embedder(ids)
103
104        img, txt = process_blocks_with_dropout(self.double_blocks, layer_dropouts, 0, "double", lambda block, state: block(img=state[0], txt=state[1], vec=vec, pe=pe), (img, txt))
105        img = torch.cat((txt, img), 1)
106
107        img = process_blocks_with_dropout(self.single_blocks, layer_dropouts, len(self.double_blocks), "single", lambda block, state: block(state, vec=vec, pe=pe), img)
108        img = img[:, txt.shape[1] :, ...]
109
110        img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * self.out_channels)
111        return img

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class FluxLoraWrapper(Flux):
114class FluxLoraWrapper(Flux):
115    def __init__(
116        self,
117        lora_rank: int = 128,
118        lora_scale: float = 1.0,
119        *args,
120        **kwargs,
121    ) -> None:
122        super().__init__(*args, **kwargs)
123
124        self.lora_rank = lora_rank
125
126        replace_linear_with_lora(
127            self,
128            max_rank=lora_rank,
129            scale=lora_scale,
130        )
131
132    def set_lora_scale(self, scale: float) -> None:
133        for module in self.modules():
134            if isinstance(module, LinearLora):
135                module.set_scale(scale=scale)

Transformer model for flow matching on sequences.

FluxLoraWrapper(lora_rank: int = 128, lora_scale: float = 1.0, *args, **kwargs)
115    def __init__(
116        self,
117        lora_rank: int = 128,
118        lora_scale: float = 1.0,
119        *args,
120        **kwargs,
121    ) -> None:
122        super().__init__(*args, **kwargs)
123
124        self.lora_rank = lora_rank
125
126        replace_linear_with_lora(
127            self,
128            max_rank=lora_rank,
129            scale=lora_scale,
130        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

lora_rank
def set_lora_scale(self, scale: float) -> None:
132    def set_lora_scale(self, scale: float) -> None:
133        for module in self.modules():
134            if isinstance(module, LinearLora):
135                module.set_scale(scale=scale)