divisor.flux1.autoencoder
1# SPDX-License-Identifier:Apache-2.0 2# original BFL Flux code from https://github.com/black-forest-labs/flux 3 4# type: ignore 5 6from dataclasses import dataclass 7 8from einops import rearrange 9import torch 10from torch import Tensor, nn 11 12 13@dataclass 14class AutoEncoderParams: 15 resolution: int 16 in_channels: int 17 ch: int 18 out_ch: int 19 ch_mult: list[int] 20 num_res_blocks: int 21 z_channels: int 22 scale_factor: float 23 shift_factor: float 24 25 26def swish(x: Tensor) -> Tensor: 27 return x * torch.sigmoid(x) 28 29 30class AttnBlock(nn.Module): 31 def __init__(self, in_channels: int): 32 super().__init__() 33 self.in_channels = in_channels 34 35 self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 36 37 self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 38 self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 39 self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 40 self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 41 42 def attention(self, h_: Tensor) -> Tensor: 43 h_ = self.norm(h_) 44 q = self.q(h_) 45 k = self.k(h_) 46 v = self.v(h_) 47 48 b, c, h, w = q.shape 49 q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 50 k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 51 v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 52 h_ = nn.functional.scaled_dot_product_attention(q, k, v) 53 54 return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 55 56 def forward(self, x: Tensor) -> Tensor: 57 return x + self.proj_out(self.attention(x)) 58 59 60class ResnetBlock(nn.Module): 61 def __init__(self, in_channels: int, out_channels: int): 62 super().__init__() 63 self.in_channels = in_channels 64 out_channels = in_channels if out_channels is None else out_channels 65 self.out_channels = out_channels 66 67 self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 68 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 69 self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 70 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 71 if self.in_channels != self.out_channels: 72 self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 73 74 def forward(self, x): 75 h = x 76 h = self.norm1(h) 77 h = swish(h) 78 h = self.conv1(h) 79 80 h = self.norm2(h) 81 h = swish(h) 82 h = self.conv2(h) 83 84 if self.in_channels != self.out_channels: 85 x = self.nin_shortcut(x) 86 87 return x + h 88 89 90class Downsample(nn.Module): 91 def __init__(self, in_channels: int): 92 super().__init__() 93 # no asymmetric padding in torch conv, must do it ourselves 94 self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) 95 96 def forward(self, x: Tensor): 97 pad = (0, 1, 0, 1) 98 x = nn.functional.pad(x, pad, mode="constant", value=0) 99 x = self.conv(x) 100 return x 101 102 103class Upsample(nn.Module): 104 def __init__(self, in_channels: int): 105 super().__init__() 106 self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 107 108 def forward(self, x: Tensor): 109 x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 110 x = self.conv(x) 111 return x 112 113 114class Encoder(nn.Module): 115 def __init__( 116 self, 117 resolution: int, 118 in_channels: int, 119 ch: int, 120 ch_mult: list[int], 121 num_res_blocks: int, 122 z_channels: int, 123 ): 124 super().__init__() 125 self.ch = ch 126 self.num_resolutions = len(ch_mult) 127 self.num_res_blocks = num_res_blocks 128 self.resolution = resolution 129 self.in_channels = in_channels 130 # downsampling 131 self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 132 133 curr_res = resolution 134 in_ch_mult = (1,) + tuple(ch_mult) 135 self.in_ch_mult = in_ch_mult 136 self.down = nn.ModuleList() 137 block_in = self.ch 138 for i_level in range(self.num_resolutions): 139 block = nn.ModuleList() 140 attn = nn.ModuleList() 141 block_in = ch * in_ch_mult[i_level] 142 block_out = ch * ch_mult[i_level] 143 for _ in range(self.num_res_blocks): 144 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 145 block_in = block_out 146 down = nn.Module() 147 down.block = block 148 down.attn = attn 149 if i_level != self.num_resolutions - 1: 150 down.downsample = Downsample(block_in) 151 curr_res = curr_res // 2 152 self.down.append(down) 153 154 # middle 155 self.mid = nn.Module() 156 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 157 self.mid.attn_1 = AttnBlock(block_in) 158 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 159 160 # end 161 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 162 self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) 163 164 def forward(self, x: Tensor) -> Tensor: 165 # downsampling 166 hs = [self.conv_in(x)] 167 for i_level in range(self.num_resolutions): 168 for i_block in range(self.num_res_blocks): 169 h = self.down[i_level].block[i_block](hs[-1]) 170 if len(self.down[i_level].attn) > 0: 171 h = self.down[i_level].attn[i_block](h) 172 hs.append(h) 173 if i_level != self.num_resolutions - 1: 174 hs.append(self.down[i_level].downsample(hs[-1])) 175 176 # middle 177 h = hs[-1] 178 h = self.mid.block_1(h) 179 h = self.mid.attn_1(h) 180 h = self.mid.block_2(h) 181 # end 182 h = self.norm_out(h) 183 h = swish(h) 184 h = self.conv_out(h) 185 return h 186 187 188class Decoder(nn.Module): 189 def __init__( 190 self, 191 ch: int, 192 out_ch: int, 193 ch_mult: list[int], 194 num_res_blocks: int, 195 in_channels: int, 196 resolution: int, 197 z_channels: int, 198 ): 199 super().__init__() 200 self.ch = ch 201 self.num_resolutions = len(ch_mult) 202 self.num_res_blocks = num_res_blocks 203 self.resolution = resolution 204 self.in_channels = in_channels 205 self.ffactor = 2 ** (self.num_resolutions - 1) 206 207 # compute in_ch_mult, block_in and curr_res at lowest res 208 block_in = ch * ch_mult[self.num_resolutions - 1] 209 curr_res = resolution // 2 ** (self.num_resolutions - 1) 210 self.z_shape = (1, z_channels, curr_res, curr_res) 211 212 # z to block_in 213 self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 214 215 # middle 216 self.mid = nn.Module() 217 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 218 self.mid.attn_1 = AttnBlock(block_in) 219 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 220 221 # upsampling 222 self.up = nn.ModuleList() 223 for i_level in reversed(range(self.num_resolutions)): 224 block = nn.ModuleList() 225 attn = nn.ModuleList() 226 block_out = ch * ch_mult[i_level] 227 for _ in range(self.num_res_blocks + 1): 228 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 229 block_in = block_out 230 up = nn.Module() 231 up.block = block 232 up.attn = attn 233 if i_level != 0: 234 up.upsample = Upsample(block_in) 235 curr_res = curr_res * 2 236 self.up.insert(0, up) # prepend to get consistent order 237 238 # end 239 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 240 self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 241 242 def forward(self, z: Tensor) -> Tensor: 243 # get dtype for proper tracing 244 upscale_dtype = next(self.up.parameters()).dtype 245 246 # z to block_in 247 h = self.conv_in(z) 248 249 # middle 250 h = self.mid.block_1(h) 251 h = self.mid.attn_1(h) 252 h = self.mid.block_2(h) 253 254 # cast to proper dtype 255 h = h.to(upscale_dtype) 256 # upsampling 257 for i_level in reversed(range(self.num_resolutions)): 258 for i_block in range(self.num_res_blocks + 1): 259 h = self.up[i_level].block[i_block](h) 260 if len(self.up[i_level].attn) > 0: 261 h = self.up[i_level].attn[i_block](h) 262 if i_level != 0: 263 h = self.up[i_level].upsample(h) 264 265 # end 266 h = self.norm_out(h) 267 h = swish(h) 268 h = self.conv_out(h) 269 return h 270 271 272class DiagonalGaussian(nn.Module): 273 def __init__(self, sample: bool = True, chunk_dim: int = 1): 274 super().__init__() 275 self.sample = sample 276 self.chunk_dim = chunk_dim 277 278 def forward(self, z: Tensor) -> Tensor: 279 mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 280 if self.sample: 281 std = torch.exp(0.5 * logvar) 282 return mean + std * torch.randn_like(mean) 283 else: 284 return mean 285 286 287class AutoEncoder(nn.Module): 288 def __init__(self, params: AutoEncoderParams, sample_z: bool = False): 289 super().__init__() 290 self.params = params 291 self.encoder = Encoder( 292 resolution=params.resolution, 293 in_channels=params.in_channels, 294 ch=params.ch, 295 ch_mult=params.ch_mult, 296 num_res_blocks=params.num_res_blocks, 297 z_channels=params.z_channels, 298 ) 299 self.decoder = Decoder( 300 resolution=params.resolution, 301 in_channels=params.in_channels, 302 ch=params.ch, 303 out_ch=params.out_ch, 304 ch_mult=params.ch_mult, 305 num_res_blocks=params.num_res_blocks, 306 z_channels=params.z_channels, 307 ) 308 self.reg = DiagonalGaussian(sample=sample_z) 309 310 self.scale_factor = params.scale_factor 311 self.shift_factor = params.shift_factor 312 313 def encode(self, x: Tensor) -> Tensor: 314 encoder_dtype = next(self.encoder.parameters()).dtype 315 x = x.to(encoder_dtype) # Quality of Life:sample dtype always matches encoder dtype 316 z = self.reg(self.encoder(x)) 317 z = self.scale_factor * (z - self.shift_factor) 318 return z 319 320 def decode(self, z: Tensor) -> Tensor: 321 z = z / self.scale_factor + self.shift_factor 322 return self.decoder(z) 323 324 def forward(self, x: Tensor) -> Tensor: 325 return self.decode(self.encode(x)) 326 327 328if __name__ == "__main__": 329 import torch 330 331 params = AutoEncoderParams( 332 resolution=256, 333 in_channels=3, 334 ch=128, 335 out_ch=3, 336 ch_mult=[1, 2, 4, 4], 337 num_res_blocks=2, 338 z_channels=16, 339 scale_factor=0.3611, 340 shift_factor=0.1159, 341 ) 342 343 model = AutoEncoder(params) 344 x = torch.randn(1, 3, 512, 512) 345 z = model.encode(x)
14@dataclass 15class AutoEncoderParams: 16 resolution: int 17 in_channels: int 18 ch: int 19 out_ch: int 20 ch_mult: list[int] 21 num_res_blocks: int 22 z_channels: int 23 scale_factor: float 24 shift_factor: float
31class AttnBlock(nn.Module): 32 def __init__(self, in_channels: int): 33 super().__init__() 34 self.in_channels = in_channels 35 36 self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 37 38 self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 39 self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 40 self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 41 self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) 42 43 def attention(self, h_: Tensor) -> Tensor: 44 h_ = self.norm(h_) 45 q = self.q(h_) 46 k = self.k(h_) 47 v = self.v(h_) 48 49 b, c, h, w = q.shape 50 q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 51 k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 52 v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 53 h_ = nn.functional.scaled_dot_product_attention(q, k, v) 54 55 return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) 56 57 def forward(self, x: Tensor) -> Tensor: 58 return x + self.proj_out(self.attention(x))
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
32 def __init__(self, in_channels: int): 33 super().__init__() 34 self.in_channels = in_channels 35 36 self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 37 38 self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) 39 self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) 40 self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) 41 self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
43 def attention(self, h_: Tensor) -> Tensor: 44 h_ = self.norm(h_) 45 q = self.q(h_) 46 k = self.k(h_) 47 v = self.v(h_) 48 49 b, c, h, w = q.shape 50 q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() 51 k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() 52 v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() 53 h_ = nn.functional.scaled_dot_product_attention(q, k, v) 54 55 return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
61class ResnetBlock(nn.Module): 62 def __init__(self, in_channels: int, out_channels: int): 63 super().__init__() 64 self.in_channels = in_channels 65 out_channels = in_channels if out_channels is None else out_channels 66 self.out_channels = out_channels 67 68 self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 69 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 70 self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 71 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 72 if self.in_channels != self.out_channels: 73 self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 74 75 def forward(self, x): 76 h = x 77 h = self.norm1(h) 78 h = swish(h) 79 h = self.conv1(h) 80 81 h = self.norm2(h) 82 h = swish(h) 83 h = self.conv2(h) 84 85 if self.in_channels != self.out_channels: 86 x = self.nin_shortcut(x) 87 88 return x + h
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
62 def __init__(self, in_channels: int, out_channels: int): 63 super().__init__() 64 self.in_channels = in_channels 65 out_channels = in_channels if out_channels is None else out_channels 66 self.out_channels = out_channels 67 68 self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 69 self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 70 self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) 71 self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 72 if self.in_channels != self.out_channels: 73 self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
75 def forward(self, x): 76 h = x 77 h = self.norm1(h) 78 h = swish(h) 79 h = self.conv1(h) 80 81 h = self.norm2(h) 82 h = swish(h) 83 h = self.conv2(h) 84 85 if self.in_channels != self.out_channels: 86 x = self.nin_shortcut(x) 87 88 return x + h
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
91class Downsample(nn.Module): 92 def __init__(self, in_channels: int): 93 super().__init__() 94 # no asymmetric padding in torch conv, must do it ourselves 95 self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) 96 97 def forward(self, x: Tensor): 98 pad = (0, 1, 0, 1) 99 x = nn.functional.pad(x, pad, mode="constant", value=0) 100 x = self.conv(x) 101 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
92 def __init__(self, in_channels: int): 93 super().__init__() 94 # no asymmetric padding in torch conv, must do it ourselves 95 self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
97 def forward(self, x: Tensor): 98 pad = (0, 1, 0, 1) 99 x = nn.functional.pad(x, pad, mode="constant", value=0) 100 x = self.conv(x) 101 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
104class Upsample(nn.Module): 105 def __init__(self, in_channels: int): 106 super().__init__() 107 self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 108 109 def forward(self, x: Tensor): 110 x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 111 x = self.conv(x) 112 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
105 def __init__(self, in_channels: int): 106 super().__init__() 107 self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
109 def forward(self, x: Tensor): 110 x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 111 x = self.conv(x) 112 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
115class Encoder(nn.Module): 116 def __init__( 117 self, 118 resolution: int, 119 in_channels: int, 120 ch: int, 121 ch_mult: list[int], 122 num_res_blocks: int, 123 z_channels: int, 124 ): 125 super().__init__() 126 self.ch = ch 127 self.num_resolutions = len(ch_mult) 128 self.num_res_blocks = num_res_blocks 129 self.resolution = resolution 130 self.in_channels = in_channels 131 # downsampling 132 self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 133 134 curr_res = resolution 135 in_ch_mult = (1,) + tuple(ch_mult) 136 self.in_ch_mult = in_ch_mult 137 self.down = nn.ModuleList() 138 block_in = self.ch 139 for i_level in range(self.num_resolutions): 140 block = nn.ModuleList() 141 attn = nn.ModuleList() 142 block_in = ch * in_ch_mult[i_level] 143 block_out = ch * ch_mult[i_level] 144 for _ in range(self.num_res_blocks): 145 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 146 block_in = block_out 147 down = nn.Module() 148 down.block = block 149 down.attn = attn 150 if i_level != self.num_resolutions - 1: 151 down.downsample = Downsample(block_in) 152 curr_res = curr_res // 2 153 self.down.append(down) 154 155 # middle 156 self.mid = nn.Module() 157 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 158 self.mid.attn_1 = AttnBlock(block_in) 159 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 160 161 # end 162 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 163 self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) 164 165 def forward(self, x: Tensor) -> Tensor: 166 # downsampling 167 hs = [self.conv_in(x)] 168 for i_level in range(self.num_resolutions): 169 for i_block in range(self.num_res_blocks): 170 h = self.down[i_level].block[i_block](hs[-1]) 171 if len(self.down[i_level].attn) > 0: 172 h = self.down[i_level].attn[i_block](h) 173 hs.append(h) 174 if i_level != self.num_resolutions - 1: 175 hs.append(self.down[i_level].downsample(hs[-1])) 176 177 # middle 178 h = hs[-1] 179 h = self.mid.block_1(h) 180 h = self.mid.attn_1(h) 181 h = self.mid.block_2(h) 182 # end 183 h = self.norm_out(h) 184 h = swish(h) 185 h = self.conv_out(h) 186 return h
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
116 def __init__( 117 self, 118 resolution: int, 119 in_channels: int, 120 ch: int, 121 ch_mult: list[int], 122 num_res_blocks: int, 123 z_channels: int, 124 ): 125 super().__init__() 126 self.ch = ch 127 self.num_resolutions = len(ch_mult) 128 self.num_res_blocks = num_res_blocks 129 self.resolution = resolution 130 self.in_channels = in_channels 131 # downsampling 132 self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 133 134 curr_res = resolution 135 in_ch_mult = (1,) + tuple(ch_mult) 136 self.in_ch_mult = in_ch_mult 137 self.down = nn.ModuleList() 138 block_in = self.ch 139 for i_level in range(self.num_resolutions): 140 block = nn.ModuleList() 141 attn = nn.ModuleList() 142 block_in = ch * in_ch_mult[i_level] 143 block_out = ch * ch_mult[i_level] 144 for _ in range(self.num_res_blocks): 145 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 146 block_in = block_out 147 down = nn.Module() 148 down.block = block 149 down.attn = attn 150 if i_level != self.num_resolutions - 1: 151 down.downsample = Downsample(block_in) 152 curr_res = curr_res // 2 153 self.down.append(down) 154 155 # middle 156 self.mid = nn.Module() 157 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 158 self.mid.attn_1 = AttnBlock(block_in) 159 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 160 161 # end 162 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 163 self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
165 def forward(self, x: Tensor) -> Tensor: 166 # downsampling 167 hs = [self.conv_in(x)] 168 for i_level in range(self.num_resolutions): 169 for i_block in range(self.num_res_blocks): 170 h = self.down[i_level].block[i_block](hs[-1]) 171 if len(self.down[i_level].attn) > 0: 172 h = self.down[i_level].attn[i_block](h) 173 hs.append(h) 174 if i_level != self.num_resolutions - 1: 175 hs.append(self.down[i_level].downsample(hs[-1])) 176 177 # middle 178 h = hs[-1] 179 h = self.mid.block_1(h) 180 h = self.mid.attn_1(h) 181 h = self.mid.block_2(h) 182 # end 183 h = self.norm_out(h) 184 h = swish(h) 185 h = self.conv_out(h) 186 return h
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
189class Decoder(nn.Module): 190 def __init__( 191 self, 192 ch: int, 193 out_ch: int, 194 ch_mult: list[int], 195 num_res_blocks: int, 196 in_channels: int, 197 resolution: int, 198 z_channels: int, 199 ): 200 super().__init__() 201 self.ch = ch 202 self.num_resolutions = len(ch_mult) 203 self.num_res_blocks = num_res_blocks 204 self.resolution = resolution 205 self.in_channels = in_channels 206 self.ffactor = 2 ** (self.num_resolutions - 1) 207 208 # compute in_ch_mult, block_in and curr_res at lowest res 209 block_in = ch * ch_mult[self.num_resolutions - 1] 210 curr_res = resolution // 2 ** (self.num_resolutions - 1) 211 self.z_shape = (1, z_channels, curr_res, curr_res) 212 213 # z to block_in 214 self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 215 216 # middle 217 self.mid = nn.Module() 218 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 219 self.mid.attn_1 = AttnBlock(block_in) 220 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 221 222 # upsampling 223 self.up = nn.ModuleList() 224 for i_level in reversed(range(self.num_resolutions)): 225 block = nn.ModuleList() 226 attn = nn.ModuleList() 227 block_out = ch * ch_mult[i_level] 228 for _ in range(self.num_res_blocks + 1): 229 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 230 block_in = block_out 231 up = nn.Module() 232 up.block = block 233 up.attn = attn 234 if i_level != 0: 235 up.upsample = Upsample(block_in) 236 curr_res = curr_res * 2 237 self.up.insert(0, up) # prepend to get consistent order 238 239 # end 240 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 241 self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 242 243 def forward(self, z: Tensor) -> Tensor: 244 # get dtype for proper tracing 245 upscale_dtype = next(self.up.parameters()).dtype 246 247 # z to block_in 248 h = self.conv_in(z) 249 250 # middle 251 h = self.mid.block_1(h) 252 h = self.mid.attn_1(h) 253 h = self.mid.block_2(h) 254 255 # cast to proper dtype 256 h = h.to(upscale_dtype) 257 # upsampling 258 for i_level in reversed(range(self.num_resolutions)): 259 for i_block in range(self.num_res_blocks + 1): 260 h = self.up[i_level].block[i_block](h) 261 if len(self.up[i_level].attn) > 0: 262 h = self.up[i_level].attn[i_block](h) 263 if i_level != 0: 264 h = self.up[i_level].upsample(h) 265 266 # end 267 h = self.norm_out(h) 268 h = swish(h) 269 h = self.conv_out(h) 270 return h
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
190 def __init__( 191 self, 192 ch: int, 193 out_ch: int, 194 ch_mult: list[int], 195 num_res_blocks: int, 196 in_channels: int, 197 resolution: int, 198 z_channels: int, 199 ): 200 super().__init__() 201 self.ch = ch 202 self.num_resolutions = len(ch_mult) 203 self.num_res_blocks = num_res_blocks 204 self.resolution = resolution 205 self.in_channels = in_channels 206 self.ffactor = 2 ** (self.num_resolutions - 1) 207 208 # compute in_ch_mult, block_in and curr_res at lowest res 209 block_in = ch * ch_mult[self.num_resolutions - 1] 210 curr_res = resolution // 2 ** (self.num_resolutions - 1) 211 self.z_shape = (1, z_channels, curr_res, curr_res) 212 213 # z to block_in 214 self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 215 216 # middle 217 self.mid = nn.Module() 218 self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) 219 self.mid.attn_1 = AttnBlock(block_in) 220 self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) 221 222 # upsampling 223 self.up = nn.ModuleList() 224 for i_level in reversed(range(self.num_resolutions)): 225 block = nn.ModuleList() 226 attn = nn.ModuleList() 227 block_out = ch * ch_mult[i_level] 228 for _ in range(self.num_res_blocks + 1): 229 block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) 230 block_in = block_out 231 up = nn.Module() 232 up.block = block 233 up.attn = attn 234 if i_level != 0: 235 up.upsample = Upsample(block_in) 236 curr_res = curr_res * 2 237 self.up.insert(0, up) # prepend to get consistent order 238 239 # end 240 self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) 241 self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
243 def forward(self, z: Tensor) -> Tensor: 244 # get dtype for proper tracing 245 upscale_dtype = next(self.up.parameters()).dtype 246 247 # z to block_in 248 h = self.conv_in(z) 249 250 # middle 251 h = self.mid.block_1(h) 252 h = self.mid.attn_1(h) 253 h = self.mid.block_2(h) 254 255 # cast to proper dtype 256 h = h.to(upscale_dtype) 257 # upsampling 258 for i_level in reversed(range(self.num_resolutions)): 259 for i_block in range(self.num_res_blocks + 1): 260 h = self.up[i_level].block[i_block](h) 261 if len(self.up[i_level].attn) > 0: 262 h = self.up[i_level].attn[i_block](h) 263 if i_level != 0: 264 h = self.up[i_level].upsample(h) 265 266 # end 267 h = self.norm_out(h) 268 h = swish(h) 269 h = self.conv_out(h) 270 return h
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
273class DiagonalGaussian(nn.Module): 274 def __init__(self, sample: bool = True, chunk_dim: int = 1): 275 super().__init__() 276 self.sample = sample 277 self.chunk_dim = chunk_dim 278 279 def forward(self, z: Tensor) -> Tensor: 280 mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 281 if self.sample: 282 std = torch.exp(0.5 * logvar) 283 return mean + std * torch.randn_like(mean) 284 else: 285 return mean
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
274 def __init__(self, sample: bool = True, chunk_dim: int = 1): 275 super().__init__() 276 self.sample = sample 277 self.chunk_dim = chunk_dim
Initialize internal Module state, shared by both nn.Module and ScriptModule.
279 def forward(self, z: Tensor) -> Tensor: 280 mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) 281 if self.sample: 282 std = torch.exp(0.5 * logvar) 283 return mean + std * torch.randn_like(mean) 284 else: 285 return mean
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
288class AutoEncoder(nn.Module): 289 def __init__(self, params: AutoEncoderParams, sample_z: bool = False): 290 super().__init__() 291 self.params = params 292 self.encoder = Encoder( 293 resolution=params.resolution, 294 in_channels=params.in_channels, 295 ch=params.ch, 296 ch_mult=params.ch_mult, 297 num_res_blocks=params.num_res_blocks, 298 z_channels=params.z_channels, 299 ) 300 self.decoder = Decoder( 301 resolution=params.resolution, 302 in_channels=params.in_channels, 303 ch=params.ch, 304 out_ch=params.out_ch, 305 ch_mult=params.ch_mult, 306 num_res_blocks=params.num_res_blocks, 307 z_channels=params.z_channels, 308 ) 309 self.reg = DiagonalGaussian(sample=sample_z) 310 311 self.scale_factor = params.scale_factor 312 self.shift_factor = params.shift_factor 313 314 def encode(self, x: Tensor) -> Tensor: 315 encoder_dtype = next(self.encoder.parameters()).dtype 316 x = x.to(encoder_dtype) # Quality of Life:sample dtype always matches encoder dtype 317 z = self.reg(self.encoder(x)) 318 z = self.scale_factor * (z - self.shift_factor) 319 return z 320 321 def decode(self, z: Tensor) -> Tensor: 322 z = z / self.scale_factor + self.shift_factor 323 return self.decoder(z) 324 325 def forward(self, x: Tensor) -> Tensor: 326 return self.decode(self.encode(x))
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
289 def __init__(self, params: AutoEncoderParams, sample_z: bool = False): 290 super().__init__() 291 self.params = params 292 self.encoder = Encoder( 293 resolution=params.resolution, 294 in_channels=params.in_channels, 295 ch=params.ch, 296 ch_mult=params.ch_mult, 297 num_res_blocks=params.num_res_blocks, 298 z_channels=params.z_channels, 299 ) 300 self.decoder = Decoder( 301 resolution=params.resolution, 302 in_channels=params.in_channels, 303 ch=params.ch, 304 out_ch=params.out_ch, 305 ch_mult=params.ch_mult, 306 num_res_blocks=params.num_res_blocks, 307 z_channels=params.z_channels, 308 ) 309 self.reg = DiagonalGaussian(sample=sample_z) 310 311 self.scale_factor = params.scale_factor 312 self.shift_factor = params.shift_factor
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.