divisor.flux2.autoencoder

  1# SPDX-License-Identifier:Apache-2.0
  2# original BFL Flux code from https://github.com/black-forest-labs/flux2
  3
  4from dataclasses import dataclass, field
  5import math
  6
  7from einops import rearrange
  8import torch
  9from torch import Tensor, nn
 10
 11
 12@dataclass
 13class AutoEncoderParams:
 14    resolution: int = 256
 15    in_channels: int = 3
 16    ch: int = 128
 17    out_ch: int = 3
 18    ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
 19    num_res_blocks: int = 2
 20    z_channels: int = 32
 21
 22
 23def swish(x: Tensor) -> Tensor:
 24    return x * torch.sigmoid(x)
 25
 26
 27class AttnBlock(nn.Module):
 28    def __init__(self, in_channels: int):
 29        super().__init__()
 30        self.in_channels = in_channels
 31
 32        self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
 33
 34        self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
 35        self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
 36        self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
 37        self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
 38
 39    def attention(self, h_: Tensor) -> Tensor:
 40        h_ = self.norm(h_)
 41        q = self.q(h_)
 42        k = self.k(h_)
 43        v = self.v(h_)
 44
 45        b, c, h, w = q.shape
 46        q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
 47        k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
 48        v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
 49        h_ = nn.functional.scaled_dot_product_attention(q, k, v)
 50
 51        return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
 52
 53    def forward(self, x: Tensor) -> Tensor:
 54        return x + self.proj_out(self.attention(x))
 55
 56
 57class ResnetBlock(nn.Module):
 58    def __init__(self, in_channels: int, out_channels: int):
 59        super().__init__()
 60        self.in_channels = in_channels
 61        out_channels = in_channels if out_channels is None else out_channels
 62        self.out_channels = out_channels
 63
 64        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
 65        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
 66        self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
 67        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
 68        if self.in_channels != self.out_channels:
 69            self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
 70
 71    def forward(self, x):
 72        h = x
 73        h = self.norm1(h)
 74        h = swish(h)
 75        h = self.conv1(h)
 76
 77        h = self.norm2(h)
 78        h = swish(h)
 79        h = self.conv2(h)
 80
 81        if self.in_channels != self.out_channels:
 82            x = self.nin_shortcut(x)
 83
 84        return x + h
 85
 86
 87class Downsample(nn.Module):
 88    def __init__(self, in_channels: int):
 89        super().__init__()
 90        # no asymmetric padding in torch conv, must do it ourselves
 91        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
 92
 93    def forward(self, x: Tensor):
 94        pad = (0, 1, 0, 1)
 95        x = nn.functional.pad(x, pad, mode="constant", value=0)
 96        x = self.conv(x)
 97        return x
 98
 99
