divisor.acestep.schedulers.scheduling_flow_match_heun_discrete

  1# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
  2#
  3# Licensed under the Apache License, Version 2.0 (the "License");
  4# you may not use this file except in compliance with the License.
  5# You may obtain a copy of the License at
  6#
  7#     http://www.apache.org/licenses/LICENSE-2.0
  8#
  9# Unless required by applicable law or agreed to in writing, software
 10# distributed under the License is distributed on an "AS IS" BASIS,
 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12# See the License for the specific language governing permissions and
 13# limitations under the License.
 14
 15from dataclasses import dataclass
 16from typing import Optional, Tuple, Union
 17
 18import numpy as np
 19import torch
 20
 21from diffusers.configuration_utils import ConfigMixin, register_to_config
 22from diffusers.utils import BaseOutput, logging
 23from diffusers.utils.torch_utils import randn_tensor
 24from diffusers.schedulers.scheduling_utils import SchedulerMixin
 25
 26
 27logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
 28
 29
 30@dataclass
 31class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput):
 32    """
 33    Output class for the scheduler's `step` function output.
 34
 35    Args:
 36        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
 37            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
 38            denoising loop.
 39    """
 40
 41    prev_sample: torch.FloatTensor
 42
 43
 44class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
 45    """
 46    Heun scheduler.
 47
 48    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
 49    methods the library implements for all schedulers such as loading and saving.
 50
 51    Args:
 52        num_train_timesteps (`int`, defaults to 1000):
 53            The number of diffusion steps to train the model.
 54        timestep_spacing (`str`, defaults to `"linspace"`):
 55            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
 56            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
 57        shift (`float`, defaults to 1.0):
 58            The shift value for the timestep schedule.
 59    """
 60
 61    _compatibles = []
 62    order = 2
 63
 64    @register_to_config
 65    def __init__(
 66        self,
 67        num_train_timesteps: int = 1000,
 68        shift: float = 1.0,
 69        sigma_max: Optional[float] = 1.0,
 70    ):
 71        timesteps = np.linspace(
 72            1.0, sigma_max*num_train_timesteps, num_train_timesteps, dtype=np.float32
 73        )[::-1].copy()
 74        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
 75
 76        sigmas = timesteps / num_train_timesteps
 77        sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
 78
 79        self.timesteps = sigmas * num_train_timesteps
 80
 81        self._step_index = None
 82        self._begin_index = None
 83
 84        self.sigmas = sigmas.to("cpu")  # to avoid too much CPU/GPU communication
 85        self.sigma_min = self.sigmas[-1].item()
 86        self.sigma_max = self.sigmas[0].item()
 87
 88    @property
 89    def step_index(self):
 90        """
 91        The index counter for current timestep. It will increase 1 after each scheduler step.
 92        """
 93        return self._step_index
 94
 95    @property
 96    def begin_index(self):
 97        """
 98        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
 99        """
