divisor.flux1.autoencoder

  1# SPDX-License-Identifier:Apache-2.0
  2# original BFL Flux code from https://github.com/black-forest-labs/flux
  3
  4# type: ignore
  5
  6from dataclasses import dataclass
  7
  8from einops import rearrange
  9import torch
 10from torch import Tensor, nn
 11
 12
 13@dataclass
 14class AutoEncoderParams:
 15    resolution: int
 16    in_channels: int
 17    ch: int
 18    out_ch: int
 19    ch_mult: list[int]
 20    num_res_blocks: int
 21    z_channels: int
 22    scale_factor: float
 23    shift_factor: float
 24
 25
 26def swish(x: Tensor) -> Tensor:
 27    return x * torch.sigmoid(x)
 28
 29
 30class AttnBlock(nn.Module):
 31    def __init__(self, in_channels: int):
 32        super().__init__()
 33        self.in_channels = in_channels
 34
 35        self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
 36
 37        self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
 38        self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
 39        self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
 40        self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
 41
 42    def attention(self, h_: Tensor) -> Tensor:
 43        h_ = self.norm(h_)
 44        q = self.q(h_)
 45        k = self.k(h_)
 46        v = self.v(h_)
 47
 48        b, c, h, w = q.shape
 49        q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
 50        k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
 51        v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
 52        h_ = nn.functional.scaled_dot_product_attention(q, k, v)
 53
 54        return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
 55
 56    def forward(self, x: Tensor) -> Tensor:
 57        return x + self.proj_out(self.attention(x))
 58
 59
 60class ResnetBlock(nn.Module):
 61    def __init__(self, in_channels: int, out_channels: int):
 62        super().__init__()
 63        self.in_channels = in_channels
 64        out_channels = in_channels if out_channels is None else out_channels
 65        self.out_channels = out_channels
 66
 67        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
 68        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
 69        self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
 70        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
 71        if self.in_channels != self.out_channels:
 72            self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
 73
 74    def forward(self, x):
 75        h = x
 76        h = self.norm1(h)
 77        h = swish(h)
 78        h = self.conv1(h)
 79
 80        h = self.norm2(h)
 81        h = swish(h)
 82        h = self.conv2(h)
 83
 84        if self.in_channels != self.out_channels:
 85            x = self.nin_shortcut(x)
 86
 87        return x + h
 88
 89
 90class Downsample(nn.Module):
 91    def __init__(self, in_channels: int):
 92        super().__init__()
 93        # no asymmetric padding in torch conv, must do it ourselves
 94        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
 95
 96    def forward(self, x: Tensor):
 97        pad = (0, 1, 0, 1)
 98        x = nn.functional.pad(x, pad, mode="constant", value=0)
 99        x = self.conv(x)
