import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from diffusers.configuration_utils import (
ConfigMixin,
register_to_config,
)
from diffusers.schedulers.scheduling_utils import (
SchedulerMixin,
SchedulerOutput,
)
from einops import rearrange
[docs]
@dataclass
class AnnealedLangevinDynamicsOutput(SchedulerOutput):
r"""Annealed Langevin Dynamics output class."""
[docs]
class AnnealedLangevinDynamicsScheduler(SchedulerMixin, ConfigMixin): # type: ignore
r"""Annealed Langevin Dynamics scheduler for Noise Conditional Score Networks (NCSN).
This scheduler inherits from :py:class:`~diffusers.SchedulerMixin`. Check the superclass documentation for it's generic methods implemented for all schedulers (such as downloading or saving).
Args:
num_train_timesteps (`int`): Number of training timesteps.
num_annealed_steps (`int`): Number of annealed steps.
sigma_min (`float`): Minimum standard deviation for the isotropic Gaussian noise.
sigma_max (`float`): Maximum standard deviation for the isotropic Gaussian noise.
sampling_eps (`float`): Sampling epsilon for the Langevin dynamics.
"""
order = 1
[docs]
@register_to_config
def __init__(
self,
num_train_timesteps: int,
num_annealed_steps: int,
sigma_min: float,
sigma_max: float,
sampling_eps: float,
) -> None:
self.num_train_timesteps = num_train_timesteps
self.num_annealed_steps = num_annealed_steps
self._sigma_min = sigma_min
self._sigma_max = sigma_max
self._sampling_eps = sampling_eps
self._sigmas: Optional[torch.Tensor] = None
self._step_size: Optional[torch.Tensor] = None
self._timesteps: Optional[torch.Tensor] = None
self.set_sigmas(num_inference_steps=num_train_timesteps)
@property
def sigmas(self) -> torch.Tensor:
assert self._sigmas is not None
return self._sigmas
@property
def step_size(self) -> torch.Tensor:
assert self._step_size is not None
return self._step_size
@property
def timesteps(self) -> torch.Tensor:
assert self._timesteps is not None
return self._timesteps
def scale_model_input(
self, sample: torch.Tensor, timestep: Optional[int] = None
) -> torch.Tensor:
return sample
[docs]
def set_timesteps(
self,
num_inference_steps: int,
sampling_eps: Optional[float] = None,
device: Optional[Union[str, torch.device]] = None,
) -> None:
r"""Sets the timesteps for the scheduler (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
Defined to maintain compatibility with other pipelines, but this argument is not actually used.
sampling_eps (`float`, *optional*):
The sampling epsilon for the Langevin dynamics. If `None`, the default value is used.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed,
`num_inference_steps` must be `None`.
"""
sampling_eps = sampling_eps or self._sampling_eps
self._timesteps = torch.arange(start=0, end=num_inference_steps)
[docs]
def set_sigmas(
self,
num_inference_steps: int,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
sampling_eps: Optional[float] = None,
) -> None:
r"""Sets the sigmas and step sizes for the scheduler (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`sigmas` and `step_size` must be `None`.
sigma_min (`float`, *optional*):
The minimum standard deviation for the isotropic Gaussian noise. If `None`, the default value is used.
sigma_max (`float`, *optional*):
The maximum standard deviation for the isotropic Gaussian noise. If `None`, the default value is used.
sampling_eps (`float`, *optional*):
The sampling epsilon for the Langevin dynamics. If `None`, the default value is used.
"""
if self._timesteps is None:
self.set_timesteps(
num_inference_steps=num_inference_steps,
sampling_eps=sampling_eps,
)
sigma_min = sigma_min or self._sigma_min
sigma_max = sigma_max or self._sigma_max
self._sigmas = torch.exp(
torch.linspace(
start=math.log(sigma_max),
end=math.log(sigma_min),
steps=num_inference_steps,
)
)
sampling_eps = sampling_eps or self._sampling_eps
self._step_size = sampling_eps * (self.sigmas / self.sigmas[-1]) ** 2
[docs]
def step(
self,
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
return_dict: bool = True,
**kwargs,
) -> Union[AnnealedLangevinDynamicsOutput, Tuple]:
r"""Perform one step following Langevin dynamics. Annealing must be done separately.
Args:
model_output (`torch.Tensor`):
The score output from learned neural network-based score function.
timestep (`int`):
The current timestep.
sample (`torch.Tensor`):
The current sample.
return_dict (`bool`, *optional*):
Whether or not to return :py:class:`~ncsn.scheduler.AnnealedLangevinDynamicsOutput` or `tuple`.
Returns:
:py:class:`~ncsn.scheduler.AnnealedLangevinDynamicsOutput` or `tuple`:
if `return_dict` is `True`, :py:class:`~ncsn.scheduler.AnnealedLangevinDynamicsOutput` is returned,
otherwise a tuple is returned where the first element is the updated sample.
"""
z = torch.randn_like(sample)
step_size = self.step_size[timestep]
sample = sample + 0.5 * step_size * model_output + torch.sqrt(step_size) * z
if return_dict:
return AnnealedLangevinDynamicsOutput(prev_sample=sample)
else:
return (sample,)
[docs]
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
r"""Add noise to the original samples.
Args:
original_samples (`torch.Tensor`):
The original samples.
noise (`torch.Tensor`):
The noise to be added.
timesteps (`torch.Tensor`):
The timesteps.
Returns:
`torch.Tensor`:
The noisy samples.
"""
timesteps = timesteps.to(original_samples.device)
sigmas = self.sigmas.to(original_samples.device)[timesteps]
sigmas = rearrange(sigmas, "b -> b 1 1 1")
noisy_samples = original_samples + noise * sigmas
return noisy_samples