divisor.xflux1.model
1# SPDX-License-Identifier:Apache-2.0 2# original XFlux code from https://github.com/TencentARC/FluxKits 3# type: ignore 4 5from dataclasses import dataclass 6 7import torch 8from torch import Tensor, nn 9 10from divisor.flux1.layers import EmbedND, LastLayer, MLPEmbedder, timestep_embedding 11from divisor.xflux1.layers import DoubleStreamBlock, SingleStreamBlock 12 13 14@dataclass 15class XFluxParams: 16 in_channels: int 17 vec_in_dim: int 18 context_in_dim: int 19 hidden_size: int 20 mlp_ratio: float 21 num_heads: int 22 depth: int 23 depth_single_blocks: int 24 axes_dim: list[int] 25 theta: int 26 qkv_bias: bool 27 guidance_embed: bool 28 29 30class XFlux(nn.Module): 31 """ 32 Transformer model for flow matching on sequences. 33 """ 34 35 _supports_gradient_checkpointing = True 36 37 def __init__(self, params: XFluxParams): 38 super().__init__() 39 40 self.params = params 41 self.in_channels = params.in_channels 42 self.out_channels = self.in_channels 43 if params.hidden_size % params.num_heads != 0: 44 raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") 45 pe_dim = params.hidden_size // params.num_heads 46 if sum(params.axes_dim) != pe_dim: 47 raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 48 self.hidden_size = params.hidden_size 49 self.num_heads = params.num_heads 50 self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 51 52 self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 53 self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 54 self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 55 self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() 56 self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 57 58 self.double_blocks = nn.ModuleList( 59 [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias) for i in range(1, params.depth + 1)] 60 ) 61 62 self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for i in range(1, params.depth_single_blocks + 1)]) 63 64 self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 65 self.gradient_checkpointing = True 66 67 def _set_gradient_checkpointing(self, module, value=False): 68 if hasattr(module, "gradient_checkpointing"): 69 module.gradient_checkpointing = value 70 71 @property 72 def attn_processors(self): 73 # set recursively 74 processors = {} 75 76 def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): 77 if hasattr(module, "set_processor"): 78 processors[f"{name}.processor"] = module.processor 79 80 for sub_name, child in module.named_children(): 81 fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 82 83 return processors 84 85 for name, module in self.named_children(): 86 fn_recursive_add_processors(name, module, processors) 87 88 return processors 89 90 def set_attn_processor(self, processor): 91 r""" 92 Sets the attention processor to use to compute attention. 93 94 Parameters: 95 processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 96 The instantiated processor class or a dictionary of processor classes that will be set as the processor 97 for **all** `Attention` layers. 98 99 If `processor` is a dict, the key needs to define the path to the corresponding cross attention 100 processor. This is strongly recommended when setting trainable attention processors. 101 102 """ 103 count = len(self.attn_processors.keys()) 104 105 if isinstance(processor, dict) and len(processor) != count: 106 raise ValueError( 107 f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 108 f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 109 ) 110 111 def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 112 if hasattr(module, "set_processor"): 113 if not isinstance(processor, dict): 114 module.set_processor(processor) 115 else: 116 module.set_processor(processor.pop(f"{name}.processor")) 117 118 for sub_name, child in module.named_children(): 119 fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 120 121 for name, module in self.named_children(): 122 fn_recursive_attn_processor(name, module, processor) 123 124 def forward( 125 self, 126 img: Tensor, 127 img_ids: Tensor, 128 txt: Tensor, 129 txt_ids: Tensor, 130 timesteps: Tensor, 131 y: Tensor, 132 block_controlnet_hidden_states=None, 133 guidance: Tensor = None, 134 image_proj: Tensor = None, 135 ip_scale: Tensor = 1.0, 136 return_intermediate: bool = False, 137 ): 138 if return_intermediate: 139 intermediate_double = [] 140 intermediate_single = [] 141 142 # running on sequences img 143 img = self.img_in(img) 144 vec = self.time_in(timestep_embedding(timesteps, 256)) 145 if self.params.guidance_embed: 146 if guidance is None: 147 raise ValueError("Didn't get guidance strength for guidance distilled model.") 148 vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 149 vec = vec + self.vector_in(y) 150 txt = self.txt_in(txt) 151 152 ids = torch.cat((txt_ids, img_ids), dim=1) 153 pe = self.pe_embedder(ids) 154 for index_block, block in enumerate(self.double_blocks): 155 img, txt = block(img=img, txt=txt, vec=vec, pe=pe, image_proj=image_proj, ip_scale=ip_scale) 156 157 if return_intermediate: 158 intermediate_double.append([img, txt]) 159 160 if block_controlnet_hidden_states is not None: 161 img = img + block_controlnet_hidden_states[index_block % 2] 162 163 img = torch.cat((txt, img), dim=1) 164 txt_dim = txt.shape[1] 165 for index_block, block in enumerate(self.single_blocks): 166 img = block(img, vec=vec, pe=pe) 167 168 # if return_intermediate: 169 img_ = img[:, txt.shape[1] :, ...] 170 txt_ = img[:, : txt.shape[1], ...] 171 172 if return_intermediate: 173 intermediate_single.append([img_, txt_]) 174 175 img = torch.cat([txt_, img_], dim=1) 176 177 img = img[:, txt.shape[1] :, ...] 178 img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 179 if return_intermediate: 180 return img, intermediate_double, intermediate_single 181 else: 182 return img
@dataclass
class
XFluxParams:
15@dataclass 16class XFluxParams: 17 in_channels: int 18 vec_in_dim: int 19 context_in_dim: int 20 hidden_size: int 21 mlp_ratio: float 22 num_heads: int 23 depth: int 24 depth_single_blocks: int 25 axes_dim: list[int] 26 theta: int 27 qkv_bias: bool 28 guidance_embed: bool
class
XFlux(torch.nn.modules.module.Module):
31class XFlux(nn.Module): 32 """ 33 Transformer model for flow matching on sequences. 34 """ 35 36 _supports_gradient_checkpointing = True 37 38 def __init__(self, params: XFluxParams): 39 super().__init__() 40 41 self.params = params 42 self.in_channels = params.in_channels 43 self.out_channels = self.in_channels 44 if params.hidden_size % params.num_heads != 0: 45 raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") 46 pe_dim = params.hidden_size // params.num_heads 47 if sum(params.axes_dim) != pe_dim: 48 raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 49 self.hidden_size = params.hidden_size 50 self.num_heads = params.num_heads 51 self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 52 53 self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 54 self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 55 self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 56 self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() 57 self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 58 59 self.double_blocks = nn.ModuleList( 60 [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias) for i in range(1, params.depth + 1)] 61 ) 62 63 self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for i in range(1, params.depth_single_blocks + 1)]) 64 65 self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 66 self.gradient_checkpointing = True 67 68 def _set_gradient_checkpointing(self, module, value=False): 69 if hasattr(module, "gradient_checkpointing"): 70 module.gradient_checkpointing = value 71 72 @property 73 def attn_processors(self): 74 # set recursively 75 processors = {} 76 77 def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): 78 if hasattr(module, "set_processor"): 79 processors[f"{name}.processor"] = module.processor 80 81 for sub_name, child in module.named_children(): 82 fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 83 84 return processors 85 86 for name, module in self.named_children(): 87 fn_recursive_add_processors(name, module, processors) 88 89 return processors 90 91 def set_attn_processor(self, processor): 92 r""" 93 Sets the attention processor to use to compute attention. 94 95 Parameters: 96 processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 97 The instantiated processor class or a dictionary of processor classes that will be set as the processor 98 for **all** `Attention` layers. 99 100 If `processor` is a dict, the key needs to define the path to the corresponding cross attention 101 processor. This is strongly recommended when setting trainable attention processors. 102 103 """ 104 count = len(self.attn_processors.keys()) 105 106 if isinstance(processor, dict) and len(processor) != count: 107 raise ValueError( 108 f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 109 f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 110 ) 111 112 def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 113 if hasattr(module, "set_processor"): 114 if not isinstance(processor, dict): 115 module.set_processor(processor) 116 else: 117 module.set_processor(processor.pop(f"{name}.processor")) 118 119 for sub_name, child in module.named_children(): 120 fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 121 122 for name, module in self.named_children(): 123 fn_recursive_attn_processor(name, module, processor) 124 125 def forward( 126 self, 127 img: Tensor, 128 img_ids: Tensor, 129 txt: Tensor, 130 txt_ids: Tensor, 131 timesteps: Tensor, 132 y: Tensor, 133 block_controlnet_hidden_states=None, 134 guidance: Tensor = None, 135 image_proj: Tensor = None, 136 ip_scale: Tensor = 1.0, 137 return_intermediate: bool = False, 138 ): 139 if return_intermediate: 140 intermediate_double = [] 141 intermediate_single = [] 142 143 # running on sequences img 144 img = self.img_in(img) 145 vec = self.time_in(timestep_embedding(timesteps, 256)) 146 if self.params.guidance_embed: 147 if guidance is None: 148 raise ValueError("Didn't get guidance strength for guidance distilled model.") 149 vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 150 vec = vec + self.vector_in(y) 151 txt = self.txt_in(txt) 152 153 ids = torch.cat((txt_ids, img_ids), dim=1) 154 pe = self.pe_embedder(ids) 155 for index_block, block in enumerate(self.double_blocks): 156 img, txt = block(img=img, txt=txt, vec=vec, pe=pe, image_proj=image_proj, ip_scale=ip_scale) 157 158 if return_intermediate: 159 intermediate_double.append([img, txt]) 160 161 if block_controlnet_hidden_states is not None: 162 img = img + block_controlnet_hidden_states[index_block % 2] 163 164 img = torch.cat((txt, img), dim=1) 165 txt_dim = txt.shape[1] 166 for index_block, block in enumerate(self.single_blocks): 167 img = block(img, vec=vec, pe=pe) 168 169 # if return_intermediate: 170 img_ = img[:, txt.shape[1] :, ...] 171 txt_ = img[:, : txt.shape[1], ...] 172 173 if return_intermediate: 174 intermediate_single.append([img_, txt_]) 175 176 img = torch.cat([txt_, img_], dim=1) 177 178 img = img[:, txt.shape[1] :, ...] 179 img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 180 if return_intermediate: 181 return img, intermediate_double, intermediate_single 182 else: 183 return img
Transformer model for flow matching on sequences.
XFlux(params: XFluxParams)
38 def __init__(self, params: XFluxParams): 39 super().__init__() 40 41 self.params = params 42 self.in_channels = params.in_channels 43 self.out_channels = self.in_channels 44 if params.hidden_size % params.num_heads != 0: 45 raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}") 46 pe_dim = params.hidden_size // params.num_heads 47 if sum(params.axes_dim) != pe_dim: 48 raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") 49 self.hidden_size = params.hidden_size 50 self.num_heads = params.num_heads 51 self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) 52 53 self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) 54 self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) 55 self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) 56 self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() 57 self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) 58 59 self.double_blocks = nn.ModuleList( 60 [DoubleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias) for i in range(1, params.depth + 1)] 61 ) 62 63 self.single_blocks = nn.ModuleList([SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for i in range(1, params.depth_single_blocks + 1)]) 64 65 self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) 66 self.gradient_checkpointing = True
Initialize internal Module state, shared by both nn.Module and ScriptModule.
attn_processors
72 @property 73 def attn_processors(self): 74 # set recursively 75 processors = {} 76 77 def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors): 78 if hasattr(module, "set_processor"): 79 processors[f"{name}.processor"] = module.processor 80 81 for sub_name, child in module.named_children(): 82 fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 83 84 return processors 85 86 for name, module in self.named_children(): 87 fn_recursive_add_processors(name, module, processors) 88 89 return processors
def
set_attn_processor(self, processor):
91 def set_attn_processor(self, processor): 92 r""" 93 Sets the attention processor to use to compute attention. 94 95 Parameters: 96 processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 97 The instantiated processor class or a dictionary of processor classes that will be set as the processor 98 for **all** `Attention` layers. 99 100 If `processor` is a dict, the key needs to define the path to the corresponding cross attention 101 processor. This is strongly recommended when setting trainable attention processors. 102 103 """ 104 count = len(self.attn_processors.keys()) 105 106 if isinstance(processor, dict) and len(processor) != count: 107 raise ValueError( 108 f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 109 f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 110 ) 111 112 def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 113 if hasattr(module, "set_processor"): 114 if not isinstance(processor, dict): 115 module.set_processor(processor) 116 else: 117 module.set_processor(processor.pop(f"{name}.processor")) 118 119 for sub_name, child in module.named_children(): 120 fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 121 122 for name, module in self.named_children(): 123 fn_recursive_attn_processor(name, module, processor)
Sets the attention processor to use to compute attention.
Parameters:
processor (dict of AttentionProcessor or only AttentionProcessor):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for all Attention layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
def
forward( self, img: torch.Tensor, img_ids: torch.Tensor, txt: torch.Tensor, txt_ids: torch.Tensor, timesteps: torch.Tensor, y: torch.Tensor, block_controlnet_hidden_states=None, guidance: torch.Tensor = None, image_proj: torch.Tensor = None, ip_scale: torch.Tensor = 1.0, return_intermediate: bool = False):
125 def forward( 126 self, 127 img: Tensor, 128 img_ids: Tensor, 129 txt: Tensor, 130 txt_ids: Tensor, 131 timesteps: Tensor, 132 y: Tensor, 133 block_controlnet_hidden_states=None, 134 guidance: Tensor = None, 135 image_proj: Tensor = None, 136 ip_scale: Tensor = 1.0, 137 return_intermediate: bool = False, 138 ): 139 if return_intermediate: 140 intermediate_double = [] 141 intermediate_single = [] 142 143 # running on sequences img 144 img = self.img_in(img) 145 vec = self.time_in(timestep_embedding(timesteps, 256)) 146 if self.params.guidance_embed: 147 if guidance is None: 148 raise ValueError("Didn't get guidance strength for guidance distilled model.") 149 vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) 150 vec = vec + self.vector_in(y) 151 txt = self.txt_in(txt) 152 153 ids = torch.cat((txt_ids, img_ids), dim=1) 154 pe = self.pe_embedder(ids) 155 for index_block, block in enumerate(self.double_blocks): 156 img, txt = block(img=img, txt=txt, vec=vec, pe=pe, image_proj=image_proj, ip_scale=ip_scale) 157 158 if return_intermediate: 159 intermediate_double.append([img, txt]) 160 161 if block_controlnet_hidden_states is not None: 162 img = img + block_controlnet_hidden_states[index_block % 2] 163 164 img = torch.cat((txt, img), dim=1) 165 txt_dim = txt.shape[1] 166 for index_block, block in enumerate(self.single_blocks): 167 img = block(img, vec=vec, pe=pe) 168 169 # if return_intermediate: 170 img_ = img[:, txt.shape[1] :, ...] 171 txt_ = img[:, : txt.shape[1], ...] 172 173 if return_intermediate: 174 intermediate_single.append([img_, txt_]) 175 176 img = torch.cat([txt_, img_], dim=1) 177 178 img = img[:, txt.shape[1] :, ...] 179 img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) 180 if return_intermediate: 181 return img, intermediate_double, intermediate_single 182 else: 183 return img
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.