100        return x
101
102
103class Upsample(nn.Module):
104    def __init__(self, in_channels: int):
105        super().__init__()
106        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
107
108    def forward(self, x: Tensor):
109        x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
110        x = self.conv(x)
111        return x
112
113
114class Encoder(nn.Module):
115    def __init__(
116        self,
117        resolution: int,
118        in_channels: int,
119        ch: int,
120        ch_mult: list[int],
121        num_res_blocks: int,
122        z_channels: int,
123    ):
124        super().__init__()
125        self.ch = ch
126        self.num_resolutions = len(ch_mult)
127        self.num_res_blocks = num_res_blocks
128        self.resolution = resolution
129        self.in_channels = in_channels
130        # downsampling
131        self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
132
133        curr_res = resolution
134        in_ch_mult = (1,) + tuple(ch_mult)
135        self.in_ch_mult = in_ch_mult
136        self.down = nn.ModuleList()
137        block_in = self.ch
138        for i_level in range(self.num_resolutions):
139            block = nn.ModuleList()
140            attn = nn.ModuleList()
141            block_in = ch * in_ch_mult[i_level]
142            block_out = ch * ch_mult[i_level]
143            for _ in range(self.num_res_blocks):
144                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
145                block_in = block_out
146            down = nn.Module()
147            down.block = block
148            down.attn = attn
149            if i_level != self.num_resolutions - 1:
150                down.downsample = Downsample(block_in)
151                curr_res = curr_res // 2
152            self.down.append(down)
153
154        # middle
155        self.mid = nn.Module()
156        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
157        self.mid.attn_1 = AttnBlock(block_in)
158        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
159
160        # end
161        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
162        self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
163
164    def forward(self, x: Tensor) -> Tensor:
165        # downsampling
166        hs = [self.conv_in(x)]
167        for i_level in range(self.num_resolutions):
168            for i_block in range(self.num_res_blocks):
169                h = self.down[i_level].block[i_block](hs[-1])
170                if len(self.down[i_level].attn) > 0:
171                    h = self.down[i_level].attn[i_block](h)
172                hs.append(h)
173            if i_level != self.num_resolutions - 1:
174                hs.append(self.down[i_level].downsample(hs[-1]))
175
176        # middle
177        h = hs[-1]
178        h = self.mid.block_1(h)
179        h = self.mid.attn_1(h)
180        h = self.mid.block_2(h)
181        # end
182        h = self.norm_out(h)
183        h = swish(h)
184        h = self.conv_out(h)
185        return h
186
187
188class Decoder(nn.Module):
189    def __init__(
190        self,
191        ch: int,
192        out_ch: int,
193        ch_mult: list[int],
194        num_res_blocks: int,
195        in_channels: int,
196        resolution: int,
197        z_channels: int,
198    ):
199        super().__init__()
200        self.ch = ch
201        self.num_resolutions = len(ch_mult)
202        self.num_res_blocks = num_res_blocks
203        self.resolution = resolution
204        self.in_channels = in_channels
205        self.ffactor = 2 ** (self.num_resolutions - 1)
206
207        # compute in_ch_mult, block_in and curr_res at lowest res
208        block_in = ch * ch_mult[self.num_resolutions - 1]
209        curr_res = resolution // 2 ** (self.num_resolutions - 1)
210        self.z_shape = (1, z_channels, curr_res, curr_res)
211
212        # z to block_in
213        self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
214
215        # middle
216        self.mid = nn.Module()
217        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
218        self.mid.attn_1 = AttnBlock(block_in)
219        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
220
221        # upsampling
222        self.up = nn.ModuleList()
223        for i_level in reversed(range(self.num_resolutions)):
224            block = nn.ModuleList()
225            attn = nn.ModuleList()
226            block_out = ch * ch_mult[i_level]
227            for _ in range(self.num_res_blocks + 1):
228                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
229                block_in = block_out
230            up = nn.Module()
231            up.block = block
232            up.attn = attn
233            if i_level != 0:
234                up.upsample = Upsample(block_in)
235                curr_res = curr_res * 2
236            self.up.insert(0, up)  # prepend to get consistent order
237
238        # end
239        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
240        self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
241
242    def forward(self, z: Tensor) -> Tensor:
243        # get dtype for proper tracing
244        upscale_dtype = next(self.up.parameters()).dtype
245
246        # z to block_in
247        h = self.conv_in(z)
248
249        # middle
250        h = self.mid.block_1(h)
251        h = self.mid.attn_1(h)
252        h = self.mid.block_2(h)
253
254        # cast to proper dtype
255        h = h.to(upscale_dtype)
256        # upsampling
257        for i_level in reversed(range(self.num_resolutions)):
258            for i_block in range(self.num_res_blocks + 1):
259                h = self.up[i_level].block[i_block](h)
260                if len(self.up[i_level].attn) > 0:
261                    h = self.up[i_level].attn[i_block](h)
262            if i_level != 0:
263                h = self.up[i_level].upsample(h)
264
265        # end
266        h = self.norm_out(h)
267        h = swish(h)
268        h = self.conv_out(h)
269        return h
270
271
272class DiagonalGaussian(nn.Module):
273    def __init__(self, sample: bool = True, chunk_dim: int = 1):
274        super().__init__()
275        self.sample = sample
276        self.chunk_dim = chunk_dim
277
278    def forward(self, z: Tensor) -> Tensor:
279        mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
280        if self.sample:
281            std = torch.exp(0.5 * logvar)
282            return mean + std * torch.randn_like(mean)
283        else:
284            return mean
285
286
287class AutoEncoder(nn.Module):
288    def __init__(self, params: AutoEncoderParams, sample_z: bool = False):
289        super().__init__()
290        self.params = params
291        self.encoder = Encoder(
292            resolution=params.resolution,
293            in_channels=params.in_channels,
294            ch=params.ch,
295            ch_mult=params.ch_mult,
296            num_res_blocks=params.num_res_blocks,
297            z_channels=params.z_channels,
298        )
299        self.decoder = Decoder(
300            resolution=params.resolution,
301            in_channels=params.in_channels,
302            ch=params.ch,
303            out_ch=params.out_ch,
304            ch_mult=params.ch_mult,
305            num_res_blocks=params.num_res_blocks,
306            z_channels=params.z_channels,
307        )
308        self.reg = DiagonalGaussian(sample=sample_z)
309
310        self.scale_factor = params.scale_factor
311        self.shift_factor = params.shift_factor
312
313    def encode(self, x: Tensor) -> Tensor:
314        encoder_dtype = next(self.encoder.parameters()).dtype
315        x = x.to(encoder_dtype)  # Quality of Life:sample dtype always matches encoder dtype
316        z = self.reg(self.encoder(x))
317        z = self.scale_factor * (z - self.shift_factor)
318        return z
319
320    def decode(self, z: Tensor) -> Tensor:
321        z = z / self.scale_factor + self.shift_factor
322        return self.decoder(z)
323
324    def forward(self, x: Tensor) -> Tensor:
325        return self.decode(self.encode(x))
326
327
328if __name__ == "__main__":
329    import torch
330
331    params = AutoEncoderParams(
332        resolution=256,
333        in_channels=3,
334        ch=128,
335        out_ch=3,
336        ch_mult=[1, 2, 4, 4],
337        num_res_blocks=2,
338        z_channels=16,
339        scale_factor=0.3611,
340        shift_factor=0.1159,
341    )
342
343    model = AutoEncoder(params)
344    x = torch.randn(1, 3, 512, 512)
345    z = model.encode(x)
@dataclass
class AutoEncoderParams:
14@dataclass
15class AutoEncoderParams:
16    resolution: int
17    in_channels: int
18    ch: int
19    out_ch: int
20    ch_mult: list[int]
21    num_res_blocks: int
22    z_channels: int
23    scale_factor: float
24    shift_factor: float
AutoEncoderParams( resolution: int, in_channels: int, ch: int, out_ch: int, ch_mult: list[int], num_res_blocks: int, z_channels: int, scale_factor: float, shift_factor: float)
resolution: int
in_channels: int
ch: int
out_ch: int
ch_mult: list[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
def swish(x: torch.Tensor) -> torch.Tensor:
27def swish(x: Tensor) -> Tensor:
28    return x * torch.sigmoid(x)
class AttnBlock(torch.nn.modules.module.Module):
31class AttnBlock(nn.Module):
32    def __init__(self, in_channels: int):
33        super().__init__()
34        self.in_channels = in_channels
35
36        self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
37
38        self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
39        self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
40        self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
41        self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
42
43    def attention(self, h_: Tensor) -> Tensor:
44        h_ = self.norm(h_)
45        q = self.q(h_)
46        k = self.k(h_)
47        v = self.v(h_)
48
49        b, c, h, w = q.shape
50        q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
51        k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
52        v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
53        h_ = nn.functional.scaled_dot_product_attention(q, k, v)
54
55        return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
56
57    def forward(self, x: Tensor) -> Tensor:
58        return x + self.proj_out(self.attention(x))

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

AttnBlock(in_channels: int)
32    def __init__(self, in_channels: int):
33        super().__init__()
34        self.in_channels = in_channels
35
36        self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
37
38        self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
39        self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
40        self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
41        self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_channels
norm
q
k
v
proj_out
def attention(self, h_: torch.Tensor) -> torch.Tensor:
43    def attention(self, h_: Tensor) -> Tensor:
44        h_ = self.norm(h_)
45        q = self.q(h_)
46        k = self.k(h_)
47        v = self.v(h_)
48
49        b, c, h, w = q.shape
50        q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
51        k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
52        v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
53        h_ = nn.functional.scaled_dot_product_attention(q, k, v)
54
55        return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: torch.Tensor) -> torch.Tensor:
57    def forward(self, x: Tensor) -> Tensor:
58        return x + self.proj_out(self.attention(x))

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ResnetBlock(torch.nn.modules.module.Module):
61class ResnetBlock(nn.Module):
62    def __init__(self, in_channels: int, out_channels: int):
63        super().__init__()
64        self.in_channels = in_channels
65        out_channels = in_channels if out_channels is None else out_channels
66        self.out_channels = out_channels
67
68        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
69        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
70        self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
71        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
72        if self.in_channels != self.out_channels:
73            self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
74
75    def forward(self, x):
76        h = x
77        h = self.norm1(h)
78        h = swish(h)
79        h = self.conv1(h)
80
81        h = self.norm2(h)
82        h = swish(h)
83        h = self.conv2(h)
84
85        if self.in_channels != self.out_channels:
86            x = self.nin_shortcut(x)
87
88        return x + h

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ResnetBlock(in_channels: int, out_channels: int)
62    def __init__(self, in_channels: int, out_channels: int):
63        super().__init__()
64        self.in_channels = in_channels
65        out_channels = in_channels if out_channels is None else out_channels
66        self.out_channels = out_channels
67
68        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
69        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
70        self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
71        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
72        if self.in_channels != self.out_channels:
73            self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_channels
out_channels
norm1
conv1
norm2
conv2
def forward(self, x):
75    def forward(self, x):
76        h = x
77        h = self.norm1(h)
78        h = swish(h)
79        h = self.conv1(h)
80
81        h = self.norm2(h)
82        h = swish(h)
83        h = self.conv2(h)
84
85        if self.in_channels != self.out_channels:
86            x = self.nin_shortcut(x)
87
88        return x + h

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Downsample(torch.nn.modules.module.Module):
 91class Downsample(nn.Module):
 92    def __init__(self, in_channels: int):
 93        super().__init__()
 94        # no asymmetric padding in torch conv, must do it ourselves
 95        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
 96
 97    def forward(self, x: Tensor):
 98        pad = (0, 1, 0, 1)
 99        x = nn.functional.pad(x, pad, mode="constant", value=0)
