divisor.xflux1.model

  1# SPDX-License-Identifier:Apache-2.0
  2# original XFlux code from https://github.com/TencentARC/FluxKits
  3# type: ignore
  4
  5from dataclasses import dataclass
  6
  7import torch
  8from torch import Tensor, nn
  9
 10from divisor.flux1.layers import EmbedND, LastLayer, MLPEmbedder, timestep_embedding
 11from divisor.xflux1.layers import DoubleStreamBlock, SingleStreamBlock
 12
 13
 14@dataclass
 15class XFluxParams:
 16    in_channels: int
 17    vec_in_dim: int
 18    context_in_dim: int
 19    hidden_size: int
 20    mlp_ratio: float
 21    num_heads: int
 22    depth: int
 23    depth_single_blocks: int
 24    axes_dim: list[int]
 25    theta: int
 26    qkv_bias: bool
 27    guidance_embed: bool
 28
 29
 30class XFlux(nn.Module):
 31    """
 32    Transformer model for flow matching on sequences.
 33    """
 34
 35    _supports_gradient_checkpointing = True
 36
 37    def __init__(self, params: XFluxParams):
 38        super().__init__()
 39
 40        self.params = params
 41        self.in_channels = params.in_channels
 42        self.out_channels = self.in_channels
 43        if params.hidden_size % params.num_heads != 0:
 44            raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
 45        pe_dim = params.hidden_size // params.num_heads
 46        if sum(params.axes_dim) != pe_dim:
 47            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
 48        self.hidden_size = params.hidden_size
 49        self.num_heads = params.num_heads
 50        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
 51
 52        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
 53        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
 54        self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
 55        self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
 56        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
 57
 58        self.double_blocks = nn.ModuleList(
 59            [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias) for i in range(1, params.depth + 1)]
 60        )
 61
 62        self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for i in range(1, params.depth_single_blocks + 1)])
 63
 64        self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
 65        self.gradient_checkpointing = True
 66
 67    def _set_gradient_checkpointing(self, module, value=False):
 68        if hasattr(module, "gradient_checkpointing"):
 69            module.gradient_checkpointing = value
 70
 71    @property
 72    def attn_processors(self):
 73        # set recursively
 74        processors = {}
 75
 76        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
 77            if hasattr(module, "set_processor"):
 78                processors[f"{name}.processor"] = module.processor
 79
 80            for sub_name, child in module.named_children():
 81                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
 82
 83            return processors
 84
 85        for name, module in self.named_children():
 86            fn_recursive_add_processors(name, module, processors)
 87
 88        return processors
 89
 90    def set_attn_processor(self, processor):
 91        r"""
 92        Sets the attention processor to use to compute attention.
 93
 94        Parameters:
 95            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
 96                The instantiated processor class or a dictionary of processor classes that will be set as the processor
 97                for **all** `Attention` layers.
 98
 99                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
100                processor. This is strongly recommended when setting trainable attention processors.
101
102        """
103        count = len(self.attn_processors.keys())
104
105        if isinstance(processor, dict) and len(processor) != count:
106            raise ValueError(
107                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
108                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
109            )
110
111        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
112            if hasattr(module, "set_processor"):
113                if not isinstance(processor, dict):
114                    module.set_processor(processor)
115                else:
116                    module.set_processor(processor.pop(f"{name}.processor"))
117
118            for sub_name, child in module.named_children():
119                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
120
121        for name, module in self.named_children():
122            fn_recursive_attn_processor(name, module, processor)
123
124    def forward(
125        self,
126        img: Tensor,
127        img_ids: Tensor,
128        txt: Tensor,
129        txt_ids: Tensor,
130        timesteps: Tensor,
131        y: Tensor,
132        block_controlnet_hidden_states=None,
133        guidance: Tensor = None,
134        image_proj: Tensor = None,
135        ip_scale: Tensor = 1.0,
136        return_intermediate: bool = False,
137    ):
138        if return_intermediate:
139            intermediate_double = []
140            intermediate_single = []
141
142        # running on sequences img
143        img = self.img_in(img)
144        vec = self.time_in(timestep_embedding(timesteps, 256))
145        if self.params.guidance_embed:
146            if guidance is None:
147                raise ValueError("Didn't get guidance strength for guidance distilled model.")
148            vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
149        vec = vec + self.vector_in(y)
150        txt = self.txt_in(txt)
151
152        ids = torch.cat((txt_ids, img_ids), dim=1)
153        pe = self.pe_embedder(ids)
154        for index_block, block in enumerate(self.double_blocks):
155            img, txt = block(img=img, txt=txt, vec=vec, pe=pe, image_proj=image_proj, ip_scale=ip_scale)
156
157            if return_intermediate:
158                intermediate_double.append([img, txt])
159
160            if block_controlnet_hidden_states is not None:
161                img = img + block_controlnet_hidden_states[index_block % 2]
162
163        img = torch.cat((txt, img), dim=1)
164        txt_dim = txt.shape[1]
165        for index_block, block in enumerate(self.single_blocks):
166            img = block(img, vec=vec, pe=pe)
167
168            # if return_intermediate:
169            img_ = img[:, txt.shape[1] :, ...]
170            txt_ = img[:, : txt.shape[1], ...]
171
172            if return_intermediate:
173                intermediate_single.append([img_, txt_])
174
175            img = torch.cat([txt_, img_], dim=1)
176
177        img = img[:, txt.shape[1] :, ...]
178        img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)
179        if return_intermediate:
180            return img, intermediate_double, intermediate_single
181        else:
182            return img
@dataclass
class XFluxParams:
15@dataclass
16class XFluxParams:
17    in_channels: int
18    vec_in_dim: int
19    context_in_dim: int
20    hidden_size: int
21    mlp_ratio: float
22    num_heads: int
23    depth: int
24    depth_single_blocks: int
25    axes_dim: list[int]
26    theta: int
27    qkv_bias: bool
28    guidance_embed: bool
XFluxParams( in_channels: int, vec_in_dim: int, context_in_dim: int, hidden_size: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], theta: int, qkv_bias: bool, guidance_embed: bool)
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class XFlux(torch.nn.modules.module.Module):
 31class XFlux(nn.Module):
 32    """
 33    Transformer model for flow matching on sequences.
 34    """
 35
 36    _supports_gradient_checkpointing = True
 37
 38    def __init__(self, params: XFluxParams):
 39        super().__init__()
 40
 41        self.params = params
 42        self.in_channels = params.in_channels
 43        self.out_channels = self.in_channels
 44        if params.hidden_size % params.num_heads != 0:
 45            raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
 46        pe_dim = params.hidden_size // params.num_heads
 47        if sum(params.axes_dim) != pe_dim:
 48            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
 49        self.hidden_size = params.hidden_size
 50        self.num_heads = params.num_heads
 51        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
 52
 53        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
 54        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
 55        self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
 56        self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
 57        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
 58
 59        self.double_blocks = nn.ModuleList(
 60            [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias) for i in range(1, params.depth + 1)]
 61        )
 62
 63        self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for i in range(1, params.depth_single_blocks + 1)])
 64
 65        self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
 66        self.gradient_checkpointing = True
 67
 68    def _set_gradient_checkpointing(self, module, value=False):
 69        if hasattr(module, "gradient_checkpointing"):
 70            module.gradient_checkpointing = value
 71
 72    @property
 73    def attn_processors(self):
 74        # set recursively
 75        processors = {}
 76
 77        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
 78            if hasattr(module, "set_processor"):
 79                processors[f"{name}.processor"] = module.processor
 80
 81            for sub_name, child in module.named_children():
 82                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
 83
 84            return processors
 85
 86        for name, module in self.named_children():
 87            fn_recursive_add_processors(name, module, processors)
 88
 89        return processors
 90
 91    def set_attn_processor(self, processor):
 92        r"""
 93        Sets the attention processor to use to compute attention.
 94
 95        Parameters:
 96            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
 97                The instantiated processor class or a dictionary of processor classes that will be set as the processor
 98                for **all** `Attention` layers.
 99
100                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
101                processor. This is strongly recommended when setting trainable attention processors.
102
103        """
104        count = len(self.attn_processors.keys())
105
106        if isinstance(processor, dict) and len(processor) != count:
107            raise ValueError(
108                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
109                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
110            )
111
112        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
113            if hasattr(module, "set_processor"):
114                if not isinstance(processor, dict):
115                    module.set_processor(processor)
116                else:
117                    module.set_processor(processor.pop(f"{name}.processor"))
118
119            for sub_name, child in module.named_children():
120                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
121
122        for name, module in self.named_children():
123            fn_recursive_attn_processor(name, module, processor)
124
125    def forward(
126        self,
127        img: Tensor,
128        img_ids: Tensor,
129        txt: Tensor,
130        txt_ids: Tensor,
131        timesteps: Tensor,
132        y: Tensor,
133        block_controlnet_hidden_states=None,
134        guidance: Tensor = None,
135        image_proj: Tensor = None,
136        ip_scale: Tensor = 1.0,
137        return_intermediate: bool = False,
138    ):
139        if return_intermediate:
140            intermediate_double = []
141            intermediate_single = []
142
143        # running on sequences img
144        img = self.img_in(img)
145        vec = self.time_in(timestep_embedding(timesteps, 256))
146        if self.params.guidance_embed:
147            if guidance is None:
148                raise ValueError("Didn't get guidance strength for guidance distilled model.")
149            vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
150        vec = vec + self.vector_in(y)
151        txt = self.txt_in(txt)
152
153        ids = torch.cat((txt_ids, img_ids), dim=1)
154        pe = self.pe_embedder(ids)
155        for index_block, block in enumerate(self.double_blocks):
156            img, txt = block(img=img, txt=txt, vec=vec, pe=pe, image_proj=image_proj, ip_scale=ip_scale)
157
158            if return_intermediate:
159                intermediate_double.append([img, txt])
160
161            if block_controlnet_hidden_states is not None:
162                img = img + block_controlnet_hidden_states[index_block % 2]
163
164        img = torch.cat((txt, img), dim=1)
165        txt_dim = txt.shape[1]
166        for index_block, block in enumerate(self.single_blocks):
167            img = block(img, vec=vec, pe=pe)
168
169            # if return_intermediate:
170            img_ = img[:, txt.shape[1] :, ...]
171            txt_ = img[:, : txt.shape[1], ...]
172
173            if return_intermediate:
174                intermediate_single.append([img_, txt_])
175
176            img = torch.cat([txt_, img_], dim=1)
177
178        img = img[:, txt.shape[1] :, ...]
179        img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)
180        if return_intermediate:
181            return img, intermediate_double, intermediate_single
182        else:
183            return img

