Source code for ncsn.unet.unet_2d_ncsn

import math
from typing import Optional, Tuple, Union

import torch
from diffusers import UNet2DModel
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin


[docs] class UNet2DModelForNCSN(UNet2DModel, ModelMixin, ConfigMixin): # type: ignore[misc] r"""A 2D UNet model for Noise Conditional Score Networks (NCSN). This model inherits from :py:class:`~diffusers.UNet2DModel`, which is a 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. This model also inherits from :py:class:`~diffusers.ModelMixin`. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Args: sigma_min (`float`): Minimum standard deviation for the isotropic Gaussian noise. sigma_max (`float`): Maximum standard deviation for the isotropic Gaussian noise. sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - 1)`. in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip sin to cos for Fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block types. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): Tuple of block output channels. layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. downsample_type (`str`, *optional*, defaults to `conv`): The downsample type for downsampling layers. Choose between "conv" and "resnet" upsample_type (`str`, *optional*, defaults to `conv`): The upsample type for upsampling layers. Choose between "conv" and "resnet" dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. attn_norm_num_groups (`int`, *optional*, defaults to `None`): If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the given number of groups. If left as `None`, the group norm layer will only be created if `resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups. norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see :py:class:`~diffusers.ResnetBlock2D`). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class conditioning with `class_embed_type` equal to `None`. """ sigmas: torch.Tensor
[docs] @register_to_config def __init__( self, sigma_min: float, sigma_max: float, num_train_timesteps: int, sample_size: Optional[Union[int, Tuple[int, int]]] = None, in_channels: int = 3, out_channels: int = 3, center_input_sample: bool = False, time_embedding_type: str = "positional", freq_shift: int = 0, flip_sin_to_cos: bool = True, down_block_types: Tuple[str, ...] = ( "DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", ), up_block_types: Tuple[str, ...] = ( "AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D", ), block_out_channels: Tuple[int, ...] = (224, 448, 672, 896), layers_per_block: int = 2, mid_block_scale_factor: float = 1, downsample_padding: int = 1, downsample_type: str = "conv", upsample_type: str = "conv", dropout: float = 0.0, act_fn: str = "silu", attention_head_dim: Optional[int] = 8, norm_num_groups: int = 32, attn_norm_num_groups: Optional[int] = None, norm_eps: float = 1e-5, resnet_time_scale_shift: str = "default", add_attention: bool = True, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, ) -> None: super().__init__( sample_size=sample_size, in_channels=in_channels, out_channels=out_channels, center_input_sample=center_input_sample, time_embedding_type=time_embedding_type, freq_shift=freq_shift, flip_sin_to_cos=flip_sin_to_cos, down_block_types=down_block_types, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, mid_block_scale_factor=mid_block_scale_factor, downsample_padding=downsample_padding, downsample_type=downsample_type, upsample_type=upsample_type, dropout=dropout, act_fn=act_fn, attention_head_dim=attention_head_dim, norm_num_groups=norm_num_groups, attn_norm_num_groups=attn_norm_num_groups, norm_eps=norm_eps, resnet_time_scale_shift=resnet_time_scale_shift, add_attention=add_attention, class_embed_type=class_embed_type, num_class_embeds=num_class_embeds, # TODO: If you specify num_train_timesteps here, the quality of the generated product will suddenly decrease ;( # num_train_timesteps=num_train_timesteps, ) sigmas = torch.exp( torch.linspace( start=math.log(sigma_max), end=math.log(sigma_min), steps=num_train_timesteps, ) ) self.register_buffer("sigmas", sigmas)