import math
from typing import Optional, Tuple, Union
import torch
from diffusers import UNet2DModel
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
[docs]
class UNet2DModelForNCSN(UNet2DModel, ModelMixin, ConfigMixin): # type: ignore[misc]
r"""A 2D UNet model for Noise Conditional Score Networks (NCSN).
This model inherits from :py:class:`~diffusers.UNet2DModel`, which is a 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. This model also inherits from :py:class:`~diffusers.ModelMixin`. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving).
Args:
sigma_min (`float`): Minimum standard deviation for the isotropic Gaussian noise.
sigma_max (`float`): Maximum standard deviation for the isotropic Gaussian noise.
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) -
1)`.
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip sin to cos for Fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
Tuple of block output channels.
layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
downsample_type (`str`, *optional*, defaults to `conv`):
The downsample type for downsampling layers. Choose between "conv" and "resnet"
upsample_type (`str`, *optional*, defaults to `conv`):
The upsample type for upsampling layers. Choose between "conv" and "resnet"
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the
given number of groups. If left as `None`, the group norm layer will only be created if
`resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see :py:class:`~diffusers.ResnetBlock2D`). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class
conditioning with `class_embed_type` equal to `None`.
"""
sigmas: torch.Tensor
[docs]
@register_to_config
def __init__(
self,
sigma_min: float,
sigma_max: float,
num_train_timesteps: int,
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
in_channels: int = 3,
out_channels: int = 3,
center_input_sample: bool = False,
time_embedding_type: str = "positional",
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = (
"DownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
),
up_block_types: Tuple[str, ...] = (
"AttnUpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
),
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
layers_per_block: int = 2,
mid_block_scale_factor: float = 1,
downsample_padding: int = 1,
downsample_type: str = "conv",
upsample_type: str = "conv",
dropout: float = 0.0,
act_fn: str = "silu",
attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32,
attn_norm_num_groups: Optional[int] = None,
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
add_attention: bool = True,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
) -> None:
super().__init__(
sample_size=sample_size,
in_channels=in_channels,
out_channels=out_channels,
center_input_sample=center_input_sample,
time_embedding_type=time_embedding_type,
freq_shift=freq_shift,
flip_sin_to_cos=flip_sin_to_cos,
down_block_types=down_block_types,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
mid_block_scale_factor=mid_block_scale_factor,
downsample_padding=downsample_padding,
downsample_type=downsample_type,
upsample_type=upsample_type,
dropout=dropout,
act_fn=act_fn,
attention_head_dim=attention_head_dim,
norm_num_groups=norm_num_groups,
attn_norm_num_groups=attn_norm_num_groups,
norm_eps=norm_eps,
resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=add_attention,
class_embed_type=class_embed_type,
num_class_embeds=num_class_embeds,
# TODO: If you specify num_train_timesteps here, the quality of the generated product will suddenly decrease ;(
# num_train_timesteps=num_train_timesteps,
)
sigmas = torch.exp(
torch.linspace(
start=math.log(sigma_max),
end=math.log(sigma_min),
steps=num_train_timesteps,
)
)
self.register_buffer("sigmas", sigmas)