Transformer model for flow matching on sequences.

XFlux(params: XFluxParams)
38    def __init__(self, params: XFluxParams):
39        super().__init__()
40
41        self.params = params
42        self.in_channels = params.in_channels
43        self.out_channels = self.in_channels
44        if params.hidden_size % params.num_heads != 0:
45            raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
46        pe_dim = params.hidden_size // params.num_heads
47        if sum(params.axes_dim) != pe_dim:
48            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
49        self.hidden_size = params.hidden_size
50        self.num_heads = params.num_heads
51        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
52
53        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
54        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
55        self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
56        self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
57        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
58
59        self.double_blocks = nn.ModuleList(
60            [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias) for i in range(1, params.depth + 1)]
61        )
62
63        self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for i in range(1, params.depth_single_blocks + 1)])
64
65        self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
66        self.gradient_checkpointing = True

Initialize internal Module state, shared by both nn.Module and ScriptModule.

params
in_channels
out_channels
hidden_size
num_heads
pe_embedder
img_in
time_in
vector_in
guidance_in
txt_in
double_blocks
single_blocks
final_layer
gradient_checkpointing
attn_processors
72    @property
73    def attn_processors(self):
74        # set recursively
75        processors = {}
76
77        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
78            if hasattr(module, "set_processor"):
79                processors[f"{name}.processor"] = module.processor
80
81            for sub_name, child in module.named_children():
82                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
83
84            return processors
85
86        for name, module in self.named_children():
87            fn_recursive_add_processors(name, module, processors)
88
89        return processors
def set_attn_processor(self, processor):
 91    def set_attn_processor(self, processor):
 92        r"""
 93        Sets the attention processor to use to compute attention.
 94
 95        Parameters:
 96            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
 97                The instantiated processor class or a dictionary of processor classes that will be set as the processor
 98                for **all** `Attention` layers.
 99
100                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
101                processor. This is strongly recommended when setting trainable attention processors.
102
103        """
104        count = len(self.attn_processors.keys())
105
106        if isinstance(processor, dict) and len(processor) != count:
107            raise ValueError(
108                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
109                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
110            )
111
112        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
113            if hasattr(module, "set_processor"):
114                if not isinstance(processor, dict):
115                    module.set_processor(processor)
116                else:
117                    module.set_processor(processor.pop(f"{name}.processor"))
118
119            for sub_name, child in module.named_children():
120                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
121
122        for name, module in self.named_children():
123            fn_recursive_attn_processor(name, module, processor)