100class Upsample(nn.Module):
101    def __init__(self, in_channels: int):
102        super().__init__()
103        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
104
105    def forward(self, x: Tensor):
106        x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
107        x = self.conv(x)
108        return x
109
110
111class Encoder(nn.Module):
112    def __init__(
113        self,
114        resolution: int,
115        in_channels: int,
116        ch: int,
117        ch_mult: list[int],
118        num_res_blocks: int,
119        z_channels: int,
120    ):
121        super().__init__()
122        self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
123        self.ch = ch
124        self.num_resolutions = len(ch_mult)
125        self.num_res_blocks = num_res_blocks
126        self.resolution = resolution
127        self.in_channels = in_channels
128        # downsampling
129        self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
130
131        curr_res = resolution
132        in_ch_mult = (1,) + tuple(ch_mult)
133        self.in_ch_mult = in_ch_mult
134        self.down = nn.ModuleList()
135        block_in = self.ch
136        for i_level in range(self.num_resolutions):
137            block = nn.ModuleList()
138            attn = nn.ModuleList()
139            block_in = ch * in_ch_mult[i_level]
140            block_out = ch * ch_mult[i_level]
141            for _ in range(self.num_res_blocks):
142                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
143                block_in = block_out
144            down = nn.Module()
145            down.block = block
146            down.attn = attn
147            if i_level != self.num_resolutions - 1:
148                down.downsample = Downsample(block_in)
149                curr_res = curr_res // 2
150            self.down.append(down)
151
152        # middle
153        self.mid = nn.Module()
154        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
155        self.mid.attn_1 = AttnBlock(block_in)
156        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
157
158        # end
159        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
160        self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
161
162    def forward(self, x: Tensor) -> Tensor:
163        # downsampling
164        hs = [self.conv_in(x)]
165        for i_level in range(self.num_resolutions):
166            for i_block in range(self.num_res_blocks):
167                h = self.down[i_level].block[i_block](hs[-1])  # type: ignore
168                if len(self.down[i_level].attn) > 0:  # type: ignore
169                    h = self.down[i_level].attn[i_block](h)  # type: ignore
170                hs.append(h)
171            if i_level != self.num_resolutions - 1:
172                hs.append(self.down[i_level].downsample(hs[-1]))  # type: ignore
173
174        # middle
175        h = hs[-1]
176        h = self.mid.block_1(h)  # type: ignore
177        h = self.mid.attn_1(h)  # type: ignore
178        h = self.mid.block_2(h)  # type: ignore
179        # end
180        h = self.norm_out(h)
181        h = swish(h)
182        h = self.conv_out(h)
183        h = self.quant_conv(h)
184        return h
185
186
187class Decoder(nn.Module):
188    def __init__(
189        self,
190        ch: int,
191        out_ch: int,
192        ch_mult: list[int],
193        num_res_blocks: int,
194        in_channels: int,
195        resolution: int,
196        z_channels: int,
197    ):
198        super().__init__()
199        self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
200        self.ch = ch
201        self.num_resolutions = len(ch_mult)
202        self.num_res_blocks = num_res_blocks
203        self.resolution = resolution
204        self.in_channels = in_channels
205        self.ffactor = 2 ** (self.num_resolutions - 1)
206
207        # compute in_ch_mult, block_in and curr_res at lowest res
208        block_in = ch * ch_mult[self.num_resolutions - 1]
209        curr_res = resolution // 2 ** (self.num_resolutions - 1)
210        self.z_shape = (1, z_channels, curr_res, curr_res)
211
212        # z to block_in
213        self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
214
215        # middle
216        self.mid = nn.Module()
217        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
218        self.mid.attn_1 = AttnBlock(block_in)
219        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
220
221        # upsampling
222        self.up = nn.ModuleList()
223        for i_level in reversed(range(self.num_resolutions)):
224            block = nn.ModuleList()
225            attn = nn.ModuleList()
226            block_out = ch * ch_mult[i_level]
227            for _ in range(self.num_res_blocks + 1):
228                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
229                block_in = block_out
230            up = nn.Module()
231            up.block = block
232            up.attn = attn
233            if i_level != 0:
234                up.upsample = Upsample(block_in)
235                curr_res = curr_res * 2
236            self.up.insert(0, up)  # prepend to get consistent order
237
238        # end
239        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
240        self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
241
242    def forward(self, z: Tensor) -> Tensor:
243        z = self.post_quant_conv(z)
244
245        # get dtype for proper tracing
246        upscale_dtype = next(self.up.parameters()).dtype
247
248        # z to block_in
249        h = self.conv_in(z)
250
251        # middle
252        h = self.mid.block_1(h)  # type: ignore
253        h = self.mid.attn_1(h)  # type: ignore
254        h = self.mid.block_2(h)  # type: ignore
255
256        # cast to proper dtype
257        h = h.to(upscale_dtype)
258        # upsampling
259        for i_level in reversed(range(self.num_resolutions)):
260            for i_block in range(self.num_res_blocks + 1):
261                h = self.up[i_level].block[i_block](h)  # type: ignore
262                if len(self.up[i_level].attn) > 0:  # type: ignore
263                    h = self.up[i_level].attn[i_block](h)  # type: ignore
264            if i_level != 0:
265                h = self.up[i_level].upsample(h)  # type: ignore
266
267        # end
268        h = self.norm_out(h)
269        h = swish(h)
270        h = self.conv_out(h)
271        return h
272
273
274class AutoEncoder(nn.Module):
275    def __init__(self, params: AutoEncoderParams):
276        super().__init__()
277        self.params = params
278        self.encoder = Encoder(
279            resolution=params.resolution,
280            in_channels=params.in_channels,
281            ch=params.ch,
282            ch_mult=params.ch_mult,
283            num_res_blocks=params.num_res_blocks,
284            z_channels=params.z_channels,
285        )
286        self.decoder = Decoder(
287            resolution=params.resolution,
288            in_channels=params.in_channels,
289            ch=params.ch,
290            out_ch=params.out_ch,
291            ch_mult=params.ch_mult,
292            num_res_blocks=params.num_res_blocks,
293            z_channels=params.z_channels,
294        )
295
296        self.bn_eps = 1e-4
297        self.bn_momentum = 0.1
298        self.ps = [2, 2]
299        self.bn = torch.nn.BatchNorm2d(
300            math.prod(self.ps) * params.z_channels,
301            eps=self.bn_eps,
302            momentum=self.bn_momentum,
303            affine=False,
304            track_running_stats=True,
305        )
306
307    def normalize(self, z):
308        self.bn.eval()
309        return self.bn(z)
310
311    def inv_normalize(self, z):
312        self.bn.eval()
313        s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps)  # type: ignore
314        m = self.bn.running_mean.view(1, -1, 1, 1)  # type: ignore
315        return z * s + m
316
317    def encode(self, x: Tensor) -> Tensor:
318        moments = self.encoder(x)
319        mean = torch.chunk(moments, 2, dim=1)[0]
320
321        z = rearrange(
322            mean,
323            "... c (i pi) (j pj)  -> ... (c pi pj) i j",
324            pi=self.ps[0],
325            pj=self.ps[1],
326        )
327        z = self.normalize(z)
328        return z
329
330    def decode(self, z: Tensor) -> Tensor:
331        z = self.inv_normalize(z)
332        z = rearrange(
333            z,
334            "... (c pi pj) i j -> ... c (i pi) (j pj)",
335            pi=self.ps[0],
336            pj=self.ps[1],
337        )
338        dec = self.decoder(z)
339        return dec
@dataclass
class AutoEncoderParams:
13@dataclass
14class AutoEncoderParams:
15    resolution: int = 256
16    in_channels: int = 3
17    ch: int = 128
18    out_ch: int = 3
19    ch_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
20    num_res_blocks: int = 2
21    z_channels: int = 32
AutoEncoderParams( resolution: int = 256, in_channels: int = 3, ch: int = 128, out_ch: int = 3, ch_mult: list[int] = <factory>, num_res_blocks: int = 2, z_channels: int = 32)
resolution: int = 256
in_channels: int = 3
ch: int = 128
out_ch: int = 3
ch_mult: list[int]
num_res_blocks: int = 2
z_channels: int = 32
def swish(x: torch.Tensor) -> torch.Tensor:
24def swish(x: Tensor) -> Tensor:
25    return x * torch.sigmoid(x)
class AttnBlock(torch.nn.modules.module.Module):
28class AttnBlock(nn.Module):
29    def __init__(self, in_channels: int):
30        super().__init__()
31        self.in_channels = in_channels
32
33        self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
34
35        self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36        self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
37        self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
38        self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
39
40    def attention(self, h_: Tensor) -> Tensor:
41        h_ = self.norm(h_)
42        q = self.q(h_)
43        k = self.k(h_)
44        v = self.v(h_)
45
46        b, c, h, w = q.shape
47        q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
48        k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
49        v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
50        h_ = nn.functional.scaled_dot_product_attention(q, k, v)
51
52        return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
53
54    def forward(self, x: Tensor) -> Tensor:
55        return x + self.proj_out(self.attention(x))

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

