divisor.acestep.schedulers.scheduling_flow_match_euler_discrete

  1# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
  2#
  3# Licensed under the Apache License, Version 2.0 (the "License");
  4# you may not use this file except in compliance with the License.
  5# You may obtain a copy of the License at
  6#
  7#     http://www.apache.org/licenses/LICENSE-2.0
  8#
  9# Unless required by applicable law or agreed to in writing, software
 10# distributed under the License is distributed on an "AS IS" BASIS,
 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12# See the License for the specific language governing permissions and
 13# limitations under the License.
 14
 15import math
 16from dataclasses import dataclass
 17from typing import List, Optional, Tuple, Union
 18
 19import numpy as np
 20import torch
 21
 22from diffusers.configuration_utils import ConfigMixin, register_to_config
 23from diffusers.utils import BaseOutput, logging
 24from diffusers.schedulers.scheduling_utils import SchedulerMixin
 25
 26
 27logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
 28
 29
 30@dataclass
 31class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
 32    """
 33    Output class for the scheduler's `step` function output.
 34
 35    Args:
 36        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
 37            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
 38            denoising loop.
 39    """
 40
 41    prev_sample: torch.FloatTensor
 42
 43
 44class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
 45    """
 46    Euler scheduler.
 47
 48    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
 49    methods the library implements for all schedulers such as loading and saving.
 50
 51    Args:
 52        num_train_timesteps (`int`, defaults to 1000):
 53            The number of diffusion steps to train the model.
 54        timestep_spacing (`str`, defaults to `"linspace"`):
 55            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
 56            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
 57        shift (`float`, defaults to 1.0):
 58            The shift value for the timestep schedule.
 59    """
 60
 61    _compatibles = []
 62    order = 1
 63
 64    @register_to_config
 65    def __init__(
 66        self,
 67        num_train_timesteps: int = 1000,
 68        shift: float = 1.0,
 69        use_dynamic_shifting=False,
 70        base_shift: Optional[float] = 0.5,
 71        max_shift: Optional[float] = 1.15,
 72        base_image_seq_len: Optional[int] = 256,
 73        max_image_seq_len: Optional[int] = 4096,
 74        sigma_max: Optional[float] = 1.0,
 75    ):
 76        timesteps = np.linspace(
 77            1.0, sigma_max*num_train_timesteps, num_train_timesteps, dtype=np.float32
 78        )[::-1].copy()
 79        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
 80
 81        sigmas = timesteps / num_train_timesteps
 82        if not use_dynamic_shifting:
 83            # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
 84            sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
 85
 86        self.timesteps = sigmas * num_train_timesteps
 87
 88        self._step_index = None
 89        self._begin_index = None
 90
 91        self.sigmas = sigmas.to("cpu")  # to avoid too much CPU/GPU communication
 92        self.sigma_min = self.sigmas[-1].item()
 93        self.sigma_max = self.sigmas[0].item()
 94
 95    @property
 96    def step_index(self):
 97        """
 98        The index counter for current timestep. It will increase 1 after each scheduler step.
 99        """
