divisor.acestep.schedulers.scheduling_flow_match_euler_discrete
1# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import math 16from dataclasses import dataclass 17from typing import List, Optional, Tuple, Union 18 19import numpy as np 20import torch 21 22from diffusers.configuration_utils import ConfigMixin, register_to_config 23from diffusers.utils import BaseOutput, logging 24from diffusers.schedulers.scheduling_utils import SchedulerMixin 25 26 27logger = logging.get_logger(__name__) # pylint: disable=invalid-name 28 29 30@dataclass 31class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): 32 """ 33 Output class for the scheduler's `step` function output. 34 35 Args: 36 prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 37 Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 38 denoising loop. 39 """ 40 41 prev_sample: torch.FloatTensor 42 43 44class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): 45 """ 46 Euler scheduler. 47 48 This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 49 methods the library implements for all schedulers such as loading and saving. 50 51 Args: 52 num_train_timesteps (`int`, defaults to 1000): 53 The number of diffusion steps to train the model. 54 timestep_spacing (`str`, defaults to `"linspace"`): 55 The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 56 Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 57 shift (`float`, defaults to 1.0): 58 The shift value for the timestep schedule. 59 """ 60 61 _compatibles = [] 62 order = 1 63 64 @register_to_config 65 def __init__( 66 self, 67 num_train_timesteps: int = 1000, 68 shift: float = 1.0, 69 use_dynamic_shifting=False, 70 base_shift: Optional[float] = 0.5, 71 max_shift: Optional[float] = 1.15, 72 base_image_seq_len: Optional[int] = 256, 73 max_image_seq_len: Optional[int] = 4096, 74 sigma_max: Optional[float] = 1.0, 75 ): 76 timesteps = np.linspace( 77 1.0, sigma_max*num_train_timesteps, num_train_timesteps, dtype=np.float32 78 )[::-1].copy() 79 timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) 80 81 sigmas = timesteps / num_train_timesteps 82 if not use_dynamic_shifting: 83 # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution 84 sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) 85 86 self.timesteps = sigmas * num_train_timesteps 87 88 self._step_index = None 89 self._begin_index = None 90 91 self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication 92 self.sigma_min = self.sigmas[-1].item() 93 self.sigma_max = self.sigmas[0].item() 94 95 @property 96 def step_index(self): 97 """ 98 The index counter for current timestep. It will increase 1 after each scheduler step. 99 """ 100 return self._step_index 101 102 @property 103 def begin_index(self): 104 """ 105 The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 106 """ 107 return self._begin_index 108 109 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 110 def set_begin_index(self, begin_index: int = 0): 111 """ 112 Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 113 114 Args: 115 begin_index (`int`): 116 The begin index for the scheduler. 117 """ 118 self._begin_index = begin_index 119 120 def scale_noise( 121 self, 122 sample: torch.FloatTensor, 123 timestep: Union[float, torch.FloatTensor], 124 noise: Optional[torch.FloatTensor] = None, 125 ) -> torch.FloatTensor: 126 """ 127 Forward process in flow-matching 128 129 Args: 130 sample (`torch.FloatTensor`): 131 The input sample. 132 timestep (`int`, *optional*): 133 The current timestep in the diffusion chain. 134 135 Returns: 136 `torch.FloatTensor`: 137 A scaled input sample. 138 """ 139 # Make sure sigmas and timesteps have the same device and dtype as original_samples 140 sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) 141 142 if sample.device.type == "mps" and torch.is_floating_point(timestep): 143 # mps does not support float64 144 schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) 145 timestep = timestep.to(sample.device, dtype=torch.float32) 146 else: 147 schedule_timesteps = self.timesteps.to(sample.device) 148 timestep = timestep.to(sample.device) 149 150 # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index 151 if self.begin_index is None: 152 step_indices = [ 153 self.index_for_timestep(t, schedule_timesteps) for t in timestep 154 ] 155 elif self.step_index is not None: 156 # add_noise is called after first denoising step (for inpainting) 157 step_indices = [self.step_index] * timestep.shape[0] 158 else: 159 # add noise is called before first denoising step to create initial latent(img2img) 160 step_indices = [self.begin_index] * timestep.shape[0] 161 162 sigma = sigmas[step_indices].flatten() 163 while len(sigma.shape) < len(sample.shape): 164 sigma = sigma.unsqueeze(-1) 165 166 sample = sigma * noise + (1.0 - sigma) * sample 167 168 return sample 169 170 def _sigma_to_t(self, sigma): 171 return sigma * self.config.num_train_timesteps 172 173 def time_shift(self, mu: float, sigma: float, t: torch.Tensor): 174 return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 175 176 def set_timesteps( 177 self, 178 num_inference_steps: int = None, 179 device: Union[str, torch.device] = None, 180 sigmas: Optional[List[float]] = None, 181 mu: Optional[float] = None, 182 ): 183 """ 184 Sets the discrete timesteps used for the diffusion chain (to be run before inference). 185 186 Args: 187 num_inference_steps (`int`): 188 The number of diffusion steps used when generating samples with a pre-trained model. 189 device (`str` or `torch.device`, *optional*): 190 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 191 """ 192 193 if self.config.use_dynamic_shifting and mu is None: 194 raise ValueError( 195 " you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" 196 ) 197 198 if sigmas is None: 199 self.num_inference_steps = num_inference_steps 200 timesteps = np.linspace( 201 self._sigma_to_t(self.sigma_max), 202 self._sigma_to_t(self.sigma_min), 203 num_inference_steps, 204 ) 205 206 sigmas = timesteps / self.config.num_train_timesteps 207 208 if self.config.use_dynamic_shifting: 209 sigmas = self.time_shift(mu, 1.0, sigmas) 210 else: 211 sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) 212 213 sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) 214 timesteps = sigmas * self.config.num_train_timesteps 215 216 self.timesteps = timesteps.to(device=device) 217 self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 218 219 self._step_index = None 220 self._begin_index = None 221 222 def index_for_timestep(self, timestep, schedule_timesteps=None): 223 if schedule_timesteps is None: 224 schedule_timesteps = self.timesteps 225 226 indices = (schedule_timesteps == timestep).nonzero() 227 228 # The sigma index that is taken for the **very** first `step` 229 # is always the second index (or the last index if there is only 1) 230 # This way we can ensure we don't accidentally skip a sigma in 231 # case we start in the middle of the denoising schedule (e.g. for image-to-image) 232 pos = 1 if len(indices) > 1 else 0 233 234 return indices[pos].item() 235 236 def _init_step_index(self, timestep): 237 if self.begin_index is None: 238 if isinstance(timestep, torch.Tensor): 239 timestep = timestep.to(self.timesteps.device) 240 self._step_index = self.index_for_timestep(timestep) 241 else: 242 self._step_index = self._begin_index 243 244 def step( 245 self, 246 model_output: torch.FloatTensor, 247 timestep: Union[float, torch.FloatTensor], 248 sample: torch.FloatTensor, 249 s_churn: float = 0.0, 250 s_tmin: float = 0.0, 251 s_tmax: float = float("inf"), 252 s_noise: float = 1.0, 253 generator: Optional[torch.Generator] = None, 254 return_dict: bool = True, 255 omega: Union[float, np.array] = 0.0, 256 ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: 257 """ 258 Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 259 process from the learned model outputs (most often the predicted noise). 260 261 Args: 262 model_output (`torch.FloatTensor`): 263 The direct output from learned diffusion model. 264 timestep (`float`): 265 The current discrete timestep in the diffusion chain. 266 sample (`torch.FloatTensor`): 267 A current instance of a sample created by the diffusion process. 268 s_churn (`float`): 269 s_tmin (`float`): 270 s_tmax (`float`): 271 s_noise (`float`, defaults to 1.0): 272 Scaling factor for noise added to the sample. 273 generator (`torch.Generator`, *optional*): 274 A random number generator. 275 return_dict (`bool`): 276 Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 277 tuple. 278 279 Returns: 280 [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 281 If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 282 returned, otherwise a tuple is returned where the first element is the sample tensor. 283 """ 284 285 def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1): 286 # L = Lower bound 287 # U = Upper bound 288 # x_0 = Midpoint (x corresponding to y = 1.0) 289 # k = Steepness, can adjust based on preference 290 291 if isinstance(x, torch.Tensor): 292 device_ = x.device 293 x = x.to(torch.float).cpu().numpy() 294 295 new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0))) 296 297 if isinstance(new_x, np.ndarray): 298 new_x = torch.from_numpy(new_x).to(device_) 299 return new_x 300 301 self.omega_bef_rescale = omega 302 omega = logistic_function(omega, k=0.1) 303 self.omega_aft_rescale = omega 304 305 if ( 306 isinstance(timestep, int) 307 or isinstance(timestep, torch.IntTensor) 308 or isinstance(timestep, torch.LongTensor) 309 ): 310 raise ValueError( 311 ( 312 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 313 " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 314 " one of the `scheduler.timesteps` as a timestep." 315 ), 316 ) 317 318 if self.step_index is None: 319 self._init_step_index(timestep) 320 321 # Upcast to avoid precision issues when computing prev_sample 322 sample = sample.to(torch.float32) 323 324 sigma = self.sigmas[self.step_index] 325 sigma_next = self.sigmas[self.step_index + 1] 326 327 ## -- 328 ## mean shift 1 329 dx = (sigma_next - sigma) * model_output 330 m = dx.mean() 331 # print(dx.shape) # torch.Size([1, 16, 128, 128]) 332 # print(f'm: {m}') # m: -0.0014209747314453125 333 # raise NotImplementedError 334 dx_ = (dx - m) * omega + m 335 prev_sample = sample + dx_ 336 337 # ## -- 338 # ## mean shift 2 339 # m = model_output.mean() 340 # model_output_ = (model_output - m) * omega + m 341 # prev_sample = sample + (sigma_next - sigma) * model_output_ 342 343 # ## -- 344 # ## original 345 # prev_sample = sample + (sigma_next - sigma) * model_output * omega 346 347 # ## -- 348 # ## spatial mean 1 349 # dx = (sigma_next - sigma) * model_output 350 # m = dx.mean(dim=(0, 1), keepdim=True) 351 # # print(dx.shape) # torch.Size([1, 16, 128, 128]) 352 # # print(m.shape) # torch.Size([1, 1, 128, 128]) 353 # # raise NotImplementedError 354 # dx_ = (dx - m) * omega + m 355 # prev_sample = sample + dx_ 356 357 # ## -- 358 # ## spatial mean 2 359 # m = model_output.mean(dim=(0, 1), keepdim=True) 360 # model_output_ = (model_output - m) * omega + m 361 # prev_sample = sample + (sigma_next - sigma) * model_output_ 362 363 # ## -- 364 # ## channel mean 1 365 # m = model_output.mean(dim=(2, 3), keepdim=True) 366 # # print(m.shape) # torch.Size([1, 16, 1, 1]) 367 # model_output_ = (model_output - m) * omega + m 368 # prev_sample = sample + (sigma_next - sigma) * model_output_ 369 370 # ## -- 371 # ## channel mean 2 372 # dx = (sigma_next - sigma) * model_output 373 # m = dx.mean(dim=(2, 3), keepdim=True) 374 # # print(m.shape) # torch.Size([1, 16, 1, 1]) 375 # dx_ = (dx - m) * omega + m 376 # prev_sample = sample + dx_ 377 378 # ## -- 379 # ## keep sample mean 380 # m_tgt = sample.mean() 381 # prev_sample_ = sample + (sigma_next - sigma) * model_output * omega 382 # m_src = prev_sample_.mean() 383 # prev_sample = prev_sample_ - m_src + m_tgt 384 385 # ## -- 386 # ## test 387 # # print(sample.mean()) 388 # prev_sample = sample + (sigma_next - sigma) * model_output * omega 389 # # raise NotImplementedError 390 391 # Cast sample back to model compatible dtype 392 prev_sample = prev_sample.to(model_output.dtype) 393 394 # upon completion increase step index by one 395 self._step_index += 1 396 397 if not return_dict: 398 return (prev_sample,) 399 400 return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) 401 402 def __len__(self): 403 return self.config.num_train_timesteps
31@dataclass 32class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): 33 """ 34 Output class for the scheduler's `step` function output. 35 36 Args: 37 prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 38 Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 39 denoising loop. 40 """ 41 42 prev_sample: torch.FloatTensor
Output class for the scheduler's step function output.
Args:
prev_sample (torch.FloatTensor of shape (batch_size, num_channels, height, width) for images):
Computed sample (x_{t-1}) of previous timestep. prev_sample should be used as next model input in the
denoising loop.
45class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): 46 """ 47 Euler scheduler. 48 49 This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 50 methods the library implements for all schedulers such as loading and saving. 51 52 Args: 53 num_train_timesteps (`int`, defaults to 1000): 54 The number of diffusion steps to train the model. 55 timestep_spacing (`str`, defaults to `"linspace"`): 56 The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 57 Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 58 shift (`float`, defaults to 1.0): 59 The shift value for the timestep schedule. 60 """ 61 62 _compatibles = [] 63 order = 1 64 65 @register_to_config 66 def __init__( 67 self, 68 num_train_timesteps: int = 1000, 69 shift: float = 1.0, 70 use_dynamic_shifting=False, 71 base_shift: Optional[float] = 0.5, 72 max_shift: Optional[float] = 1.15, 73 base_image_seq_len: Optional[int] = 256, 74 max_image_seq_len: Optional[int] = 4096, 75 sigma_max: Optional[float] = 1.0, 76 ): 77 timesteps = np.linspace( 78 1.0, sigma_max*num_train_timesteps, num_train_timesteps, dtype=np.float32 79 )[::-1].copy() 80 timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) 81 82 sigmas = timesteps / num_train_timesteps 83 if not use_dynamic_shifting: 84 # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution 85 sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) 86 87 self.timesteps = sigmas * num_train_timesteps 88 89 self._step_index = None 90 self._begin_index = None 91 92 self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication 93 self.sigma_min = self.sigmas[-1].item() 94 self.sigma_max = self.sigmas[0].item() 95 96 @property 97 def step_index(self): 98 """ 99 The index counter for current timestep. It will increase 1 after each scheduler step. 100 """ 101 return self._step_index 102 103 @property 104 def begin_index(self): 105 """ 106 The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 107 """ 108 return self._begin_index 109 110 # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index 111 def set_begin_index(self, begin_index: int = 0): 112 """ 113 Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 114 115 Args: 116 begin_index (`int`): 117 The begin index for the scheduler. 118 """ 119 self._begin_index = begin_index 120 121 def scale_noise( 122 self, 123 sample: torch.FloatTensor, 124 timestep: Union[float, torch.FloatTensor], 125 noise: Optional[torch.FloatTensor] = None, 126 ) -> torch.FloatTensor: 127 """ 128 Forward process in flow-matching 129 130 Args: 131 sample (`torch.FloatTensor`): 132 The input sample. 133 timestep (`int`, *optional*): 134 The current timestep in the diffusion chain. 135 136 Returns: 137 `torch.FloatTensor`: 138 A scaled input sample. 139 """ 140 # Make sure sigmas and timesteps have the same device and dtype as original_samples 141 sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) 142 143 if sample.device.type == "mps" and torch.is_floating_point(timestep): 144 # mps does not support float64 145 schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) 146 timestep = timestep.to(sample.device, dtype=torch.float32) 147 else: 148 schedule_timesteps = self.timesteps.to(sample.device) 149 timestep = timestep.to(sample.device) 150 151 # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index 152 if self.begin_index is None: 153 step_indices = [ 154 self.index_for_timestep(t, schedule_timesteps) for t in timestep 155 ] 156 elif self.step_index is not None: 157 # add_noise is called after first denoising step (for inpainting) 158 step_indices = [self.step_index] * timestep.shape[0] 159 else: 160 # add noise is called before first denoising step to create initial latent(img2img) 161 step_indices = [self.begin_index] * timestep.shape[0] 162 163 sigma = sigmas[step_indices].flatten() 164 while len(sigma.shape) < len(sample.shape): 165 sigma = sigma.unsqueeze(-1) 166 167 sample = sigma * noise + (1.0 - sigma) * sample 168 169 return sample 170 171 def _sigma_to_t(self, sigma): 172 return sigma * self.config.num_train_timesteps 173 174 def time_shift(self, mu: float, sigma: float, t: torch.Tensor): 175 return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) 176 177 def set_timesteps( 178 self, 179 num_inference_steps: int = None, 180 device: Union[str, torch.device] = None, 181 sigmas: Optional[List[float]] = None, 182 mu: Optional[float] = None, 183 ): 184 """ 185 Sets the discrete timesteps used for the diffusion chain (to be run before inference). 186 187 Args: 188 num_inference_steps (`int`): 189 The number of diffusion steps used when generating samples with a pre-trained model. 190 device (`str` or `torch.device`, *optional*): 191 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 192 """ 193 194 if self.config.use_dynamic_shifting and mu is None: 195 raise ValueError( 196 " you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" 197 ) 198 199 if sigmas is None: 200 self.num_inference_steps = num_inference_steps 201 timesteps = np.linspace( 202 self._sigma_to_t(self.sigma_max), 203 self._sigma_to_t(self.sigma_min), 204 num_inference_steps, 205 ) 206 207 sigmas = timesteps / self.config.num_train_timesteps 208 209 if self.config.use_dynamic_shifting: 210 sigmas = self.time_shift(mu, 1.0, sigmas) 211 else: 212 sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) 213 214 sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) 215 timesteps = sigmas * self.config.num_train_timesteps 216 217 self.timesteps = timesteps.to(device=device) 218 self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 219 220 self._step_index = None 221 self._begin_index = None 222 223 def index_for_timestep(self, timestep, schedule_timesteps=None): 224 if schedule_timesteps is None: 225 schedule_timesteps = self.timesteps 226 227 indices = (schedule_timesteps == timestep).nonzero() 228 229 # The sigma index that is taken for the **very** first `step` 230 # is always the second index (or the last index if there is only 1) 231 # This way we can ensure we don't accidentally skip a sigma in 232 # case we start in the middle of the denoising schedule (e.g. for image-to-image) 233 pos = 1 if len(indices) > 1 else 0 234 235 return indices[pos].item() 236 237 def _init_step_index(self, timestep): 238 if self.begin_index is None: 239 if isinstance(timestep, torch.Tensor): 240 timestep = timestep.to(self.timesteps.device) 241 self._step_index = self.index_for_timestep(timestep) 242 else: 243 self._step_index = self._begin_index 244 245 def step( 246 self, 247 model_output: torch.FloatTensor, 248 timestep: Union[float, torch.FloatTensor], 249 sample: torch.FloatTensor, 250 s_churn: float = 0.0, 251 s_tmin: float = 0.0, 252 s_tmax: float = float("inf"), 253 s_noise: float = 1.0, 254 generator: Optional[torch.Generator] = None, 255 return_dict: bool = True, 256 omega: Union[float, np.array] = 0.0, 257 ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: 258 """ 259 Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 260 process from the learned model outputs (most often the predicted noise). 261 262 Args: 263 model_output (`torch.FloatTensor`): 264 The direct output from learned diffusion model. 265 timestep (`float`): 266 The current discrete timestep in the diffusion chain. 267 sample (`torch.FloatTensor`): 268 A current instance of a sample created by the diffusion process. 269 s_churn (`float`): 270 s_tmin (`float`): 271 s_tmax (`float`): 272 s_noise (`float`, defaults to 1.0): 273 Scaling factor for noise added to the sample. 274 generator (`torch.Generator`, *optional*): 275 A random number generator. 276 return_dict (`bool`): 277 Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 278 tuple. 279 280 Returns: 281 [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 282 If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 283 returned, otherwise a tuple is returned where the first element is the sample tensor. 284 """ 285 286 def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1): 287 # L = Lower bound 288 # U = Upper bound 289 # x_0 = Midpoint (x corresponding to y = 1.0) 290 # k = Steepness, can adjust based on preference 291 292 if isinstance(x, torch.Tensor): 293 device_ = x.device 294 x = x.to(torch.float).cpu().numpy() 295 296 new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0))) 297 298 if isinstance(new_x, np.ndarray): 299 new_x = torch.from_numpy(new_x).to(device_) 300 return new_x 301 302 self.omega_bef_rescale = omega 303 omega = logistic_function(omega, k=0.1) 304 self.omega_aft_rescale = omega 305 306 if ( 307 isinstance(timestep, int) 308 or isinstance(timestep, torch.IntTensor) 309 or isinstance(timestep, torch.LongTensor) 310 ): 311 raise ValueError( 312 ( 313 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 314 " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 315 " one of the `scheduler.timesteps` as a timestep." 316 ), 317 ) 318 319 if self.step_index is None: 320 self._init_step_index(timestep) 321 322 # Upcast to avoid precision issues when computing prev_sample 323 sample = sample.to(torch.float32) 324 325 sigma = self.sigmas[self.step_index] 326 sigma_next = self.sigmas[self.step_index + 1] 327 328 ## -- 329 ## mean shift 1 330 dx = (sigma_next - sigma) * model_output 331 m = dx.mean() 332 # print(dx.shape) # torch.Size([1, 16, 128, 128]) 333 # print(f'm: {m}') # m: -0.0014209747314453125 334 # raise NotImplementedError 335 dx_ = (dx - m) * omega + m 336 prev_sample = sample + dx_ 337 338 # ## -- 339 # ## mean shift 2 340 # m = model_output.mean() 341 # model_output_ = (model_output - m) * omega + m 342 # prev_sample = sample + (sigma_next - sigma) * model_output_ 343 344 # ## -- 345 # ## original 346 # prev_sample = sample + (sigma_next - sigma) * model_output * omega 347 348 # ## -- 349 # ## spatial mean 1 350 # dx = (sigma_next - sigma) * model_output 351 # m = dx.mean(dim=(0, 1), keepdim=True) 352 # # print(dx.shape) # torch.Size([1, 16, 128, 128]) 353 # # print(m.shape) # torch.Size([1, 1, 128, 128]) 354 # # raise NotImplementedError 355 # dx_ = (dx - m) * omega + m 356 # prev_sample = sample + dx_ 357 358 # ## -- 359 # ## spatial mean 2 360 # m = model_output.mean(dim=(0, 1), keepdim=True) 361 # model_output_ = (model_output - m) * omega + m 362 # prev_sample = sample + (sigma_next - sigma) * model_output_ 363 364 # ## -- 365 # ## channel mean 1 366 # m = model_output.mean(dim=(2, 3), keepdim=True) 367 # # print(m.shape) # torch.Size([1, 16, 1, 1]) 368 # model_output_ = (model_output - m) * omega + m 369 # prev_sample = sample + (sigma_next - sigma) * model_output_ 370 371 # ## -- 372 # ## channel mean 2 373 # dx = (sigma_next - sigma) * model_output 374 # m = dx.mean(dim=(2, 3), keepdim=True) 375 # # print(m.shape) # torch.Size([1, 16, 1, 1]) 376 # dx_ = (dx - m) * omega + m 377 # prev_sample = sample + dx_ 378 379 # ## -- 380 # ## keep sample mean 381 # m_tgt = sample.mean() 382 # prev_sample_ = sample + (sigma_next - sigma) * model_output * omega 383 # m_src = prev_sample_.mean() 384 # prev_sample = prev_sample_ - m_src + m_tgt 385 386 # ## -- 387 # ## test 388 # # print(sample.mean()) 389 # prev_sample = sample + (sigma_next - sigma) * model_output * omega 390 # # raise NotImplementedError 391 392 # Cast sample back to model compatible dtype 393 prev_sample = prev_sample.to(model_output.dtype) 394 395 # upon completion increase step index by one 396 self._step_index += 1 397 398 if not return_dict: 399 return (prev_sample,) 400 401 return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) 402 403 def __len__(self): 404 return self.config.num_train_timesteps
Euler scheduler.
This model inherits from [SchedulerMixin] and [ConfigMixin]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (int, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (str, defaults to "linspace"):
The way the timesteps should be scaled. Refer to Table 2 of the Common Diffusion Noise Schedules and
Sample Steps are Flawed for more information.
shift (float, defaults to 1.0):
The shift value for the timestep schedule.
65 @register_to_config 66 def __init__( 67 self, 68 num_train_timesteps: int = 1000, 69 shift: float = 1.0, 70 use_dynamic_shifting=False, 71 base_shift: Optional[float] = 0.5, 72 max_shift: Optional[float] = 1.15, 73 base_image_seq_len: Optional[int] = 256, 74 max_image_seq_len: Optional[int] = 4096, 75 sigma_max: Optional[float] = 1.0, 76 ): 77 timesteps = np.linspace( 78 1.0, sigma_max*num_train_timesteps, num_train_timesteps, dtype=np.float32 79 )[::-1].copy() 80 timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) 81 82 sigmas = timesteps / num_train_timesteps 83 if not use_dynamic_shifting: 84 # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution 85 sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) 86 87 self.timesteps = sigmas * num_train_timesteps 88 89 self._step_index = None 90 self._begin_index = None 91 92 self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication 93 self.sigma_min = self.sigmas[-1].item() 94 self.sigma_max = self.sigmas[0].item()
96 @property 97 def step_index(self): 98 """ 99 The index counter for current timestep. It will increase 1 after each scheduler step. 100 """ 101 return self._step_index
The index counter for current timestep. It will increase 1 after each scheduler step.
103 @property 104 def begin_index(self): 105 """ 106 The index for the first timestep. It should be set from pipeline with `set_begin_index` method. 107 """ 108 return self._begin_index
The index for the first timestep. It should be set from pipeline with set_begin_index method.
111 def set_begin_index(self, begin_index: int = 0): 112 """ 113 Sets the begin index for the scheduler. This function should be run from pipeline before the inference. 114 115 Args: 116 begin_index (`int`): 117 The begin index for the scheduler. 118 """ 119 self._begin_index = begin_index
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (int):
The begin index for the scheduler.
121 def scale_noise( 122 self, 123 sample: torch.FloatTensor, 124 timestep: Union[float, torch.FloatTensor], 125 noise: Optional[torch.FloatTensor] = None, 126 ) -> torch.FloatTensor: 127 """ 128 Forward process in flow-matching 129 130 Args: 131 sample (`torch.FloatTensor`): 132 The input sample. 133 timestep (`int`, *optional*): 134 The current timestep in the diffusion chain. 135 136 Returns: 137 `torch.FloatTensor`: 138 A scaled input sample. 139 """ 140 # Make sure sigmas and timesteps have the same device and dtype as original_samples 141 sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) 142 143 if sample.device.type == "mps" and torch.is_floating_point(timestep): 144 # mps does not support float64 145 schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) 146 timestep = timestep.to(sample.device, dtype=torch.float32) 147 else: 148 schedule_timesteps = self.timesteps.to(sample.device) 149 timestep = timestep.to(sample.device) 150 151 # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index 152 if self.begin_index is None: 153 step_indices = [ 154 self.index_for_timestep(t, schedule_timesteps) for t in timestep 155 ] 156 elif self.step_index is not None: 157 # add_noise is called after first denoising step (for inpainting) 158 step_indices = [self.step_index] * timestep.shape[0] 159 else: 160 # add noise is called before first denoising step to create initial latent(img2img) 161 step_indices = [self.begin_index] * timestep.shape[0] 162 163 sigma = sigmas[step_indices].flatten() 164 while len(sigma.shape) < len(sample.shape): 165 sigma = sigma.unsqueeze(-1) 166 167 sample = sigma * noise + (1.0 - sigma) * sample 168 169 return sample
Forward process in flow-matching
Args:
sample (torch.FloatTensor):
The input sample.
timestep (int, optional):
The current timestep in the diffusion chain.
Returns:
torch.FloatTensor:
A scaled input sample.
177 def set_timesteps( 178 self, 179 num_inference_steps: int = None, 180 device: Union[str, torch.device] = None, 181 sigmas: Optional[List[float]] = None, 182 mu: Optional[float] = None, 183 ): 184 """ 185 Sets the discrete timesteps used for the diffusion chain (to be run before inference). 186 187 Args: 188 num_inference_steps (`int`): 189 The number of diffusion steps used when generating samples with a pre-trained model. 190 device (`str` or `torch.device`, *optional*): 191 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. 192 """ 193 194 if self.config.use_dynamic_shifting and mu is None: 195 raise ValueError( 196 " you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" 197 ) 198 199 if sigmas is None: 200 self.num_inference_steps = num_inference_steps 201 timesteps = np.linspace( 202 self._sigma_to_t(self.sigma_max), 203 self._sigma_to_t(self.sigma_min), 204 num_inference_steps, 205 ) 206 207 sigmas = timesteps / self.config.num_train_timesteps 208 209 if self.config.use_dynamic_shifting: 210 sigmas = self.time_shift(mu, 1.0, sigmas) 211 else: 212 sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) 213 214 sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) 215 timesteps = sigmas * self.config.num_train_timesteps 216 217 self.timesteps = timesteps.to(device=device) 218 self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) 219 220 self._step_index = None 221 self._begin_index = None
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (int):
The number of diffusion steps used when generating samples with a pre-trained model.
device (str or torch.device, optional):
The device to which the timesteps should be moved to. If None, the timesteps are not moved.
223 def index_for_timestep(self, timestep, schedule_timesteps=None): 224 if schedule_timesteps is None: 225 schedule_timesteps = self.timesteps 226 227 indices = (schedule_timesteps == timestep).nonzero() 228 229 # The sigma index that is taken for the **very** first `step` 230 # is always the second index (or the last index if there is only 1) 231 # This way we can ensure we don't accidentally skip a sigma in 232 # case we start in the middle of the denoising schedule (e.g. for image-to-image) 233 pos = 1 if len(indices) > 1 else 0 234 235 return indices[pos].item()
245 def step( 246 self, 247 model_output: torch.FloatTensor, 248 timestep: Union[float, torch.FloatTensor], 249 sample: torch.FloatTensor, 250 s_churn: float = 0.0, 251 s_tmin: float = 0.0, 252 s_tmax: float = float("inf"), 253 s_noise: float = 1.0, 254 generator: Optional[torch.Generator] = None, 255 return_dict: bool = True, 256 omega: Union[float, np.array] = 0.0, 257 ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: 258 """ 259 Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 260 process from the learned model outputs (most often the predicted noise). 261 262 Args: 263 model_output (`torch.FloatTensor`): 264 The direct output from learned diffusion model. 265 timestep (`float`): 266 The current discrete timestep in the diffusion chain. 267 sample (`torch.FloatTensor`): 268 A current instance of a sample created by the diffusion process. 269 s_churn (`float`): 270 s_tmin (`float`): 271 s_tmax (`float`): 272 s_noise (`float`, defaults to 1.0): 273 Scaling factor for noise added to the sample. 274 generator (`torch.Generator`, *optional*): 275 A random number generator. 276 return_dict (`bool`): 277 Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or 278 tuple. 279 280 Returns: 281 [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: 282 If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is 283 returned, otherwise a tuple is returned where the first element is the sample tensor. 284 """ 285 286 def logistic_function(x, L=0.9, U=1.1, x_0=0.0, k=1): 287 # L = Lower bound 288 # U = Upper bound 289 # x_0 = Midpoint (x corresponding to y = 1.0) 290 # k = Steepness, can adjust based on preference 291 292 if isinstance(x, torch.Tensor): 293 device_ = x.device 294 x = x.to(torch.float).cpu().numpy() 295 296 new_x = L + (U - L) / (1 + np.exp(-k * (x - x_0))) 297 298 if isinstance(new_x, np.ndarray): 299 new_x = torch.from_numpy(new_x).to(device_) 300 return new_x 301 302 self.omega_bef_rescale = omega 303 omega = logistic_function(omega, k=0.1) 304 self.omega_aft_rescale = omega 305 306 if ( 307 isinstance(timestep, int) 308 or isinstance(timestep, torch.IntTensor) 309 or isinstance(timestep, torch.LongTensor) 310 ): 311 raise ValueError( 312 ( 313 "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" 314 " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" 315 " one of the `scheduler.timesteps` as a timestep." 316 ), 317 ) 318 319 if self.step_index is None: 320 self._init_step_index(timestep) 321 322 # Upcast to avoid precision issues when computing prev_sample 323 sample = sample.to(torch.float32) 324 325 sigma = self.sigmas[self.step_index] 326 sigma_next = self.sigmas[self.step_index + 1] 327 328 ## -- 329 ## mean shift 1 330 dx = (sigma_next - sigma) * model_output 331 m = dx.mean() 332 # print(dx.shape) # torch.Size([1, 16, 128, 128]) 333 # print(f'm: {m}') # m: -0.0014209747314453125 334 # raise NotImplementedError 335 dx_ = (dx - m) * omega + m 336 prev_sample = sample + dx_ 337 338 # ## -- 339 # ## mean shift 2 340 # m = model_output.mean() 341 # model_output_ = (model_output - m) * omega + m 342 # prev_sample = sample + (sigma_next - sigma) * model_output_ 343 344 # ## -- 345 # ## original 346 # prev_sample = sample + (sigma_next - sigma) * model_output * omega 347 348 # ## -- 349 # ## spatial mean 1 350 # dx = (sigma_next - sigma) * model_output 351 # m = dx.mean(dim=(0, 1), keepdim=True) 352 # # print(dx.shape) # torch.Size([1, 16, 128, 128]) 353 # # print(m.shape) # torch.Size([1, 1, 128, 128]) 354 # # raise NotImplementedError 355 # dx_ = (dx - m) * omega + m 356 # prev_sample = sample + dx_ 357 358 # ## -- 359 # ## spatial mean 2 360 # m = model_output.mean(dim=(0, 1), keepdim=True) 361 # model_output_ = (model_output - m) * omega + m 362 # prev_sample = sample + (sigma_next - sigma) * model_output_ 363 364 # ## -- 365 # ## channel mean 1 366 # m = model_output.mean(dim=(2, 3), keepdim=True) 367 # # print(m.shape) # torch.Size([1, 16, 1, 1]) 368 # model_output_ = (model_output - m) * omega + m 369 # prev_sample = sample + (sigma_next - sigma) * model_output_ 370 371 # ## -- 372 # ## channel mean 2 373 # dx = (sigma_next - sigma) * model_output 374 # m = dx.mean(dim=(2, 3), keepdim=True) 375 # # print(m.shape) # torch.Size([1, 16, 1, 1]) 376 # dx_ = (dx - m) * omega + m 377 # prev_sample = sample + dx_ 378 379 # ## -- 380 # ## keep sample mean 381 # m_tgt = sample.mean() 382 # prev_sample_ = sample + (sigma_next - sigma) * model_output * omega 383 # m_src = prev_sample_.mean() 384 # prev_sample = prev_sample_ - m_src + m_tgt 385 386 # ## -- 387 # ## test 388 # # print(sample.mean()) 389 # prev_sample = sample + (sigma_next - sigma) * model_output * omega 390 # # raise NotImplementedError 391 392 # Cast sample back to model compatible dtype 393 prev_sample = prev_sample.to(model_output.dtype) 394 395 # upon completion increase step index by one 396 self._step_index += 1 397 398 if not return_dict: 399 return (prev_sample,) 400 401 return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise).
Args:
model_output (torch.FloatTensor):
The direct output from learned diffusion model.
timestep (float):
The current discrete timestep in the diffusion chain.
sample (torch.FloatTensor):
A current instance of a sample created by the diffusion process.
s_churn (float):
s_tmin (float):
s_tmax (float):
s_noise (float, defaults to 1.0):
Scaling factor for noise added to the sample.
generator (torch.Generator, *optional*):
A random number generator.
return_dict (bool):
Whether or not to return a [~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput] or
tuple.
Returns:
[~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput] or tuple:
If return_dict is True, [~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput] is
returned, otherwise a tuple is returned where the first element is the sample tensor.