AttnBlock(in_channels: int)
29    def __init__(self, in_channels: int):
30        super().__init__()
31        self.in_channels = in_channels
32
33        self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
34
35        self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36        self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
37        self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
38        self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_channels
norm
q
k
v
proj_out
def attention(self, h_: torch.Tensor) -> torch.Tensor:
40    def attention(self, h_: Tensor) -> Tensor:
41        h_ = self.norm(h_)
42        q = self.q(h_)
43        k = self.k(h_)
44        v = self.v(h_)
45
46        b, c, h, w = q.shape
47        q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
48        k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
49        v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
50        h_ = nn.functional.scaled_dot_product_attention(q, k, v)
51
52        return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: torch.Tensor) -> torch.Tensor:
54    def forward(self, x: Tensor) -> Tensor:
55        return x + self.proj_out(self.attention(x))

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ResnetBlock(torch.nn.modules.module.Module):
58class ResnetBlock(nn.Module):
59    def __init__(self, in_channels: int, out_channels: int):
60        super().__init__()
61        self.in_channels = in_channels
62        out_channels = in_channels if out_channels is None else out_channels
63        self.out_channels = out_channels
64
65        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
66        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
67        self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
68        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
69        if self.in_channels != self.out_channels:
70            self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
71
72    def forward(self, x):
73        h = x
74        h = self.norm1(h)
75        h = swish(h)
76        h = self.conv1(h)
77
78        h = self.norm2(h)
79        h = swish(h)
80        h = self.conv2(h)
81
82        if self.in_channels != self.out_channels:
83            x = self.nin_shortcut(x)
84
85        return x + h

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ResnetBlock(in_channels: int, out_channels: int)
59    def __init__(self, in_channels: int, out_channels: int):
60        super().__init__()
61        self.in_channels = in_channels
62        out_channels = in_channels if out_channels is None else out_channels
63        self.out_channels = out_channels
64
65        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
66        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
67        self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
68        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
69        if self.in_channels != self.out_channels:
70            self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

