divisor.acestep.schedulers.scheduling_flow_match_heun_discrete
1# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from dataclasses import dataclass 16from typing import Optional, Tuple, Union 17 18import numpy as np 19import torch 20 21from diffusers.configuration_utils import ConfigMixin, register_to_config 22from diffusers.utils import BaseOutput, logging 23from diffusers.utils.torch_utils import randn_tensor 24from diffusers.schedulers.scheduling_utils import SchedulerMixin 25 26 27logger = logging.get_logger(__name__) # pylint: disable=invalid-name 28 29 30@dataclass 31class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput): 32 """ 33 Output class for the scheduler's `step` function output. 34 35 Args: 36 prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 37 Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 38 denoising loop. 39 """ 40 41 prev_sample: torch.FloatTensor 42 43 44class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin): 45 """ 46 Heun scheduler. 47 48 This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 49 methods the library implements for all schedulers such as loading and saving. 50 51 Args: 52 num_train_timesteps (`int`, defaults to 1000): 53 The number of diffusion steps to train the model. 54 timestep_spacing (`str`, defaults to `"linspace"`): 55 The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 56 Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 57 shift (`float`, defaults to 1.0): 58 The shift value for the timestep schedule. 59 """ 60 61 _compatibles = [] 62 order = 2 63 64 @register_to_config 65 def __init__( 66 self, 67 num_train_timesteps: int = 1000, 68 shift: float = 1.0, 69 sigma_max: Optional[float] = 1.0, 70 ): 71 timesteps = np.linspace( 72 1.0, sigma_max*num_train_timesteps, num_train_timesteps, dtype=np.float32 73 )[::-1].copy() 74 timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) 75 76 sigmas = timesteps / num_train_timesteps 77 sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) 78 79 self.timesteps = sigmas * num_train_timesteps 80 81 self._step_index = None 82 self._begin_index = None 83 84 self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication 85 self.sigma_min = self.sigmas[-1].item() 86 self.sigma_max = self.sigmas[0].item() 87 88 @property 89 def step_index(self): 90 """ 91 The index counter for current timestep. It will increase 1 after each scheduler step. 92 """ 93 return self._step_index 94 95 @property 96 def begin_index(self): 97 """ 98 The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 99 """ 100 return self._begin_index 101 102 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 103 def set_begin_index(self, begin_index: int = 0): 104 """ 105 Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 106 107 Args: 108 begin_index (`int`): 109 The begin index for the scheduler. 110 """ 111 self._begin_index = begin_index 112 113 def scale_noise( 114 self, 115 sample: torch.FloatTensor, 116 timestep: Union[float, torch.FloatTensor], 117 noise: Optional[torch.FloatTensor] = None, 118 ) -> torch.FloatTensor: 119 """ 120 Forward process in flow-matching 121 122 Args: 123 sample (`torch.FloatTensor`): 124 The input sample. 125 timestep (`int`, *optional*): 126 The current timestep in the diffusion chain. 127 128 Returns: 129 `torch.FloatTensor`: 130 A scaled input sample. 131 """ 132 if self.step_index is None: 133 self._init_step_index(timestep) 134 135 sigma = self.sigmas[self.step_index] 136 sample = sigma * noise + (1.0 - sigma) * sample 137 138 return sample 139 140 def _sigma_to_t(self, sigma): 141 return sigma * self.config.num_train_timesteps 142 143 def set_timesteps( 144 self, num_inference_steps: int, device: Union[str, torch.device] = None 145 ): 146 """ 147 Sets the discrete timesteps used for the diffusion chain (to be run before inference). 148 149 Args: 150 num_inference_steps (`int`): 151 The number of diffusion steps used when generating samples with a pre-trained model. 152 device (`str` or `torch.device`, *optional*): 153 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 154 """ 155 self.num_inference_steps = num_inference_steps 156 157 timesteps = np.linspace( 158 self._sigma_to_t(self.sigma_max), 159 self._sigma_to_t(self.sigma_min), 160 num_inference_steps, 161 ) 162 163 sigmas = timesteps / self.config.num_train_timesteps 164 sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) 165 sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) 166 167 timesteps = sigmas * self.config.num_train_timesteps 168 timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) 169 self.timesteps = timesteps.to(device=device) 170 171 sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 172 self.sigmas = torch.cat( 173 [sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]] 174 ) 175 176 # empty dt and derivative 177 self.prev_derivative = None 178 self.dt = None 179 180 self._step_index = None 181 self._begin_index = None 182 183 def index_for_timestep(self, timestep, schedule_timesteps=None): 184 if schedule_timesteps is None: 185 schedule_timesteps = self.timesteps 186 187 indices = (schedule_timesteps == timestep).nonzero() 188 189 # The sigma index that is taken for the **very** first `step` 190 # is always the second index (or the last index if there is only 1) 191 # This way we can ensure we don't accidentally skip a sigma in 192 # case we start in the middle of the denoising schedule (e.g. for image-to-image) 193 pos = 1 if len(indices) > 1 else 0 194 195 return indices[pos].item() 196 197 def _init_step_index(self, timestep): 198 if self.begin_index is None: 199 if isinstance(timestep, torch.Tensor): 200 timestep = timestep.to(self.timesteps.device) 201 self._step_index = self.index_for_timestep(timestep) 202 else: 203 self._step_index = self._begin_index 204 205 @property 206 def state_in_first_order(self): 207 return self.dt is None 208 209 def step( 210 self, 211 model_output: torch.FloatTensor, 212 timestep: Union[float, torch.FloatTensor], 213 sample: torch.FloatTensor, 214 s_churn: float = 0.0, 215 s_tmin: float = 0.0, 216 s_tmax: float = float("inf"), 217 s_noise: float = 1.0, 218 generator: Optional[torch.Generator] = None, 219 return_dict: bool = True, 220 omega: Union[float, np.array] = 0.0, 221 ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]: 222 """ 223 Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 224 process from the learned model outputs (most often the predicted noise). 225 226 Args: 227 model_output (`torch.FloatTensor`): 228 The direct output from learned diffusion model. 229 timestep (`float`): 230 The current discrete timestep in the diffusion chain. 231 sample (`torch.FloatTensor`): 232 A current instance of a sample created by the diffusion process. 233 s_churn (`float`): 234 s_tmin (`float`): 235 s_tmax (`float`): 236 s_noise (`float`, defaults to 1.0): 237 Scaling factor for noise added to the sample. 238 generator (`torch.Generator`, *optional*): 239 A random number generator. 240 return_dict (`bool`): 241 Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or 242 tuple. 243 244 Returns: 245 [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`: 246 If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is 247 returned, otherwise a tuple is returned where the first element is the sample tensor. 248 """ 249 250 def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1): 251 # L = Lower bound 252 # U = Upper bound 253 # x_0 = Midpoint (x corresponding to y = 1.0) 254 # k = Steepness, can adjust based on preference 255 256 if isinstance(x, torch.Tensor): 257 device_ = x.device 258 x = x.to(torch.float).cpu().numpy() 259 260 new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0))) 261 262 if isinstance(new_x, np.ndarray): 263 new_x = torch.from_numpy(new_x).to(device_) 264 return new_x 265 266 self.omega_bef_rescale = omega 267 omega = logistic_function(omega, k=0.1) 268 self.omega_aft_rescale = omega 269 270 if ( 271 isinstance(timestep, int) 272 or isinstance(timestep, torch.IntTensor) 273 or isinstance(timestep, torch.LongTensor) 274 ): 275 raise ValueError( 276 ( 277 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 278 " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass" 279 " one of the `scheduler.timesteps` as a timestep." 280 ), 281 ) 282 283 if self.step_index is None: 284 self._init_step_index(timestep) 285 286 # Upcast to avoid precision issues when computing prev_sample 287 sample = sample.to(torch.float32) 288 289 if self.state_in_first_order: 290 sigma = self.sigmas[self.step_index] 291 sigma_next = self.sigmas[self.step_index + 1] 292 else: 293 # 2nd order / Heun's method 294 sigma = self.sigmas[self.step_index - 1] 295 sigma_next = self.sigmas[self.step_index] 296 297 gamma = ( 298 min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) 299 if s_tmin <= sigma <= s_tmax 300 else 0.0 301 ) 302 303 sigma_hat = sigma * (gamma + 1) 304 305 if gamma > 0: 306 noise = randn_tensor( 307 model_output.shape, 308 dtype=model_output.dtype, 309 device=model_output.device, 310 generator=generator, 311 ) 312 eps = noise * s_noise 313 sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 314 315 if self.state_in_first_order: 316 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 317 denoised = sample - model_output * sigma 318 # 2. convert to an ODE derivative for 1st order 319 derivative = (sample - denoised) / sigma_hat 320 # 3. Delta timestep 321 dt = sigma_next - sigma_hat 322 323 # store for 2nd order step 324 self.prev_derivative = derivative 325 self.dt = dt 326 self.sample = sample 327 else: 328 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 329 denoised = sample - model_output * sigma_next 330 # 2. 2nd order / Heun's method 331 derivative = (sample - denoised) / sigma_next 332 derivative = 0.5 * (self.prev_derivative + derivative) 333 334 # 3. take prev timestep & sample 335 dt = self.dt 336 sample = self.sample 337 338 # free dt and derivative 339 # Note, this puts the scheduler in "first order mode" 340 self.prev_derivative = None 341 self.dt = None 342 self.sample = None 343 344 # original sample way 345 # prev_sample = sample + derivative * dt 346 347 dx = derivative * dt 348 m = dx.mean() 349 dx_ = (dx - m) * omega + m 350 prev_sample = sample + dx_ 351 352 # Cast sample back to model compatible dtype 353 prev_sample = prev_sample.to(model_output.dtype) 354 355 # upon completion increase step index by one 356 self._step_index += 1 357 358 if not return_dict: 359 return (prev_sample,) 360 361 return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample) 362 363 def __len__(self): 364 return self.config.num_train_timesteps
31@dataclass 32class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput): 33 """ 34 Output class for the scheduler's `step` function output. 35 36 Args: 37 prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 38 Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 39 denoising loop. 40 """ 41 42 prev_sample: torch.FloatTensor
Output class for the scheduler's step function output.
Args:
prev_sample (torch.FloatTensor of shape (batch_size, num_channels, height, width) for images):
Computed sample (x_{t-1}) of previous timestep. prev_sample should be used as next model input in the
denoising loop.
45class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin): 46 """ 47 Heun scheduler. 48 49 This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 50 methods the library implements for all schedulers such as loading and saving. 51 52 Args: 53 num_train_timesteps (`int`, defaults to 1000): 54 The number of diffusion steps to train the model. 55 timestep_spacing (`str`, defaults to `"linspace"`): 56 The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 57 Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 58 shift (`float`, defaults to 1.0): 59 The shift value for the timestep schedule. 60 """ 61 62 _compatibles = [] 63 order = 2 64 65 @register_to_config 66 def __init__( 67 self, 68 num_train_timesteps: int = 1000, 69 shift: float = 1.0, 70 sigma_max: Optional[float] = 1.0, 71 ): 72 timesteps = np.linspace( 73 1.0, sigma_max*num_train_timesteps, num_train_timesteps, dtype=np.float32 74 )[::-1].copy() 75 timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) 76 77 sigmas = timesteps / num_train_timesteps 78 sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) 79 80 self.timesteps = sigmas * num_train_timesteps 81 82 self._step_index = None 83 self._begin_index = None 84 85 self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication 86 self.sigma_min = self.sigmas[-1].item() 87 self.sigma_max = self.sigmas[0].item() 88 89 @property 90 def step_index(self): 91 """ 92 The index counter for current timestep. It will increase 1 after each scheduler step. 93 """ 94 return self._step_index 95 96 @property 97 def begin_index(self): 98 """ 99 The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 100 """ 101 return self._begin_index 102 103 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 104 def set_begin_index(self, begin_index: int = 0): 105 """ 106 Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 107 108 Args: 109 begin_index (`int`): 110 The begin index for the scheduler. 111 """ 112 self._begin_index = begin_index 113 114 def scale_noise( 115 self, 116 sample: torch.FloatTensor, 117 timestep: Union[float, torch.FloatTensor], 118 noise: Optional[torch.FloatTensor] = None, 119 ) -> torch.FloatTensor: 120 """ 121 Forward process in flow-matching 122 123 Args: 124 sample (`torch.FloatTensor`): 125 The input sample. 126 timestep (`int`, *optional*): 127 The current timestep in the diffusion chain. 128 129 Returns: 130 `torch.FloatTensor`: 131 A scaled input sample. 132 """ 133 if self.step_index is None: 134 self._init_step_index(timestep) 135 136 sigma = self.sigmas[self.step_index] 137 sample = sigma * noise + (1.0 - sigma) * sample 138 139 return sample 140 141 def _sigma_to_t(self, sigma): 142 return sigma * self.config.num_train_timesteps 143 144 def set_timesteps( 145 self, num_inference_steps: int, device: Union[str, torch.device] = None 146 ): 147 """ 148 Sets the discrete timesteps used for the diffusion chain (to be run before inference). 149 150 Args: 151 num_inference_steps (`int`): 152 The number of diffusion steps used when generating samples with a pre-trained model. 153 device (`str` or `torch.device`, *optional*): 154 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 155 """ 156 self.num_inference_steps = num_inference_steps 157 158 timesteps = np.linspace( 159 self._sigma_to_t(self.sigma_max), 160 self._sigma_to_t(self.sigma_min), 161 num_inference_steps, 162 ) 163 164 sigmas = timesteps / self.config.num_train_timesteps 165 sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) 166 sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) 167 168 timesteps = sigmas * self.config.num_train_timesteps 169 timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) 170 self.timesteps = timesteps.to(device=device) 171 172 sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 173 self.sigmas = torch.cat( 174 [sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]] 175 ) 176 177 # empty dt and derivative 178 self.prev_derivative = None 179 self.dt = None 180 181 self._step_index = None 182 self._begin_index = None 183 184 def index_for_timestep(self, timestep, schedule_timesteps=None): 185 if schedule_timesteps is None: 186 schedule_timesteps = self.timesteps 187 188 indices = (schedule_timesteps == timestep).nonzero() 189 190 # The sigma index that is taken for the **very** first `step` 191 # is always the second index (or the last index if there is only 1) 192 # This way we can ensure we don't accidentally skip a sigma in 193 # case we start in the middle of the denoising schedule (e.g. for image-to-image) 194 pos = 1 if len(indices) > 1 else 0 195 196 return indices[pos].item() 197 198 def _init_step_index(self, timestep): 199 if self.begin_index is None: 200 if isinstance(timestep, torch.Tensor): 201 timestep = timestep.to(self.timesteps.device) 202 self._step_index = self.index_for_timestep(timestep) 203 else: 204 self._step_index = self._begin_index 205 206 @property 207 def state_in_first_order(self): 208 return self.dt is None 209 210 def step( 211 self, 212 model_output: torch.FloatTensor, 213 timestep: Union[float, torch.FloatTensor], 214 sample: torch.FloatTensor, 215 s_churn: float = 0.0, 216 s_tmin: float = 0.0, 217 s_tmax: float = float("inf"), 218 s_noise: float = 1.0, 219 generator: Optional[torch.Generator] = None, 220 return_dict: bool = True, 221 omega: Union[float, np.array] = 0.0, 222 ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]: 223 """ 224 Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 225 process from the learned model outputs (most often the predicted noise). 226 227 Args: 228 model_output (`torch.FloatTensor`): 229 The direct output from learned diffusion model. 230 timestep (`float`): 231 The current discrete timestep in the diffusion chain. 232 sample (`torch.FloatTensor`): 233 A current instance of a sample created by the diffusion process. 234 s_churn (`float`): 235 s_tmin (`float`): 236 s_tmax (`float`): 237 s_noise (`float`, defaults to 1.0): 238 Scaling factor for noise added to the sample. 239 generator (`torch.Generator`, *optional*): 240 A random number generator. 241 return_dict (`bool`): 242 Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or 243 tuple. 244 245 Returns: 246 [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`: 247 If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is 248 returned, otherwise a tuple is returned where the first element is the sample tensor. 249 """ 250 251 def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1): 252 # L = Lower bound 253 # U = Upper bound 254 # x_0 = Midpoint (x corresponding to y = 1.0) 255 # k = Steepness, can adjust based on preference 256 257 if isinstance(x, torch.Tensor): 258 device_ = x.device 259 x = x.to(torch.float).cpu().numpy() 260 261 new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0))) 262 263 if isinstance(new_x, np.ndarray): 264 new_x = torch.from_numpy(new_x).to(device_) 265 return new_x 266 267 self.omega_bef_rescale = omega 268 omega = logistic_function(omega, k=0.1) 269 self.omega_aft_rescale = omega 270 271 if ( 272 isinstance(timestep, int) 273 or isinstance(timestep, torch.IntTensor) 274 or isinstance(timestep, torch.LongTensor) 275 ): 276 raise ValueError( 277 ( 278 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 279 " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass" 280 " one of the `scheduler.timesteps` as a timestep." 281 ), 282 ) 283 284 if self.step_index is None: 285 self._init_step_index(timestep) 286 287 # Upcast to avoid precision issues when computing prev_sample 288 sample = sample.to(torch.float32) 289 290 if self.state_in_first_order: 291 sigma = self.sigmas[self.step_index] 292 sigma_next = self.sigmas[self.step_index + 1] 293 else: 294 # 2nd order / Heun's method 295 sigma = self.sigmas[self.step_index - 1] 296 sigma_next = self.sigmas[self.step_index] 297 298 gamma = ( 299 min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) 300 if s_tmin <= sigma <= s_tmax 301 else 0.0 302 ) 303 304 sigma_hat = sigma * (gamma + 1) 305 306 if gamma > 0: 307 noise = randn_tensor( 308 model_output.shape, 309 dtype=model_output.dtype, 310 device=model_output.device, 311 generator=generator, 312 ) 313 eps = noise * s_noise 314 sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 315 316 if self.state_in_first_order: 317 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 318 denoised = sample - model_output * sigma 319 # 2. convert to an ODE derivative for 1st order 320 derivative = (sample - denoised) / sigma_hat 321 # 3. Delta timestep 322 dt = sigma_next - sigma_hat 323 324 # store for 2nd order step 325 self.prev_derivative = derivative 326 self.dt = dt 327 self.sample = sample 328 else: 329 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 330 denoised = sample - model_output * sigma_next 331 # 2. 2nd order / Heun's method 332 derivative = (sample - denoised) / sigma_next 333 derivative = 0.5 * (self.prev_derivative + derivative) 334 335 # 3. take prev timestep & sample 336 dt = self.dt 337 sample = self.sample 338 339 # free dt and derivative 340 # Note, this puts the scheduler in "first order mode" 341 self.prev_derivative = None 342 self.dt = None 343 self.sample = None 344 345 # original sample way 346 # prev_sample = sample + derivative * dt 347 348 dx = derivative * dt 349 m = dx.mean() 350 dx_ = (dx - m) * omega + m 351 prev_sample = sample + dx_ 352 353 # Cast sample back to model compatible dtype 354 prev_sample = prev_sample.to(model_output.dtype) 355 356 # upon completion increase step index by one 357 self._step_index += 1 358 359 if not return_dict: 360 return (prev_sample,) 361 362 return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample) 363 364 def __len__(self): 365 return self.config.num_train_timesteps
Heun scheduler.
This model inherits from [SchedulerMixin] and [ConfigMixin]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (int, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (str, defaults to "linspace"):
The way the timesteps should be scaled. Refer to Table 2 of the Common Diffusion Noise Schedules and
Sample Steps are Flawed for more information.
shift (float, defaults to 1.0):
The shift value for the timestep schedule.
65 @register_to_config 66 def __init__( 67 self, 68 num_train_timesteps: int = 1000, 69 shift: float = 1.0, 70 sigma_max: Optional[float] = 1.0, 71 ): 72 timesteps = np.linspace( 73 1.0, sigma_max*num_train_timesteps, num_train_timesteps, dtype=np.float32 74 )[::-1].copy() 75 timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) 76 77 sigmas = timesteps / num_train_timesteps 78 sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) 79 80 self.timesteps = sigmas * num_train_timesteps 81 82 self._step_index = None 83 self._begin_index = None 84 85 self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication 86 self.sigma_min = self.sigmas[-1].item() 87 self.sigma_max = self.sigmas[0].item()
89 @property 90 def step_index(self): 91 """ 92 The index counter for current timestep. It will increase 1 after each scheduler step. 93 """ 94 return self._step_index
The index counter for current timestep. It will increase 1 after each scheduler step.
96 @property 97 def begin_index(self): 98 """ 99 The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 100 """ 101 return self._begin_index
The index for the first timestep. It should be set from pipeline with set_begin_index method.
104 def set_begin_index(self, begin_index: int = 0): 105 """ 106 Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 107 108 Args: 109 begin_index (`int`): 110 The begin index for the scheduler. 111 """ 112 self._begin_index = begin_index
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (int):
The begin index for the scheduler.
114 def scale_noise( 115 self, 116 sample: torch.FloatTensor, 117 timestep: Union[float, torch.FloatTensor], 118 noise: Optional[torch.FloatTensor] = None, 119 ) -> torch.FloatTensor: 120 """ 121 Forward process in flow-matching 122 123 Args: 124 sample (`torch.FloatTensor`): 125 The input sample. 126 timestep (`int`, *optional*): 127 The current timestep in the diffusion chain. 128 129 Returns: 130 `torch.FloatTensor`: 131 A scaled input sample. 132 """ 133 if self.step_index is None: 134 self._init_step_index(timestep) 135 136 sigma = self.sigmas[self.step_index] 137 sample = sigma * noise + (1.0 - sigma) * sample 138 139 return sample
Forward process in flow-matching
Args:
sample (torch.FloatTensor):
The input sample.
timestep (int, optional):
The current timestep in the diffusion chain.
Returns:
torch.FloatTensor:
A scaled input sample.
144 def set_timesteps( 145 self, num_inference_steps: int, device: Union[str, torch.device] = None 146 ): 147 """ 148 Sets the discrete timesteps used for the diffusion chain (to be run before inference). 149 150 Args: 151 num_inference_steps (`int`): 152 The number of diffusion steps used when generating samples with a pre-trained model. 153 device (`str` or `torch.device`, *optional*): 154 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 155 """ 156 self.num_inference_steps = num_inference_steps 157 158 timesteps = np.linspace( 159 self._sigma_to_t(self.sigma_max), 160 self._sigma_to_t(self.sigma_min), 161 num_inference_steps, 162 ) 163 164 sigmas = timesteps / self.config.num_train_timesteps 165 sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) 166 sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) 167 168 timesteps = sigmas * self.config.num_train_timesteps 169 timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) 170 self.timesteps = timesteps.to(device=device) 171 172 sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 173 self.sigmas = torch.cat( 174 [sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]] 175 ) 176 177 # empty dt and derivative 178 self.prev_derivative = None 179 self.dt = None 180 181 self._step_index = None 182 self._begin_index = None
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (int):
The number of diffusion steps used when generating samples with a pre-trained model.
device (str or torch.device, optional):
The device to which the timesteps should be moved to. If None, the timesteps are not moved.
184 def index_for_timestep(self, timestep, schedule_timesteps=None): 185 if schedule_timesteps is None: 186 schedule_timesteps = self.timesteps 187 188 indices = (schedule_timesteps == timestep).nonzero() 189 190 # The sigma index that is taken for the **very** first `step` 191 # is always the second index (or the last index if there is only 1) 192 # This way we can ensure we don't accidentally skip a sigma in 193 # case we start in the middle of the denoising schedule (e.g. for image-to-image) 194 pos = 1 if len(indices) > 1 else 0 195 196 return indices[pos].item()
210 def step( 211 self, 212 model_output: torch.FloatTensor, 213 timestep: Union[float, torch.FloatTensor], 214 sample: torch.FloatTensor, 215 s_churn: float = 0.0, 216 s_tmin: float = 0.0, 217 s_tmax: float = float("inf"), 218 s_noise: float = 1.0, 219 generator: Optional[torch.Generator] = None, 220 return_dict: bool = True, 221 omega: Union[float, np.array] = 0.0, 222 ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]: 223 """ 224 Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 225 process from the learned model outputs (most often the predicted noise). 226 227 Args: 228 model_output (`torch.FloatTensor`): 229 The direct output from learned diffusion model. 230 timestep (`float`): 231 The current discrete timestep in the diffusion chain. 232 sample (`torch.FloatTensor`): 233 A current instance of a sample created by the diffusion process. 234 s_churn (`float`): 235 s_tmin (`float`): 236 s_tmax (`float`): 237 s_noise (`float`, defaults to 1.0): 238 Scaling factor for noise added to the sample. 239 generator (`torch.Generator`, *optional*): 240 A random number generator. 241 return_dict (`bool`): 242 Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or 243 tuple. 244 245 Returns: 246 [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`: 247 If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is 248 returned, otherwise a tuple is returned where the first element is the sample tensor. 249 """ 250 251 def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1): 252 # L = Lower bound 253 # U = Upper bound 254 # x_0 = Midpoint (x corresponding to y = 1.0) 255 # k = Steepness, can adjust based on preference 256 257 if isinstance(x, torch.Tensor): 258 device_ = x.device 259 x = x.to(torch.float).cpu().numpy() 260 261 new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0))) 262 263 if isinstance(new_x, np.ndarray): 264 new_x = torch.from_numpy(new_x).to(device_) 265 return new_x 266 267 self.omega_bef_rescale = omega 268 omega = logistic_function(omega, k=0.1) 269 self.omega_aft_rescale = omega 270 271 if ( 272 isinstance(timestep, int) 273 or isinstance(timestep, torch.IntTensor) 274 or isinstance(timestep, torch.LongTensor) 275 ): 276 raise ValueError( 277 ( 278 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 279 " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass" 280 " one of the `scheduler.timesteps` as a timestep." 281 ), 282 ) 283 284 if self.step_index is None: 285 self._init_step_index(timestep) 286 287 # Upcast to avoid precision issues when computing prev_sample 288 sample = sample.to(torch.float32) 289 290 if self.state_in_first_order: 291 sigma = self.sigmas[self.step_index] 292 sigma_next = self.sigmas[self.step_index + 1] 293 else: 294 # 2nd order / Heun's method 295 sigma = self.sigmas[self.step_index - 1] 296 sigma_next = self.sigmas[self.step_index] 297 298 gamma = ( 299 min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) 300 if s_tmin <= sigma <= s_tmax 301 else 0.0 302 ) 303 304 sigma_hat = sigma * (gamma + 1) 305 306 if gamma > 0: 307 noise = randn_tensor( 308 model_output.shape, 309 dtype=model_output.dtype, 310 device=model_output.device, 311 generator=generator, 312 ) 313 eps = noise * s_noise 314 sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 315 316 if self.state_in_first_order: 317 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 318 denoised = sample - model_output * sigma 319 # 2. convert to an ODE derivative for 1st order 320 derivative = (sample - denoised) / sigma_hat 321 # 3. Delta timestep 322 dt = sigma_next - sigma_hat 323 324 # store for 2nd order step 325 self.prev_derivative = derivative 326 self.dt = dt 327 self.sample = sample 328 else: 329 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 330 denoised = sample - model_output * sigma_next 331 # 2. 2nd order / Heun's method 332 derivative = (sample - denoised) / sigma_next 333 derivative = 0.5 * (self.prev_derivative + derivative) 334 335 # 3. take prev timestep & sample 336 dt = self.dt 337 sample = self.sample 338 339 # free dt and derivative 340 # Note, this puts the scheduler in "first order mode" 341 self.prev_derivative = None 342 self.dt = None 343 self.sample = None 344 345 # original sample way 346 # prev_sample = sample + derivative * dt 347 348 dx = derivative * dt 349 m = dx.mean() 350 dx_ = (dx - m) * omega + m 351 prev_sample = sample + dx_ 352 353 # Cast sample back to model compatible dtype 354 prev_sample = prev_sample.to(model_output.dtype) 355 356 # upon completion increase step index by one 357 self._step_index += 1 358 359 if not return_dict: 360 return (prev_sample,) 361 362 return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise).
Args:
model_output (torch.FloatTensor):
The direct output from learned diffusion model.
timestep (float):
The current discrete timestep in the diffusion chain.
sample (torch.FloatTensor):
A current instance of a sample created by the diffusion process.
s_churn (float):
s_tmin (float):
s_tmax (float):
s_noise (float, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (torch.Generator, *optional*):
A random number generator.
return_dict (bool):
Whether or not to return a [~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput] or
tuple.
Returns:
[~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput] or tuple:
If return_dict is True, [~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput] is
returned, otherwise a tuple is returned where the first element is the sample tensor.