100        x = self.conv(x)
101        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Downsample(in_channels: int)
92    def __init__(self, in_channels: int):
93        super().__init__()
94        # no asymmetric padding in torch conv, must do it ourselves
95        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x: torch.Tensor):
 97    def forward(self, x: Tensor):
 98        pad = (0, 1, 0, 1)
 99        x = nn.functional.pad(x, pad, mode="constant", value=0)
100        x = self.conv(x)
101        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Upsample(torch.nn.modules.module.Module):
104class Upsample(nn.Module):
105    def __init__(self, in_channels: int):
106        super().__init__()
107        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
108
109    def forward(self, x: Tensor):
110        x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
111        x = self.conv(x)
112        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Upsample(in_channels: int)
105    def __init__(self, in_channels: int):
106        super().__init__()
107        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x: torch.Tensor):
109    def forward(self, x: Tensor):
110        x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
111        x = self.conv(x)
112        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Encoder(torch.nn.modules.module.Module):
115class Encoder(nn.Module):
116    def __init__(
117        self,
118        resolution: int,
119        in_channels: int,
120        ch: int,
121        ch_mult: list[int],
122        num_res_blocks: int,
123        z_channels: int,
124    ):
125        super().__init__()
126        self.ch = ch
127        self.num_resolutions = len(ch_mult)
128        self.num_res_blocks = num_res_blocks
129        self.resolution = resolution
130        self.in_channels = in_channels
131        # downsampling
132        self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
133
134        curr_res = resolution
135        in_ch_mult = (1,) + tuple(ch_mult)
136        self.in_ch_mult = in_ch_mult
137        self.down = nn.ModuleList()
138        block_in = self.ch
139        for i_level in range(self.num_resolutions):
140            block = nn.ModuleList()
141            attn = nn.ModuleList()
142            block_in = ch * in_ch_mult[i_level]
143            block_out = ch * ch_mult[i_level]
144            for _ in range(self.num_res_blocks):
145                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
146                block_in = block_out
147            down = nn.Module()
148            down.block = block
149            down.attn = attn
150            if i_level != self.num_resolutions - 1:
151                down.downsample = Downsample(block_in)
152                curr_res = curr_res // 2
153            self.down.append(down)
154
155        # middle
156        self.mid = nn.Module()
157        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
158        self.mid.attn_1 = AttnBlock(block_in)
159        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
160
161        # end
162        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
163        self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
164
165    def forward(self, x: Tensor) -> Tensor:
166        # downsampling
167        hs = [self.conv_in(x)]
168        for i_level in range(self.num_resolutions):
169            for i_block in range(self.num_res_blocks):
170                h = self.down[i_level].block[i_block](hs[-1])
171                if len(self.down[i_level].attn) > 0:
172                    h = self.down[i_level].attn[i_block](h)
173                hs.append(h)
174            if i_level != self.num_resolutions - 1:
175                hs.append(self.down[i_level].downsample(hs[-1]))
176
177        # middle
178        h = hs[-1]
179        h = self.mid.block_1(h)
180        h = self.mid.attn_1(h)
181        h = self.mid.block_2(h)
182        # end
183        h = self.norm_out(h)
184        h = swish(h)
185        h = self.conv_out(h)
186        return h

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Encoder( resolution: int, in_channels: int, ch: int, ch_mult: list[int], num_res_blocks: int, z_channels: int)
116    def __init__(
117        self,
118        resolution: int,
119        in_channels: int,
120        ch: int,
121        ch_mult: list[int],
122        num_res_blocks: int,
123        z_channels: int,
124    ):
125        super().__init__()
126        self.ch = ch
127        self.num_resolutions = len(ch_mult)
128        self.num_res_blocks = num_res_blocks
129        self.resolution = resolution
130        self.in_channels = in_channels
131        # downsampling
132        self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
133
134        curr_res = resolution
135        in_ch_mult = (1,) + tuple(ch_mult)
136        self.in_ch_mult = in_ch_mult
137        self.down = nn.ModuleList()
138        block_in = self.ch
139        for i_level in range(self.num_resolutions):
140            block = nn.ModuleList()
141            attn = nn.ModuleList()
142            block_in = ch * in_ch_mult[i_level]
143            block_out = ch * ch_mult[i_level]
144            for _ in range(self.num_res_blocks):
145                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
146                block_in = block_out
147            down = nn.Module()
148            down.block = block
149            down.attn = attn
150            if i_level != self.num_resolutions - 1:
151                down.downsample = Downsample(block_in)
152                curr_res = curr_res // 2
153            self.down.append(down)
154
155        # middle
156        self.mid = nn.Module()
157        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
158        self.mid.attn_1 = AttnBlock(block_in)
159        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
160
161        # end
162        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
163        self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