in_channels
out_channels
norm1
conv1
norm2
conv2
def forward(self, x):
72    def forward(self, x):
73        h = x
74        h = self.norm1(h)
75        h = swish(h)
76        h = self.conv1(h)
77
78        h = self.norm2(h)
79        h = swish(h)
80        h = self.conv2(h)
81
82        if self.in_channels != self.out_channels:
83            x = self.nin_shortcut(x)
84
85        return x + h

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Downsample(torch.nn.modules.module.Module):
88class Downsample(nn.Module):
89    def __init__(self, in_channels: int):
90        super().__init__()
91        # no asymmetric padding in torch conv, must do it ourselves
92        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
93
94    def forward(self, x: Tensor):
95        pad = (0, 1, 0, 1)
96        x = nn.functional.pad(x, pad, mode="constant", value=0)
97        x = self.conv(x)
98        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Downsample(in_channels: int)
89    def __init__(self, in_channels: int):
90        super().__init__()
91        # no asymmetric padding in torch conv, must do it ourselves
92        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x: torch.Tensor):
94    def forward(self, x: Tensor):
95        pad = (0, 1, 0, 1)
96        x = nn.functional.pad(x, pad, mode="constant", value=0)
97        x = self.conv(x)
98        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Upsample(torch.nn.modules.module.Module):
101class Upsample(nn.Module):
102    def __init__(self, in_channels: int):
103        super().__init__()
104        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
105
106    def forward(self, x: Tensor):
107        x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
108        x = self.conv(x)
109        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Upsample(in_channels: int)
102    def __init__(self, in_channels: int):
103        super().__init__()
104        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x: torch.Tensor):
106    def forward(self, x: Tensor):
107        x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
108        x = self.conv(x)
109        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Encoder(torch.nn.modules.module.Module):
112class Encoder(nn.Module):
113    def __init__(
114        self,
115        resolution: int,
116        in_channels: int,
117        ch: int,
118        ch_mult: list[int],
119        num_res_blocks: int,
120        z_channels: int,
121    ):
122        super().__init__()
123        self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
124        self.ch = ch
125        self.num_resolutions = len(ch_mult)
126        self.num_res_blocks = num_res_blocks
127        self.resolution = resolution
128        self.in_channels = in_channels
129        # downsampling
130        self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
131
132        curr_res = resolution
133        in_ch_mult = (1,) + tuple(ch_mult)
134        self.in_ch_mult = in_ch_mult
135        self.down = nn.ModuleList()
136        block_in = self.ch
137        for i_level in range(self.num_resolutions):
138            block = nn.ModuleList()
139            attn = nn.ModuleList()
140            block_in = ch * in_ch_mult[i_level]
141            block_out = ch * ch_mult[i_level]
142            for _ in range(self.num_res_blocks):
143                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
144                block_in = block_out
145            down = nn.Module()
146            down.block = block
147            down.attn = attn
148            if i_level != self.num_resolutions - 1:
149                down.downsample = Downsample(block_in)
150                curr_res = curr_res // 2
151            self.down.append(down)
152
153        # middle
154        self.mid = nn.Module()
155        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
156        self.mid.attn_1 = AttnBlock(block_in)
157        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
158
159        # end
160        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
161        self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
162
163    def forward(self, x: Tensor) -> Tensor:
164        # downsampling
165        hs = [self.conv_in(x)]
166        for i_level in range(self.num_resolutions):
167            for i_block in range(self.num_res_blocks):
168                h = self.down[i_level].block[i_block](hs[-1])  # type: ignore
169                if len(self.down[i_level].attn) > 0:  # type: ignore
170                    h = self.down[i_level].attn[i_block](h)  # type: ignore
171                hs.append(h)
172            if i_level != self.num_resolutions - 1:
173                hs.append(self.down[i_level].downsample(hs[-1]))  # type: ignore
174
175        # middle
176        h = hs[-1]
177        h = self.mid.block_1(h)  # type: ignore
178        h = self.mid.attn_1(h)  # type: ignore
179        h = self.mid.block_2(h)  # type: ignore
180        # end
181        h = self.norm_out(h)
182        h = swish(h)
183        h = self.conv_out(h)
184        h = self.quant_conv(h)
185        return h

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Encoder( resolution: int, in_channels: int, ch: int, ch_mult: list[int], num_res_blocks: int, z_channels: int)
113    def __init__(
114        self,
115        resolution: int,
116        in_channels: int,
117        ch: int,
118        ch_mult: list[int],
119        num_res_blocks: int,
120        z_channels: int,
121    ):
122        super().__init__()
123        self.quant_conv = torch.nn.Conv2d(2 * z_channels, 2 * z_channels, 1)
124        self.ch = ch
125        self.num_resolutions = len(ch_mult)
126        self.num_res_blocks = num_res_blocks
127        self.resolution = resolution
128        self.in_channels = in_channels
129        # downsampling
130        self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
131
132        curr_res = resolution
133        in_ch_mult = (1,) + tuple(ch_mult)
134        self.in_ch_mult = in_ch_mult
135        self.down = nn.ModuleList()
136        block_in = self.ch
137        for i_level in range(self.num_resolutions):
138            block = nn.ModuleList()
139            attn = nn.ModuleList()
140            block_in = ch * in_ch_mult[i_level]
141            block_out = ch * ch_mult[i_level]
142            for _ in range(self.num_res_blocks):
143                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
144                block_in = block_out
145            down = nn.Module()
146            down.block = block
147            down.attn = attn
148            if i_level != self.num_resolutions - 1:
149                down.downsample = Downsample(block_in)
150                curr_res = curr_res // 2
151            self.down.append(down)
152
153        # middle
154        self.mid = nn.Module()
155        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
156        self.mid.attn_1 = AttnBlock(block_in)
157        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
158
159        # end
160        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
161        self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