100        return self._begin_index
101
102    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
103    def set_begin_index(self, begin_index: int = 0):
104        """
105        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
106
107        Args:
108            begin_index (`int`):
109                The begin index for the scheduler.
110        """
111        self._begin_index = begin_index
112
113    def scale_noise(
114        self,
115        sample: torch.FloatTensor,
116        timestep: Union[float, torch.FloatTensor],
117        noise: Optional[torch.FloatTensor] = None,
118    ) -> torch.FloatTensor:
119        """
120        Forward process in flow-matching
121
122        Args:
123            sample (`torch.FloatTensor`):
124                The input sample.
125            timestep (`int`, *optional*):
126                The current timestep in the diffusion chain.
127
128        Returns:
129            `torch.FloatTensor`:
130                A scaled input sample.
131        """
132        if self.step_index is None:
133            self._init_step_index(timestep)
134
135        sigma = self.sigmas[self.step_index]
136        sample = sigma * noise + (1.0 - sigma) * sample
137
138        return sample
139
140    def _sigma_to_t(self, sigma):
141        return sigma * self.config.num_train_timesteps
142
143    def set_timesteps(
144        self, num_inference_steps: int, device: Union[str, torch.device] = None
145    ):
146        """
147        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
148
149        Args:
150            num_inference_steps (`int`):
151                The number of diffusion steps used when generating samples with a pre-trained model.
152            device (`str` or `torch.device`, *optional*):
153                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
154        """
155        self.num_inference_steps = num_inference_steps
156
157        timesteps = np.linspace(
158            self._sigma_to_t(self.sigma_max),
159            self._sigma_to_t(self.sigma_min),
160            num_inference_steps,
161        )
162
163        sigmas = timesteps / self.config.num_train_timesteps
164        sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
165        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
166
167        timesteps = sigmas * self.config.num_train_timesteps
168        timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
169        self.timesteps = timesteps.to(device=device)
170
171        sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
172        self.sigmas = torch.cat(
173            [sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]
174        )
175
176        # empty dt and derivative
177        self.prev_derivative = None
178        self.dt = None
179
180        self._step_index = None
181        self._begin_index = None
182
183    def index_for_timestep(self, timestep, schedule_timesteps=None):
184        if schedule_timesteps is None:
185            schedule_timesteps = self.timesteps
186
187        indices = (schedule_timesteps == timestep).nonzero()
188
189        # The sigma index that is taken for the **very** first `step`
190        # is always the second index (or the last index if there is only 1)
191        # This way we can ensure we don't accidentally skip a sigma in
192        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
193        pos = 1 if len(indices) > 1 else 0
194
195        return indices[pos].item()
196
197    def _init_step_index(self, timestep):
198        if self.begin_index is None:
199            if isinstance(timestep, torch.Tensor):
200                timestep = timestep.to(self.timesteps.device)
201            self._step_index = self.index_for_timestep(timestep)
202        else:
203            self._step_index = self._begin_index
204
205    @property
206    def state_in_first_order(self):
207        return self.dt is None
208
209    def step(
210        self,
211        model_output: torch.FloatTensor,
212        timestep: Union[float, torch.FloatTensor],
213        sample: torch.FloatTensor,
214        s_churn: float = 0.0,
215        s_tmin: float = 0.0,
216        s_tmax: float = float("inf"),
217        s_noise: float = 1.0,
218        generator: Optional[torch.Generator] = None,
219        return_dict: bool = True,
220        omega: Union[float, np.array] = 0.0,
221    ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
222        """
223        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
224        process from the learned model outputs (most often the predicted noise).
225
226        Args:
227            model_output (`torch.FloatTensor`):
228                The direct output from learned diffusion model.
229            timestep (`float`):
230                The current discrete timestep in the diffusion chain.
231            sample (`torch.FloatTensor`):
232                A current instance of a sample created by the diffusion process.
233            s_churn (`float`):
234            s_tmin  (`float`):
235            s_tmax  (`float`):
236            s_noise (`float`, defaults to 1.0):
237                Scaling factor for noise added to the sample.
238            generator (`torch.Generator`, *optional*):
239                A random number generator.
240            return_dict (`bool`):
241                Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or
242                tuple.
243
244        Returns:
245            [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
246                If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is
247                returned, otherwise a tuple is returned where the first element is the sample tensor.
248        """
249
250        def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
251            # L = Lower bound
252            # U = Upper bound
253            # x_0 = Midpoint (x corresponding to y = 1.0)
254            # k = Steepness, can adjust based on preference
255
256            if isinstance(x, torch.Tensor):
257                device_ = x.device
258                x = x.to(torch.float).cpu().numpy()
259
260            new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
261
262            if isinstance(new_x, np.ndarray):
263                new_x = torch.from_numpy(new_x).to(device_)
264            return new_x
265
266        self.omega_bef_rescale = omega
267        omega = logistic_function(omega, k=0.1)
268        self.omega_aft_rescale = omega
269
270        if (
271            isinstance(timestep, int)
272            or isinstance(timestep, torch.IntTensor)
273            or isinstance(timestep, torch.LongTensor)
274        ):
275            raise ValueError(
276                (
277                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
278                    " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass"
279                    " one of the `scheduler.timesteps` as a timestep."
280                ),
281            )
282
283        if self.step_index is None:
284            self._init_step_index(timestep)
285
286        # Upcast to avoid precision issues when computing prev_sample
287        sample = sample.to(torch.float32)
288
289        if self.state_in_first_order:
290            sigma = self.sigmas[self.step_index]
291            sigma_next = self.sigmas[self.step_index + 1]
292        else:
293            # 2nd order / Heun's method
294            sigma = self.sigmas[self.step_index - 1]
295            sigma_next = self.sigmas[self.step_index]
296
297        gamma = (
298            min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
299            if s_tmin <= sigma <= s_tmax
300            else 0.0
301        )
302
303        sigma_hat = sigma * (gamma + 1)
304
305        if gamma > 0:
306            noise = randn_tensor(
307                model_output.shape,
308                dtype=model_output.dtype,
309                device=model_output.device,
310                generator=generator,
311            )
312            eps = noise * s_noise
313            sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
314
315        if self.state_in_first_order:
316            # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
317            denoised = sample - model_output * sigma
318            # 2. convert to an ODE derivative for 1st order
319            derivative = (sample - denoised) / sigma_hat
320            # 3. Delta timestep
321            dt = sigma_next - sigma_hat
322
323            # store for 2nd order step
324            self.prev_derivative = derivative
325            self.dt = dt
326            self.sample = sample
327        else:
328            # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
329            denoised = sample - model_output * sigma_next
330            # 2. 2nd order / Heun's method
331            derivative = (sample - denoised) / sigma_next
332            derivative = 0.5 * (self.prev_derivative + derivative)
333
334            # 3. take prev timestep & sample
335            dt = self.dt
336            sample = self.sample
337
338            # free dt and derivative
339            # Note, this puts the scheduler in "first order mode"
340            self.prev_derivative = None
341            self.dt = None
342            self.sample = None
343
344        # original sample way
345        # prev_sample = sample + derivative * dt
346
347        dx = derivative * dt
348        m = dx.mean()
349        dx_ = (dx - m) * omega + m
350        prev_sample = sample + dx_
351
352        # Cast sample back to model compatible dtype
353        prev_sample = prev_sample.to(model_output.dtype)
354
355        # upon completion increase step index by one
356        self._step_index += 1
357
358        if not return_dict:
359            return (prev_sample,)
360
361        return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
362
363    def __len__(self):
364        return self.config.num_train_timesteps
@dataclass
class FlowMatchHeunDiscreteSchedulerOutput(diffusers.utils.outputs.BaseOutput):
31@dataclass
32class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput):
33    """
34    Output class for the scheduler's `step` function output.
35
36    Args:
37        prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
38            Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
39            denoising loop.
40    """
41
42    prev_sample: torch.FloatTensor