ch
num_resolutions
num_res_blocks
resolution
in_channels
conv_in
in_ch_mult
down
mid
norm_out
conv_out
def forward(self, x: torch.Tensor) -> torch.Tensor:
165    def forward(self, x: Tensor) -> Tensor:
166        # downsampling
167        hs = [self.conv_in(x)]
168        for i_level in range(self.num_resolutions):
169            for i_block in range(self.num_res_blocks):
170                h = self.down[i_level].block[i_block](hs[-1])
171                if len(self.down[i_level].attn) > 0:
172                    h = self.down[i_level].attn[i_block](h)
173                hs.append(h)
174            if i_level != self.num_resolutions - 1:
175                hs.append(self.down[i_level].downsample(hs[-1]))
176
177        # middle
178        h = hs[-1]
179        h = self.mid.block_1(h)
180        h = self.mid.attn_1(h)
181        h = self.mid.block_2(h)
182        # end
183        h = self.norm_out(h)
184        h = swish(h)
185        h = self.conv_out(h)
186        return h

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Decoder(torch.nn.modules.module.Module):
189class Decoder(nn.Module):
190    def __init__(
191        self,
192        ch: int,
193        out_ch: int,
194        ch_mult: list[int],
195        num_res_blocks: int,
196        in_channels: int,
197        resolution: int,
198        z_channels: int,
199    ):
200        super().__init__()
201        self.ch = ch
202        self.num_resolutions = len(ch_mult)
203        self.num_res_blocks = num_res_blocks
204        self.resolution = resolution
205        self.in_channels = in_channels
206        self.ffactor = 2 ** (self.num_resolutions - 1)
207
208        # compute in_ch_mult, block_in and curr_res at lowest res
209        block_in = ch * ch_mult[self.num_resolutions - 1]
210        curr_res = resolution // 2 ** (self.num_resolutions - 1)
211        self.z_shape = (1, z_channels, curr_res, curr_res)
212
213        # z to block_in
214        self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
215
216        # middle
217        self.mid = nn.Module()
218        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
219        self.mid.attn_1 = AttnBlock(block_in)
220        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
221
222        # upsampling
223        self.up = nn.ModuleList()
224        for i_level in reversed(range(self.num_resolutions)):
225            block = nn.ModuleList()
226            attn = nn.ModuleList()
227            block_out = ch * ch_mult[i_level]
228            for _ in range(self.num_res_blocks + 1):
229                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
230                block_in = block_out
231            up = nn.Module()
232            up.block = block
233            up.attn = attn
234            if i_level != 0:
235                up.upsample = Upsample(block_in)
236                curr_res = curr_res * 2
237            self.up.insert(0, up)  # prepend to get consistent order
238
239        # end
240        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
241        self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
242
243    def forward(self, z: Tensor) -> Tensor:
244        # get dtype for proper tracing
245        upscale_dtype = next(self.up.parameters()).dtype
246
247        # z to block_in
248        h = self.conv_in(z)
249
250        # middle
251        h = self.mid.block_1(h)
252        h = self.mid.attn_1(h)
253        h = self.mid.block_2(h)
254
255        # cast to proper dtype
256        h = h.to(upscale_dtype)
257        # upsampling
258        for i_level in reversed(range(self.num_resolutions)):
259            for i_block in range(self.num_res_blocks + 1):
260                h = self.up[i_level].block[i_block](h)
261                if len(self.up[i_level].attn) > 0:
262                    h = self.up[i_level].attn[i_block](h)
263            if i_level != 0:
264                h = self.up[i_level].upsample(h)
265
266        # end
267        h = self.norm_out(h)
268        h = swish(h)
269        h = self.conv_out(h)
270        return h

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Decoder( ch: int, out_ch: int, ch_mult: list[int], num_res_blocks: int, in_channels: int, resolution: int, z_channels: int)
190    def __init__(
191        self,
192        ch: int,
193        out_ch: int,
194        ch_mult: list[int],
195        num_res_blocks: int,
196        in_channels: int,
197        resolution: int,
198        z_channels: int,
199    ):
200        super().__init__()
201        self.ch = ch
202        self.num_resolutions = len(ch_mult)
203        self.num_res_blocks = num_res_blocks
204        self.resolution = resolution
205        self.in_channels = in_channels
206        self.ffactor = 2 ** (self.num_resolutions - 1)
207
208        # compute in_ch_mult, block_in and curr_res at lowest res
209        block_in = ch * ch_mult[self.num_resolutions - 1]
210        curr_res = resolution // 2 ** (self.num_resolutions - 1)
211        self.z_shape = (1, z_channels, curr_res, curr_res)
212
213        # z to block_in
214        self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
215
216        # middle
217        self.mid = nn.Module()
218        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
219        self.mid.attn_1 = AttnBlock(block_in)
220        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
221
222        # upsampling
223        self.up = nn.ModuleList()
224        for i_level in reversed(range(self.num_resolutions)):
225            block = nn.ModuleList()
226            attn = nn.ModuleList()
227            block_out = ch * ch_mult[i_level]
228            for _ in range(self.num_res_blocks + 1):
229                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
230                block_in = block_out
231            up = nn.Module()
232            up.block = block
233            up.attn = attn
234            if i_level != 0:
235                up.upsample = Upsample(block_in)
236                curr_res = curr_res * 2
237            self.up.insert(0, up)  # prepend to get consistent order
238
239        # end
240        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
241        self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