quant_conv
ch
num_resolutions
num_res_blocks
resolution
in_channels
conv_in
in_ch_mult
down
mid
norm_out
conv_out
def forward(self, x: torch.Tensor) -> torch.Tensor:
163    def forward(self, x: Tensor) -> Tensor:
164        # downsampling
165        hs = [self.conv_in(x)]
166        for i_level in range(self.num_resolutions):
167            for i_block in range(self.num_res_blocks):
168                h = self.down[i_level].block[i_block](hs[-1])  # type: ignore
169                if len(self.down[i_level].attn) > 0:  # type: ignore
170                    h = self.down[i_level].attn[i_block](h)  # type: ignore
171                hs.append(h)
172            if i_level != self.num_resolutions - 1:
173                hs.append(self.down[i_level].downsample(hs[-1]))  # type: ignore
174
175        # middle
176        h = hs[-1]
177        h = self.mid.block_1(h)  # type: ignore
178        h = self.mid.attn_1(h)  # type: ignore
179        h = self.mid.block_2(h)  # type: ignore
180        # end
181        h = self.norm_out(h)
182        h = swish(h)
183        h = self.conv_out(h)
184        h = self.quant_conv(h)
185        return h

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Decoder(torch.nn.modules.module.Module):
188class Decoder(nn.Module):
189    def __init__(
190        self,
191        ch: int,
192        out_ch: int,
193        ch_mult: list[int],
194        num_res_blocks: int,
195        in_channels: int,
196        resolution: int,
197        z_channels: int,
198    ):
199        super().__init__()
200        self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
201        self.ch = ch
202        self.num_resolutions = len(ch_mult)
203        self.num_res_blocks = num_res_blocks
204        self.resolution = resolution
205        self.in_channels = in_channels
206        self.ffactor = 2 ** (self.num_resolutions - 1)
207
208        # compute in_ch_mult, block_in and curr_res at lowest res
209        block_in = ch * ch_mult[self.num_resolutions - 1]
210        curr_res = resolution // 2 ** (self.num_resolutions - 1)
211        self.z_shape = (1, z_channels, curr_res, curr_res)
212
213        # z to block_in
214        self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
215
216        # middle
217        self.mid = nn.Module()
218        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
219        self.mid.attn_1 = AttnBlock(block_in)
220        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
221
222        # upsampling
223        self.up = nn.ModuleList()
224        for i_level in reversed(range(self.num_resolutions)):
225            block = nn.ModuleList()
226            attn = nn.ModuleList()
227            block_out = ch * ch_mult[i_level]
228            for _ in range(self.num_res_blocks + 1):
229                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
230                block_in = block_out
231            up = nn.Module()
232            up.block = block
233            up.attn = attn
234            if i_level != 0:
235                up.upsample = Upsample(block_in)
236                curr_res = curr_res * 2
237            self.up.insert(0, up)  # prepend to get consistent order
238
239        # end
240        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
241        self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
242
243    def forward(self, z: Tensor) -> Tensor:
244        z = self.post_quant_conv(z)
245
246        # get dtype for proper tracing
247        upscale_dtype = next(self.up.parameters()).dtype
248
249        # z to block_in
250        h = self.conv_in(z)
251
252        # middle
253        h = self.mid.block_1(h)  # type: ignore
254        h = self.mid.attn_1(h)  # type: ignore
255        h = self.mid.block_2(h)  # type: ignore
256
257        # cast to proper dtype
258        h = h.to(upscale_dtype)
259        # upsampling
260        for i_level in reversed(range(self.num_resolutions)):
261            for i_block in range(self.num_res_blocks + 1):
262                h = self.up[i_level].block[i_block](h)  # type: ignore
263                if len(self.up[i_level].attn) > 0:  # type: ignore
264                    h = self.up[i_level].attn[i_block](h)  # type: ignore
265            if i_level != 0:
266                h = self.up[i_level].upsample(h)  # type: ignore
267
268        # end
269        h = self.norm_out(h)
270        h = swish(h)
271        h = self.conv_out(h)
272        return h

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Decoder( ch: int, out_ch: int, ch_mult: list[int], num_res_blocks: int, in_channels: int, resolution: int, z_channels: int)
189    def __init__(
190        self,
191        ch: int,
192        out_ch: int,
193        ch_mult: list[int],
194        num_res_blocks: int,
195        in_channels: int,
196        resolution: int,
197        z_channels: int,
198    ):
199        super().__init__()
200        self.post_quant_conv = torch.nn.Conv2d(z_channels, z_channels, 1)
201        self.ch = ch
202        self.num_resolutions = len(ch_mult)
203        self.num_res_blocks = num_res_blocks
204        self.resolution = resolution
205        self.in_channels = in_channels
206        self.ffactor = 2 ** (self.num_resolutions - 1)
207
208        # compute in_ch_mult, block_in and curr_res at lowest res
209        block_in = ch * ch_mult[self.num_resolutions - 1]
210        curr_res = resolution // 2 ** (self.num_resolutions - 1)
211        self.z_shape = (1, z_channels, curr_res, curr_res)
212
213        # z to block_in
214        self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
215
216        # middle
217        self.mid = nn.Module()
218        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
219        self.mid.attn_1 = AttnBlock(block_in)
220        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
221
222        # upsampling
223        self.up = nn.ModuleList()
224        for i_level in reversed(range(self.num_resolutions)):
225            block = nn.ModuleList()
226            attn = nn.ModuleList()
227            block_out = ch * ch_mult[i_level]
228            for _ in range(self.num_res_blocks + 1):
229                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
230                block_in = block_out
231            up = nn.Module()
232            up.block = block
233            up.attn = attn
234            if i_level != 0:
235                up.upsample = Upsample(block_in)
236                curr_res = curr_res * 2
237            self.up.insert(0, up)  # prepend to get consistent order
238
239        # end
240        self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
241        self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