Sets the attention processor to use to compute attention.

Parameters: processor (dict of AttentionProcessor or only AttentionProcessor): The instantiated processor class or a dictionary of processor classes that will be set as the processor for all Attention layers.

    If `processor` is a dict, the key needs to define the path to the corresponding cross attention
    processor. This is strongly recommended when setting trainable attention processors.
def forward( self, img: torch.Tensor, img_ids: torch.Tensor, txt: torch.Tensor, txt_ids: torch.Tensor, timesteps: torch.Tensor, y: torch.Tensor, block_controlnet_hidden_states=None, guidance: torch.Tensor = None, image_proj: torch.Tensor = None, ip_scale: torch.Tensor = 1.0, return_intermediate: bool = False):
125    def forward(
126        self,
127        img: Tensor,
128        img_ids: Tensor,
129        txt: Tensor,
130        txt_ids: Tensor,
131        timesteps: Tensor,
132        y: Tensor,
133        block_controlnet_hidden_states=None,
134        guidance: Tensor = None,
135        image_proj: Tensor = None,
136        ip_scale: Tensor = 1.0,
137        return_intermediate: bool = False,
138    ):
139        if return_intermediate:
140            intermediate_double = []
141            intermediate_single = []
142
143        # running on sequences img
144        img = self.img_in(img)
145        vec = self.time_in(timestep_embedding(timesteps, 256))
146        if self.params.guidance_embed:
147            if guidance is None:
148                raise ValueError("Didn't get guidance strength for guidance distilled model.")
149            vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
150        vec = vec + self.vector_in(y)
151        txt = self.txt_in(txt)
152
153        ids = torch.cat((txt_ids, img_ids), dim=1)
154        pe = self.pe_embedder(ids)
155        for index_block, block in enumerate(self.double_blocks):
156            img, txt = block(img=img, txt=txt, vec=vec, pe=pe, image_proj=image_proj, ip_scale=ip_scale)
157
158            if return_intermediate:
159                intermediate_double.append([img, txt])
160
161            if block_controlnet_hidden_states is not None:
162                img = img + block_controlnet_hidden_states[index_block % 2]
163
164        img = torch.cat((txt, img), dim=1)
165        txt_dim = txt.shape[1]
166        for index_block, block in enumerate(self.single_blocks):
167            img = block(img, vec=vec, pe=pe)
168
169            # if return_intermediate:
170            img_ = img[:, txt.shape[1] :, ...]
171            txt_ = img[:, : txt.shape[1], ...]
172
173            if return_intermediate:
174                intermediate_single.append([img_, txt_])
175
176            img = torch.cat([txt_, img_], dim=1)
177
178        img = img[:, txt.shape[1] :, ...]
179        img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)
180        if return_intermediate:
181            return img, intermediate_double, intermediate_single
182        else:
183            return img

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.