ch
num_resolutions
num_res_blocks
resolution
in_channels
ffactor
z_shape
conv_in
mid
up
norm_out
conv_out
def forward(self, z: torch.Tensor) -> torch.Tensor:
243    def forward(self, z: Tensor) -> Tensor:
244        # get dtype for proper tracing
245        upscale_dtype = next(self.up.parameters()).dtype
246
247        # z to block_in
248        h = self.conv_in(z)
249
250        # middle
251        h = self.mid.block_1(h)
252        h = self.mid.attn_1(h)
253        h = self.mid.block_2(h)
254
255        # cast to proper dtype
256        h = h.to(upscale_dtype)
257        # upsampling
258        for i_level in reversed(range(self.num_resolutions)):
259            for i_block in range(self.num_res_blocks + 1):
260                h = self.up[i_level].block[i_block](h)
261                if len(self.up[i_level].attn) > 0:
262                    h = self.up[i_level].attn[i_block](h)
263            if i_level != 0:
264                h = self.up[i_level].upsample(h)
265
266        # end
267        h = self.norm_out(h)
268        h = swish(h)
269        h = self.conv_out(h)
270        return h

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class DiagonalGaussian(torch.nn.modules.module.Module):
273class DiagonalGaussian(nn.Module):
274    def __init__(self, sample: bool = True, chunk_dim: int = 1):
275        super().__init__()
276        self.sample = sample
277        self.chunk_dim = chunk_dim
278
279    def forward(self, z: Tensor) -> Tensor:
280        mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
281        if self.sample:
282            std = torch.exp(0.5 * logvar)
283            return mean + std * torch.randn_like(mean)
284        else:
285            return mean

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