100        return self._step_index
101
102    @property
103    def begin_index(self):
104        """
105        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
106        """
107        return self._begin_index
108
109    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
110    def set_begin_index(self, begin_index: int = 0):
111        """
112        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
113
114        Args:
115            begin_index (`int`):
116                The begin index for the scheduler.
117        """
118        self._begin_index = begin_index
119
120    def scale_noise(
121        self,
122        sample: torch.FloatTensor,
123        timestep: Union[float, torch.FloatTensor],
124        noise: Optional[torch.FloatTensor] = None,
125    ) -> torch.FloatTensor:
126        """
127        Forward process in flow-matching
128
129        Args:
130            sample (`torch.FloatTensor`):
131                The input sample.
132            timestep (`int`, *optional*):
133                The current timestep in the diffusion chain.
134
135        Returns:
136            `torch.FloatTensor`:
137                A scaled input sample.
138        """
139        # Make sure sigmas and timesteps have the same device and dtype as original_samples
140        sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
141
142        if sample.device.type == "mps" and torch.is_floating_point(timestep):
143            # mps does not support float64
144            schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
145            timestep = timestep.to(sample.device, dtype=torch.float32)
146        else:
147            schedule_timesteps = self.timesteps.to(sample.device)
148            timestep = timestep.to(sample.device)
149
150        # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
151        if self.begin_index is None:
152            step_indices = [
153                self.index_for_timestep(t, schedule_timesteps) for t in timestep
154            ]
155        elif self.step_index is not None:
156            # add_noise is called after first denoising step (for inpainting)
157            step_indices = [self.step_index] * timestep.shape[0]
158        else:
159            # add noise is called before first denoising step to create initial latent(img2img)
160            step_indices = [self.begin_index] * timestep.shape[0]
161
162        sigma = sigmas[step_indices].flatten()
163        while len(sigma.shape) < len(sample.shape):
164            sigma = sigma.unsqueeze(-1)
165
166        sample = sigma * noise + (1.0 - sigma) * sample
167
168        return sample
169
170    def _sigma_to_t(self, sigma):
171        return sigma * self.config.num_train_timesteps
172
173    def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
174        return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
175
176    def set_timesteps(
177        self,
178        num_inference_steps: int = None,
179        device: Union[str, torch.device] = None,
180        sigmas: Optional[List[float]] = None,
181        mu: Optional[float] = None,
182    ):
183        """
184        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
185
186        Args:
187            num_inference_steps (`int`):
188                The number of diffusion steps used when generating samples with a pre-trained model.
189            device (`str` or `torch.device`, *optional*):
190                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
191        """
192
193        if self.config.use_dynamic_shifting and mu is None:
194            raise ValueError(
195                " you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
196            )
197
198        if sigmas is None:
199            self.num_inference_steps = num_inference_steps
200            timesteps = np.linspace(
201                self._sigma_to_t(self.sigma_max),
202                self._sigma_to_t(self.sigma_min),
203                num_inference_steps,
204            )
205
206            sigmas = timesteps / self.config.num_train_timesteps
207
208        if self.config.use_dynamic_shifting:
209            sigmas = self.time_shift(mu, 1.0, sigmas)
210        else:
211            sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
212
213        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
214        timesteps = sigmas * self.config.num_train_timesteps
215
216        self.timesteps = timesteps.to(device=device)
217        self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
218
219        self._step_index = None
220        self._begin_index = None
221
222    def index_for_timestep(self, timestep, schedule_timesteps=None):
223        if schedule_timesteps is None:
224            schedule_timesteps = self.timesteps
225
226        indices = (schedule_timesteps == timestep).nonzero()
227
228        # The sigma index that is taken for the **very** first `step`
229        # is always the second index (or the last index if there is only 1)
230        # This way we can ensure we don't accidentally skip a sigma in
231        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
232        pos = 1 if len(indices) > 1 else 0
233
234        return indices[pos].item()
235
236    def _init_step_index(self, timestep):
237        if self.begin_index is None:
238            if isinstance(timestep, torch.Tensor):
239                timestep = timestep.to(self.timesteps.device)
240            self._step_index = self.index_for_timestep(timestep)
241        else:
242            self._step_index = self._begin_index
243
244    def step(
245        self,
246        model_output: torch.FloatTensor,
247        timestep: Union[float, torch.FloatTensor],
248        sample: torch.FloatTensor,
249        s_churn: float = 0.0,
250        s_tmin: float = 0.0,
251        s_tmax: float = float("inf"),
252        s_noise: float = 1.0,
253        generator: Optional[torch.Generator] = None,
254        return_dict: bool = True,
255        omega: Union[float, np.array] = 0.0,
256    ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
257        """
258        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
259        process from the learned model outputs (most often the predicted noise).
260
261        Args:
262            model_output (`torch.FloatTensor`):
263                The direct output from learned diffusion model.
264            timestep (`float`):
265                The current discrete timestep in the diffusion chain.
266            sample (`torch.FloatTensor`):
267                A current instance of a sample created by the diffusion process.
268            s_churn (`float`):
269            s_tmin  (`float`):
270            s_tmax  (`float`):
271            s_noise (`float`, defaults to 1.0):
272                Scaling factor for noise added to the sample.
273            generator (`torch.Generator`, *optional*):
274                A random number generator.
275            return_dict (`bool`):
276                Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
277                tuple.
278
279        Returns:
280            [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
281                If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
282                returned, otherwise a tuple is returned where the first element is the sample tensor.
283        """
284
285        def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
286            # L = Lower bound
287            # U = Upper bound
288            # x_0 = Midpoint (x corresponding to y = 1.0)
289            # k = Steepness, can adjust based on preference
290
291            if isinstance(x, torch.Tensor):
292                device_ = x.device
293                x = x.to(torch.float).cpu().numpy()
294
295            new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
296
297            if isinstance(new_x, np.ndarray):
298                new_x = torch.from_numpy(new_x).to(device_)
299            return new_x
300
301        self.omega_bef_rescale = omega
302        omega = logistic_function(omega, k=0.1)
303        self.omega_aft_rescale = omega
304
305        if (
306            isinstance(timestep, int)
307            or isinstance(timestep, torch.IntTensor)
308            or isinstance(timestep, torch.LongTensor)
309        ):
310            raise ValueError(
311                (
312                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
313                    " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
314                    " one of the `scheduler.timesteps` as a timestep."
315                ),
316            )
317
318        if self.step_index is None:
319            self._init_step_index(timestep)
320
321        # Upcast to avoid precision issues when computing prev_sample
322        sample = sample.to(torch.float32)
323
324        sigma = self.sigmas[self.step_index]
325        sigma_next = self.sigmas[self.step_index + 1]
326
327        ## --
328        ## mean shift 1
329        dx = (sigma_next - sigma) * model_output
330        m = dx.mean()
331        # print(dx.shape) # torch.Size([1, 16, 128, 128])
332        # print(f'm: {m}') # m: -0.0014209747314453125
333        # raise NotImplementedError
334        dx_ = (dx - m) * omega + m
335        prev_sample = sample + dx_
336
337        # ## --
338        # ## mean shift 2
339        # m = model_output.mean()
340        # model_output_ = (model_output - m) * omega + m
341        # prev_sample = sample + (sigma_next - sigma) * model_output_
342
343        # ## --
344        # ## original
345        # prev_sample = sample + (sigma_next - sigma) * model_output * omega
346
347        # ## --
348        # ## spatial mean 1
349        # dx = (sigma_next - sigma) * model_output
350        # m = dx.mean(dim=(0, 1), keepdim=True)
351        # # print(dx.shape) # torch.Size([1, 16, 128, 128])
352        # # print(m.shape) # torch.Size([1, 1, 128, 128])
353        # # raise NotImplementedError
354        # dx_ = (dx - m) * omega + m
355        # prev_sample = sample + dx_
356
357        # ## --
358        # ## spatial mean 2
359        # m = model_output.mean(dim=(0, 1), keepdim=True)
360        # model_output_ = (model_output - m) * omega + m
361        # prev_sample = sample + (sigma_next - sigma) * model_output_
362
363        # ## --
364        # ## channel mean 1
365        # m = model_output.mean(dim=(2, 3), keepdim=True)
366        # # print(m.shape) # torch.Size([1, 16, 1, 1])
367        # model_output_ = (model_output - m) * omega + m
368        # prev_sample = sample + (sigma_next - sigma) * model_output_
369
370        # ## --
371        # ## channel mean 2
372        # dx = (sigma_next - sigma) * model_output
373        # m = dx.mean(dim=(2, 3), keepdim=True)
374        # # print(m.shape) # torch.Size([1, 16, 1, 1])
375        # dx_ = (dx - m) * omega + m
376        # prev_sample = sample + dx_
377
378        # ## --
379        # ## keep sample mean
380        # m_tgt = sample.mean()
381        # prev_sample_ = sample + (sigma_next - sigma) * model_output * omega
382        # m_src = prev_sample_.mean()
383        # prev_sample = prev_sample_ - m_src + m_tgt
384
385        # ## --
386        # ## test
387        # # print(sample.mean())
388        # prev_sample = sample + (sigma_next - sigma) * model_output * omega
389        # # raise NotImplementedError
390
391        # Cast sample back to model compatible dtype
392        prev_sample = prev_sample.to(model_output.dtype)
393
394        # upon completion increase step index by one
395        self._step_index += 1
396
397        if not return_dict:
398            return (prev_sample,)
399
400        return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
401
402    def __len__(self):
403        return self.config.num_train_timesteps
@dataclass
class FlowMatchEulerDiscreteSchedulerOutput(diffusers.utils.outputs.BaseOutput):
31@dataclass
32class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
33    """
34    Output class for the scheduler's `step` function output.
35
36    Args:
37        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
39            denoising loop.
40    """
41
42    prev_sample: torch.FloatTensor

Output class for the scheduler's step function output.

Args: prev_sample (torch.FloatTensor of shape (batch_size, num_channels, height, width) for images): Computed sample (x_{t-1}) of previous timestep. prev_sample should be used as next model input in the denoising loop.

prev_sample: torch.FloatTensor
class FlowMatchEulerDiscreteScheduler(diffusers.schedulers.scheduling_utils.SchedulerMixin, diffusers.configuration_utils.ConfigMixin):
 45class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
 46    """
 47    Euler scheduler.
 48
 49    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
 50    methods the library implements for all schedulers such as loading and saving.
 51
 52    Args:
 53        num_train_timesteps (`int`, defaults to 1000):
 54            The number of diffusion steps to train the model.
 55        timestep_spacing (`str`, defaults to `"linspace"`):
 56            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
 57            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
 58        shift (`float`, defaults to 1.0):
 59            The shift value for the timestep schedule.
 60    """
 61
 62    _compatibles = []
 63    order = 1
 64
 65    @register_to_config
 66    def __init__(
 67        self,
 68        num_train_timesteps: int = 1000,
 69        shift: float = 1.0,
 70        use_dynamic_shifting=False,
 71        base_shift: Optional[float] = 0.5,
 72        max_shift: Optional[float] = 1.15,
 73        base_image_seq_len: Optional[int] = 256,
 74        max_image_seq_len: Optional[int] = 4096,
 75        sigma_max: Optional[float] = 1.0,
 76    ):
 77        timesteps = np.linspace(
 78            1.0, sigma_max*num_train_timesteps, num_train_timesteps, dtype=np.float32
 79        )[::-1].copy()
 80        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
 81
 82        sigmas = timesteps / num_train_timesteps
 83        if not use_dynamic_shifting:
 84            # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
 85            sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
 86
 87        self.timesteps = sigmas * num_train_timesteps
 88
 89        self._step_index = None
 90        self._begin_index = None
 91
 92        self.sigmas = sigmas.to("cpu")  # to avoid too much CPU/GPU communication
 93        self.sigma_min = self.sigmas[-1].item()
 94        self.sigma_max = self.sigmas[0].item()
 95
 96    @property
 97    def step_index(self):
 98        """
 99        The index counter for current timestep. It will increase 1 after each scheduler step.
100        """
101        return self._step_index
102
103    @property
104    def begin_index(self):
105        """
106        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
107        """
108        return self._begin_index
109
110    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
111    def set_begin_index(self, begin_index: int = 0):
112        """
113        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
114
115        Args:
116            begin_index (`int`):
117                The begin index for the scheduler.
118        """
119        self._begin_index = begin_index
120
121    def scale_noise(
122        self,
123        sample: torch.FloatTensor,
124        timestep: Union[float, torch.FloatTensor],
125        noise: Optional[torch.FloatTensor] = None,
126    ) -> torch.FloatTensor:
127        """
128        Forward process in flow-matching
129
130        Args:
131            sample (`torch.FloatTensor`):
132                The input sample.
133            timestep (`int`, *optional*):
134                The current timestep in the diffusion chain.
135
136        Returns:
137            `torch.FloatTensor`:
138                A scaled input sample.
139        """
140        # Make sure sigmas and timesteps have the same device and dtype as original_samples
141        sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
142
143        if sample.device.type == "mps" and torch.is_floating_point(timestep):
144            # mps does not support float64
145            schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
146            timestep = timestep.to(sample.device, dtype=torch.float32)
147        else:
148            schedule_timesteps = self.timesteps.to(sample.device)
149            timestep = timestep.to(sample.device)
150
151        # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
152        if self.begin_index is None:
153            step_indices = [
154                self.index_for_timestep(t, schedule_timesteps) for t in timestep
155            ]
156        elif self.step_index is not None:
157            # add_noise is called after first denoising step (for inpainting)
158            step_indices = [self.step_index] * timestep.shape[0]
159        else:
160            # add noise is called before first denoising step to create initial latent(img2img)
161            step_indices = [self.begin_index] * timestep.shape[0]
162
163        sigma = sigmas[step_indices].flatten()
164        while len(sigma.shape) < len(sample.shape):
165            sigma = sigma.unsqueeze(-1)
166
167        sample = sigma * noise + (1.0 - sigma) * sample
168
169        return sample
170
171    def _sigma_to_t(self, sigma):
172        return sigma * self.config.num_train_timesteps
173
174    def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
175        return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
176
177    def set_timesteps(
178        self,
179        num_inference_steps: int = None,
180        device: Union[str, torch.device] = None,
181        sigmas: Optional[List[float]] = None,
182        mu: Optional[float] = None,
183    ):
184        """
185        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
186
187        Args:
188            num_inference_steps (`int`):
189                The number of diffusion steps used when generating samples with a pre-trained model.
190            device (`str` or `torch.device`, *optional*):
191                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
192        """
193
194        if self.config.use_dynamic_shifting and mu is None:
195            raise ValueError(
196                " you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
197            )
198
199        if sigmas is None:
200            self.num_inference_steps = num_inference_steps
201            timesteps = np.linspace(
202                self._sigma_to_t(self.sigma_max),
203                self._sigma_to_t(self.sigma_min),
204                num_inference_steps,
205            )
206
207            sigmas = timesteps / self.config.num_train_timesteps
208
209        if self.config.use_dynamic_shifting:
210            sigmas = self.time_shift(mu, 1.0, sigmas)
211        else:
212            sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
213
214        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
215        timesteps = sigmas * self.config.num_train_timesteps
216
217        self.timesteps = timesteps.to(device=device)
218        self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
219
220        self._step_index = None
221        self._begin_index = None
222
223    def index_for_timestep(self, timestep, schedule_timesteps=None):
224        if schedule_timesteps is None:
225            schedule_timesteps = self.timesteps
226
227        indices = (schedule_timesteps == timestep).nonzero()
228
229        # The sigma index that is taken for the **very** first `step`
230        # is always the second index (or the last index if there is only 1)
231        # This way we can ensure we don't accidentally skip a sigma in
232        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
233        pos = 1 if len(indices) > 1 else 0
234
235        return indices[pos].item()
236
237    def _init_step_index(self, timestep):
238        if self.begin_index is None:
239            if isinstance(timestep, torch.Tensor):
240                timestep = timestep.to(self.timesteps.device)
241            self._step_index = self.index_for_timestep(timestep)
242        else:
243            self._step_index = self._begin_index
244
245    def step(
246        self,
247        model_output: torch.FloatTensor,
248        timestep: Union[float, torch.FloatTensor],
249        sample: torch.FloatTensor,
250        s_churn: float = 0.0,
251        s_tmin: float = 0.0,
252        s_tmax: float = float("inf"),
253        s_noise: float = 1.0,
254        generator: Optional[torch.Generator] = None,
255        return_dict: bool = True,
256        omega: Union[float, np.array] = 0.0,
257    ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
258        """
259        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
260        process from the learned model outputs (most often the predicted noise).
261
262        Args:
263            model_output (`torch.FloatTensor`):
264                The direct output from learned diffusion model.
265            timestep (`float`):
266                The current discrete timestep in the diffusion chain.
267            sample (`torch.FloatTensor`):
268                A current instance of a sample created by the diffusion process.
269            s_churn (`float`):
270            s_tmin  (`float`):
271            s_tmax  (`float`):
272            s_noise (`float`, defaults to 1.0):
273                Scaling factor for noise added to the sample.
274            generator (`torch.Generator`, *optional*):
275                A random number generator.
276            return_dict (`bool`):
277                Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
278                tuple.
279
280        Returns:
281            [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
282                If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
283                returned, otherwise a tuple is returned where the first element is the sample tensor.
284        """
285
286        def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
287            # L = Lower bound
288            # U = Upper bound
289            # x_0 = Midpoint (x corresponding to y = 1.0)
290            # k = Steepness, can adjust based on preference
291
292            if isinstance(x, torch.Tensor):
293                device_ = x.device
294                x = x.to(torch.float).cpu().numpy()
295
296            new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
297
298            if isinstance(new_x, np.ndarray):
299                new_x = torch.from_numpy(new_x).to(device_)
300            return new_x
301
302        self.omega_bef_rescale = omega
303        omega = logistic_function(omega, k=0.1)
304        self.omega_aft_rescale = omega
305
306        if (
307            isinstance(timestep, int)
308            or isinstance(timestep, torch.IntTensor)
309            or isinstance(timestep, torch.LongTensor)
310        ):
311            raise ValueError(
312                (
313                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
314                    " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
315                    " one of the `scheduler.timesteps` as a timestep."
316                ),
317            )
318
319        if self.step_index is None:
320            self._init_step_index(timestep)
321
322        # Upcast to avoid precision issues when computing prev_sample
323        sample = sample.to(torch.float32)
324
325        sigma = self.sigmas[self.step_index]
326        sigma_next = self.sigmas[self.step_index + 1]
327
328        ## --
329        ## mean shift 1
330        dx = (sigma_next - sigma) * model_output
331        m = dx.mean()
332        # print(dx.shape) # torch.Size([1, 16, 128, 128])
333        # print(f'm: {m}') # m: -0.0014209747314453125
334        # raise NotImplementedError
335        dx_ = (dx - m) * omega + m
336        prev_sample = sample + dx_
337
338        # ## --
339        # ## mean shift 2
340        # m = model_output.mean()
341        # model_output_ = (model_output - m) * omega + m
342        # prev_sample = sample + (sigma_next - sigma) * model_output_
343
344        # ## --
345        # ## original
346        # prev_sample = sample + (sigma_next - sigma) * model_output * omega
347
348        # ## --
349        # ## spatial mean 1
350        # dx = (sigma_next - sigma) * model_output
351        # m = dx.mean(dim=(0, 1), keepdim=True)
352        # # print(dx.shape) # torch.Size([1, 16, 128, 128])
353        # # print(m.shape) # torch.Size([1, 1, 128, 128])
354        # # raise NotImplementedError
355        # dx_ = (dx - m) * omega + m
356        # prev_sample = sample + dx_
357
358        # ## --
359        # ## spatial mean 2
360        # m = model_output.mean(dim=(0, 1), keepdim=True)
361        # model_output_ = (model_output - m) * omega + m
362        # prev_sample = sample + (sigma_next - sigma) * model_output_
363
364        # ## --
365        # ## channel mean 1
366        # m = model_output.mean(dim=(2, 3), keepdim=True)
367        # # print(m.shape) # torch.Size([1, 16, 1, 1])
368        # model_output_ = (model_output - m) * omega + m
369        # prev_sample = sample + (sigma_next - sigma) * model_output_
370
371        # ## --
372        # ## channel mean 2
373        # dx = (sigma_next - sigma) * model_output
374        # m = dx.mean(dim=(2, 3), keepdim=True)
375        # # print(m.shape) # torch.Size([1, 16, 1, 1])
376        # dx_ = (dx - m) * omega + m
377        # prev_sample = sample + dx_
378
379        # ## --
380        # ## keep sample mean
381        # m_tgt = sample.mean()
382        # prev_sample_ = sample + (sigma_next - sigma) * model_output * omega
383        # m_src = prev_sample_.mean()
384        # prev_sample = prev_sample_ - m_src + m_tgt
385
386        # ## --
387        # ## test
388        # # print(sample.mean())
389        # prev_sample = sample + (sigma_next - sigma) * model_output * omega
390        # # raise NotImplementedError
391
392        # Cast sample back to model compatible dtype
393        prev_sample = prev_sample.to(model_output.dtype)
394
395        # upon completion increase step index by one
396        self._step_index += 1
397
398        if not return_dict:
399            return (prev_sample,)
400
401        return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
402
403    def __len__(self):
404        return self.config.num_train_timesteps

Euler scheduler.

This model inherits from [SchedulerMixin] and [ConfigMixin]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving.

Args: num_train_timesteps (int, defaults to 1000): The number of diffusion steps to train the model. timestep_spacing (str, defaults to "linspace"): The way the timesteps should be scaled. Refer to Table 2 of the Common Diffusion Noise Schedules and Sample Steps are Flawed for more information. shift (float, defaults to 1.0): The shift value for the timestep schedule.

@register_to_config
FlowMatchEulerDiscreteScheduler( num_train_timesteps: int = 1000, shift: float = 1.0, use_dynamic_shifting=False, base_shift: Optional[float] = 0.5, max_shift: Optional[float] = 1.15, base_image_seq_len: Optional[int] = 256, max_image_seq_len: Optional[int] = 4096, sigma_max: Optional[float] = 1.0)
65    @register_to_config
66    def __init__(
67        self,
68        num_train_timesteps: int = 1000,
69        shift: float = 1.0,
70        use_dynamic_shifting=False,
71        base_shift: Optional[float] = 0.5,
72        max_shift: Optional[float] = 1.15,
73        base_image_seq_len: Optional[int] = 256,
74        max_image_seq_len: Optional[int] = 4096,
75        sigma_max: Optional[float] = 1.0,
76    ):
77        timesteps = np.linspace(
78            1.0, sigma_max*num_train_timesteps, num_train_timesteps, dtype=np.float32
79        )[::-1].copy()
80        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
81
82        sigmas = timesteps / num_train_timesteps
83        if not use_dynamic_shifting:
84            # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
85            sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
86
87        self.timesteps = sigmas * num_train_timesteps
88
89        self._step_index = None
90        self._begin_index = None
91
92        self.sigmas = sigmas.to("cpu")  # to avoid too much CPU/GPU communication
93        self.sigma_min = self.sigmas[-1].item()
94        self.sigma_max = self.sigmas[0].item()
order = 1
timesteps
sigmas
sigma_min
sigma_max
step_index
 96    @property
 97    def step_index(self):
 98        """
 99        The index counter for current timestep. It will increase 1 after each scheduler step.
100        """
101        return self._step_index

The index counter for current timestep. It will increase 1 after each scheduler step.

begin_index
103    @property
104    def begin_index(self):
105        """
106        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
107        """
108        return self._begin_index

The index for the first timestep. It should be set from pipeline with set_begin_index method.

def set_begin_index(self, begin_index: int = 0):
111    def set_begin_index(self, begin_index: int = 0):
112        """
113        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
114
115        Args:
116            begin_index (`int`):
117                The begin index for the scheduler.
118        """
119        self._begin_index = begin_index

Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

Args: begin_index (int): The begin index for the scheduler.

def scale_noise( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], noise: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
121    def scale_noise(
122        self,
123        sample: torch.FloatTensor,
124        timestep: Union[float, torch.FloatTensor],
125        noise: Optional[torch.FloatTensor] = None,
126    ) -> torch.FloatTensor:
127        """
128        Forward process in flow-matching
129
130        Args:
131            sample (`torch.FloatTensor`):
132                The input sample.
133            timestep (`int`, *optional*):
134                The current timestep in the diffusion chain.
135
136        Returns:
137            `torch.FloatTensor`:
138                A scaled input sample.
139        """
140        # Make sure sigmas and timesteps have the same device and dtype as original_samples
141        sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
142
143        if sample.device.type == "mps" and torch.is_floating_point(timestep):
144            # mps does not support float64
145            schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
146            timestep = timestep.to(sample.device, dtype=torch.float32)
147        else:
148            schedule_timesteps = self.timesteps.to(sample.device)
149            timestep = timestep.to(sample.device)
150
151        # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
152        if self.begin_index is None:
153            step_indices = [
154                self.index_for_timestep(t, schedule_timesteps) for t in timestep
155            ]
156        elif self.step_index is not None:
157            # add_noise is called after first denoising step (for inpainting)
158            step_indices = [self.step_index] * timestep.shape[0]
159        else:
160            # add noise is called before first denoising step to create initial latent(img2img)
161            step_indices = [self.begin_index] * timestep.shape[0]
162
163        sigma = sigmas[step_indices].flatten()
164        while len(sigma.shape) < len(sample.shape):
165            sigma = sigma.unsqueeze(-1)
166
167        sample = sigma * noise + (1.0 - sigma) * sample
168
169        return sample

Forward process in flow-matching

Args: sample (torch.FloatTensor): The input sample. timestep (int, optional): The current timestep in the diffusion chain.

Returns: torch.FloatTensor: A scaled input sample.

def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
174    def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
175        return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def set_timesteps( self, num_inference_steps: int = None, device: Union[str, torch.device] = None, sigmas: Optional[List[float]] = None, mu: Optional[float] = None):
177    def set_timesteps(
178        self,
179        num_inference_steps: int = None,
180        device: Union[str, torch.device] = None,
181        sigmas: Optional[List[float]] = None,
182        mu: Optional[float] = None,
183    ):
184        """
185        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
186
187        Args:
188            num_inference_steps (`int`):
189                The number of diffusion steps used when generating samples with a pre-trained model.
190            device (`str` or `torch.device`, *optional*):
191                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
192        """
193
194        if self.config.use_dynamic_shifting and mu is None:
195            raise ValueError(
196                " you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`"
197            )
198
199        if sigmas is None:
200            self.num_inference_steps = num_inference_steps
201            timesteps = np.linspace(
202                self._sigma_to_t(self.sigma_max),
203                self._sigma_to_t(self.sigma_min),
204                num_inference_steps,
205            )
206
207            sigmas = timesteps / self.config.num_train_timesteps
208
209        if self.config.use_dynamic_shifting:
210            sigmas = self.time_shift(mu, 1.0, sigmas)
211        else:
212            sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
213
214        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
215        timesteps = sigmas * self.config.num_train_timesteps
216
217        self.timesteps = timesteps.to(device=device)
218        self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
219
220        self._step_index = None
221        self._begin_index = None

Sets the discrete timesteps used for the diffusion chain (to be run before inference).

Args: num_inference_steps (int): The number of diffusion steps used when generating samples with a pre-trained model. device (str or torch.device, optional): The device to which the timesteps should be moved to. If None, the timesteps are not moved.

def index_for_timestep(self, timestep, schedule_timesteps=None):
223    def index_for_timestep(self, timestep, schedule_timesteps=None):
224        if schedule_timesteps is None:
225            schedule_timesteps = self.timesteps
226
227        indices = (schedule_timesteps == timestep).nonzero()
228
229        # The sigma index that is taken for the **very** first `step`
230        # is always the second index (or the last index if there is only 1)
231        # This way we can ensure we don't accidentally skip a sigma in
232        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
233        pos = 1 if len(indices) > 1 else 0
234
235        return indices[pos].item()
def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = inf, s_noise: float = 1.0, generator: Optional[torch._C.Generator] = None, return_dict: bool = True, omega: Union[float, <built-in function array>] = 0.0) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
245    def step(
246        self,
247        model_output: torch.FloatTensor,
248        timestep: Union[float, torch.FloatTensor],
249        sample: torch.FloatTensor,
250        s_churn: float = 0.0,
251        s_tmin: float = 0.0,
252        s_tmax: float = float("inf"),
253        s_noise: float = 1.0,
254        generator: Optional[torch.Generator] = None,
255        return_dict: bool = True,
256        omega: Union[float, np.array] = 0.0,
257    ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
258        """
259        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
260        process from the learned model outputs (most often the predicted noise).
261
262        Args:
263            model_output (`torch.FloatTensor`):
264                The direct output from learned diffusion model.
265            timestep (`float`):
266                The current discrete timestep in the diffusion chain.
267            sample (`torch.FloatTensor`):
268                A current instance of a sample created by the diffusion process.
269            s_churn (`float`):
270            s_tmin  (`float`):
271            s_tmax  (`float`):
272            s_noise (`float`, defaults to 1.0):
273                Scaling factor for noise added to the sample.
274            generator (`torch.Generator`, *optional*):
275                A random number generator.
276            return_dict (`bool`):
277                Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
278                tuple.
279
280        Returns:
281            [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
282                If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
283                returned, otherwise a tuple is returned where the first element is the sample tensor.
284        """
285
286        def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
287            # L = Lower bound
288            # U = Upper bound
289            # x_0 = Midpoint (x corresponding to y = 1.0)
290            # k = Steepness, can adjust based on preference
291
292            if isinstance(x, torch.Tensor):
293                device_ = x.device
294                x = x.to(torch.float).cpu().numpy()
295
296            new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
297
298            if isinstance(new_x, np.ndarray):
299                new_x = torch.from_numpy(new_x).to(device_)
300            return new_x
301
302        self.omega_bef_rescale = omega
303        omega = logistic_function(omega, k=0.1)
304        self.omega_aft_rescale = omega
305
306        if (
307            isinstance(timestep, int)
308            or isinstance(timestep, torch.IntTensor)
309            or isinstance(timestep, torch.LongTensor)
310        ):
311            raise ValueError(
312                (
313                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
314                    " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
315                    " one of the `scheduler.timesteps` as a timestep."
316                ),
317            )
318
319        if self.step_index is None:
320            self._init_step_index(timestep)
321
322        # Upcast to avoid precision issues when computing prev_sample
323        sample = sample.to(torch.float32)
324
325        sigma = self.sigmas[self.step_index]
326        sigma_next = self.sigmas[self.step_index + 1]
327
328        ## --
329        ## mean shift 1
330        dx = (sigma_next - sigma) * model_output
331        m = dx.mean()
332        # print(dx.shape) # torch.Size([1, 16, 128, 128])
333        # print(f'm: {m}') # m: -0.0014209747314453125
334        # raise NotImplementedError
335        dx_ = (dx - m) * omega + m
336        prev_sample = sample + dx_
337
338        # ## --
339        # ## mean shift 2
340        # m = model_output.mean()
341        # model_output_ = (model_output - m) * omega + m
342        # prev_sample = sample + (sigma_next - sigma) * model_output_
343
344        # ## --
345        # ## original
346        # prev_sample = sample + (sigma_next - sigma) * model_output * omega
347
348        # ## --
349        # ## spatial mean 1
350        # dx = (sigma_next - sigma) * model_output
351        # m = dx.mean(dim=(0, 1), keepdim=True)
352        # # print(dx.shape) # torch.Size([1, 16, 128, 128])
353        # # print(m.shape) # torch.Size([1, 1, 128, 128])
354        # # raise NotImplementedError
355        # dx_ = (dx - m) * omega + m
356        # prev_sample = sample + dx_
357
358        # ## --
359        # ## spatial mean 2
360        # m = model_output.mean(dim=(0, 1), keepdim=True)
361        # model_output_ = (model_output - m) * omega + m
362        # prev_sample = sample + (sigma_next - sigma) * model_output_
363
364        # ## --
365        # ## channel mean 1
366        # m = model_output.mean(dim=(2, 3), keepdim=True)
367        # # print(m.shape) # torch.Size([1, 16, 1, 1])
368        # model_output_ = (model_output - m) * omega + m
369        # prev_sample = sample + (sigma_next - sigma) * model_output_
370
371        # ## --
372        # ## channel mean 2
373        # dx = (sigma_next - sigma) * model_output
374        # m = dx.mean(dim=(2, 3), keepdim=True)
375        # # print(m.shape) # torch.Size([1, 16, 1, 1])
376        # dx_ = (dx - m) * omega + m
377        # prev_sample = sample + dx_
378
379        # ## --
380        # ## keep sample mean
381        # m_tgt = sample.mean()
382        # prev_sample_ = sample + (sigma_next - sigma) * model_output * omega
383        # m_src = prev_sample_.mean()
384        # prev_sample = prev_sample_ - m_src + m_tgt
385
386        # ## --
387        # ## test
388        # # print(sample.mean())
389        # prev_sample = sample + (sigma_next - sigma) * model_output * omega
390        # # raise NotImplementedError
391
392        # Cast sample back to model compatible dtype
393        prev_sample = prev_sample.to(model_output.dtype)
394
395        # upon completion increase step index by one
396        self._step_index += 1
397
398        if not return_dict:
399            return (prev_sample,)
400
401        return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)

Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise).

Args: model_output (torch.FloatTensor): The direct output from learned diffusion model. timestep (float): The current discrete timestep in the diffusion chain. sample (torch.FloatTensor): A current instance of a sample created by the diffusion process. s_churn (float): s_tmin (float): s_tmax (float): s_noise (float, defaults to 1.0): Scaling factor for noise added to the sample. generator (torch.Generator, *optional*): A random number generator. return_dict (bool): Whether or not to return a [~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput] or tuple.

Returns: [~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput] or tuple: If return_dict is True, [~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput] is returned, otherwise a tuple is returned where the first element is the sample tensor.