未验证 提交 a6790a65 编写于 作者: C chenjian 提交者: GitHub

Add stable diffusion module

上级 1007c416
# stable_diffusion
|模型名称|stable_diffusion|
| :--- | :---: |
|类别|多模态-文图生成|
|网络|CLIP Text Encoder+UNet+VAD|
|数据集|-|
|是否支持Fine-tuning|否|
|模型大小|4.0GB|
|最新更新日期|2022-08-26|
|数据指标|-|
## 一、模型基本信息
### 应用效果展示
- 输入文本 "in the morning light,Overlooking TOKYO city by greg rutkowski and thomas kinkade,Trending on artstation."
- 输出图像
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/186873437-2e426acd-7656-4d37-9ee4-8cafa48f097f.png" width = "80%" hspace='10'/>
<br />
- 生成过程
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/186873216-d2a9761a-78b0-4f6a-97ec-919768f324f5.gif" width = "80%" hspace='10'/>
<br />
### 模型介绍
Stable Diffusion是一种潜在扩散模型(Latent Diffusion), 属于生成类模型,这类模型通过对随机噪声进行一步步地迭代降噪并采样来获得感兴趣的图像,当前取得了令人惊艳的效果。相比于Disco Diffusion, Stable Diffusion通过在低纬度的潜在空间(lower dimensional latent space)而不是原像素空间来做迭代,极大地降低了内存和计算量的需求,并且在V100上一分钟之内即可以渲染出想要的图像,欢迎体验。
更多详情请参考论文:[High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)
## 二、安装
- ### 1、环境依赖
- paddlepaddle >= 2.0.0
- paddlehub >= 2.0.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst)
- ### 2、安装
- ```shell
$ hub install stable_diffusion
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
- ```shell
$ hub run stable_diffusion --text_prompts "in the morning light,Overlooking TOKYO city by greg rutkowski and thomas kinkade,Trending on artstation." --output_dir stable_diffusion_out
```
- ### 2、预测代码示例
- ```python
import paddlehub as hub
module = hub.Module(name="stable_diffusion")
text_prompts = ["in the morning light,Overlooking TOKYO city by greg rutkowski and thomas kinkade,Trending on artstation."]
# 生成图像, 默认会在stable_diffusion_out目录保存图像
# 返回的da是一个DocumentArray对象,保存了所有的结果,包括最终结果和迭代过程的中间结果
# 可以通过操作DocumentArray对象对生成的图像做后处理,保存或者分析
# 您可以设置batch_size一次生成多张
da = module.generate_image(text_prompts=text_prompts, batch_size=3, output_dir='./stable_diffusion_out/')
# 展示所有的中间结果
da[0].chunks[-1].chunks.plot_image_sprites(skip_empty=True, show_index=True, keep_aspect_ratio=True)
# 将整个生成过程保存为一个动态图gif
da[0].chunks[-1].chunks.save_gif('stable_diffusion_out-merged-result.gif')
# da索引的是prompt, da[0].chunks索引的是该prompt下生成的第一张图,在batch_size不为1时能同时生成多张图
# 您也可以按照上述操作显示单张图,如第0张的生成过程
da[0].chunks[0].chunks.plot_image_sprites(skip_empty=True, show_index=True, keep_aspect_ratio=True)
da[0].chunks[0].chunks.save_gif('stable_diffusion_out-image-0-result.gif')
```
- ### 3、API
- ```python
def generate_image(
text_prompts,
style: Optional[str] = None,
artist: Optional[str] = None,
width_height: Optional[List[int]] = [512, 512],
seed: Optional[int] = None,
batch_size: Optional[int] = 1,
output_dir: Optional[str] = 'stable_diffusion_out'):
```
- 文图生成API,生成文本描述内容的图像。
- **参数**
- text_prompts(str): 输入的语句,描述想要生成的图像的内容。通常比较有效的构造方式为 "一段描述性的文字内容" + "指定艺术家的名字",如"in the morning light,Overlooking TOKYO city by greg rutkowski and thomas kinkade,Trending on artstation."。prompt的构造可以参考[网站](https://docs.google.com/document/d/1XUT2G9LmkZataHFzmuOtRXnuWBfhvXDAo8DkS--8tec/edit#)。
- style(Optional[str]): 指定绘画的风格,如'watercolor','Chinese painting'等。当不指定时,风格完全由您所填写的prompt决定。
- artist(Optional[str]): 指定特定的艺术家,如Greg Rutkowsk、krenz,将会生成所指定艺术家的绘画风格。当不指定时,风格完全由您所填写的prompt决定。各种艺术家的风格可以参考[网站](https://weirdwonderfulai.art/resources/disco-diffusion-70-plus-artist-studies/)。
- width_height(Optional[List[int]]): 指定最终输出图像的宽高,宽和高都需要是64的倍数,生成的图像越大,所需要的计算时间越长。
- seed(Optional[int]): 随机种子,由于输入默认是随机高斯噪声,设置不同的随机种子会由不同的初始输入,从而最终生成不同的结果,可以设置该参数来获得不同的输出图像。
- batch_size(Optional[int]): 指定每个prompt一次生成的图像的数量。
- output_dir(Optional[str]): 保存输出图像的目录,默认为"stable_diffusion_out"。
- **返回**
- ra(DocumentArray): DocumentArray对象, 包含`batch_size`个Documents,其中每个Document都保存了迭代过程的所有中间结果。详细可参考[DocumentArray使用文档](https://docarray.jina.ai/fundamentals/documentarray/index.html)。
## 四、服务部署
- PaddleHub Serving可以部署一个在线文图生成服务。
- ### 第一步:启动PaddleHub Serving
- 运行启动命令:
- ```shell
$ hub serving start -m stable_diffusion
```
- 这样就完成了一个文图生成的在线服务API的部署,默认端口号为8866。
- **NOTE:** 如使用GPU预测,则需要在启动服务之前,请设置CUDA\_VISIBLE\_DEVICES环境变量,否则不用设置。
- ### 第二步:发送预测请求
- 配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果,返回的预测结果在反序列化后即是上述接口声明中说明的DocumentArray类型,返回后对结果的操作方式和使用generate_image接口完全相同。
- ```python
import requests
import json
import cv2
import base64
from docarray import DocumentArray
# 发送HTTP请求
data = {'text_prompts': 'in the morning light,Overlooking TOKYO city by greg rutkowski and thomas kinkade,Trending on artstation.'}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/stable_diffusion"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
# 获取返回结果
r.json()["results"]
da = DocumentArray.from_base64(r.json()["results"])
# 保存结果图
da[0].save_uri_to_file('stable_diffusion_out.png')
# 将生成过程保存为一个动态图gif
da[0].chunks[0].chunks.save_gif('stable_diffusion_out.gif')
```
## 五、更新历史
* 1.0.0
初始发布
```shell
$ hub install stable_diffusion == 1.0.0
```
# OpenAI CLIP implemented in Paddle.
The original implementation repo is [ranchlai/clip.paddle](https://github.com/ranchlai/clip.paddle). We use this repo here for text encoder in stable diffusion.
from typing import Optional
import paddle
import paddle.nn as nn
from paddle import Tensor
from paddle.nn import functional as F
from paddle.nn import Linear
__all__ = ['ResidualAttentionBlock', 'AttentionPool2d', 'multi_head_attention_forward', 'MultiHeadAttention']
def multi_head_attention_forward(x: Tensor,
num_heads: int,
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
c_proj: Linear,
attn_mask: Optional[Tensor] = None):
max_len, batch_size, emb_dim = x.shape
head_dim = emb_dim // num_heads
scaling = float(head_dim)**-0.5
q = q_proj(x) # L, N, E
k = k_proj(x) # L, N, E
v = v_proj(x) # L, N, E
#k = k.con
v = v.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
k = k.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
q = q.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
q = q * scaling
qk = paddle.bmm(q, k.transpose((0, 2, 1)))
if attn_mask is not None:
if attn_mask.ndim == 2:
attn_mask.unsqueeze_(0)
#assert str(attn_mask.dtype) == 'VarType.FP32' and attn_mask.ndim == 3
assert attn_mask.shape[0] == 1 and attn_mask.shape[1] == max_len and attn_mask.shape[2] == max_len
qk += attn_mask
qk = paddle.nn.functional.softmax(qk, axis=-1)
atten = paddle.bmm(qk, v)
atten = atten.transpose((1, 0, 2))
atten = atten.reshape((max_len, batch_size, emb_dim))
atten = c_proj(atten)
return atten
class MultiHeadAttention(nn.Layer): # without attention mask
def __init__(self, emb_dim: int, num_heads: int):
super().__init__()
self.q_proj = nn.Linear(emb_dim, emb_dim, bias_attr=True)
self.k_proj = nn.Linear(emb_dim, emb_dim, bias_attr=True)
self.v_proj = nn.Linear(emb_dim, emb_dim, bias_attr=True)
self.c_proj = nn.Linear(emb_dim, emb_dim, bias_attr=True)
self.head_dim = emb_dim // num_heads
self.emb_dim = emb_dim
self.num_heads = num_heads
assert self.head_dim * num_heads == emb_dim, "embed_dim must be divisible by num_heads"
#self.scaling = float(self.head_dim) ** -0.5
def forward(self, x, attn_mask=None): # x is in shape[max_len,batch_size,emb_dim]
atten = multi_head_attention_forward(x,
self.num_heads,
self.q_proj,
self.k_proj,
self.v_proj,
self.c_proj,
attn_mask=attn_mask)
return atten
class Identity(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, x):
return x
class Bottleneck(nn.Layer):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2D(inplanes, planes, 1, bias_attr=False)
self.bn1 = nn.BatchNorm2D(planes)
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(planes)
self.avgpool = nn.AvgPool2D(stride) if stride > 1 else Identity()
self.conv3 = nn.Conv2D(planes, planes * self.expansion, 1, bias_attr=False)
self.bn3 = nn.BatchNorm2D(planes * self.expansion)
self.relu = nn.ReLU()
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
self.downsample = nn.Sequential(
("-1", nn.AvgPool2D(stride)),
("0", nn.Conv2D(inplanes, planes * self.expansion, 1, stride=1, bias_attr=False)),
("1", nn.BatchNorm2D(planes * self.expansion)))
def forward(self, x):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Layer):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = paddle.create_parameter((spacial_dim**2 + 1, embed_dim), dtype='float32')
self.q_proj = nn.Linear(embed_dim, embed_dim, bias_attr=True)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias_attr=True)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias_attr=True)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim, bias_attr=True)
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
def forward(self, x):
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])).transpose((2, 0, 1)) # NCHW -> (HW)NC
max_len, batch_size, emb_dim = x.shape
head_dim = self.head_dim
x = paddle.concat([paddle.mean(x, axis=0, keepdim=True), x], axis=0)
x = x + paddle.unsqueeze(self.positional_embedding, 1)
out = multi_head_attention_forward(x, self.num_heads, self.q_proj, self.k_proj, self.v_proj, self.c_proj)
return out[0]
class QuickGELU(nn.Layer):
def forward(self, x):
return x * paddle.nn.functional.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Layer):
def __init__(self, d_model: int, n_head: int, attn_mask=None):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_head)
self.ln_1 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model)))
self.ln_2 = nn.LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x):
x = self.attn(x, self.attn_mask)
assert isinstance(x, paddle.Tensor) # not tuble here
return x
def forward(self, x):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
from typing import Tuple
from typing import Union
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import nn
from .layers import AttentionPool2d
from .layers import Bottleneck
from .layers import MultiHeadAttention
from .layers import ResidualAttentionBlock
class ModifiedResNet(nn.Layer):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2D(3, width // 2, kernel_size=3, stride=2, padding=1, bias_attr=False)
self.bn1 = nn.BatchNorm2D(width // 2)
self.conv2 = nn.Conv2D(width // 2, width // 2, kernel_size=3, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(width // 2)
self.conv3 = nn.Conv2D(width // 2, width, kernel_size=3, padding=1, bias_attr=False)
self.bn3 = nn.BatchNorm2D(width)
self.avgpool = nn.AvgPool2D(2)
self.relu = nn.ReLU()
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
#x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class Transformer(nn.Layer):
def __init__(self, width: int, layers: int, heads: int, attn_mask=None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x):
return self.resblocks(x)
class VisualTransformer(nn.Layer):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
# used patch_size x patch_size, stride patch_size to do linear projection
self.conv1 = nn.Conv2D(in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias_attr=False)
# scale = width ** -0.5
self.class_embedding = paddle.create_parameter((width, ), 'float32')
self.positional_embedding = paddle.create_parameter(((input_resolution // patch_size)**2 + 1, width), 'float32')
self.ln_pre = nn.LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = nn.LayerNorm(width)
self.proj = paddle.create_parameter((width, output_dim), 'float32')
def forward(self, x):
x = self.conv1(x)
x = x.reshape((x.shape[0], x.shape[1], -1))
x = x.transpose((0, 2, 1))
x = paddle.concat([self.class_embedding + paddle.zeros((x.shape[0], 1, x.shape[-1]), dtype=x.dtype), x], axis=1)
x = x + self.positional_embedding
x = self.ln_pre(x)
x = x.transpose((1, 0, 2))
x = self.transformer(x)
x = x.transpose((1, 0, 2))
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = paddle.matmul(x, self.proj)
return x
class TextTransformer(nn.Layer):
def __init__(self, context_length: int, vocab_size: int, transformer_width: int, transformer_heads: int,
transformer_layers: int):
super().__init__()
self.context_length = context_length
self.transformer = Transformer(width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask())
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = paddle.create_parameter((self.context_length, transformer_width), 'float32')
self.ln_final = nn.LayerNorm(transformer_width)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# mask = paddle.empty((self.context_length, self.context_length),dtype='float32')
# mask.fill_(float("-inf"))
#mask.triu_(1) # zero out the lower diagonal
mask = paddle.ones((self.context_length, self.context_length)) * float("-inf")
mask = paddle.triu(mask, diagonal=1)
return mask
def forward(self, text):
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.transpose((1, 0, 2)) # NLD -> LND
x = self.transformer(x)
x = x.transpose((1, 0, 2)) # LND -> NLD
x = self.ln_final(x)
return x
class CLIP(nn.Layer):
def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int):
super().__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width)
else:
vision_heads = vision_width // 64
self.visual = VisualTransformer(input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim)
self.transformer = Transformer(width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask())
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = paddle.create_parameter((self.context_length, transformer_width), 'float32')
self.ln_final = nn.LayerNorm(transformer_width)
self.text_projection = paddle.create_parameter((transformer_width, embed_dim), 'float32')
self.logit_scale = paddle.create_parameter((1, ), 'float32')
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# mask = paddle.empty((self.context_length, self.context_length),dtype='float32')
# mask.fill_(float("-inf"))
#mask.triu_(1) # zero out the lower diagonal
mask = paddle.ones((self.context_length, self.context_length)) * float("-inf")
mask = paddle.triu(mask, diagonal=1)
return mask
def encode_image(self, image):
return self.visual(image)
def encode_text(self, text):
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.transpose((1, 0, 2)) # NLD -> LND
x = self.transformer(x)
x = x.transpose((1, 0, 2)) # LND -> NLD
x = self.ln_final(x)
idx = text.numpy().argmax(-1)
idx = list(idx)
x = [x[i:i + 1, int(j), :] for i, j in enumerate(idx)]
x = paddle.concat(x, 0)
x = paddle.matmul(x, self.text_projection)
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = paddle.matmul(logit_scale * image_features, text_features.t())
logits_per_text = paddle.matmul(logit_scale * text_features, image_features.t())
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "../assets/bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + '</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + '</w>', )
pairs = get_pairs(word)
if not pairs:
return token + '</w>'
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text
import os
from typing import List
from typing import Union
import numpy as np
import paddle
from paddle.utils import download
from paddle.vision.transforms import CenterCrop
from paddle.vision.transforms import Compose
from paddle.vision.transforms import Normalize
from paddle.vision.transforms import Resize
from paddle.vision.transforms import ToTensor
from .model import CLIP
from .model import TextTransformer
from .simple_tokenizer import SimpleTokenizer
__all__ = ['transform', 'tokenize', 'build_model']
MODEL_NAMES = ['VITL14']
URL = {'VITL14': os.path.join(os.path.dirname(__file__), 'pre_trained', 'vitl14_textencoder.pdparams')}
MEAN, STD = (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
_tokenizer = SimpleTokenizer()
transform = Compose([
Resize(224, interpolation='bicubic'),
CenterCrop(224), lambda image: image.convert('RGB'),
ToTensor(),
Normalize(mean=MEAN, std=STD), lambda t: t.unsqueeze_(0)
])
def tokenize(texts: Union[str, List[str]], context_length: int = 77):
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = paddle.zeros((len(all_tokens), context_length), dtype='int64')
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = paddle.Tensor(np.array(tokens))
return result
def build_model(name='VITL14'):
assert name in MODEL_NAMES, f"model name must be one of {MODEL_NAMES}"
name2model = {'VITL14': build_vitl14_language_model}
model = name2model[name]()
weight = URL[name]
sd = paddle.load(weight)
state_dict = model.state_dict()
for key, value in sd.items():
if key in state_dict:
state_dict[key] = value
model.load_dict(state_dict)
model.eval()
return model
def build_vitl14_language_model():
model = TextTransformer(context_length=77,
vocab_size=49408,
transformer_width=768,
transformer_heads=12,
transformer_layers=12)
return model
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.4"
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .schedulers import (DDIMScheduler, DDPMScheduler, KarrasVeScheduler, PNDMScheduler, SchedulerMixin,
ScoreSdeVeScheduler, LMSDiscreteScheduler)
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" ConfigMixinuration base class and utilities."""
import functools
import inspect
import json
import os
import re
from collections import OrderedDict
from typing import Any
from typing import Dict
from typing import Tuple
from typing import Union
from requests import HTTPError
from paddlehub.common.logger import logger
HUGGINGFACE_CO_RESOLVE_ENDPOINT = "HUGGINGFACE_CO_RESOLVE_ENDPOINT"
DIFFUSERS_CACHE = "./caches"
_re_configuration_file = re.compile(r"config\.(.*)\.json")
class ConfigMixin:
r"""
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
methods for loading/downloading/saving configurations.
"""
config_name = "model_config.json"
ignore_for_config = []
def register_to_config(self, **kwargs):
if self.config_name is None:
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
kwargs["_class_name"] = self.__class__.__name__
kwargs["_diffusers_version"] = "0.0.1"
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
if not hasattr(self, "_internal_dict"):
internal_dict = kwargs
else:
previous_dict = dict(self._internal_dict)
internal_dict = {**self._internal_dict, **kwargs}
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
self._internal_dict = FrozenDict(internal_dict)
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~ConfigMixin.from_config`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Directory where the configuration JSON file will be saved (will be created if it does not exist).
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
os.makedirs(save_directory, exist_ok=True)
# If we save using the predefined names, we can load using `from_config`
output_config_file = os.path.join(save_directory, self.config_name)
self.to_json_file(output_config_file)
logger.info(f"ConfigMixinuration saved in {output_config_file}")
@classmethod
def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
model = cls(**init_dict)
if return_unused_kwargs:
return model, unused_kwargs
else:
return model
@classmethod
def get_config_dict(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = {"file_type": "config"}
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if cls.config_name is None:
raise ValueError(
"`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`")
if os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
# Load from a PyTorch checkpoint
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)):
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
else:
raise EnvironmentError(
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}.")
else:
try:
# Load from URL or cache if already cached
from huggingface_hub import hf_hub_download
config_file = hf_hub_download(
pretrained_model_name_or_path,
filename=cls.config_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
)
except HTTPError as err:
raise EnvironmentError("There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}")
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
" run the library in offline mode at"
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'.")
except EnvironmentError:
raise EnvironmentError(
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {cls.config_name} file")
try:
# Load config dict
config_dict = cls._dict_from_json_file(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
return config_dict
@classmethod
def extract_init_dict(cls, config_dict, **kwargs):
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self")
# remove general kwargs if present in dict
if "kwargs" in expected_keys:
expected_keys.remove("kwargs")
# remove keys to be ignored
if len(cls.ignore_for_config) > 0:
expected_keys = expected_keys - set(cls.ignore_for_config)
init_dict = {}
for key in expected_keys:
if key in kwargs:
# overwrite key
init_dict[key] = kwargs.pop(key)
elif key in config_dict:
# use value from config dict
init_dict[key] = config_dict.pop(key)
unused_kwargs = config_dict.update(kwargs)
passed_keys = set(init_dict.keys())
if len(expected_keys - passed_keys) > 0:
logger.warning(
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values.")
return init_dict, unused_kwargs
@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
@property
def config(self) -> Dict[str, Any]:
return self._internal_dict
def to_json_string(self) -> str:
"""
Serializes this instance to a JSON string.
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string())
class FrozenDict(OrderedDict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for key, value in self.items():
setattr(self, key, value)
self.__frozen = True
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __setattr__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setattr__(name, value)
def __setitem__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setitem__(name, value)
def register_to_config(init):
"""
Decorator to apply on the init of classes inheriting from `ConfigMixin` so that all the arguments are automatically
sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that shouldn't be
registered in the config, use the `ignore_for_config` class variable
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
"""
@functools.wraps(init)
def inner_init(self, *args, **kwargs):
# Ignore private kwargs in the init.
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
init(self, *args, **init_kwargs)
if not isinstance(self, ConfigMixin):
raise RuntimeError(
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
"not inherit from `ConfigMixin`.")
ignore = getattr(self, "ignore_for_config", [])
# Get positional arguments aligned with kwargs
new_kwargs = {}
signature = inspect.signature(init)
parameters = {
name: p.default
for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
}
for arg, name in zip(args, parameters.keys()):
new_kwargs[name] = arg
# Then add all kwargs
new_kwargs.update({
k: init_kwargs.get(k, default)
for k, default in parameters.items() if k not in ignore and k not in new_kwargs
})
getattr(self, "register_to_config")(**new_kwargs)
return inner_init
# Models
- Models: Neural network that models $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$ (see image below) and is trained end-to-end to denoise a noisy input to an image. Examples: UNet, Conditioned UNet, 3D UNet, Transformer UNet
## API
TODO(Suraj, Patrick)
## Examples
TODO(Suraj, Patrick)
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL
from .vae import VQModel
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from inspect import isfunction
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def finfo(dtype):
if dtype == paddle.float32:
return np.finfo(np.float32)
if dtype == paddle.float16:
return np.finfo(np.float16)
if dtype == paddle.float64:
return np.finfo(np.float64)
paddle.finfo = finfo
class AttentionBlockNew(nn.Layer):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Uses three q, k, v linear layers to compute attention
"""
def __init__(
self,
channels,
num_head_channels=None,
num_groups=32,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
self.num_head_size = num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, epsilon=eps)
# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
self.key = nn.Linear(channels, channels)
self.value = nn.Linear(channels, channels)
self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels)
def transpose_for_scores(self, projection: paddle.Tensor) -> paddle.Tensor:
new_projection_shape = projection.shape[:-1] + [self.num_heads, -1]
# move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
new_projection = projection.reshape(new_projection_shape).transpose([0, 2, 1, 3])
return new_projection
def forward(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
# norm
hidden_states = self.group_norm(hidden_states)
hidden_states = hidden_states.reshape([batch, channel, height * width]).transpose([0, 2, 1])
# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)
# transpose
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# get scores
scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
attention_scores = paddle.matmul(query_states * scale, key_states * scale, transpose_y=True)
attention_probs = F.softmax(attention_scores.astype("float32"), axis=-1).astype(attention_scores.dtype)
# compute attention output
context_states = paddle.matmul(attention_probs, value_states)
context_states = context_states.transpose([0, 2, 1, 3])
new_context_states_shape = context_states.shape[:-2] + [
self.channels,
]
context_states = context_states.reshape(new_context_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(context_states)
hidden_states = hidden_states.transpose([0, 2, 1]).reshape([batch, channel, height, width])
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
def set_weight(self, attn_layer):
self.group_norm.weight.set_value(attn_layer.norm.weight)
self.group_norm.bias.set_value(attn_layer.norm.bias)
if hasattr(attn_layer, "q"):
self.query.weight.set_value(attn_layer.q.weight[:, :, 0, 0])
self.key.weight.set_value(attn_layer.k.weight[:, :, 0, 0])
self.value.weight.set_value(attn_layer.v.weight[:, :, 0, 0])
self.query.bias.set_value(attn_layer.q.bias)
self.key.bias.set_value(attn_layer.k.bias)
self.value.bias.set_value(attn_layer.v.bias)
self.proj_attn.weight.set_value(attn_layer.proj_out.weight[:, :, 0, 0])
self.proj_attn.bias.set_value(attn_layer.proj_out.bias)
elif hasattr(attn_layer, "NIN_0"):
self.query.weight.set_value(attn_layer.NIN_0.W.t())
self.key.weight.set_value(attn_layer.NIN_1.W.t())
self.value.weight.set_value(attn_layer.NIN_2.W.t())
self.query.bias.set_value(attn_layer.NIN_0.b)
self.key.bias.set_value(attn_layer.NIN_1.b)
self.value.bias.set_value(attn_layer.NIN_2.b)
self.proj_attn.weight.set_value(attn_layer.NIN_3.W.t())
self.proj_attn.bias.set_value(attn_layer.NIN_3.b)
self.group_norm.weight.set_value(attn_layer.GroupNorm_0.weight)
self.group_norm.bias.set_value(attn_layer.GroupNorm_0.bias)
else:
qkv_weight = attn_layer.qkv.weight.reshape(
[self.num_heads, 3 * self.channels // self.num_heads, self.channels])
qkv_bias = attn_layer.qkv.bias.reshape([self.num_heads, 3 * self.channels // self.num_heads])
q_w, k_w, v_w = qkv_weight.split(self.channels // self.num_heads, axis=1)
q_b, k_b, v_b = qkv_bias.split(self.channels // self.num_heads, axis=1)
self.query.weight.set_value(q_w.reshape([-1, self.channels]))
self.key.weight.set_value(k_w.reshape([-1, self.channels]))
self.value.weight.set_value(v_w.reshape([-1, self.channels]))
self.query.bias.set_value(q_b.flatten())
self.key.bias.set_value(k_b.flatten())
self.value.bias.set_value(v_b.flatten())
self.proj_attn.weight.set_value(attn_layer.proj.weight[:, :, 0])
self.proj_attn.bias.set_value(attn_layer.proj.bias)
class SpatialTransformer(nn.Layer):
"""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, epsilon=1e-6)
self.proj_in = nn.Conv2D(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.transformer_blocks = nn.LayerList([
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)
])
self.proj_out = nn.Conv2D(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = x.transpose([0, 2, 3, 1]).reshape([b, h * w, c])
for block in self.transformer_blocks:
x = block(x, context=context)
x = x.reshape([b, h, w, c]).transpose([0, 3, 1, 2])
x = self.proj_out(x)
return x + x_in
def set_weight(self, layer):
self.norm = layer.norm
self.proj_in = layer.proj_in
self.transformer_blocks = layer.transformer_blocks
self.proj_out = layer.proj_out
class BasicTransformerBlock(nn.Layer):
def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class CrossAttention(nn.Layer):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias_attr=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias_attr=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias_attr=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape([batch_size, seq_len, head_size, dim // head_size])
tensor = tensor.transpose([0, 2, 1, 3]).reshape([batch_size * head_size, seq_len, dim // head_size])
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape([batch_size // head_size, head_size, seq_len, dim])
tensor = tensor.transpose([0, 2, 1, 3]).reshape([batch_size // head_size, seq_len, dim * head_size])
return tensor
def forward(self, x, context=None, mask=None):
batch_size, sequence_length, dim = x.shape
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q = self.reshape_heads_to_batch_dim(q)
k = self.reshape_heads_to_batch_dim(k)
v = self.reshape_heads_to_batch_dim(v)
sim = paddle.einsum("b i d, b j d -> b i j", q * self.scale, k)
if exists(mask):
mask = mask.reshape([batch_size, -1])
max_neg_value = -paddle.finfo(sim.dtype).max
mask = mask[:, None, :].repeat(h, 1, 1)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = F.softmax(sim, axis=-1)
out = paddle.einsum("b i j, b j d -> b i d", attn, v)
out = self.reshape_batch_dim_to_heads(out)
return self.to_out(out)
class FeedForward(nn.Layer):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
# feedforward
class GEGLU(nn.Layer):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, axis=-1)
return x * F.gelu(gate)
# TODO(Patrick) - remove once all weights have been converted -> not needed anymore then
class NIN(nn.Layer):
def __init__(self, in_dim, num_units, init_scale=0.1):
super().__init__()
self.W = self.create_parameter(shape=[in_dim, num_units], default_initializer=nn.initializer.Constant(0.))
self.b = self.create_parameter(shape=[
num_units,
],
is_bias=True,
default_initializer=nn.initializer.Constant(0.))
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
# the main attention block that is used for all models
class AttentionBlock(nn.Layer):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=None,
num_groups=32,
encoder_channels=None,
overwrite_qkv=False,
overwrite_linear=False,
rescale_output_factor=1.0,
eps=1e-5,
):
super().__init__()
self.channels = channels
if num_head_channels is None:
self.num_heads = num_heads
else:
assert (channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, epsilon=eps)
self.qkv = nn.Conv1D(channels, channels * 3, 1)
self.n_heads = self.num_heads
self.rescale_output_factor = rescale_output_factor
if encoder_channels is not None:
self.encoder_kv = nn.Conv1D(encoder_channels, channels * 2, 1)
self.proj = nn.Conv1D(channels, channels, 1)
self.overwrite_qkv = overwrite_qkv
self.overwrite_linear = overwrite_linear
if overwrite_qkv:
in_channels = channels
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, epsilon=1e-6)
self.q = nn.Conv2D(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2D(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2D(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2D(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
elif self.overwrite_linear:
num_groups = min(channels // 4, 32)
self.norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, epsilon=1e-6)
self.NIN_0 = NIN(channels, channels)
self.NIN_1 = NIN(channels, channels)
self.NIN_2 = NIN(channels, channels)
self.NIN_3 = NIN(channels, channels)
self.GroupNorm_0 = nn.GroupNorm(num_groups=num_groups, num_channels=channels, epsilon=1e-6)
else:
self.proj_out = nn.Conv1D(channels, channels, 1)
self.set_weights(self)
self.is_overwritten = False
def set_weights(self, layer):
if self.overwrite_qkv:
qkv_weight = paddle.concat([layer.q.weight, layer.k.weight, layer.v.weight], axis=0)[:, :, :, 0]
qkv_bias = paddle.concat([layer.q.bias, layer.k.bias, layer.v.bias], axis=0)
self.qkv.weight.set_value(qkv_weight)
self.qkv.bias.set_value(qkv_bias)
proj_out = nn.Conv1D(self.channels, self.channels, 1)
proj_out.weight.set_value(layer.proj_out.weight[:, :, :, 0])
proj_out.bias.set_value(layer.proj_out.bias)
self.proj = proj_out
elif self.overwrite_linear:
self.qkv.weight.set_value(
paddle.concat([self.NIN_0.W.t(), self.NIN_1.W.t(), self.NIN_2.W.t()], axis=0)[:, :, None])
self.qkv.bias.set_value(paddle.concat([self.NIN_0.b, self.NIN_1.b, self.NIN_2.b], axis=0))
self.proj.weight.set_value(self.NIN_3.W.t()[:, :, None])
self.proj.bias.set_value(self.NIN_3.b)
self.norm.weight.set_value(self.GroupNorm_0.weight)
self.norm.bias.set_value(self.GroupNorm_0.bias)
else:
self.proj.weight.set_value(self.proj_out.weight)
self.proj.bias.set_value(self.proj_out.bias)
def forward(self, x, encoder_out=None):
if not self.is_overwritten and (self.overwrite_qkv or self.overwrite_linear):
self.set_weights(self)
self.is_overwritten = True
b, c, *spatial = x.shape
hid_states = self.norm(x).reshape([b, c, -1])
qkv = self.qkv(hid_states)
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape([bs * self.n_heads, ch * 3, length]).split(ch, axis=1)
if encoder_out is not None:
encoder_kv = self.encoder_kv(encoder_out)
assert encoder_kv.shape[1] == self.n_heads * ch * 2
ek, ev = encoder_kv.reshape([bs * self.n_heads, ch * 2, -1]).split(ch, axis=1)
k = paddle.concat([ek, k], axis=-1)
v = paddle.concat([ev, v], axis=-1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = paddle.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
weight = F.softmax(weight.astype("float32"), axis=-1).astype(weight.dtype)
a = paddle.einsum("bts,bcs->bct", weight, v)
h = a.reshape([bs, -1, length])
h = self.proj(h)
h = h.reshape([b, c, *spatial])
result = x + h
result = result / self.rescale_output_factor
return result
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def get_timestep_embedding(timesteps,
embedding_dim,
flip_sin_to_cos=False,
downscale_freq_shift=1,
scale=1,
max_period=10000):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * paddle.arange(start=0, end=half_dim, dtype="float32")
exponent = exponent / (half_dim - downscale_freq_shift)
emb = paddle.exp(exponent)
emb = timesteps[:, None].astype("float32") * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], axis=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = paddle.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = paddle.concat(emb, paddle.zeros([emb.shape[0], 1]), axis=-1)
return emb
class TimestepEmbedding(nn.Layer):
def __init__(self, channel, time_embed_dim, act_fn="silu"):
super().__init__()
self.linear_1 = nn.Linear(channel, time_embed_dim)
self.act = None
if act_fn == "silu":
self.act = nn.Silu()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def forward(self, sample):
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
return sample
class Timesteps(nn.Layer):
def __init__(self, num_channels, flip_sin_to_cos, downscale_freq_shift):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class GaussianFourierProjection(nn.Layer):
"""Gaussian Fourier embeddings for noise levels."""
def __init__(self, embedding_size=256, scale=1.0):
super().__init__()
self.register_buffer("weight", paddle.randn((embedding_size, )) * scale)
# to delete later
self.register_buffer("W", paddle.randn((embedding_size, )) * scale)
self.weight = self.W
def forward(self, x):
x = paddle.log(x)
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
out = paddle.concat([paddle.sin(x_proj), paddle.cos(x_proj)], axis=-1)
return out
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
def pad_new(x, pad, mode="constant", value=0):
new_pad = []
for _ in range(x.ndim * 2 - len(pad)):
new_pad.append(0)
ndim = list(range(x.ndim - 1, 0, -1))
axes_start = {}
for i, _pad in enumerate(pad):
if _pad < 0:
new_pad.append(0)
zhengshu, yushu = divmod(i, 2)
if yushu == 0:
axes_start[ndim[zhengshu]] = -_pad
else:
new_pad.append(_pad)
padded = paddle.nn.functional.pad(x, new_pad, mode=mode, value=value)
padded_shape = paddle.shape(padded)
axes = []
starts = []
ends = []
for k, v in axes_start.items():
axes.append(k)
starts.append(v)
ends.append(padded_shape[k])
assert v < padded_shape[k]
if axes:
return padded.slice(axes=axes, starts=starts, ends=ends)
else:
return padded
class Upsample2D(nn.Layer):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose:
conv = nn.Conv2DTranspose(channels, self.out_channels, 4, 2, 1)
elif use_conv:
conv = nn.Conv2D(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.conv = conv
else:
self.Conv2d_0 = conv
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(x)
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
x = self.conv(x)
else:
x = self.Conv2d_0(x)
return x
class Downsample2D(nn.Layer):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
conv = nn.Conv2D(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool2D(kernel_size=stride, stride=stride)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
self.Conv2d_0 = conv
self.conv = conv
elif name == "Conv2d_0":
self.conv = conv
else:
self.conv = conv
def forward(self, x):
assert x.shape[1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
x = pad_new(x, pad, mode="constant", value=0)
assert x.shape[1] == self.channels
x = self.conv(x)
return x
class FirUpsample2D(nn.Layer):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = nn.Conv2D(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.use_conv = use_conv
self.fir_kernel = fir_kernel
self.out_channels = out_channels
def _upsample_2d(self, x, w=None, k=None, factor=2, gain=1):
"""Fused `upsample_2d()` followed by `Conv2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
w: Weight tensor of the shape `[filterH, filterW, inChannels,
outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
`x`.
"""
assert isinstance(factor, int) and factor >= 1
# Setup filter kernel.
if k is None:
k = [1] * factor
# setup kernel
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
k = k * (gain * (factor**2))
if self.use_conv:
convH = w.shape[2]
convW = w.shape[3]
inC = w.shape[1]
p = (k.shape[0] - factor) - (convW - 1)
stride = (factor, factor)
# Determine data dimensions.
stride = [1, 1, factor, factor]
output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
output_padding = (
output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
)
assert output_padding[0] >= 0 and output_padding[1] >= 0
inC = w.shape[1]
num_groups = x.shape[1] // inC
# Transpose weights.
w = paddle.reshape(w, (num_groups, -1, inC, convH, convW))
w = w[..., ::-1, ::-1].transpose([0, 2, 1, 3, 4])
w = paddle.reshape(w, (num_groups * inC, -1, convH, convW))
x = F.conv2d_transpose(x, w, stride=stride, output_padding=output_padding, padding=0)
x = upfirdn2d_native(x, paddle.to_tensor(k), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
else:
p = k.shape[0] - factor
x = upfirdn2d_native(x, paddle.to_tensor(k), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
return x
def forward(self, x):
if self.use_conv:
h = self._upsample_2d(x, self.Conv2d_0.weight, k=self.fir_kernel)
h = h + self.Conv2d_0.bias.reshape([1, -1, 1, 1])
else:
h = self._upsample_2d(x, k=self.fir_kernel, factor=2)
return h
class FirDownsample2D(nn.Layer):
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
super().__init__()
out_channels = out_channels if out_channels else channels
if use_conv:
self.Conv2d_0 = nn.Conv2D(channels, out_channels, kernel_size=3, stride=1, padding=1)
self.fir_kernel = fir_kernel
self.use_conv = use_conv
self.out_channels = out_channels
def _downsample_2d(self, x, w=None, k=None, factor=2, gain=1):
"""Fused `Conv2d()` followed by `downsample_2d()`.
Args:
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
order.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
datatype as `x`.
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
# setup kernel
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
k = k * gain
if self.use_conv:
_, _, convH, convW = w.shape
p = (k.shape[0] - factor) + (convW - 1)
s = [factor, factor]
x = upfirdn2d_native(x, paddle.to_tensor(k), pad=((p + 1) // 2, p // 2))
x = F.conv2d(x, w, stride=s, padding=0)
else:
p = k.shape[0] - factor
x = upfirdn2d_native(x, paddle.to_tensor(k), down=factor, pad=((p + 1) // 2, p // 2))
return x
def forward(self, x):
if self.use_conv:
x = self._downsample_2d(x, w=self.Conv2d_0.weight, k=self.fir_kernel)
x = x + self.Conv2d_0.bias.reshape([1, -1, 1, 1])
else:
x = self._downsample_2d(x, k=self.fir_kernel, factor=2)
return x
class ResnetBlock(nn.Layer):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
kernel=None,
output_scale_factor=1.0,
use_nin_shortcut=None,
up=False,
down=False,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, epsilon=eps)
self.conv1 = nn.Conv2D(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
else:
self.time_emb_proj = None
self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, epsilon=eps)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2D(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.Silu()
self.upsample = self.downsample = None
if self.up:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else:
self.upsample = Upsample2D(in_channels, use_conv=False)
elif self.down:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, k=fir_kernel)
elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else:
self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
self.conv_shortcut = None
if self.use_nin_shortcut:
self.conv_shortcut = nn.Conv2D(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb, hey=False):
h = x
# make sure hidden states is in float32
# when running in half-precision
h = self.norm1(h.astype("float32")).astype(h.dtype)
h = self.nonlinearity(h)
if self.upsample is not None:
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
x = self.downsample(x)
h = self.downsample(h)
h = self.conv1(h)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
h = h + temb
# make sure hidden states is in float32
# when running in half-precision
h = self.norm2(h.astype("float32")).astype(h.dtype)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
out = (x + h) / self.output_scale_factor
return out
class Mish(nn.Layer):
def forward(self, x):
return x * F.tanh(F.softplus(x))
def upsample_2d(x, k=None, factor=2, gain=1):
r"""Upsample2D a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
multiple of the upsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H * factor, W * factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
k = k * (gain * (factor**2))
p = k.shape[0] - factor
return upfirdn2d_native(x, paddle.to_tensor(k), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2))
def downsample_2d(x, k=None, factor=2, gain=1):
r"""Downsample2D a batch of 2D images with the given filter.
Args:
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
shape is a multiple of the downsampling factor.
x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
C]`.
k: FIR filter of the shape `[firH, firW]` or `[firN]`
(separable). The default is `[1] * factor`, which corresponds to average pooling.
factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
Returns:
Tensor of the shape `[N, C, H // factor, W // factor]`
"""
assert isinstance(factor, int) and factor >= 1
if k is None:
k = [1] * factor
k = np.asarray(k, dtype=np.float32)
if k.ndim == 1:
k = np.outer(k, k)
k /= np.sum(k)
k = k * gain
p = k.shape[0] - factor
return upfirdn2d_native(x, paddle.to_tensor(k), down=factor, pad=((p + 1) // 2, p // 2))
def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
up_x = up_y = up
down_x = down_y = down
pad_x0 = pad_y0 = pad[0]
pad_x1 = pad_y1 = pad[1]
_, channel, in_h, in_w = input.shape
input = input.reshape([-1, in_h, in_w, 1])
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.reshape([-1, in_h, 1, in_w, 1, minor])
# TODO
out = pad_new(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.reshape([-1, in_h * up_y, in_w * up_x, minor])
out = pad_new(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
out = out.transpose([0, 3, 1, 2])
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = paddle.flip(kernel, [0, 1]).reshape([1, 1, kernel_h, kernel_w])
out = F.conv2d(out, w)
out = out.reshape(
[-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1])
out = out.transpose([0, 2, 3, 1])
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.reshape([-1, channel, out_h, out_w])
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import Union
import paddle
import paddle.nn as nn
from ..configuration_utils import ConfigMixin
from ..configuration_utils import register_to_config
from .embeddings import GaussianFourierProjection
from .embeddings import TimestepEmbedding
from .embeddings import Timesteps
from .unet_blocks import get_down_block
from .unet_blocks import get_up_block
from .unet_blocks import UNetMidBlock2D
class UNet2DModel(nn.Layer, ConfigMixin):
@register_to_config
def __init__(
self,
sample_size=None,
in_channels=3,
out_channels=3,
center_input_sample=False,
time_embedding_type="positional",
freq_shift=0,
flip_sin_to_cos=True,
down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels=(224, 448, 672, 896),
layers_per_block=2,
mid_block_scale_factor=1,
downsample_padding=1,
act_fn="silu",
attention_head_dim=8,
norm_num_groups=32,
norm_eps=1e-5,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input
self.conv_in = nn.Conv2D(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.down_blocks = nn.LayerList([])
self.mid_block = None
self.up_blocks = nn.LayerList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
attn_num_head_channels=attention_head_dim,
resnet_groups=norm_num_groups,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
attn_num_head_channels=attention_head_dim,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0],
num_groups=num_groups_out,
epsilon=norm_eps)
self.conv_act = nn.Silu()
self.conv_out = nn.Conv2D(block_out_channels[0], out_channels, 3, padding=1)
def forward(self, sample: paddle.Tensor, timestep: Union[paddle.Tensor, float, int]) -> Dict[str, paddle.Tensor]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not paddle.is_tensor(timesteps):
timesteps = paddle.to_tensor([timesteps], dtype="int64")
elif paddle.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None]
# broadcast to batch dimension
timesteps = paddle.broadcast_to(timesteps, [sample.shape[0]])
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)
# 2. pre-process
skip_sample = sample
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample, )
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "skip_conv"):
sample, res_samples, skip_sample = downsample_block(hidden_states=sample,
temb=emb,
skip_sample=skip_sample)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)]
if hasattr(upsample_block, "skip_conv"):
sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
else:
sample = upsample_block(sample, res_samples, emb)
# 6. post-process
# make sure hidden states is in float32
# when running in half-precision
sample = self.conv_norm_out(sample.astype("float32")).astype(sample.dtype)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if skip_sample is not None:
sample += skip_sample
if self.config.time_embedding_type == "fourier":
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
output = {"sample": sample}
return output
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from typing import Union
import paddle
import paddle.nn as nn
from ..configuration_utils import ConfigMixin
from ..configuration_utils import register_to_config
from .embeddings import TimestepEmbedding
from .embeddings import Timesteps
from .unet_blocks import get_down_block
from .unet_blocks import get_up_block
from .unet_blocks import UNetMidBlock2DCrossAttn
class UNet2DConditionModel(nn.Layer, ConfigMixin):
@register_to_config
def __init__(
self,
sample_size=64,
in_channels=4,
out_channels=4,
center_input_sample=False,
flip_sin_to_cos=True,
freq_shift=0,
down_block_types=("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
block_out_channels=(320, 640, 1280, 1280),
layers_per_block=2,
downsample_padding=1,
mid_block_scale_factor=1,
act_fn="silu",
norm_num_groups=32,
norm_eps=1e-5,
cross_attention_dim=768,
attention_head_dim=8,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
# input
self.conv_in = nn.Conv2D(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
# time
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
self.down_blocks = nn.LayerList([])
self.mid_block = None
self.up_blocks = nn.LayerList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
resnet_groups=norm_num_groups,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0],
num_groups=norm_num_groups,
epsilon=norm_eps)
self.conv_act = nn.Silu()
self.conv_out = nn.Conv2D(block_out_channels[0], out_channels, 3, padding=1)
def forward(
self,
sample: paddle.Tensor,
timestep: Union[paddle.Tensor, float, int],
encoder_hidden_states: paddle.Tensor,
) -> Dict[str, paddle.Tensor]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not paddle.is_tensor(timesteps):
timesteps = paddle.to_tensor([timesteps], dtype="int64")
elif paddle.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None]
# broadcast to batch dimension
timesteps = paddle.broadcast_to(timesteps, [sample.shape[0]])
t_emb = self.time_proj(timesteps)
emb = self.time_embedding(t_emb)
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
down_block_res_samples = (sample, )
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
# 5. up
for upsample_block in self.up_blocks:
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)]
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
# 6. post-process
# make sure hidden states is in float32
# when running in half-precision
sample = self.conv_norm_out(sample.astype("float32")).astype(sample.dtype)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
output = {"sample": sample}
return output
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
from ..configuration_utils import ConfigMixin
from ..configuration_utils import register_to_config
from .unet_blocks import get_down_block
from .unet_blocks import get_up_block
from .unet_blocks import UNetMidBlock2D
class Encoder(nn.Layer):
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D", ),
block_out_channels=(64, ),
layers_per_block=2,
act_fn="silu",
double_z=True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2D(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
self.mid_block = None
self.down_blocks = nn.LayerList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
attn_num_head_channels=None,
temb_channels=None,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=32,
temb_channels=None,
)
# out
num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, epsilon=1e-6)
self.conv_act = nn.Silu()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2D(block_out_channels[-1], conv_out_channels, 3, padding=1)
def forward(self, x):
sample = x
sample = self.conv_in(sample)
# down
for down_block in self.down_blocks:
sample = down_block(sample)
# middle
sample = self.mid_block(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class Decoder(nn.Layer):
def __init__(
self,
in_channels=3,
out_channels=3,
up_block_types=("UpDecoderBlock2D", ),
block_out_channels=(64, ),
layers_per_block=2,
act_fn="silu",
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2D(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
self.mid_block = None
self.up_blocks = nn.LayerList([])
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attn_num_head_channels=None,
resnet_groups=32,
temb_channels=None,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = get_up_block(
up_block_type,
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
prev_output_channel=None,
add_upsample=not is_final_block,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
attn_num_head_channels=None,
temb_channels=None,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = 32
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, epsilon=1e-6)
self.conv_act = nn.Silu()
self.conv_out = nn.Conv2D(block_out_channels[0], out_channels, 3, padding=1)
def forward(self, z):
sample = z
sample = self.conv_in(sample)
# middle
sample = self.mid_block(sample)
# up
for up_block in self.up_blocks:
sample = up_block(sample)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
class VectorQuantizer(nn.Layer):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
if self.remap is not None:
self.register_buffer("used", paddle.to_tensor(np.load(self.remap)))
self.re_embed = self.used.shape[0]
self.unknown_index = unknown_index # "random" or "extra" or integer
if self.unknown_index == "extra":
self.unknown_index = self.re_embed
self.re_embed = self.re_embed + 1
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
f"Using {self.unknown_index} for unknown indices.")
else:
self.re_embed = n_e
self.sane_index_shape = sane_index_shape
def remap_to_used(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape([ishape[0], -1])
used = self.used
match = (inds[:, :, None] == used[None, None, ...]).astype("int64")
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = paddle.randint(0, self.re_embed, shape=new[unknown].shape)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
def unmap_to_all(self, inds):
ishape = inds.shape
assert len(ishape) > 1
inds = inds.reshape([ishape[0], -1])
used = self.used
if self.re_embed > self.used.shape[0]: # extra token
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
back = paddle.gather(used[None, :][inds.shape[0] * [0], :], inds, axis=1)
return back.reshape(ishape)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.transpose([0, 2, 3, 1])
z_flattened = z.reshape([-1, self.e_dim])
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (paddle.sum(z_flattened**2, axis=1, keepdim=True) + paddle.sum(self.embedding.weight**2, axis=1) -
2 * paddle.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t()))
min_encoding_indices = paddle.argmin(d, axis=1)
z_q = self.embedding(min_encoding_indices).reshape(z.shape)
perplexity = None
min_encodings = None
# compute loss for embedding
if not self.legacy:
loss = self.beta * paddle.mean((z_q.detach() - z)**2) + paddle.mean((z_q - z.detach())**2)
else:
loss = paddle.mean((z_q.detach() - z)**2) + self.beta * paddle.mean((z_q - z.detach())**2)
# preserve gradients
z_q = z + (z_q - z).detach()
# reshape back to match original input shape
z_q = z_q.transpose([0, 3, 1, 2])
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape([z.shape[0], -1]) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape([-1, 1]) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape([z_q.shape[0], z_q.shape[2], z_q.shape[3]])
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices, shape):
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape([shape[0], -1]) # add batch axis
indices = self.unmap_to_all(indices)
indices = indices.flatten() # flatten again
# get quantized latent vectors
z_q = self.embedding(indices)
if shape is not None:
z_q = z_q.reshape(shape)
# reshape back to match original input shape
z_q = z_q.transpose([0, 3, 1, 2])
return z_q
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = paddle.chunk(parameters, 2, axis=1)
self.logvar = paddle.clip(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = paddle.exp(0.5 * self.logvar)
self.var = paddle.exp(self.logvar)
if self.deterministic:
self.var = self.std = paddle.zeros_like(self.mean)
def sample(self):
x = self.mean + self.std * paddle.randn(self.mean.shape)
return x
def kl(self, other=None):
if self.deterministic:
return paddle.to_tensor([0.0])
else:
if other is None:
return 0.5 * paddle.sum(paddle.pow(self.mean, 2) + self.var - 1.0 - self.logvar, axis=[1, 2, 3])
else:
return 0.5 * paddle.sum(
paddle.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar +
other.logvar,
axis=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return paddle.to_tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * paddle.sum(logtwopi + self.logvar + paddle.pow(sample - self.mean, 2) / self.var, axis=dims)
def mode(self):
return self.mean
class VQModel(ConfigMixin):
@register_to_config
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D", ),
up_block_types=("UpDecoderBlock2D", ),
block_out_channels=(64, ),
layers_per_block=1,
act_fn="silu",
latent_channels=3,
sample_size=32,
num_vq_embeddings=256,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
double_z=False,
)
self.quant_conv = nn.Conv2D(latent_channels, latent_channels, 1)
self.quantize = VectorQuantizer(num_vq_embeddings,
latent_channels,
beta=0.25,
remap=None,
sane_index_shape=False)
self.post_quant_conv = nn.Conv2D(latent_channels, latent_channels, 1)
# pass init params to Decoder
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def forward(self, sample):
x = sample
h = self.encode(x)
dec = self.decode(h)
return dec
class AutoencoderKL(nn.Layer, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"),
block_out_channels=(128, 256, 512, 512),
layers_per_block=2,
act_fn="silu",
latent_channels=4,
sample_size=512,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
double_z=True,
)
# pass init params to Decoder
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
)
self.quant_conv = nn.Conv2D(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2D(latent_channels, latent_channels, 1)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, sample, sample_posterior=False):
x = sample
posterior = self.encode(x)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec
# Schedulers
- Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps.
- Schedulers can be used interchangable between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are available in numpy, but can easily be transformed into Py
## API
- Schedulers should provide one or more `def step(...)` functions that should be called iteratively to unroll the diffusion loop during
the forward pass.
- Schedulers should be framework-agnostic, but provide a simple functionality to convert the scheduler into a specific framework, such as PyTorch
with a `set_format(...)` method.
## Examples
- The DDPM scheduler was proposed in [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and can be found in [scheduling_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py). An example of how to use this scheduler can be found in [pipeline_ddpm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddpm.py).
- The DDIM scheduler was proposed in [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) and can be found in [scheduling_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py). An example of how to use this scheduler can be found in [pipeline_ddim.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_ddim.py).
- The PNDM scheduler was proposed in [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) and can be found in [scheduling_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py). An example of how to use this scheduler can be found in [pipeline_pndm.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py).
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_lms_discrete import LMSDiscreteScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin
# Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pypaddle_diffusion
# and https://github.com/hojonathanho/diffusion
import math
from typing import Union
import numpy as np
import paddle
from ..configuration_utils import ConfigMixin
from ..configuration_utils import register_to_config
from .scheduling_utils import SchedulerMixin
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t
from 0 to 1 and
produces the cumulative product of (1-beta) up to that part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2)**2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float32)
class DDIMScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
clip_sample=True,
set_alpha_to_one=True,
tensor_format="pd",
):
if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32)**2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this paratemer simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def set_timesteps(self, num_inference_steps, offset=0):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.config.num_train_timesteps,
self.config.num_train_timesteps // self.num_inference_steps)[::-1].copy()
self.timesteps += offset
self.set_format(tensor_format=self.tensor_format)
def step(
self,
model_output: Union[paddle.Tensor, np.ndarray],
timestep: int,
sample: Union[paddle.Tensor, np.ndarray],
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointingc to x_t"
# - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t**(0.5) * model_output) / alpha_prod_t**(0.5)
# 4. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = self.clip(pred_original_sample, -1, 1)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance**(0.5)
if use_clipped_model_output:
# the model_output is always re-derived from the clipped x_0 in Glide
model_output = (sample - alpha_prod_t**(0.5) * pred_original_sample) / beta_prod_t**(0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2)**(0.5) * model_output
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev**(0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
noise = paddle.randn(model_output.shape)
variance = self._get_variance(timestep, prev_timestep)**(0.5) * eta * noise
if not paddle.is_tensor(model_output):
variance = variance.numpy()
prev_sample = prev_sample + variance
return {"prev_sample": prev_sample}
def add_noise(self, original_samples, noise, timesteps):
sqrt_alpha_prod = self.alphas_cumprod[timesteps]**0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps])**0.5
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
import numpy as np
import paddle
from ..configuration_utils import ConfigMixin
from ..configuration_utils import register_to_config
from .scheduling_utils import SchedulerMixin
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
the VE column of Table 1 from [1] for reference.
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models."
https://arxiv.org/abs/2206.00364 [2] Song, Yang, et al. "Score-based generative modeling through stochastic
differential equations." https://arxiv.org/abs/2011.13456
"""
@register_to_config
def __init__(
self,
sigma_min=0.02,
sigma_max=100,
s_noise=1.007,
s_churn=80,
s_min=0.05,
s_max=50,
tensor_format="pd",
):
"""
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the
optimal {s_noise, s_churn, s_min, s_max} for a specific model are described in Table 5 of the paper.
Args:
sigma_min (`float`): minimum noise magnitude
sigma_max (`float`): maximum noise magnitude
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
A reasonable range is [1.000, 1.011].
s_churn (`float`): the parameter controlling the overall amount of stochasticity.
A reasonable range is [0, 100].
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
A reasonable range is [0, 10].
s_max (`float`): the end value of the sigma range where we add noise.
A reasonable range is [0.2, 80].
"""
# setable values
self.num_inference_steps = None
self.timesteps = None
self.schedule = None # sigma(t_i)
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [(self.sigma_max * (self.sigma_min**2 / self.sigma_max**2)**(i / (num_inference_steps - 1)))
for i in self.timesteps]
self.schedule = np.array(self.schedule, dtype=np.float32)
self.set_format(tensor_format=self.tensor_format)
def add_noise_to_input(self, sample, sigma, generator=None):
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
"""
if self.s_min <= sigma <= self.s_max:
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
else:
gamma = 0
# sample eps ~ N(0, S_noise^2 * I)
eps = self.s_noise * paddle.randn(sample.shape)
sigma_hat = sigma + gamma * sigma
sample_hat = sample + ((sigma_hat**2 - sigma**2)**0.5 * eps)
return sample_hat, sigma_hat
def step(
self,
model_output: Union[paddle.Tensor, np.ndarray],
sigma_hat: float,
sigma_prev: float,
sample_hat: Union[paddle.Tensor, np.ndarray],
):
pred_original_sample = sample_hat + sigma_hat * model_output
derivative = (sample_hat - pred_original_sample) / sigma_hat
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
return {"prev_sample": sample_prev, "derivative": derivative}
def step_correct(
self,
model_output: Union[paddle.Tensor, np.ndarray],
sigma_hat: float,
sigma_prev: float,
sample_hat: Union[paddle.Tensor, np.ndarray],
sample_prev: Union[paddle.Tensor, np.ndarray],
derivative: Union[paddle.Tensor, np.ndarray],
):
pred_original_sample = sample_prev + sigma_prev * model_output
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
return {"prev_sample": sample_prev, "derivative": derivative_corr}
def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError()
# Copyright 2022 Google Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import numpy as np
import paddle
from ..configuration_utils import ConfigMixin
from ..configuration_utils import register_to_config
from .scheduling_utils import SchedulerMixin
class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(self, num_train_timesteps=2000, beta_min=0.1, beta_max=20, sampling_eps=1e-3, tensor_format="np"):
self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None
def set_timesteps(self, num_inference_steps):
self.timesteps = paddle.linspace(1, self.config.sampling_eps, num_inference_steps)
def step_pred(self, score, x, t):
# TODO(Patrick) better comments + non-PyTorch
# postprocess model score
log_mean_coeff = (-0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min)
std = paddle.sqrt(1.0 - paddle.exp(2.0 * log_mean_coeff))
score = -score / std[:, None, None, None]
# compute
dt = -1.0 / len(self.timesteps)
beta_t = self.config.beta_min + t * (self.config.beta_max - self.config.beta_min)
drift = -0.5 * beta_t[:, None, None, None] * x
diffusion = paddle.sqrt(beta_t)
drift = drift - diffusion[:, None, None, None]**2 * score
x_mean = x + drift * dt
# add noise
noise = self.randn_like(x)
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * noise
return x, x_mean
def __len__(self):
return self.config.num_train_timesteps
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册