divisor.flux2.autoencoder
1# SPDX-License-Identifier:Apache-2.0 2# original BFL Flux code from https://github.com/black-forest-labs/flux2 3 4from dataclasses import dataclass, field 5import math 6 7from einops import rearrange 8import torch 9from torch import Tensor, nn 10 11 12@dataclass 13class AutoEncoderParams: 14 resolution: int = 256 15 in_channels: int = 3 16 ch: int = 128 17 out_ch: int = 3 18 ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4]) 19 num_res_blocks: int = 2 20 z_channels: int = 32 21 22 23def swish(x: Tensor) -> Tensor: 24 return x * torch.sigmoid(x) 25 26 27class AttnBlock(nn.Module): 28 def __init__(self, in_channels: int): 29 super().__init__() 30 self.in_channels = in_channels 31 32 self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 33 34 self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 35 self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 36 self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 37 self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 38 39 def attention(self, h_: Tensor) -> Tensor: 40 h_ = self.norm(h_) 41 q = self.q(h_) 42 k = self.k(h_) 43 v = self.v(h_) 44 45 b, c, h, w = q.shape 46 q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 47 k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 48 v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 49 h_ = nn.functional.scaled_dot_product_attention(q, k, v) 50 51 return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 52 53 def forward(self, x: Tensor) -> Tensor: 54 return x + self.proj_out(self.attention(x)) 55 56 57class ResnetBlock(nn.Module): 58 def __init__(self, in_channels: int, out_channels: int): 59 super().__init__() 60 self.in_channels = in_channels 61 out_channels = in_channels if out_channels is None else out_channels 62 self.out_channels = out_channels 63 64 self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 65 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 66 self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 67 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 68 if self.in_channels != self.out_channels: 69 self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 70 71 def forward(self, x): 72 h = x 73 h = self.norm1(h) 74 h = swish(h) 75 h = self.conv1(h) 76 77 h = self.norm2(h) 78 h = swish(h) 79 h = self.conv2(h) 80 81 if self.in_channels != self.out_channels: 82 x = self.nin_shortcut(x) 83 84 return x + h 85 86 87class Downsample(nn.Module): 88 def __init__(self, in_channels: int): 89 super().__init__() 90 # no asymmetric padding in torch conv, must do it ourselves 91 self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) 92 93 def forward(self, x: Tensor): 94 pad = (0, 1, 0, 1) 95 x = nn.functional.pad(x, pad, mode="constant", value=0) 96 x = self.conv(x) 97 return x 98 99 100class Upsample(nn.Module): 101 def __init__(self, in_channels: int): 102 super().__init__() 103 self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 104 105 def forward(self, x: Tensor): 106 x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 107 x = self.conv(x) 108 return x 109 110 111class Encoder(nn.Module): 112 def __init__( 113 self, 114 resolution: int, 115 in_channels: int, 116 ch: int, 117 ch_mult: list[int], 118 num_res_blocks: int, 119 z_channels: int, 120 ): 121 super().__init__() 122 self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1) 123 self.ch = ch 124 self.num_resolutions = len(ch_mult) 125 self.num_res_blocks = num_res_blocks 126 self.resolution = resolution 127 self.in_channels = in_channels 128 # downsampling 129 self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 130 131 curr_res = resolution 132 in_ch_mult = (1,) + tuple(ch_mult) 133 self.in_ch_mult = in_ch_mult 134 self.down = nn.ModuleList() 135 block_in = self.ch 136 for i_level in range(self.num_resolutions): 137 block = nn.ModuleList() 138 attn = nn.ModuleList() 139 block_in = ch * in_ch_mult[i_level] 140 block_out = ch * ch_mult[i_level] 141 for _ in range(self.num_res_blocks): 142 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 143 block_in = block_out 144 down = nn.Module() 145 down.block = block 146 down.attn = attn 147 if i_level != self.num_resolutions - 1: 148 down.downsample = Downsample(block_in) 149 curr_res = curr_res // 2 150 self.down.append(down) 151 152 # middle 153 self.mid = nn.Module() 154 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 155 self.mid.attn_1 = AttnBlock(block_in) 156 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 157 158 # end 159 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 160 self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) 161 162 def forward(self, x: Tensor) -> Tensor: 163 # downsampling 164 hs = [self.conv_in(x)] 165 for i_level in range(self.num_resolutions): 166 for i_block in range(self.num_res_blocks): 167 h = self.down[i_level].block[i_block](hs[-1]) # type: ignore 168 if len(self.down[i_level].attn) > 0: # type: ignore 169 h = self.down[i_level].attn[i_block](h) # type: ignore 170 hs.append(h) 171 if i_level != self.num_resolutions - 1: 172 hs.append(self.down[i_level].downsample(hs[-1])) # type: ignore 173 174 # middle 175 h = hs[-1] 176 h = self.mid.block_1(h) # type: ignore 177 h = self.mid.attn_1(h) # type: ignore 178 h = self.mid.block_2(h) # type: ignore 179 # end 180 h = self.norm_out(h) 181 h = swish(h) 182 h = self.conv_out(h) 183 h = self.quant_conv(h) 184 return h 185 186 187class Decoder(nn.Module): 188 def __init__( 189 self, 190 ch: int, 191 out_ch: int, 192 ch_mult: list[int], 193 num_res_blocks: int, 194 in_channels: int, 195 resolution: int, 196 z_channels: int, 197 ): 198 super().__init__() 199 self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1) 200 self.ch = ch 201 self.num_resolutions = len(ch_mult) 202 self.num_res_blocks = num_res_blocks 203 self.resolution = resolution 204 self.in_channels = in_channels 205 self.ffactor = 2 ** (self.num_resolutions - 1) 206 207 # compute in_ch_mult, block_in and curr_res at lowest res 208 block_in = ch * ch_mult[self.num_resolutions - 1] 209 curr_res = resolution // 2 ** (self.num_resolutions - 1) 210 self.z_shape = (1, z_channels, curr_res, curr_res) 211 212 # z to block_in 213 self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 214 215 # middle 216 self.mid = nn.Module() 217 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 218 self.mid.attn_1 = AttnBlock(block_in) 219 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 220 221 # upsampling 222 self.up = nn.ModuleList() 223 for i_level in reversed(range(self.num_resolutions)): 224 block = nn.ModuleList() 225 attn = nn.ModuleList() 226 block_out = ch * ch_mult[i_level] 227 for _ in range(self.num_res_blocks + 1): 228 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 229 block_in = block_out 230 up = nn.Module() 231 up.block = block 232 up.attn = attn 233 if i_level != 0: 234 up.upsample = Upsample(block_in) 235 curr_res = curr_res * 2 236 self.up.insert(0, up) # prepend to get consistent order 237 238 # end 239 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 240 self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 241 242 def forward(self, z: Tensor) -> Tensor: 243 z = self.post_quant_conv(z) 244 245 # get dtype for proper tracing 246 upscale_dtype = next(self.up.parameters()).dtype 247 248 # z to block_in 249 h = self.conv_in(z) 250 251 # middle 252 h = self.mid.block_1(h) # type: ignore 253 h = self.mid.attn_1(h) # type: ignore 254 h = self.mid.block_2(h) # type: ignore 255 256 # cast to proper dtype 257 h = h.to(upscale_dtype) 258 # upsampling 259 for i_level in reversed(range(self.num_resolutions)): 260 for i_block in range(self.num_res_blocks + 1): 261 h = self.up[i_level].block[i_block](h) # type: ignore 262 if len(self.up[i_level].attn) > 0: # type: ignore 263 h = self.up[i_level].attn[i_block](h) # type: ignore 264 if i_level != 0: 265 h = self.up[i_level].upsample(h) # type: ignore 266 267 # end 268 h = self.norm_out(h) 269 h = swish(h) 270 h = self.conv_out(h) 271 return h 272 273 274class AutoEncoder(nn.Module): 275 def __init__(self, params: AutoEncoderParams): 276 super().__init__() 277 self.params = params 278 self.encoder = Encoder( 279 resolution=params.resolution, 280 in_channels=params.in_channels, 281 ch=params.ch, 282 ch_mult=params.ch_mult, 283 num_res_blocks=params.num_res_blocks, 284 z_channels=params.z_channels, 285 ) 286 self.decoder = Decoder( 287 resolution=params.resolution, 288 in_channels=params.in_channels, 289 ch=params.ch, 290 out_ch=params.out_ch, 291 ch_mult=params.ch_mult, 292 num_res_blocks=params.num_res_blocks, 293 z_channels=params.z_channels, 294 ) 295 296 self.bn_eps = 1e-4 297 self.bn_momentum = 0.1 298 self.ps = [2, 2] 299 self.bn = torch.nn.BatchNorm2d( 300 math.prod(self.ps) * params.z_channels, 301 eps=self.bn_eps, 302 momentum=self.bn_momentum, 303 affine=False, 304 track_running_stats=True, 305 ) 306 307 def normalize(self, z): 308 self.bn.eval() 309 return self.bn(z) 310 311 def inv_normalize(self, z): 312 self.bn.eval() 313 s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps) # type: ignore 314 m = self.bn.running_mean.view(1, -1, 1, 1) # type: ignore 315 return z * s + m 316 317 def encode(self, x: Tensor) -> Tensor: 318 moments = self.encoder(x) 319 mean = torch.chunk(moments, 2, dim=1)[0] 320 321 z = rearrange( 322 mean, 323 "... c (i pi) (j pj) -> ... (c pi pj) i j", 324 pi=self.ps[0], 325 pj=self.ps[1], 326 ) 327 z = self.normalize(z) 328 return z 329 330 def decode(self, z: Tensor) -> Tensor: 331 z = self.inv_normalize(z) 332 z = rearrange( 333 z, 334 "... (c pi pj) i j -> ... c (i pi) (j pj)", 335 pi=self.ps[0], 336 pj=self.ps[1], 337 ) 338 dec = self.decoder(z) 339 return dec
13@dataclass 14class AutoEncoderParams: 15 resolution: int = 256 16 in_channels: int = 3 17 ch: int = 128 18 out_ch: int = 3 19 ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4]) 20 num_res_blocks: int = 2 21 z_channels: int = 32
28class AttnBlock(nn.Module): 29 def __init__(self, in_channels: int): 30 super().__init__() 31 self.in_channels = in_channels 32 33 self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 34 35 self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 36 self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 37 self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 38 self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 39 40 def attention(self, h_: Tensor) -> Tensor: 41 h_ = self.norm(h_) 42 q = self.q(h_) 43 k = self.k(h_) 44 v = self.v(h_) 45 46 b, c, h, w = q.shape 47 q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 48 k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 49 v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 50 h_ = nn.functional.scaled_dot_product_attention(q, k, v) 51 52 return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 53 54 def forward(self, x: Tensor) -> Tensor: 55 return x + self.proj_out(self.attention(x))
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
29 def __init__(self, in_channels: int): 30 super().__init__() 31 self.in_channels = in_channels 32 33 self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 34 35 self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 36 self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 37 self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 38 self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
40 def attention(self, h_: Tensor) -> Tensor: 41 h_ = self.norm(h_) 42 q = self.q(h_) 43 k = self.k(h_) 44 v = self.v(h_) 45 46 b, c, h, w = q.shape 47 q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 48 k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 49 v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 50 h_ = nn.functional.scaled_dot_product_attention(q, k, v) 51 52 return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
58class ResnetBlock(nn.Module): 59 def __init__(self, in_channels: int, out_channels: int): 60 super().__init__() 61 self.in_channels = in_channels 62 out_channels = in_channels if out_channels is None else out_channels 63 self.out_channels = out_channels 64 65 self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 66 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 67 self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 68 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 69 if self.in_channels != self.out_channels: 70 self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 71 72 def forward(self, x): 73 h = x 74 h = self.norm1(h) 75 h = swish(h) 76 h = self.conv1(h) 77 78 h = self.norm2(h) 79 h = swish(h) 80 h = self.conv2(h) 81 82 if self.in_channels != self.out_channels: 83 x = self.nin_shortcut(x) 84 85 return x + h
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
59 def __init__(self, in_channels: int, out_channels: int): 60 super().__init__() 61 self.in_channels = in_channels 62 out_channels = in_channels if out_channels is None else out_channels 63 self.out_channels = out_channels 64 65 self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 66 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 67 self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 68 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 69 if self.in_channels != self.out_channels: 70 self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
72 def forward(self, x): 73 h = x 74 h = self.norm1(h) 75 h = swish(h) 76 h = self.conv1(h) 77 78 h = self.norm2(h) 79 h = swish(h) 80 h = self.conv2(h) 81 82 if self.in_channels != self.out_channels: 83 x = self.nin_shortcut(x) 84 85 return x + h
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
88class Downsample(nn.Module): 89 def __init__(self, in_channels: int): 90 super().__init__() 91 # no asymmetric padding in torch conv, must do it ourselves 92 self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) 93 94 def forward(self, x: Tensor): 95 pad = (0, 1, 0, 1) 96 x = nn.functional.pad(x, pad, mode="constant", value=0) 97 x = self.conv(x) 98 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
89 def __init__(self, in_channels: int): 90 super().__init__() 91 # no asymmetric padding in torch conv, must do it ourselves 92 self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
94 def forward(self, x: Tensor): 95 pad = (0, 1, 0, 1) 96 x = nn.functional.pad(x, pad, mode="constant", value=0) 97 x = self.conv(x) 98 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
101class Upsample(nn.Module): 102 def __init__(self, in_channels: int): 103 super().__init__() 104 self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 105 106 def forward(self, x: Tensor): 107 x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 108 x = self.conv(x) 109 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
102 def __init__(self, in_channels: int): 103 super().__init__() 104 self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
106 def forward(self, x: Tensor): 107 x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 108 x = self.conv(x) 109 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
112class Encoder(nn.Module): 113 def __init__( 114 self, 115 resolution: int, 116 in_channels: int, 117 ch: int, 118 ch_mult: list[int], 119 num_res_blocks: int, 120 z_channels: int, 121 ): 122 super().__init__() 123 self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1) 124 self.ch = ch 125 self.num_resolutions = len(ch_mult) 126 self.num_res_blocks = num_res_blocks 127 self.resolution = resolution 128 self.in_channels = in_channels 129 # downsampling 130 self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 131 132 curr_res = resolution 133 in_ch_mult = (1,) + tuple(ch_mult) 134 self.in_ch_mult = in_ch_mult 135 self.down = nn.ModuleList() 136 block_in = self.ch 137 for i_level in range(self.num_resolutions): 138 block = nn.ModuleList() 139 attn = nn.ModuleList() 140 block_in = ch * in_ch_mult[i_level] 141 block_out = ch * ch_mult[i_level] 142 for _ in range(self.num_res_blocks): 143 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 144 block_in = block_out 145 down = nn.Module() 146 down.block = block 147 down.attn = attn 148 if i_level != self.num_resolutions - 1: 149 down.downsample = Downsample(block_in) 150 curr_res = curr_res // 2 151 self.down.append(down) 152 153 # middle 154 self.mid = nn.Module() 155 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 156 self.mid.attn_1 = AttnBlock(block_in) 157 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 158 159 # end 160 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 161 self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) 162 163 def forward(self, x: Tensor) -> Tensor: 164 # downsampling 165 hs = [self.conv_in(x)] 166 for i_level in range(self.num_resolutions): 167 for i_block in range(self.num_res_blocks): 168 h = self.down[i_level].block[i_block](hs[-1]) # type: ignore 169 if len(self.down[i_level].attn) > 0: # type: ignore 170 h = self.down[i_level].attn[i_block](h) # type: ignore 171 hs.append(h) 172 if i_level != self.num_resolutions - 1: 173 hs.append(self.down[i_level].downsample(hs[-1])) # type: ignore 174 175 # middle 176 h = hs[-1] 177 h = self.mid.block_1(h) # type: ignore 178 h = self.mid.attn_1(h) # type: ignore 179 h = self.mid.block_2(h) # type: ignore 180 # end 181 h = self.norm_out(h) 182 h = swish(h) 183 h = self.conv_out(h) 184 h = self.quant_conv(h) 185 return h
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
113 def __init__( 114 self, 115 resolution: int, 116 in_channels: int, 117 ch: int, 118 ch_mult: list[int], 119 num_res_blocks: int, 120 z_channels: int, 121 ): 122 super().__init__() 123 self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1) 124 self.ch = ch 125 self.num_resolutions = len(ch_mult) 126 self.num_res_blocks = num_res_blocks 127 self.resolution = resolution 128 self.in_channels = in_channels 129 # downsampling 130 self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 131 132 curr_res = resolution 133 in_ch_mult = (1,) + tuple(ch_mult) 134 self.in_ch_mult = in_ch_mult 135 self.down = nn.ModuleList() 136 block_in = self.ch 137 for i_level in range(self.num_resolutions): 138 block = nn.ModuleList() 139 attn = nn.ModuleList() 140 block_in = ch * in_ch_mult[i_level] 141 block_out = ch * ch_mult[i_level] 142 for _ in range(self.num_res_blocks): 143 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 144 block_in = block_out 145 down = nn.Module() 146 down.block = block 147 down.attn = attn 148 if i_level != self.num_resolutions - 1: 149 down.downsample = Downsample(block_in) 150 curr_res = curr_res // 2 151 self.down.append(down) 152 153 # middle 154 self.mid = nn.Module() 155 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 156 self.mid.attn_1 = AttnBlock(block_in) 157 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 158 159 # end 160 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 161 self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
163 def forward(self, x: Tensor) -> Tensor: 164 # downsampling 165 hs = [self.conv_in(x)] 166 for i_level in range(self.num_resolutions): 167 for i_block in range(self.num_res_blocks): 168 h = self.down[i_level].block[i_block](hs[-1]) # type: ignore 169 if len(self.down[i_level].attn) > 0: # type: ignore 170 h = self.down[i_level].attn[i_block](h) # type: ignore 171 hs.append(h) 172 if i_level != self.num_resolutions - 1: 173 hs.append(self.down[i_level].downsample(hs[-1])) # type: ignore 174 175 # middle 176 h = hs[-1] 177 h = self.mid.block_1(h) # type: ignore 178 h = self.mid.attn_1(h) # type: ignore 179 h = self.mid.block_2(h) # type: ignore 180 # end 181 h = self.norm_out(h) 182 h = swish(h) 183 h = self.conv_out(h) 184 h = self.quant_conv(h) 185 return h
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
188class Decoder(nn.Module): 189 def __init__( 190 self, 191 ch: int, 192 out_ch: int, 193 ch_mult: list[int], 194 num_res_blocks: int, 195 in_channels: int, 196 resolution: int, 197 z_channels: int, 198 ): 199 super().__init__() 200 self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1) 201 self.ch = ch 202 self.num_resolutions = len(ch_mult) 203 self.num_res_blocks = num_res_blocks 204 self.resolution = resolution 205 self.in_channels = in_channels 206 self.ffactor = 2 ** (self.num_resolutions - 1) 207 208 # compute in_ch_mult, block_in and curr_res at lowest res 209 block_in = ch * ch_mult[self.num_resolutions - 1] 210 curr_res = resolution // 2 ** (self.num_resolutions - 1) 211 self.z_shape = (1, z_channels, curr_res, curr_res) 212 213 # z to block_in 214 self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 215 216 # middle 217 self.mid = nn.Module() 218 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 219 self.mid.attn_1 = AttnBlock(block_in) 220 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 221 222 # upsampling 223 self.up = nn.ModuleList() 224 for i_level in reversed(range(self.num_resolutions)): 225 block = nn.ModuleList() 226 attn = nn.ModuleList() 227 block_out = ch * ch_mult[i_level] 228 for _ in range(self.num_res_blocks + 1): 229 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 230 block_in = block_out 231 up = nn.Module() 232 up.block = block 233 up.attn = attn 234 if i_level != 0: 235 up.upsample = Upsample(block_in) 236 curr_res = curr_res * 2 237 self.up.insert(0, up) # prepend to get consistent order 238 239 # end 240 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 241 self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 242 243 def forward(self, z: Tensor) -> Tensor: 244 z = self.post_quant_conv(z) 245 246 # get dtype for proper tracing 247 upscale_dtype = next(self.up.parameters()).dtype 248 249 # z to block_in 250 h = self.conv_in(z) 251 252 # middle 253 h = self.mid.block_1(h) # type: ignore 254 h = self.mid.attn_1(h) # type: ignore 255 h = self.mid.block_2(h) # type: ignore 256 257 # cast to proper dtype 258 h = h.to(upscale_dtype) 259 # upsampling 260 for i_level in reversed(range(self.num_resolutions)): 261 for i_block in range(self.num_res_blocks + 1): 262 h = self.up[i_level].block[i_block](h) # type: ignore 263 if len(self.up[i_level].attn) > 0: # type: ignore 264 h = self.up[i_level].attn[i_block](h) # type: ignore 265 if i_level != 0: 266 h = self.up[i_level].upsample(h) # type: ignore 267 268 # end 269 h = self.norm_out(h) 270 h = swish(h) 271 h = self.conv_out(h) 272 return h
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
189 def __init__( 190 self, 191 ch: int, 192 out_ch: int, 193 ch_mult: list[int], 194 num_res_blocks: int, 195 in_channels: int, 196 resolution: int, 197 z_channels: int, 198 ): 199 super().__init__() 200 self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1) 201 self.ch = ch 202 self.num_resolutions = len(ch_mult) 203 self.num_res_blocks = num_res_blocks 204 self.resolution = resolution 205 self.in_channels = in_channels 206 self.ffactor = 2 ** (self.num_resolutions - 1) 207 208 # compute in_ch_mult, block_in and curr_res at lowest res 209 block_in = ch * ch_mult[self.num_resolutions - 1] 210 curr_res = resolution // 2 ** (self.num_resolutions - 1) 211 self.z_shape = (1, z_channels, curr_res, curr_res) 212 213 # z to block_in 214 self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 215 216 # middle 217 self.mid = nn.Module() 218 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 219 self.mid.attn_1 = AttnBlock(block_in) 220 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 221 222 # upsampling 223 self.up = nn.ModuleList() 224 for i_level in reversed(range(self.num_resolutions)): 225 block = nn.ModuleList() 226 attn = nn.ModuleList() 227 block_out = ch * ch_mult[i_level] 228 for _ in range(self.num_res_blocks + 1): 229 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 230 block_in = block_out 231 up = nn.Module() 232 up.block = block 233 up.attn = attn 234 if i_level != 0: 235 up.upsample = Upsample(block_in) 236 curr_res = curr_res * 2 237 self.up.insert(0, up) # prepend to get consistent order 238 239 # end 240 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 241 self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
243 def forward(self, z: Tensor) -> Tensor: 244 z = self.post_quant_conv(z) 245 246 # get dtype for proper tracing 247 upscale_dtype = next(self.up.parameters()).dtype 248 249 # z to block_in 250 h = self.conv_in(z) 251 252 # middle 253 h = self.mid.block_1(h) # type: ignore 254 h = self.mid.attn_1(h) # type: ignore 255 h = self.mid.block_2(h) # type: ignore 256 257 # cast to proper dtype 258 h = h.to(upscale_dtype) 259 # upsampling 260 for i_level in reversed(range(self.num_resolutions)): 261 for i_block in range(self.num_res_blocks + 1): 262 h = self.up[i_level].block[i_block](h) # type: ignore 263 if len(self.up[i_level].attn) > 0: # type: ignore 264 h = self.up[i_level].attn[i_block](h) # type: ignore 265 if i_level != 0: 266 h = self.up[i_level].upsample(h) # type: ignore 267 268 # end 269 h = self.norm_out(h) 270 h = swish(h) 271 h = self.conv_out(h) 272 return h
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
275class AutoEncoder(nn.Module): 276 def __init__(self, params: AutoEncoderParams): 277 super().__init__() 278 self.params = params 279 self.encoder = Encoder( 280 resolution=params.resolution, 281 in_channels=params.in_channels, 282 ch=params.ch, 283 ch_mult=params.ch_mult, 284 num_res_blocks=params.num_res_blocks, 285 z_channels=params.z_channels, 286 ) 287 self.decoder = Decoder( 288 resolution=params.resolution, 289 in_channels=params.in_channels, 290 ch=params.ch, 291 out_ch=params.out_ch, 292 ch_mult=params.ch_mult, 293 num_res_blocks=params.num_res_blocks, 294 z_channels=params.z_channels, 295 ) 296 297 self.bn_eps = 1e-4 298 self.bn_momentum = 0.1 299 self.ps = [2, 2] 300 self.bn = torch.nn.BatchNorm2d( 301 math.prod(self.ps) * params.z_channels, 302 eps=self.bn_eps, 303 momentum=self.bn_momentum, 304 affine=False, 305 track_running_stats=True, 306 ) 307 308 def normalize(self, z): 309 self.bn.eval() 310 return self.bn(z) 311 312 def inv_normalize(self, z): 313 self.bn.eval() 314 s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps) # type: ignore 315 m = self.bn.running_mean.view(1, -1, 1, 1) # type: ignore 316 return z * s + m 317 318 def encode(self, x: Tensor) -> Tensor: 319 moments = self.encoder(x) 320 mean = torch.chunk(moments, 2, dim=1)[0] 321 322 z = rearrange( 323 mean, 324 "... c (i pi) (j pj) -> ... (c pi pj) i j", 325 pi=self.ps[0], 326 pj=self.ps[1], 327 ) 328 z = self.normalize(z) 329 return z 330 331 def decode(self, z: Tensor) -> Tensor: 332 z = self.inv_normalize(z) 333 z = rearrange( 334 z, 335 "... (c pi pj) i j -> ... c (i pi) (j pj)", 336 pi=self.ps[0], 337 pj=self.ps[1], 338 ) 339 dec = self.decoder(z) 340 return dec
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
276 def __init__(self, params: AutoEncoderParams): 277 super().__init__() 278 self.params = params 279 self.encoder = Encoder( 280 resolution=params.resolution, 281 in_channels=params.in_channels, 282 ch=params.ch, 283 ch_mult=params.ch_mult, 284 num_res_blocks=params.num_res_blocks, 285 z_channels=params.z_channels, 286 ) 287 self.decoder = Decoder( 288 resolution=params.resolution, 289 in_channels=params.in_channels, 290 ch=params.ch, 291 out_ch=params.out_ch, 292 ch_mult=params.ch_mult, 293 num_res_blocks=params.num_res_blocks, 294 z_channels=params.z_channels, 295 ) 296 297 self.bn_eps = 1e-4 298 self.bn_momentum = 0.1 299 self.ps = [2, 2] 300 self.bn = torch.nn.BatchNorm2d( 301 math.prod(self.ps) * params.z_channels, 302 eps=self.bn_eps, 303 momentum=self.bn_momentum, 304 affine=False, 305 track_running_stats=True, 306 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.