post_quant_conv
ch
num_resolutions
num_res_blocks
resolution
in_channels
ffactor
z_shape
conv_in
mid
up
norm_out
conv_out
def forward(self, z: torch.Tensor) -> torch.Tensor:
243    def forward(self, z: Tensor) -> Tensor:
244        z = self.post_quant_conv(z)
245
246        # get dtype for proper tracing
247        upscale_dtype = next(self.up.parameters()).dtype
248
249        # z to block_in
250        h = self.conv_in(z)
251
252        # middle
253        h = self.mid.block_1(h)  # type: ignore
254        h = self.mid.attn_1(h)  # type: ignore
255        h = self.mid.block_2(h)  # type: ignore
256
257        # cast to proper dtype
258        h = h.to(upscale_dtype)
259        # upsampling
260        for i_level in reversed(range(self.num_resolutions)):
261            for i_block in range(self.num_res_blocks + 1):
262                h = self.up[i_level].block[i_block](h)  # type: ignore
263                if len(self.up[i_level].attn) > 0:  # type: ignore
264                    h = self.up[i_level].attn[i_block](h)  # type: ignore
265            if i_level != 0:
266                h = self.up[i_level].upsample(h)  # type: ignore
267
268        # end
269        h = self.norm_out(h)
270        h = swish(h)
271        h = self.conv_out(h)
272        return h

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class AutoEncoder(torch.nn.modules.module.Module):
275class AutoEncoder(nn.Module):
276    def __init__(self, params: AutoEncoderParams):
277        super().__init__()
278        self.params = params
279        self.encoder = Encoder(
280            resolution=params.resolution,
281            in_channels=params.in_channels,
282            ch=params.ch,
283            ch_mult=params.ch_mult,
284            num_res_blocks=params.num_res_blocks,
285            z_channels=params.z_channels,
286        )
287        self.decoder = Decoder(
288            resolution=params.resolution,
289            in_channels=params.in_channels,
290            ch=params.ch,
291            out_ch=params.out_ch,
292            ch_mult=params.ch_mult,
293            num_res_blocks=params.num_res_blocks,
294            z_channels=params.z_channels,
295        )
296
297        self.bn_eps = 1e-4
298        self.bn_momentum = 0.1
299        self.ps = [2, 2]
300        self.bn = torch.nn.BatchNorm2d(
301            math.prod(self.ps) * params.z_channels,
302            eps=self.bn_eps,
303            momentum=self.bn_momentum,
304            affine=False,
305            track_running_stats=True,
306        )
307
308    def normalize(self, z):
309        self.bn.eval()
310        return self.bn(z)
311
312    def inv_normalize(self, z):
313        self.bn.eval()
314        s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps)  # type: ignore
315        m = self.bn.running_mean.view(1, -1, 1, 1)  # type: ignore
316        return z * s + m
317
318    def encode(self, x: Tensor) -> Tensor:
319        moments = self.encoder(x)
320        mean = torch.chunk(moments, 2, dim=1)[0]
321
322        z = rearrange(
323            mean,
324            "... c (i pi) (j pj)  -> ... (c pi pj) i j",
325            pi=self.ps[0],
326            pj=self.ps[1],
327        )
328        z = self.normalize(z)
329        return z
330
331    def decode(self, z: Tensor) -> Tensor:
332        z = self.inv_normalize(z)
333        z = rearrange(
334            z,
335            "... (c pi pj) i j -> ... c (i pi) (j pj)",
336            pi=self.ps[0],
337            pj=self.ps[1],
338        )
339        dec = self.decoder(z)
340        return dec

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