DiagonalGaussian(sample: bool = True, chunk_dim: int = 1)
274    def __init__(self, sample: bool = True, chunk_dim: int = 1):
275        super().__init__()
276        self.sample = sample
277        self.chunk_dim = chunk_dim

Initialize internal Module state, shared by both nn.Module and ScriptModule.

sample
chunk_dim
def forward(self, z: torch.Tensor) -> torch.Tensor:
279    def forward(self, z: Tensor) -> Tensor:
280        mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
281        if self.sample:
282            std = torch.exp(0.5 * logvar)
283            return mean + std * torch.randn_like(mean)
284        else:
285            return mean

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class AutoEncoder(torch.nn.modules.module.Module):
288class AutoEncoder(nn.Module):
289    def __init__(self, params: AutoEncoderParams, sample_z: bool = False):
290        super().__init__()
291        self.params = params
292        self.encoder = Encoder(
293            resolution=params.resolution,
294            in_channels=params.in_channels,
295            ch=params.ch,
296            ch_mult=params.ch_mult,
297            num_res_blocks=params.num_res_blocks,
298            z_channels=params.z_channels,
299        )
300        self.decoder = Decoder(
301            resolution=params.resolution,
302            in_channels=params.in_channels,
303            ch=params.ch,
304            out_ch=params.out_ch,
305            ch_mult=params.ch_mult,
306            num_res_blocks=params.num_res_blocks,
307            z_channels=params.z_channels,
308        )
309        self.reg = DiagonalGaussian(sample=sample_z)
310
311        self.scale_factor = params.scale_factor
312        self.shift_factor = params.shift_factor
313
314    def encode(self, x: Tensor) -> Tensor:
315        encoder_dtype = next(self.encoder.parameters()).dtype
316        x = x.to(encoder_dtype)  # Quality of Life:sample dtype always matches encoder dtype
317        z = self.reg(self.encoder(x))
318        z = self.scale_factor * (z - self.shift_factor)
319        return z
320
321    def decode(self, z: Tensor) -> Tensor:
322        z = z / self.scale_factor + self.shift_factor
323        return self.decoder(z)
324
325    def forward(self, x: Tensor) -> Tensor:
326        return self.decode(self.encode(x))

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