Output class for the scheduler's step function output.

Args: prev_sample (torch.FloatTensor of shape (batch_size, num_channels, height, width) for images): Computed sample (x_{t-1}) of previous timestep. prev_sample should be used as next model input in the denoising loop.

prev_sample: torch.FloatTensor
class FlowMatchHeunDiscreteScheduler(diffusers.schedulers.scheduling_utils.SchedulerMixin, diffusers.configuration_utils.ConfigMixin):
 45class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
 46    """
 47    Heun scheduler.
 48
 49    This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
 50    methods the library implements for all schedulers such as loading and saving.
 51
 52    Args:
 53        num_train_timesteps (`int`, defaults to 1000):
 54            The number of diffusion steps to train the model.
 55        timestep_spacing (`str`, defaults to `"linspace"`):
 56            The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
 57            Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
 58        shift (`float`, defaults to 1.0):
 59            The shift value for the timestep schedule.
 60    """
 61
 62    _compatibles = []
 63    order = 2
 64
 65    @register_to_config
 66    def __init__(
 67        self,
 68        num_train_timesteps: int = 1000,
 69        shift: float = 1.0,
 70        sigma_max: Optional[float] = 1.0,
 71    ):
 72        timesteps = np.linspace(
 73            1.0, sigma_max*num_train_timesteps, num_train_timesteps, dtype=np.float32
 74        )[::-1].copy()
 75        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
 76
 77        sigmas = timesteps / num_train_timesteps
 78        sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
 79
 80        self.timesteps = sigmas * num_train_timesteps
 81
 82        self._step_index = None
 83        self._begin_index = None
 84
 85        self.sigmas = sigmas.to("cpu")  # to avoid too much CPU/GPU communication
 86        self.sigma_min = self.sigmas[-1].item()
 87        self.sigma_max = self.sigmas[0].item()
 88
 89    @property
 90    def step_index(self):
 91        """
 92        The index counter for current timestep. It will increase 1 after each scheduler step.
 93        """
 94        return self._step_index
 95
 96    @property
 97    def begin_index(self):
 98        """
 99        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
100        """
101        return self._begin_index
102
103    # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
104    def set_begin_index(self, begin_index: int = 0):
105        """
106        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
107
108        Args:
109            begin_index (`int`):
110                The begin index for the scheduler.
111        """
112        self._begin_index = begin_index
113
114    def scale_noise(
115        self,
116        sample: torch.FloatTensor,
117        timestep: Union[float, torch.FloatTensor],
118        noise: Optional[torch.FloatTensor] = None,
119    ) -> torch.FloatTensor:
120        """
121        Forward process in flow-matching
122
123        Args:
124            sample (`torch.FloatTensor`):
125                The input sample.
126            timestep (`int`, *optional*):
127                The current timestep in the diffusion chain.
128
129        Returns:
130            `torch.FloatTensor`:
131                A scaled input sample.
132        """
133        if self.step_index is None:
134            self._init_step_index(timestep)
135
136        sigma = self.sigmas[self.step_index]
137        sample = sigma * noise + (1.0 - sigma) * sample
138
139        return sample
140
141    def _sigma_to_t(self, sigma):
142        return sigma * self.config.num_train_timesteps
143
144    def set_timesteps(
145        self, num_inference_steps: int, device: Union[str, torch.device] = None
146    ):
147        """
148        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
149
150        Args:
151            num_inference_steps (`int`):
152                The number of diffusion steps used when generating samples with a pre-trained model.
153            device (`str` or `torch.device`, *optional*):
154                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
155        """
156        self.num_inference_steps = num_inference_steps
157
158        timesteps = np.linspace(
159            self._sigma_to_t(self.sigma_max),
160            self._sigma_to_t(self.sigma_min),
161            num_inference_steps,
162        )
163
164        sigmas = timesteps / self.config.num_train_timesteps
165        sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
166        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
167
168        timesteps = sigmas * self.config.num_train_timesteps
169        timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
170        self.timesteps = timesteps.to(device=device)
171
172        sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
173        self.sigmas = torch.cat(
174            [sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]
175        )
176
177        # empty dt and derivative
178        self.prev_derivative = None
179        self.dt = None
180
181        self._step_index = None
182        self._begin_index = None
183
184    def index_for_timestep(self, timestep, schedule_timesteps=None):
185        if schedule_timesteps is None:
186            schedule_timesteps = self.timesteps
187
188        indices = (schedule_timesteps == timestep).nonzero()
189
190        # The sigma index that is taken for the **very** first `step`
191        # is always the second index (or the last index if there is only 1)
192        # This way we can ensure we don't accidentally skip a sigma in
193        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
194        pos = 1 if len(indices) > 1 else 0
195
196        return indices[pos].item()
197
198    def _init_step_index(self, timestep):
199        if self.begin_index is None:
200            if isinstance(timestep, torch.Tensor):
201                timestep = timestep.to(self.timesteps.device)
202            self._step_index = self.index_for_timestep(timestep)
203        else:
204            self._step_index = self._begin_index
205
206    @property
207    def state_in_first_order(self):
208        return self.dt is None
209
210    def step(
211        self,
212        model_output: torch.FloatTensor,
213        timestep: Union[float, torch.FloatTensor],
214        sample: torch.FloatTensor,
215        s_churn: float = 0.0,
216        s_tmin: float = 0.0,
217        s_tmax: float = float("inf"),
218        s_noise: float = 1.0,
219        generator: Optional[torch.Generator] = None,
220        return_dict: bool = True,
221        omega: Union[float, np.array] = 0.0,
222    ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
223        """
224        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
225        process from the learned model outputs (most often the predicted noise).
226
227        Args:
228            model_output (`torch.FloatTensor`):
229                The direct output from learned diffusion model.
230            timestep (`float`):
231                The current discrete timestep in the diffusion chain.
232            sample (`torch.FloatTensor`):
233                A current instance of a sample created by the diffusion process.
234            s_churn (`float`):
235            s_tmin  (`float`):
236            s_tmax  (`float`):
237            s_noise (`float`, defaults to 1.0):
238                Scaling factor for noise added to the sample.
239            generator (`torch.Generator`, *optional*):
240                A random number generator.
241            return_dict (`bool`):
242                Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or
243                tuple.
244
245        Returns:
246            [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
247                If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is
248                returned, otherwise a tuple is returned where the first element is the sample tensor.
249        """
250
251        def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
252            # L = Lower bound
253            # U = Upper bound
254            # x_0 = Midpoint (x corresponding to y = 1.0)
255            # k = Steepness, can adjust based on preference
256
257            if isinstance(x, torch.Tensor):
258                device_ = x.device
259                x = x.to(torch.float).cpu().numpy()
260
261            new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
262
263            if isinstance(new_x, np.ndarray):
264                new_x = torch.from_numpy(new_x).to(device_)
265            return new_x
266
267        self.omega_bef_rescale = omega
268        omega = logistic_function(omega, k=0.1)
269        self.omega_aft_rescale = omega
270
271        if (
272            isinstance(timestep, int)
273            or isinstance(timestep, torch.IntTensor)
274            or isinstance(timestep, torch.LongTensor)
275        ):
276            raise ValueError(
277                (
278                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
279                    " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass"
280                    " one of the `scheduler.timesteps` as a timestep."
281                ),
282            )
283
284        if self.step_index is None:
285            self._init_step_index(timestep)
286
287        # Upcast to avoid precision issues when computing prev_sample
288        sample = sample.to(torch.float32)
289
290        if self.state_in_first_order:
291            sigma = self.sigmas[self.step_index]
292            sigma_next = self.sigmas[self.step_index + 1]
293        else:
294            # 2nd order / Heun's method
295            sigma = self.sigmas[self.step_index - 1]
296            sigma_next = self.sigmas[self.step_index]
297
298        gamma = (
299            min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
300            if s_tmin <= sigma <= s_tmax
301            else 0.0
302        )
303
304        sigma_hat = sigma * (gamma + 1)
305
306        if gamma > 0:
307            noise = randn_tensor(
308                model_output.shape,
309                dtype=model_output.dtype,
310                device=model_output.device,
311                generator=generator,
312            )
313            eps = noise * s_noise
314            sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
315
316        if self.state_in_first_order:
317            # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
318            denoised = sample - model_output * sigma
319            # 2. convert to an ODE derivative for 1st order
320            derivative = (sample - denoised) / sigma_hat
321            # 3. Delta timestep
322            dt = sigma_next - sigma_hat
323
324            # store for 2nd order step
325            self.prev_derivative = derivative
326            self.dt = dt
327            self.sample = sample
328        else:
329            # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
330            denoised = sample - model_output * sigma_next
331            # 2. 2nd order / Heun's method
332            derivative = (sample - denoised) / sigma_next
333            derivative = 0.5 * (self.prev_derivative + derivative)
334
335            # 3. take prev timestep & sample
336            dt = self.dt
337            sample = self.sample
338
339            # free dt and derivative
340            # Note, this puts the scheduler in "first order mode"
341            self.prev_derivative = None
342            self.dt = None
343            self.sample = None
344
345        # original sample way
346        # prev_sample = sample + derivative * dt
347
348        dx = derivative * dt
349        m = dx.mean()
350        dx_ = (dx - m) * omega + m
351        prev_sample = sample + dx_
352
353        # Cast sample back to model compatible dtype
354        prev_sample = prev_sample.to(model_output.dtype)
355
356        # upon completion increase step index by one
357        self._step_index += 1
358
359        if not return_dict:
360            return (prev_sample,)
361
362        return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
363
364    def __len__(self):
365        return self.config.num_train_timesteps

Heun scheduler.

This model inherits from [SchedulerMixin] and [ConfigMixin]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving.

Args: num_train_timesteps (int, defaults to 1000): The number of diffusion steps to train the model. timestep_spacing (str, defaults to "linspace"): The way the timesteps should be scaled. Refer to Table 2 of the Common Diffusion Noise Schedules and Sample Steps are Flawed for more information. shift (float, defaults to 1.0): The shift value for the timestep schedule.

@register_to_config
FlowMatchHeunDiscreteScheduler( num_train_timesteps: int = 1000, shift: float = 1.0, sigma_max: Optional[float] = 1.0)
65    @register_to_config
66    def __init__(
67        self,
68        num_train_timesteps: int = 1000,
69        shift: float = 1.0,
70        sigma_max: Optional[float] = 1.0,
71    ):
72        timesteps = np.linspace(
73            1.0, sigma_max*num_train_timesteps, num_train_timesteps, dtype=np.float32
74        )[::-1].copy()
75        timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
76
77        sigmas = timesteps / num_train_timesteps
78        sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
79
80        self.timesteps = sigmas * num_train_timesteps
81
82        self._step_index = None
83        self._begin_index = None
84
85        self.sigmas = sigmas.to("cpu")  # to avoid too much CPU/GPU communication
86        self.sigma_min = self.sigmas[-1].item()
87        self.sigma_max = self.sigmas[0].item()
order = 2
timesteps
sigmas
sigma_min
sigma_max
step_index
89    @property
90    def step_index(self):
91        """
92        The index counter for current timestep. It will increase 1 after each scheduler step.
93        """
94        return self._step_index

The index counter for current timestep. It will increase 1 after each scheduler step.

begin_index
 96    @property
 97    def begin_index(self):
 98        """
 99        The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
100        """
101        return self._begin_index

The index for the first timestep. It should be set from pipeline with set_begin_index method.

def set_begin_index(self, begin_index: int = 0):
104    def set_begin_index(self, begin_index: int = 0):
105        """
106        Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
107
108        Args:
109            begin_index (`int`):
110                The begin index for the scheduler.
111        """
112        self._begin_index = begin_index

Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

Args: begin_index (int): The begin index for the scheduler.

def scale_noise( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], noise: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
114    def scale_noise(
115        self,
116        sample: torch.FloatTensor,
117        timestep: Union[float, torch.FloatTensor],
118        noise: Optional[torch.FloatTensor] = None,
119    ) -> torch.FloatTensor:
120        """
121        Forward process in flow-matching
122
123        Args:
124            sample (`torch.FloatTensor`):
125                The input sample.
126            timestep (`int`, *optional*):
127                The current timestep in the diffusion chain.
128
129        Returns:
130            `torch.FloatTensor`:
131                A scaled input sample.
132        """
133        if self.step_index is None:
134            self._init_step_index(timestep)
135
136        sigma = self.sigmas[self.step_index]
137        sample = sigma * noise + (1.0 - sigma) * sample
138
139        return sample

Forward process in flow-matching

Args: sample (torch.FloatTensor): The input sample. timestep (int, optional): The current timestep in the diffusion chain.

Returns: torch.FloatTensor: A scaled input sample.

def set_timesteps( self, num_inference_steps: int, device: Union[str, torch.device] = None):
144    def set_timesteps(
145        self, num_inference_steps: int, device: Union[str, torch.device] = None
146    ):
147        """
148        Sets the discrete timesteps used for the diffusion chain (to be run before inference).
149
150        Args:
151            num_inference_steps (`int`):
152                The number of diffusion steps used when generating samples with a pre-trained model.
153            device (`str` or `torch.device`, *optional*):
154                The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
155        """
156        self.num_inference_steps = num_inference_steps
157
158        timesteps = np.linspace(
159            self._sigma_to_t(self.sigma_max),
160            self._sigma_to_t(self.sigma_min),
161            num_inference_steps,
162        )
163
164        sigmas = timesteps / self.config.num_train_timesteps
165        sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
166        sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
167
168        timesteps = sigmas * self.config.num_train_timesteps
169        timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
170        self.timesteps = timesteps.to(device=device)
171
172        sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
173        self.sigmas = torch.cat(
174            [sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]
175        )
176
177        # empty dt and derivative
178        self.prev_derivative = None
179        self.dt = None
180
181        self._step_index = None
182        self._begin_index = None

Sets the discrete timesteps used for the diffusion chain (to be run before inference).

Args: num_inference_steps (int): The number of diffusion steps used when generating samples with a pre-trained model. device (str or torch.device, optional): The device to which the timesteps should be moved to. If None, the timesteps are not moved.

def index_for_timestep(self, timestep, schedule_timesteps=None):
184    def index_for_timestep(self, timestep, schedule_timesteps=None):
185        if schedule_timesteps is None:
186            schedule_timesteps = self.timesteps
187
188        indices = (schedule_timesteps == timestep).nonzero()
189
190        # The sigma index that is taken for the **very** first `step`
191        # is always the second index (or the last index if there is only 1)
192        # This way we can ensure we don't accidentally skip a sigma in
193        # case we start in the middle of the denoising schedule (e.g. for image-to-image)
194        pos = 1 if len(indices) > 1 else 0
195
196        return indices[pos].item()
state_in_first_order
206    @property
207    def state_in_first_order(self):
208        return self.dt is None
def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = inf, s_noise: float = 1.0, generator: Optional[torch._C.Generator] = None, return_dict: bool = True, omega: Union[float, <built-in function array>] = 0.0) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
210    def step(
211        self,
212        model_output: torch.FloatTensor,
213        timestep: Union[float, torch.FloatTensor],
214        sample: torch.FloatTensor,
215        s_churn: float = 0.0,
216        s_tmin: float = 0.0,
217        s_tmax: float = float("inf"),
218        s_noise: float = 1.0,
219        generator: Optional[torch.Generator] = None,
220        return_dict: bool = True,
221        omega: Union[float, np.array] = 0.0,
222    ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
223        """
224        Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
225        process from the learned model outputs (most often the predicted noise).
226
227        Args:
228            model_output (`torch.FloatTensor`):
229                The direct output from learned diffusion model.
230            timestep (`float`):
231                The current discrete timestep in the diffusion chain.
232            sample (`torch.FloatTensor`):
233                A current instance of a sample created by the diffusion process.
234            s_churn (`float`):
235            s_tmin  (`float`):
236            s_tmax  (`float`):
237            s_noise (`float`, defaults to 1.0):
238                Scaling factor for noise added to the sample.
239            generator (`torch.Generator`, *optional*):
240                A random number generator.
241            return_dict (`bool`):
242                Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or
243                tuple.
244
245        Returns:
246            [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
247                If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is
248                returned, otherwise a tuple is returned where the first element is the sample tensor.
249        """
250
251        def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1):
252            # L = Lower bound
253            # U = Upper bound
254            # x_0 = Midpoint (x corresponding to y = 1.0)
255            # k = Steepness, can adjust based on preference
256
257            if isinstance(x, torch.Tensor):
258                device_ = x.device
259                x = x.to(torch.float).cpu().numpy()
260
261            new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0)))
262
263            if isinstance(new_x, np.ndarray):
264                new_x = torch.from_numpy(new_x).to(device_)
265            return new_x
266
267        self.omega_bef_rescale = omega
268        omega = logistic_function(omega, k=0.1)
269        self.omega_aft_rescale = omega
270
271        if (
272            isinstance(timestep, int)
273            or isinstance(timestep, torch.IntTensor)
274            or isinstance(timestep, torch.LongTensor)
275        ):
276            raise ValueError(
277                (
278                    "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
279                    " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass"
280                    " one of the `scheduler.timesteps` as a timestep."
281                ),
282            )
283
284        if self.step_index is None:
285            self._init_step_index(timestep)
286
287        # Upcast to avoid precision issues when computing prev_sample
288        sample = sample.to(torch.float32)
289
290        if self.state_in_first_order:
291            sigma = self.sigmas[self.step_index]
292            sigma_next = self.sigmas[self.step_index + 1]
293        else:
294            # 2nd order / Heun's method
295            sigma = self.sigmas[self.step_index - 1]
296            sigma_next = self.sigmas[self.step_index]
297
298        gamma = (
299            min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
300            if s_tmin <= sigma <= s_tmax
301            else 0.0
302        )
303
304        sigma_hat = sigma * (gamma + 1)
305
306        if gamma > 0:
307            noise = randn_tensor(
308                model_output.shape,
309                dtype=model_output.dtype,
310                device=model_output.device,
311                generator=generator,
312            )
313            eps = noise * s_noise
314            sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
315
316        if self.state_in_first_order:
317            # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
318            denoised = sample - model_output * sigma
319            # 2. convert to an ODE derivative for 1st order
320            derivative = (sample - denoised) / sigma_hat
321            # 3. Delta timestep
322            dt = sigma_next - sigma_hat
323
324            # store for 2nd order step
325            self.prev_derivative = derivative
326            self.dt = dt
327            self.sample = sample
328        else:
329            # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
330            denoised = sample - model_output * sigma_next
331            # 2. 2nd order / Heun's method
332            derivative = (sample - denoised) / sigma_next
333            derivative = 0.5 * (self.prev_derivative + derivative)
334
335            # 3. take prev timestep & sample
336            dt = self.dt
337            sample = self.sample
338
339            # free dt and derivative
340            # Note, this puts the scheduler in "first order mode"
341            self.prev_derivative = None
342            self.dt = None
343            self.sample = None
344
345        # original sample way
346        # prev_sample = sample + derivative * dt
347
348        dx = derivative * dt
349        m = dx.mean()
350        dx_ = (dx - m) * omega + m
351        prev_sample = sample + dx_
352
353        # Cast sample back to model compatible dtype
354        prev_sample = prev_sample.to(model_output.dtype)
355
356        # upon completion increase step index by one
357        self._step_index += 1
358
359        if not return_dict:
360            return (prev_sample,)
361
362        return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)

Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise).

Args: model_output (torch.FloatTensor): The direct output from learned diffusion model. timestep (float): The current discrete timestep in the diffusion chain. sample (torch.FloatTensor): A current instance of a sample created by the diffusion process. s_churn (float): s_tmin (float): s_tmax (float): s_noise (float, defaults to 1.0): Scaling factor for noise added to the sample. generator (torch.Generator, *optional*): A random number generator. return_dict (bool): Whether or not to return a [~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput] or tuple.

Returns: [~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput] or tuple: If return_dict is True, [~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput] is returned, otherwise a tuple is returned where the first element is the sample tensor.