AutoEncoder(params: AutoEncoderParams)
276    def __init__(self, params: AutoEncoderParams):
277        super().__init__()
278        self.params = params
279        self.encoder = Encoder(
280            resolution=params.resolution,
281            in_channels=params.in_channels,
282            ch=params.ch,
283            ch_mult=params.ch_mult,
284            num_res_blocks=params.num_res_blocks,
285            z_channels=params.z_channels,
286        )
287        self.decoder = Decoder(
288            resolution=params.resolution,
289            in_channels=params.in_channels,
290            ch=params.ch,
291            out_ch=params.out_ch,
292            ch_mult=params.ch_mult,
293            num_res_blocks=params.num_res_blocks,
294            z_channels=params.z_channels,
295        )
296
297        self.bn_eps = 1e-4
298        self.bn_momentum = 0.1
299        self.ps = [2, 2]
300        self.bn = torch.nn.BatchNorm2d(
301            math.prod(self.ps) * params.z_channels,
302            eps=self.bn_eps,
303            momentum=self.bn_momentum,
304            affine=False,
305            track_running_stats=True,
306        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

params
encoder
decoder
bn_eps
bn_momentum
ps
bn
def normalize(self, z):
308    def normalize(self, z):
309        self.bn.eval()
310        return self.bn(z)
def inv_normalize(self, z):
312    def inv_normalize(self, z):
313        self.bn.eval()
314        s = torch.sqrt(self.bn.running_var.view(1, -1, 1, 1) + self.bn_eps)  # type: ignore
315        m = self.bn.running_mean.view(1, -1, 1, 1)  # type: ignore
316        return z * s + m
def encode(self, x: torch.Tensor) -> torch.Tensor:
318    def encode(self, x: Tensor) -> Tensor:
319        moments = self.encoder(x)
320        mean = torch.chunk(moments, 2, dim=1)[0]
321
322        z = rearrange(
323            mean,
324            "... c (i pi) (j pj)  -> ... (c pi pj) i j",
325            pi=self.ps[0],
326            pj=self.ps[1],
327        )
328        z = self.normalize(z)
329        return z
def decode(self, z: torch.Tensor) -> torch.Tensor:
331    def decode(self, z: Tensor) -> Tensor:
332        z = self.inv_normalize(z)
333        z = rearrange(
334            z,
335            "... (c pi pj) i j -> ... c (i pi) (j pj)",
336            pi=self.ps[0],
337            pj=self.ps[1],
338        )
339        dec = self.decoder(z)
340        return dec