AutoEncoder( params: AutoEncoderParams, sample_z: bool = False)
289    def __init__(self, params: AutoEncoderParams, sample_z: bool = False):
290        super().__init__()
291        self.params = params
292        self.encoder = Encoder(
293            resolution=params.resolution,
294            in_channels=params.in_channels,
295            ch=params.ch,
296            ch_mult=params.ch_mult,
297            num_res_blocks=params.num_res_blocks,
298            z_channels=params.z_channels,
299        )
300        self.decoder = Decoder(
301            resolution=params.resolution,
302            in_channels=params.in_channels,
303            ch=params.ch,
304            out_ch=params.out_ch,
305            ch_mult=params.ch_mult,
306            num_res_blocks=params.num_res_blocks,
307            z_channels=params.z_channels,
308        )
309        self.reg = DiagonalGaussian(sample=sample_z)
310
311        self.scale_factor = params.scale_factor
312        self.shift_factor = params.shift_factor

Initialize internal Module state, shared by both nn.Module and ScriptModule.

params
encoder
decoder
reg
scale_factor
shift_factor
def encode(self, x: torch.Tensor) -> torch.Tensor:
314    def encode(self, x: Tensor) -> Tensor:
315        encoder_dtype = next(self.encoder.parameters()).dtype
316        x = x.to(encoder_dtype)  # Quality of Life:sample dtype always matches encoder dtype
317        z = self.reg(self.encoder(x))
318        z = self.scale_factor * (z - self.shift_factor)
319        return z
def decode(self, z: torch.Tensor) -> torch.Tensor:
321    def decode(self, z: Tensor) -> Tensor:
322        z = z / self.scale_factor + self.shift_factor
323        return self.decoder(z)
def forward(self, x: torch.Tensor) -> torch.Tensor:
325    def forward(self, x: Tensor) -> Tensor:
326        return self.decode(self.encode(x))

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.