未验证 提交 fc35d0c8 编写于 作者: C chenjian 提交者: GitHub

Add disco_diffusion_clip_vitb32 model

上级 f4d6e64c
# disco_diffusion_clip_vitb32
|模型名称|disco_diffusion_clip_vitb32|
| :--- | :---: |
|类别|图像-文图生成|
|网络|dd+clip ViTB32|
|数据集|-|
|是否支持Fine-tuning|否|
|模型大小|3.1GB|
|最新更新日期|2022-08-02|
|数据指标|-|
## 一、模型基本信息
### 应用效果展示
- 输入文本 "A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation."
- 输出图像
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/182298446-7feb530b-62cc-4e3f-a693-249ec8383daa.png" width = "80%" hspace='10'/>
<br />
- 生成过程
<p align="center">
<img src="https://user-images.githubusercontent.com/22424850/182298453-9a8a8336-66e6-4adb-a46f-7a0fa211b467.gif" width = "80%" hspace='10'/>
<br />
### 模型介绍
disco_diffusion_clip_vitb32 是一个文图生成模型,可以通过输入一段文字来生成符合该句子语义的图像。该模型由两部分组成,一部分是扩散模型,是一种生成模型,可以从噪声输入中重建出原始图像。另一部分是多模态预训练模型(CLIP), 可以将文本和图像表示在同一个特征空间,相近语义的文本和图像在该特征空间里距离会更相近。在该文图生成模型中,扩散模型负责从初始噪声或者指定初始图像中来生成目标图像,CLIP负责引导生成图像的语义和输入的文本的语义尽可能接近,随着扩散模型在CLIP的引导下不断的迭代生成新图像,最终能够生成文本所描述内容的图像。该模块中使用的CLIP模型结构为ViTB32。
更多详情请参考论文:[Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/abs/2105.05233) 以及 [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)
## 二、安装
- ### 1、环境依赖
- paddlepaddle >= 2.0.0
- paddlehub >= 2.2.0 | [如何安装PaddleHub](../../../../docs/docs_ch/get_start/installation.rst)
- ### 2、安装
- ```shell
$ hub install disco_diffusion_clip_vitb32
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
- ### 1、命令行预测
- ```shell
$ hub run disco_diffusion_clip_vitb32 --text_prompts "A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation." --output_dir disco_diffusion_clip_vitb32_out
```
- ### 2、预测代码示例
- ```python
import paddlehub as hub
module = hub.Module(name="disco_diffusion_clip_vitb32")
text_prompts = ["A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation."]
# 生成图像, 默认会在disco_diffusion_clip_vitb32_out目录保存图像
# 返回的da是一个DocumentArray对象,保存了所有的结果,包括最终结果和迭代过程的中间结果
# 可以通过操作DocumentArray对象对生成的图像做后处理,保存或者分析
da = module.generate_image(text_prompts=text_prompts, output_dir='./disco_diffusion_clip_vitb32_out/')
# 手动将最终生成的图像保存到指定路径
da[0].save_uri_to_file('disco_diffusion_clip_vitb32_out-result.png')
# 展示所有的中间结果
da[0].chunks.plot_image_sprites(skip_empty=True, show_index=True, keep_aspect_ratio=True)
# 将整个生成过程保存为一个动态图gif
da[0].chunks.save_gif('disco_diffusion_clip_vitb32_out-result.gif', show_index=True, inline_display=True, size_ratio=0.5)
```
- ### 3、API
- ```python
def generate_image(
text_prompts,
style: Optional[str] = None,
artist: Optional[str] = None,
width_height: Optional[List[int]] = [1280, 768],
seed: Optional[int] = None,
output_dir: Optional[str] = 'disco_diffusion_clip_vitb32_out'):
```
- 文图生成API,生成文本描述内容的图像。
- **参数**
- text_prompts(str): 输入的语句,描述想要生成的图像的内容。通常比较有效的构造方式为 "一段描述性的文字内容" + "指定艺术家的名字",如"a beautiful painting of Chinese architecture, by krenz, sunny, super wide angle, artstation."。prompt的构造可以参考[网站](https://docs.google.com/document/d/1XUT2G9LmkZataHFzmuOtRXnuWBfhvXDAo8DkS--8tec/edit#)。
- style(Optional[str]): 指定绘画的风格,如'watercolor','Chinese painting'等。当不指定时,风格完全由您所填写的prompt决定。
- artist(Optional[str]): 指定特定的艺术家,如Greg Rutkowsk、krenz,将会生成所指定艺术家的绘画风格。当不指定时,风格完全由您所填写的prompt决定。各种艺术家的风格可以参考[网站](https://weirdwonderfulai.art/resources/disco-diffusion-70-plus-artist-studies/)。
- width_height(Optional[List[int]]): 指定最终输出图像的宽高,宽和高都需要是64的倍数,生成的图像越大,所需要的计算时间越长。
- seed(Optional[int]): 随机种子,由于输入默认是随机高斯噪声,设置不同的随机种子会由不同的初始输入,从而最终生成不同的结果,可以设置该参数来获得不同的输出图像。
- output_dir(Optional[str]): 保存输出图像的目录,默认为"disco_diffusion_clip_vitb32_out"。
- **返回**
- ra(DocumentArray): DocumentArray对象, 包含`n_batches`个Documents,其中每个Document都保存了迭代过程的所有中间结果。详细可参考[DocumentArray使用文档](https://docarray.jina.ai/fundamentals/documentarray/index.html)。
## 四、更新历史
* 1.0.0
初始发布
```shell
$ hub install disco_diffusion_clip_vitb32 == 1.0.0
```
# OpenAI CLIP implemented in Paddle.
The original implementation repo is [ranchlai/clip.paddle](https://github.com/ranchlai/clip.paddle). We copy this repo here for guided diffusion.
from typing import Optional
import paddle
import paddle.nn as nn
from paddle import Tensor
from paddle.nn import functional as F
from paddle.nn import Linear
__all__ = ['ResidualAttentionBlock', 'AttentionPool2d', 'multi_head_attention_forward', 'MultiHeadAttention']
def multi_head_attention_forward(x: Tensor,
num_heads: int,
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
c_proj: Linear,
attn_mask: Optional[Tensor] = None):
max_len, batch_size, emb_dim = x.shape
head_dim = emb_dim // num_heads
scaling = float(head_dim)**-0.5
q = q_proj(x) # L, N, E
k = k_proj(x) # L, N, E
v = v_proj(x) # L, N, E
#k = k.con
v = v.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
k = k.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
q = q.reshape((-1, batch_size * num_heads, head_dim)).transpose((1, 0, 2))
q = q * scaling
qk = paddle.bmm(q, k.transpose((0, 2, 1)))
if attn_mask is not None:
if attn_mask.ndim == 2:
attn_mask.unsqueeze_(0)
#assert str(attn_mask.dtype) == 'VarType.FP32' and attn_mask.ndim == 3
assert attn_mask.shape[0] == 1 and attn_mask.shape[1] == max_len and attn_mask.shape[2] == max_len
qk += attn_mask
qk = paddle.nn.functional.softmax(qk, axis=-1)
atten = paddle.bmm(qk, v)
atten = atten.transpose((1, 0, 2))
atten = atten.reshape((max_len, batch_size, emb_dim))
atten = c_proj(atten)
return atten
class MultiHeadAttention(nn.Layer): # without attention mask
def __init__(self, emb_dim: int, num_heads: int):
super().__init__()
self.q_proj = nn.Linear(emb_dim, emb_dim, bias_attr=True)
self.k_proj = nn.Linear(emb_dim, emb_dim, bias_attr=True)
self.v_proj = nn.Linear(emb_dim, emb_dim, bias_attr=True)
self.c_proj = nn.Linear(emb_dim, emb_dim, bias_attr=True)
self.head_dim = emb_dim // num_heads
self.emb_dim = emb_dim
self.num_heads = num_heads
assert self.head_dim * num_heads == emb_dim, "embed_dim must be divisible by num_heads"
#self.scaling = float(self.head_dim) ** -0.5
def forward(self, x, attn_mask=None): # x is in shape[max_len,batch_size,emb_dim]
atten = multi_head_attention_forward(x,
self.num_heads,
self.q_proj,
self.k_proj,
self.v_proj,
self.c_proj,
attn_mask=attn_mask)
return atten
class Identity(nn.Layer):
def __init__(self):
super().__init__()
def forward(self, x):
return x
class Bottleneck(nn.Layer):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2D(inplanes, planes, 1, bias_attr=False)
self.bn1 = nn.BatchNorm2D(planes)
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(planes)
self.avgpool = nn.AvgPool2D(stride) if stride > 1 else Identity()
self.conv3 = nn.Conv2D(planes, planes * self.expansion, 1, bias_attr=False)
self.bn3 = nn.BatchNorm2D(planes * self.expansion)
self.relu = nn.ReLU()
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
self.downsample = nn.Sequential(
("-1", nn.AvgPool2D(stride)),
("0", nn.Conv2D(inplanes, planes * self.expansion, 1, stride=1, bias_attr=False)),
("1", nn.BatchNorm2D(planes * self.expansion)))
def forward(self, x):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Layer):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = paddle.create_parameter((spacial_dim**2 + 1, embed_dim), dtype='float32')
self.q_proj = nn.Linear(embed_dim, embed_dim, bias_attr=True)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias_attr=True)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias_attr=True)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim, bias_attr=True)
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
def forward(self, x):
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])).transpose((2, 0, 1)) # NCHW -> (HW)NC
max_len, batch_size, emb_dim = x.shape
head_dim = self.head_dim
x = paddle.concat([paddle.mean(x, axis=0, keepdim=True), x], axis=0)
x = x + paddle.unsqueeze(self.positional_embedding, 1)
out = multi_head_attention_forward(x, self.num_heads, self.q_proj, self.k_proj, self.v_proj, self.c_proj)
return out[0]
class QuickGELU(nn.Layer):
def forward(self, x):
return x * paddle.nn.functional.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Layer):
def __init__(self, d_model: int, n_head: int, attn_mask=None):
super().__init__()
self.attn = MultiHeadAttention(d_model, n_head)
self.ln_1 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model)))
self.ln_2 = nn.LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x):
x = self.attn(x, self.attn_mask)
assert isinstance(x, paddle.Tensor) # not tuble here
return x
def forward(self, x):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
from typing import Tuple
from typing import Union
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import nn
from .layers import AttentionPool2d
from .layers import Bottleneck
from .layers import MultiHeadAttention
from .layers import ResidualAttentionBlock
class ModifiedResNet(nn.Layer):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2D(3, width // 2, kernel_size=3, stride=2, padding=1, bias_attr=False)
self.bn1 = nn.BatchNorm2D(width // 2)
self.conv2 = nn.Conv2D(width // 2, width // 2, kernel_size=3, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(width // 2)
self.conv3 = nn.Conv2D(width // 2, width, kernel_size=3, padding=1, bias_attr=False)
self.bn3 = nn.BatchNorm2D(width)
self.avgpool = nn.AvgPool2D(2)
self.relu = nn.ReLU()
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
#x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class Transformer(nn.Layer):
def __init__(self, width: int, layers: int, heads: int, attn_mask=None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x):
return self.resblocks(x)
class VisualTransformer(nn.Layer):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
# used patch_size x patch_size, stride patch_size to do linear projection
self.conv1 = nn.Conv2D(in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias_attr=False)
# scale = width ** -0.5
self.class_embedding = paddle.create_parameter((width, ), 'float32')
self.positional_embedding = paddle.create_parameter(((input_resolution // patch_size)**2 + 1, width), 'float32')
self.ln_pre = nn.LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = nn.LayerNorm(width)
self.proj = paddle.create_parameter((width, output_dim), 'float32')
def forward(self, x):
x = self.conv1(x)
x = x.reshape((x.shape[0], x.shape[1], -1))
x = x.transpose((0, 2, 1))
x = paddle.concat([self.class_embedding + paddle.zeros((x.shape[0], 1, x.shape[-1]), dtype=x.dtype), x], axis=1)
x = x + self.positional_embedding
x = self.ln_pre(x)
x = x.transpose((1, 0, 2))
x = self.transformer(x)
x = x.transpose((1, 0, 2))
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = paddle.matmul(x, self.proj)
return x
class CLIP(nn.Layer):
def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int):
super().__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width)
else:
vision_heads = vision_width // 64
self.visual = VisualTransformer(input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim)
self.transformer = Transformer(width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask())
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = paddle.create_parameter((self.context_length, transformer_width), 'float32')
self.ln_final = nn.LayerNorm(transformer_width)
self.text_projection = paddle.create_parameter((transformer_width, embed_dim), 'float32')
self.logit_scale = paddle.create_parameter((1, ), 'float32')
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# mask = paddle.empty((self.context_length, self.context_length),dtype='float32')
# mask.fill_(float("-inf"))
#mask.triu_(1) # zero out the lower diagonal
mask = paddle.ones((self.context_length, self.context_length)) * float("-inf")
mask = paddle.triu(mask, diagonal=1)
return mask
def encode_image(self, image):
return self.visual(image)
def encode_text(self, text):
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
# print(x.shape)
x = x + self.positional_embedding
#print(x.shape)
x = x.transpose((1, 0, 2)) # NLD -> LND
x = self.transformer(x)
x = x.transpose((1, 0, 2)) # LND -> NLD
x = self.ln_final(x)
idx = text.numpy().argmax(-1)
idx = list(idx)
x = [x[i:i + 1, int(j), :] for i, j in enumerate(idx)]
x = paddle.concat(x, 0)
x = paddle.matmul(x, self.text_projection)
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = paddle.matmul(logit_scale * image_features, text_features.t())
logits_per_text = paddle.matmul(logit_scale * text_features, image_features.t())
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "../assets/bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + '</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + '</w>', )
pairs = get_pairs(word)
if not pairs:
return token + '</w>'
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text
import os
from typing import List
from typing import Union
import numpy as np
import paddle
from paddle.utils import download
from paddle.vision.transforms import CenterCrop
from paddle.vision.transforms import Compose
from paddle.vision.transforms import Normalize
from paddle.vision.transforms import Resize
from paddle.vision.transforms import ToTensor
from .model import CLIP
from .simple_tokenizer import SimpleTokenizer
__all__ = ['transform', 'tokenize', 'build_model']
MODEL_NAMES = ['RN50', 'RN101', 'VIT32']
URL = {
'RN50': os.path.join(os.path.dirname(__file__), 'pre_trained', 'RN50.pdparams'),
'RN101': os.path.join(os.path.dirname(__file__), 'pre_trained', 'RN101.pdparams'),
'VIT32': os.path.join(os.path.dirname(__file__), 'pre_trained', 'ViT-B-32.pdparams')
}
MEAN, STD = (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
_tokenizer = SimpleTokenizer()
transform = Compose([
Resize(224, interpolation='bicubic'),
CenterCrop(224), lambda image: image.convert('RGB'),
ToTensor(),
Normalize(mean=MEAN, std=STD), lambda t: t.unsqueeze_(0)
])
def tokenize(texts: Union[str, List[str]], context_length: int = 77):
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = paddle.zeros((len(all_tokens), context_length), dtype='int64')
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
result[i, :len(tokens)] = paddle.Tensor(np.array(tokens))
return result
def build_model(name='VIT32'):
assert name in MODEL_NAMES, f"model name must be one of {MODEL_NAMES}"
name2model = {'RN101': build_rn101_model, 'VIT32': build_vit_model, 'RN50': build_rn50_model}
model = name2model[name]()
weight = URL[name]
sd = paddle.load(weight)
model.load_dict(sd)
model.eval()
return model
def build_vit_model():
model = CLIP(embed_dim=512,
image_resolution=224,
vision_layers=12,
vision_width=768,
vision_patch_size=32,
context_length=77,
vocab_size=49408,
transformer_width=512,
transformer_heads=8,
transformer_layers=12)
return model
def build_rn101_model():
model = CLIP(
embed_dim=512,
image_resolution=224,
vision_layers=(3, 4, 23, 3),
vision_width=64,
vision_patch_size=0, #Not used in resnet
context_length=77,
vocab_size=49408,
transformer_width=512,
transformer_heads=8,
transformer_layers=12)
return model
def build_rn50_model():
model = CLIP(embed_dim=1024,
image_resolution=224,
vision_layers=(3, 4, 6, 3),
vision_width=64,
vision_patch_size=None,
context_length=77,
vocab_size=49408,
transformer_width=512,
transformer_heads=8,
transformer_layers=12)
return model
numpy
paddle_lpips==0.1.2
ftfy
docarray>=0.13.29
pyyaml
regex
tqdm
ipywidgets
# ResizeRight (Paddle)
Fully differentiable resize function implemented by Paddle.
This module is based on [assafshocher/ResizeRight](https://github.com/assafshocher/ResizeRight).
from math import pi
try:
import paddle
except ImportError:
paddle = None
try:
import numpy
import numpy as np
except ImportError:
numpy = None
if numpy is None and paddle is None:
raise ImportError("Must have either Numpy or PyTorch but both not found")
def set_framework_dependencies(x):
if type(x) is numpy.ndarray:
to_dtype = lambda a: a
fw = numpy
else:
to_dtype = lambda a: paddle.cast(a, x.dtype)
fw = paddle
# eps = fw.finfo(fw.float32).eps
eps = paddle.to_tensor(np.finfo(np.float32).eps)
return fw, to_dtype, eps
def support_sz(sz):
def wrapper(f):
f.support_sz = sz
return f
return wrapper
@support_sz(4)
def cubic(x):
fw, to_dtype, eps = set_framework_dependencies(x)
absx = fw.abs(x)
absx2 = absx**2
absx3 = absx**3
return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +
(-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * to_dtype((1. < absx) & (absx <= 2.)))
@support_sz(4)
def lanczos2(x):
fw, to_dtype, eps = set_framework_dependencies(x)
return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2))
@support_sz(6)
def lanczos3(x):
fw, to_dtype, eps = set_framework_dependencies(x)
return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3))
@support_sz(2)
def linear(x):
fw, to_dtype, eps = set_framework_dependencies(x)
return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * to_dtype((0 <= x) & (x <= 1)))
@support_sz(1)
def box(x):
fw, to_dtype, eps = set_framework_dependencies(x)
return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))
import warnings
from fractions import Fraction
from math import ceil
from typing import Tuple
import disco_diffusion_clip_vitb32.resize_right.interp_methods as interp_methods
class NoneClass:
pass
try:
import paddle
from paddle import nn
nnModuleWrapped = nn.Layer
except ImportError:
warnings.warn('No PyTorch found, will work only with Numpy')
paddle = None
nnModuleWrapped = NoneClass
try:
import numpy
import numpy as np
except ImportError:
warnings.warn('No Numpy found, will work only with PyTorch')
numpy = None
if numpy is None and paddle is None:
raise ImportError("Must have either Numpy or PyTorch but both not found")
def resize(input,
scale_factors=None,
out_shape=None,
interp_method=interp_methods.cubic,
support_sz=None,
antialiasing=True,
by_convs=False,
scale_tolerance=None,
max_numerator=10,
pad_mode='constant'):
# get properties of the input tensor
in_shape, n_dims = input.shape, input.ndim
# fw stands for framework that can be either numpy or paddle,
# determined by the input type
fw = numpy if type(input) is numpy.ndarray else paddle
eps = np.finfo(np.float32).eps if fw == numpy else paddle.to_tensor(np.finfo(np.float32).eps)
device = input.place if fw is paddle else None
# set missing scale factors or output shapem one according to another,
# scream if both missing. this is also where all the defults policies
# take place. also handling the by_convs attribute carefully.
scale_factors, out_shape, by_convs = set_scale_and_out_sz(in_shape, out_shape, scale_factors, by_convs,
scale_tolerance, max_numerator, eps, fw)
# sort indices of dimensions according to scale of each dimension.
# since we are going dim by dim this is efficient
sorted_filtered_dims_and_scales = [(dim, scale_factors[dim], by_convs[dim], in_shape[dim], out_shape[dim])
for dim in sorted(range(n_dims), key=lambda ind: scale_factors[ind])
if scale_factors[dim] != 1.]
# unless support size is specified by the user, it is an attribute
# of the interpolation method
if support_sz is None:
support_sz = interp_method.support_sz
# output begins identical to input and changes with each iteration
output = input
# iterate over dims
for (dim, scale_factor, dim_by_convs, in_sz, out_sz) in sorted_filtered_dims_and_scales:
# STEP 1- PROJECTED GRID: The non-integer locations of the projection
# of output pixel locations to the input tensor
projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, dim_by_convs, device)
# STEP 1.5: ANTIALIASING- If antialiasing is taking place, we modify
# the window size and the interpolation method (see inside function)
cur_interp_method, cur_support_sz = apply_antialiasing_if_needed(interp_method, support_sz, scale_factor,
antialiasing)
# STEP 2- FIELDS OF VIEW: for each output pixels, map the input pixels
# that influence it. Also calculate needed padding and update grid
# accoedingly
field_of_view = get_field_of_view(projected_grid, cur_support_sz, fw, eps, device)
# STEP 2.5- CALCULATE PAD AND UPDATE: according to the field of view,
# the input should be padded to handle the boundaries, coordinates
# should be updated. actual padding only occurs when weights are
# aplied (step 4). if using by_convs for this dim, then we need to
# calc right and left boundaries for each filter instead.
pad_sz, projected_grid, field_of_view = calc_pad_sz(in_sz, out_sz, field_of_view, projected_grid, scale_factor,
dim_by_convs, fw, device)
# STEP 3- CALCULATE WEIGHTS: Match a set of weights to the pixels in
# the field of view for each output pixel
weights = get_weights(cur_interp_method, projected_grid, field_of_view)
# STEP 4- APPLY WEIGHTS: Each output pixel is calculated by multiplying
# its set of weights with the pixel values in its field of view.
# We now multiply the fields of view with their matching weights.
# We do this by tensor multiplication and broadcasting.
# if by_convs is true for this dim, then we do this action by
# convolutions. this is equivalent but faster.
if not dim_by_convs:
output = apply_weights(output, field_of_view, weights, dim, n_dims, pad_sz, pad_mode, fw)
else:
output = apply_convs(output, scale_factor, in_sz, out_sz, weights, dim, pad_sz, pad_mode, fw)
return output
def get_projected_grid(in_sz, out_sz, scale_factor, fw, by_convs, device=None):
# we start by having the ouput coordinates which are just integer locations
# in the special case when usin by_convs, we only need two cycles of grid
# points. the first and last.
grid_sz = out_sz if not by_convs else scale_factor.numerator
out_coordinates = fw_arange(grid_sz, fw, device)
# This is projecting the ouput pixel locations in 1d to the input tensor,
# as non-integer locations.
# the following fomrula is derived in the paper
# "From Discrete to Continuous Convolutions" by Shocher et al.
return (out_coordinates / float(scale_factor) + (in_sz - 1) / 2 - (out_sz - 1) / (2 * float(scale_factor)))
def get_field_of_view(projected_grid, cur_support_sz, fw, eps, device):
# for each output pixel, map which input pixels influence it, in 1d.
# we start by calculating the leftmost neighbor, using half of the window
# size (eps is for when boundary is exact int)
left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw)
# then we simply take all the pixel centers in the field by counting
# window size pixels from the left boundary
ordinal_numbers = fw_arange(ceil(cur_support_sz - eps), fw, device)
return left_boundaries[:, None] + ordinal_numbers
def calc_pad_sz(in_sz, out_sz, field_of_view, projected_grid, scale_factor, dim_by_convs, fw, device):
if not dim_by_convs:
# determine padding according to neighbor coords out of bound.
# this is a generalized notion of padding, when pad<0 it means crop
pad_sz = [-field_of_view[0, 0].item(), field_of_view[-1, -1].item() - in_sz + 1]
# since input image will be changed by padding, coordinates of both
# field_of_view and projected_grid need to be updated
field_of_view += pad_sz[0]
projected_grid += pad_sz[0]
else:
# only used for by_convs, to calc the boundaries of each filter the
# number of distinct convolutions is the numerator of the scale factor
num_convs, stride = scale_factor.numerator, scale_factor.denominator
# calculate left and right boundaries for each conv. left can also be
# negative right can be bigger than in_sz. such cases imply padding if
# needed. however if# both are in-bounds, it means we need to crop,
# practically apply the conv only on part of the image.
left_pads = -field_of_view[:, 0]
# next calc is tricky, explanation by rows:
# 1) counting output pixels between the first position of each filter
# to the right boundary of the input
# 2) dividing it by number of filters to count how many 'jumps'
# each filter does
# 3) multiplying by the stride gives us the distance over the input
# coords done by all these jumps for each filter
# 4) to this distance we add the right boundary of the filter when
# placed in its leftmost position. so now we get the right boundary
# of that filter in input coord.
# 5) the padding size needed is obtained by subtracting the rightmost
# input coordinate. if the result is positive padding is needed. if
# negative then negative padding means shaving off pixel columns.
right_pads = (((out_sz - fw_arange(num_convs, fw, device) - 1) # (1)
// num_convs) # (2)
* stride # (3)
+ field_of_view[:, -1] # (4)
- in_sz + 1) # (5)
# in the by_convs case pad_sz is a list of left-right pairs. one per
# each filter
pad_sz = list(zip(left_pads, right_pads))
return pad_sz, projected_grid, field_of_view
def get_weights(interp_method, projected_grid, field_of_view):
# the set of weights per each output pixels is the result of the chosen
# interpolation method applied to the distances between projected grid
# locations and the pixel-centers in the field of view (distances are
# directed, can be positive or negative)
weights = interp_method(projected_grid[:, None] - field_of_view)
# we now carefully normalize the weights to sum to 1 per each output pixel
sum_weights = weights.sum(1, keepdim=True)
sum_weights[sum_weights == 0] = 1
return weights / sum_weights
def apply_weights(input, field_of_view, weights, dim, n_dims, pad_sz, pad_mode, fw):
# for this operation we assume the resized dim is the first one.
# so we transpose and will transpose back after multiplying
tmp_input = fw_swapaxes(input, dim, 0, fw)
# apply padding
tmp_input = fw_pad(tmp_input, fw, pad_sz, pad_mode)
# field_of_view is a tensor of order 2: for each output (1d location
# along cur dim)- a list of 1d neighbors locations.
# note that this whole operations is applied to each dim separately,
# this is why it is all in 1d.
# neighbors = tmp_input[field_of_view] is a tensor of order image_dims+1:
# for each output pixel (this time indicated in all dims), these are the
# values of the neighbors in the 1d field of view. note that we only
# consider neighbors along the current dim, but such set exists for every
# multi-dim location, hence the final tensor order is image_dims+1.
paddle.device.cuda.empty_cache()
neighbors = tmp_input[field_of_view]
# weights is an order 2 tensor: for each output location along 1d- a list
# of weights matching the field of view. we augment it with ones, for
# broadcasting, so that when multiplies some tensor the weights affect
# only its first dim.
tmp_weights = fw.reshape(weights, (*weights.shape, *[1] * (n_dims - 1)))
# now we simply multiply the weights with the neighbors, and then sum
# along the field of view, to get a single value per out pixel
tmp_output = (neighbors * tmp_weights).sum(1)
# we transpose back the resized dim to its original position
return fw_swapaxes(tmp_output, 0, dim, fw)
def apply_convs(input, scale_factor, in_sz, out_sz, weights, dim, pad_sz, pad_mode, fw):
# for this operations we assume the resized dim is the last one.
# so we transpose and will transpose back after multiplying
input = fw_swapaxes(input, dim, -1, fw)
# the stride for all convs is the denominator of the scale factor
stride, num_convs = scale_factor.denominator, scale_factor.numerator
# prepare an empty tensor for the output
tmp_out_shape = list(input.shape)
tmp_out_shape[-1] = out_sz
tmp_output = fw_empty(tuple(tmp_out_shape), fw, input.device)
# iterate over the conv operations. we have as many as the numerator
# of the scale-factor. for each we need boundaries and a filter.
for conv_ind, (pad_sz, filt) in enumerate(zip(pad_sz, weights)):
# apply padding (we pad last dim, padding can be negative)
pad_dim = input.ndim - 1
tmp_input = fw_pad(input, fw, pad_sz, pad_mode, dim=pad_dim)
# apply convolution over last dim. store in the output tensor with
# positional strides so that when the loop is comlete conv results are
# interwind
tmp_output[..., conv_ind::num_convs] = fw_conv(tmp_input, filt, stride)
return fw_swapaxes(tmp_output, -1, dim, fw)
def set_scale_and_out_sz(in_shape, out_shape, scale_factors, by_convs, scale_tolerance, max_numerator, eps, fw):
# eventually we must have both scale-factors and out-sizes for all in/out
# dims. however, we support many possible partial arguments
if scale_factors is None and out_shape is None:
raise ValueError("either scale_factors or out_shape should be "
"provided")
if out_shape is not None:
# if out_shape has less dims than in_shape, we defaultly resize the
# first dims for numpy and last dims for paddle
out_shape = (list(out_shape) +
list(in_shape[len(out_shape):]) if fw is numpy else list(in_shape[:-len(out_shape)]) +
list(out_shape))
if scale_factors is None:
# if no scale given, we calculate it as the out to in ratio
# (not recomended)
scale_factors = [out_sz / in_sz for out_sz, in_sz in zip(out_shape, in_shape)]
if scale_factors is not None:
# by default, if a single number is given as scale, we assume resizing
# two dims (most common are images with 2 spatial dims)
scale_factors = (scale_factors if isinstance(scale_factors, (list, tuple)) else [scale_factors, scale_factors])
# if less scale_factors than in_shape dims, we defaultly resize the
# first dims for numpy and last dims for paddle
scale_factors = (list(scale_factors) + [1] * (len(in_shape) - len(scale_factors)) if fw is numpy else [1] *
(len(in_shape) - len(scale_factors)) + list(scale_factors))
if out_shape is None:
# when no out_shape given, it is calculated by multiplying the
# scale by the in_shape (not recomended)
out_shape = [ceil(scale_factor * in_sz) for scale_factor, in_sz in zip(scale_factors, in_shape)]
# next part intentionally after out_shape determined for stability
# we fix by_convs to be a list of truth values in case it is not
if not isinstance(by_convs, (list, tuple)):
by_convs = [by_convs] * len(out_shape)
# next loop fixes the scale for each dim to be either frac or float.
# this is determined by by_convs and by tolerance for scale accuracy.
for ind, (sf, dim_by_convs) in enumerate(zip(scale_factors, by_convs)):
# first we fractionaize
if dim_by_convs:
frac = Fraction(1 / sf).limit_denominator(max_numerator)
frac = Fraction(numerator=frac.denominator, denominator=frac.numerator)
# if accuracy is within tolerance scale will be frac. if not, then
# it will be float and the by_convs attr will be set false for
# this dim
if scale_tolerance is None:
scale_tolerance = eps
if dim_by_convs and abs(frac - sf) < scale_tolerance:
scale_factors[ind] = frac
else:
scale_factors[ind] = float(sf)
by_convs[ind] = False
return scale_factors, out_shape, by_convs
def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor, antialiasing):
# antialiasing is "stretching" the field of view according to the scale
# factor (only for downscaling). this is low-pass filtering. this
# requires modifying both the interpolation (stretching the 1d
# function and multiplying by the scale-factor) and the window size.
scale_factor = float(scale_factor)
if scale_factor >= 1.0 or not antialiasing:
return interp_method, support_sz
cur_interp_method = (lambda arg: scale_factor * interp_method(scale_factor * arg))
cur_support_sz = support_sz / scale_factor
return cur_interp_method, cur_support_sz
def fw_ceil(x, fw):
if fw is numpy:
return fw.int_(fw.ceil(x))
else:
return paddle.cast(x.ceil(), dtype='int64')
def fw_floor(x, fw):
if fw is numpy:
return fw.int_(fw.floor(x))
else:
return paddle.cast(x.floor(), dtype='int64')
def fw_cat(x, fw):
if fw is numpy:
return fw.concatenate(x)
else:
return fw.concat(x)
def fw_swapaxes(x, ax_1, ax_2, fw):
if fw is numpy:
return fw.swapaxes(x, ax_1, ax_2)
else:
if ax_1 == -1:
ax_1 = len(x.shape) - 1
if ax_2 == -1:
ax_2 = len(x.shape) - 1
perm0 = list(range(len(x.shape)))
temp = ax_1
perm0[temp] = ax_2
perm0[ax_2] = temp
return fw.transpose(x, perm0)
def fw_pad(x, fw, pad_sz, pad_mode, dim=0):
if pad_sz == (0, 0):
return x
if fw is numpy:
pad_vec = [(0, 0)] * x.ndim
pad_vec[dim] = pad_sz
return fw.pad(x, pad_width=pad_vec, mode=pad_mode)
else:
if x.ndim < 3:
x = x[None, None, ...]
pad_vec = [0] * ((x.ndim - 2) * 2)
pad_vec[0:2] = pad_sz
return fw_swapaxes(fw.nn.functional.pad(fw_swapaxes(x, dim, -1, fw), pad=pad_vec, mode=pad_mode), dim, -1, fw)
def fw_conv(input, filter, stride):
# we want to apply 1d conv to any nd array. the way to do it is to reshape
# the input to a 4D tensor. first two dims are singeletons, 3rd dim stores
# all the spatial dims that we are not convolving along now. then we can
# apply conv2d with a 1xK filter. This convolves the same way all the other
# dims stored in the 3d dim. like depthwise conv over these.
# TODO: numpy support
reshaped_input = input.reshape(1, 1, -1, input.shape[-1])
reshaped_output = paddle.nn.functional.conv2d(reshaped_input, filter.view(1, 1, 1, -1), stride=(1, stride))
return reshaped_output.reshape(*input.shape[:-1], -1)
def fw_arange(upper_bound, fw, device):
if fw is numpy:
return fw.arange(upper_bound)
else:
return fw.arange(upper_bound)
def fw_empty(shape, fw, device):
if fw is numpy:
return fw.empty(shape)
else:
return fw.empty(shape=shape)
# Diffusion model (Paddle)
This module implements diffusion model which accepts a text prompt and outputs images semantically close to the text. The code is rewritten by Paddle, and mainly refer to two projects: jina-ai/discoart[https://github.com/jina-ai/discoart] and openai/guided-diffusion[https://github.com/openai/guided-diffusion]. Thanks for their wonderful work.
'''
https://github.com/jina-ai/discoart/blob/main/discoart/__init__.py
'''
import os
import warnings
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
__all__ = ['create']
import sys
__resources_path__ = os.path.join(
os.path.dirname(sys.modules.get(__package__).__file__ if __package__ in sys.modules else __file__),
'resources',
)
import gc
# check if GPU is available
import paddle
# download and load models, this will take some time on the first load
from .helper import load_all_models, load_diffusion_model, load_clip_models
model_config, secondary_model = load_all_models('512x512_diffusion_uncond_finetune_008100', use_secondary_model=True)
from typing import TYPE_CHECKING, overload, List, Optional
if TYPE_CHECKING:
from docarray import DocumentArray, Document
_clip_models_cache = {}
# begin_create_overload
@overload
def create(text_prompts: Optional[List[str]] = [
'A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation.',
'yellow color scheme',
],
init_image: Optional[str] = None,
width_height: Optional[List[int]] = [1280, 768],
skip_steps: Optional[int] = 10,
steps: Optional[int] = 250,
cut_ic_pow: Optional[int] = 1,
init_scale: Optional[int] = 1000,
clip_guidance_scale: Optional[int] = 5000,
tv_scale: Optional[int] = 0,
range_scale: Optional[int] = 150,
sat_scale: Optional[int] = 0,
cutn_batches: Optional[int] = 4,
diffusion_model: Optional[str] = '512x512_diffusion_uncond_finetune_008100',
use_secondary_model: Optional[bool] = True,
diffusion_sampling_mode: Optional[str] = 'ddim',
perlin_init: Optional[bool] = False,
perlin_mode: Optional[str] = 'mixed',
seed: Optional[int] = None,
eta: Optional[float] = 0.8,
clamp_grad: Optional[bool] = True,
clamp_max: Optional[float] = 0.05,
randomize_class: Optional[bool] = True,
clip_denoised: Optional[bool] = False,
fuzzy_prompt: Optional[bool] = False,
rand_mag: Optional[float] = 0.05,
cut_overview: Optional[str] = '[12]*400+[4]*600',
cut_innercut: Optional[str] = '[4]*400+[12]*600',
cut_icgray_p: Optional[str] = '[0.2]*400+[0]*600',
display_rate: Optional[int] = 10,
n_batches: Optional[int] = 4,
batch_size: Optional[int] = 1,
batch_name: Optional[str] = '',
clip_models: Optional[list] = ['ViTB32', 'ViTB16', 'RN50'],
output_dir: Optional[str] = 'discoart_output') -> 'DocumentArray':
"""
Create Disco Diffusion artworks and save the result into a DocumentArray.
:param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like. The words will be analyzed by the AI and will guide the diffusion process toward the image(s) you describe. These can include commas and weights to adjust the relative importance of each element. E.g. "A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation."Notice that this prompt loosely follows a structure: [subject], [prepositional details], [setting], [meta modifiers and artist]; this is a good starting point for your experiments. Developing text prompts takes practice and experience, and is not the subject of this guide. If you are a beginner to writing text prompts, a good place to start is on a simple AI art app like Night Cafe, starry ai or WOMBO prior to using DD, to get a feel for how text gets translated into images by GAN tools. These other apps use different technologies, but many of the same principles apply.
:param init_image: Recall that in the image sequence above, the first image shown is just noise. If an init_image is provided, diffusion will replace the noise with the init_image as its starting state. To use an init_image, upload the image to the Colab instance or your Google Drive, and enter the full image path here. If using an init_image, you may need to increase skip_steps to ~ 50% of total steps to retain the character of the init. See skip_steps above for further discussion.
:param width_height: Desired final image size, in pixels. You can have a square, wide, or tall image, but each edge length should be set to a multiple of 64px, and a minimum of 512px on the default CLIP model setting. If you forget to use multiples of 64px in your dimensions, DD will adjust the dimensions of your image to make it so.
:param skip_steps: Consider the chart shown here. Noise scheduling (denoise strength) starts very high and progressively gets lower and lower as diffusion steps progress. The noise levels in the first few steps are very high, so images change dramatically in early steps.As DD moves along the curve, noise levels (and thus the amount an image changes per step) declines, and image coherence from one step to the next increases.The first few steps of denoising are often so dramatic that some steps (maybe 10-15% of total) can be skipped without affecting the final image. You can experiment with this as a way to cut render times.If you skip too many steps, however, the remaining noise may not be high enough to generate new content, and thus may not have ‘time left’ to finish an image satisfactorily.Also, depending on your other settings, you may need to skip steps to prevent CLIP from overshooting your goal, resulting in ‘blown out’ colors (hyper saturated, solid white, or solid black regions) or otherwise poor image quality. Consider that the denoising process is at its strongest in the early steps, so skipping steps can sometimes mitigate other problems.Lastly, if using an init_image, you will need to skip ~50% of the diffusion steps to retain the shapes in the original init image. However, if you’re using an init_image, you can also adjust skip_steps up or down for creative reasons. With low skip_steps you can get a result "inspired by" the init_image which will retain the colors and rough layout and shapes but look quite different. With high skip_steps you can preserve most of the init_image contents and just do fine tuning of the texture.
:param steps: When creating an image, the denoising curve is subdivided into steps for processing. Each step (or iteration) involves the AI looking at subsets of the image called ‘cuts’ and calculating the ‘direction’ the image should be guided to be more like the prompt. Then it adjusts the image with the help of the diffusion denoiser, and moves to the next step.Increasing steps will provide more opportunities for the AI to adjust the image, and each adjustment will be smaller, and thus will yield a more precise, detailed image. Increasing steps comes at the expense of longer render times. Also, while increasing steps should generally increase image quality, there is a diminishing return on additional steps beyond 250 - 500 steps. However, some intricate images can take 1000, 2000, or more steps. It is really up to the user. Just know that the render time is directly related to the number of steps, and many other parameters have a major impact on image quality, without costing additional time.
:param cut_ic_pow: This sets the size of the border used for inner cuts. High cut_ic_pow values have larger borders, and therefore the cuts themselves will be smaller and provide finer details. If you have too many or too-small inner cuts, you may lose overall image coherency and/or it may cause an undesirable ‘mosaic’ effect. Low cut_ic_pow values will allow the inner cuts to be larger, helping image coherency while still helping with some details.
:param init_scale: This controls how strongly CLIP will try to match the init_image provided. This is balanced against the clip_guidance_scale (CGS) above. Too much init scale, and the image won’t change much during diffusion. Too much CGS and the init image will be lost.
:param clip_guidance_scale: CGS is one of the most important parameters you will use. It tells DD how strongly you want CLIP to move toward your prompt each timestep. Higher is generally better, but if CGS is too strong it will overshoot the goal and distort the image. So a happy medium is needed, and it takes experience to learn how to adjust CGS. Note that this parameter generally scales with image dimensions. In other words, if you increase your total dimensions by 50% (e.g. a change from 512 x 512 to 512 x 768), then to maintain the same effect on the image, you’d want to increase clip_guidance_scale from 5000 to 7500. Of the basic settings, clip_guidance_scale, steps and skip_steps are the most important contributors to image quality, so learn them well.
:param tv_scale: Total variance denoising. Optional, set to zero to turn off. Controls ‘smoothness’ of final output. If used, tv_scale will try to smooth out your final image to reduce overall noise. If your image is too ‘crunchy’, increase tv_scale. TV denoising is good at preserving edges while smoothing away noise in flat regions. See https://en.wikipedia.org/wiki/Total_variation_denoising
:param range_scale: Optional, set to zero to turn off. Used for adjustment of color contrast. Lower range_scale will increase contrast. Very low numbers create a reduced color palette, resulting in more vibrant or poster-like images. Higher range_scale will reduce contrast, for more muted images.
:param sat_scale: Saturation scale. Optional, set to zero to turn off. If used, sat_scale will help mitigate oversaturation. If your image is too saturated, increase sat_scale to reduce the saturation.
:param cutn_batches: Each iteration, the AI cuts the image into smaller pieces known as cuts, and compares each cut to the prompt to decide how to guide the next diffusion step. More cuts can generally lead to better images, since DD has more chances to fine-tune the image precision in each timestep. Additional cuts are memory intensive, however, and if DD tries to evaluate too many cuts at once, it can run out of memory. You can use cutn_batches to increase cuts per timestep without increasing memory usage. At the default settings, DD is scheduled to do 16 cuts per timestep. If cutn_batches is set to 1, there will indeed only be 16 cuts total per timestep. However, if cutn_batches is increased to 4, DD will do 64 cuts total in each timestep, divided into 4 sequential batches of 16 cuts each. Because the cuts are being evaluated only 16 at a time, DD uses the memory required for only 16 cuts, but gives you the quality benefit of 64 cuts. The tradeoff, of course, is that this will take ~4 times as long to render each image.So, (scheduled cuts) x (cutn_batches) = (total cuts per timestep). Increasing cutn_batches will increase render times, however, as the work is being done sequentially. DD’s default cut schedule is a good place to start, but the cut schedule can be adjusted in the Cutn Scheduling section, explained below.
:param diffusion_model: Diffusion_model of choice.
:param use_secondary_model: Option to use a secondary purpose-made diffusion model to clean up interim diffusion images for CLIP evaluation. If this option is turned off, DD will use the regular (large) diffusion model. Using the secondary model is faster - one user reported a 50% improvement in render speed! However, the secondary model is much smaller, and may reduce image quality and detail. I suggest you experiment with this.
:param diffusion_sampling_mode: Two alternate diffusion denoising algorithms. ddim has been around longer, and is more established and tested. plms is a newly added alternate method that promises good diffusion results in fewer steps, but has not been as fully tested and may have side effects. This new plms mode is actively being researched in the #settings-and-techniques channel in the DD Discord.
:param perlin_init: Normally, DD will use an image filled with random noise as a starting point for the diffusion curve. If perlin_init is selected, DD will instead use a Perlin noise model as an initial state. Perlin has very interesting characteristics, distinct from random noise, so it’s worth experimenting with this for your projects. Beyond perlin, you can, of course, generate your own noise images (such as with GIMP, etc) and use them as an init_image (without skipping steps). Choosing perlin_init does not affect the actual diffusion process, just the starting point for the diffusion. Please note that selecting a perlin_init will replace and override any init_image you may have specified. Further, because the 2D, 3D and video animation systems all rely on the init_image system, if you enable Perlin while using animation modes, the perlin_init will jump in front of any previous image or video input, and DD will NOT give you the expected sequence of coherent images. All of that said, using Perlin and animation modes together do make a very colorful rainbow effect, which can be used creatively.
:param perlin_mode: sets type of Perlin noise: colored, gray, or a mix of both, giving you additional options for noise types. Experiment to see what these do in your projects.
:param seed: Deep in the diffusion code, there is a random number ‘seed’ which is used as the basis for determining the initial state of the diffusion. By default, this is random, but you can also specify your own seed. This is useful if you like a particular result and would like to run more iterations that will be similar. After each run, the actual seed value used will be reported in the parameters report, and can be reused if desired by entering seed # here. If a specific numerical seed is used repeatedly, the resulting images will be quite similar but not identical.
:param eta: eta (greek letter η) is a diffusion model variable that mixes in a random amount of scaled noise into each timestep. 0 is no noise, 1.0 is more noise. As with most DD parameters, you can go below zero for eta, but it may give you unpredictable results. The steps parameter has a close relationship with the eta parameter. If you set eta to 0, then you can get decent output with only 50-75 steps. Setting eta to 1.0 favors higher step counts, ideally around 250 and up. eta has a subtle, unpredictable effect on image, so you’ll need to experiment to see how this affects your projects.
:param clamp_grad: As I understand it, clamp_grad is an internal limiter that stops DD from producing extreme results. Try your images with and without clamp_grad. If the image changes drastically with clamp_grad turned off, it probably means your clip_guidance_scale is too high and should be reduced.
:param clamp_max: Sets the value of the clamp_grad limitation. Default is 0.05, providing for smoother, more muted coloration in images, but setting higher values (0.15-0.3) can provide interesting contrast and vibrancy.
:param fuzzy_prompt: Controls whether to add multiple noisy prompts to the prompt losses. If True, can increase variability of image output. Experiment with this.
:param rand_mag: Affects only the fuzzy_prompt. Controls the magnitude of the random noise added by fuzzy_prompt.
:param cut_overview: The schedule of overview cuts
:param cut_innercut: The schedule of inner cuts
:param cut_icgray_p: This sets the size of the border used for inner cuts. High cut_ic_pow values have larger borders, and therefore the cuts themselves will be smaller and provide finer details. If you have too many or too-small inner cuts, you may lose overall image coherency and/or it may cause an undesirable ‘mosaic’ effect. Low cut_ic_pow values will allow the inner cuts to be larger, helping image coherency while still helping with some details.
:param display_rate: During a diffusion run, you can monitor the progress of each image being created with this variable. If display_rate is set to 50, DD will show you the in-progress image every 50 timesteps. Setting this to a lower value, like 5 or 10, is a good way to get an early peek at where your image is heading. If you don’t like the progression, just interrupt execution, change some settings, and re-run. If you are planning a long, unmonitored batch, it’s better to set display_rate equal to steps, because displaying interim images does slow Colab down slightly.
:param n_batches: This variable sets the number of still images you want DD to create. If you are using an animation mode (see below for details) DD will ignore n_batches and create a single set of animated frames based on the animation settings.
:param batch_name: The name of the batch, the batch id will be named as "discoart-[batch_name]-seed". To avoid your artworks be overridden by other users, please use a unique name.
:param clip_models: CLIP Model selectors. ViTB32, ViTB16, ViTL14, RN101, RN50, RN50x4, RN50x16, RN50x64.These various CLIP models are available for you to use during image generation. Models have different styles or ‘flavors,’ so look around. You can mix in multiple models as well for different results. However, keep in mind that some models are extremely memory-hungry, and turning on additional models will take additional memory and may cause a crash.The rough order of speed/mem usage is (smallest/fastest to largest/slowest):VitB32RN50RN101VitB16RN50x4RN50x16RN50x64ViTL14For RN50x64 & ViTL14 you may need to use fewer cuts, depending on your VRAM.
:return: a DocumentArray object that has `n_batches` Documents
"""
# end_create_overload
@overload
def create(init_document: 'Document') -> 'DocumentArray':
"""
Create an artwork using a DocArray ``Document`` object as initial state.
:param init_document: its ``.tags`` will be used as parameters, ``.uri`` (if present) will be used as init image.
:return: a DocumentArray object that has `n_batches` Documents
"""
def create(**kwargs) -> 'DocumentArray':
from .config import load_config
from .runner import do_run
if 'init_document' in kwargs:
d = kwargs['init_document']
_kwargs = d.tags
if not _kwargs:
warnings.warn('init_document has no .tags, fallback to default config')
if d.uri:
_kwargs['init_image'] = kwargs['init_document'].uri
else:
warnings.warn('init_document has no .uri, fallback to no init image')
kwargs.pop('init_document')
if kwargs:
warnings.warn('init_document has .tags and .uri, but kwargs are also present, will override .tags')
_kwargs.update(kwargs)
_args = load_config(user_config=_kwargs)
else:
_args = load_config(user_config=kwargs)
model, diffusion = load_diffusion_model(model_config, _args.diffusion_model, steps=_args.steps)
clip_models = load_clip_models(enabled=_args.clip_models, clip_models=_clip_models_cache)
gc.collect()
paddle.device.cuda.empty_cache()
try:
return do_run(_args, (model, diffusion, clip_models, secondary_model))
except KeyboardInterrupt:
pass
'''
https://github.com/jina-ai/discoart/blob/main/discoart/config.py
'''
import copy
import random
import warnings
from types import SimpleNamespace
from typing import Dict
import yaml
from yaml import Loader
from . import __resources_path__
with open(f'{__resources_path__}/default.yml') as ymlfile:
default_args = yaml.load(ymlfile, Loader=Loader)
def load_config(user_config: Dict, ):
cfg = copy.deepcopy(default_args)
if user_config:
cfg.update(**user_config)
for k in user_config.keys():
if k not in cfg:
warnings.warn(f'unknown argument {k}, ignored')
for k, v in cfg.items():
if k in ('batch_size', 'display_rate', 'seed', 'skip_steps', 'steps', 'n_batches',
'cutn_batches') and isinstance(v, float):
cfg[k] = int(v)
if k == 'width_height':
cfg[k] = [int(vv) for vv in v]
cfg.update(**{
'seed': cfg['seed'] or random.randint(0, 2**32),
})
if cfg['batch_name']:
da_name = f'{__package__}-{cfg["batch_name"]}-{cfg["seed"]}'
else:
da_name = f'{__package__}-{cfg["seed"]}'
warnings.warn('you did not set `batch_name`, set it to have unique session ID')
cfg.update(**{'name_docarray': da_name})
print_args_table(cfg)
return SimpleNamespace(**cfg)
def print_args_table(cfg):
from rich.table import Table
from rich import box
from rich.console import Console
console = Console()
param_str = Table(
title=cfg['name_docarray'],
box=box.ROUNDED,
highlight=True,
title_justify='left',
)
param_str.add_column('Argument', justify='right')
param_str.add_column('Value', justify='left')
for k, v in sorted(cfg.items()):
value = str(v)
if not default_args.get(k, None) == v:
value = f'[b]{value}[/]'
param_str.add_row(k, value)
console.print(param_str)
'''
This code is rewritten by Paddle based on Jina-ai/discoart.
https://github.com/jina-ai/discoart/blob/main/discoart/helper.py
'''
import hashlib
import logging
import os
import subprocess
import sys
from os.path import expanduser
from pathlib import Path
from typing import Any
from typing import Dict
from typing import List
import paddle
def _get_logger():
logger = logging.getLogger(__package__)
logger.setLevel("INFO")
ch = logging.StreamHandler()
ch.setLevel("INFO")
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
logger.addHandler(ch)
return logger
logger = _get_logger()
def load_clip_models(enabled: List[str], clip_models: Dict[str, Any] = {}):
import disco_diffusion_clip_vitb32.clip.clip as clip
from disco_diffusion_clip_vitb32.clip.clip import build_model, tokenize, transform
# load enabled models
for k in enabled:
if k not in clip_models:
clip_models[k] = build_model(name=k)
clip_models[k].eval()
for parameter in clip_models[k].parameters():
parameter.stop_gradient = True
# disable not enabled models to save memory
for k in clip_models:
if k not in enabled:
clip_models.pop(k)
return list(clip_models.values())
def load_all_models(diffusion_model, use_secondary_model):
from .model.script_util import (
model_and_diffusion_defaults, )
model_config = model_and_diffusion_defaults()
if diffusion_model == '512x512_diffusion_uncond_finetune_008100':
model_config.update({
'attention_resolutions': '32, 16, 8',
'class_cond': False,
'diffusion_steps': 1000, # No need to edit this, it is taken care of later.
'rescale_timesteps': True,
'timestep_respacing': 250, # No need to edit this, it is taken care of later.
'image_size': 512,
'learn_sigma': True,
'noise_schedule': 'linear',
'num_channels': 256,
'num_head_channels': 64,
'num_res_blocks': 2,
'resblock_updown': True,
'use_fp16': False,
'use_scale_shift_norm': True,
})
elif diffusion_model == '256x256_diffusion_uncond':
model_config.update({
'attention_resolutions': '32, 16, 8',
'class_cond': False,
'diffusion_steps': 1000, # No need to edit this, it is taken care of later.
'rescale_timesteps': True,
'timestep_respacing': 250, # No need to edit this, it is taken care of later.
'image_size': 256,
'learn_sigma': True,
'noise_schedule': 'linear',
'num_channels': 256,
'num_head_channels': 64,
'num_res_blocks': 2,
'resblock_updown': True,
'use_fp16': False,
'use_scale_shift_norm': True,
})
secondary_model = None
if use_secondary_model:
from .model.sec_diff import SecondaryDiffusionImageNet2
secondary_model = SecondaryDiffusionImageNet2()
model_dict = paddle.load(
os.path.join(os.path.dirname(__file__), 'pre_trained', 'secondary_model_imagenet_2.pdparams'))
secondary_model.set_state_dict(model_dict)
secondary_model.eval()
for parameter in secondary_model.parameters():
parameter.stop_gradient = True
return model_config, secondary_model
def load_diffusion_model(model_config, diffusion_model, steps):
from .model.script_util import (
create_model_and_diffusion, )
timestep_respacing = f'ddim{steps}'
diffusion_steps = (1000 // steps) * steps if steps < 1000 else steps
model_config.update({
'timestep_respacing': timestep_respacing,
'diffusion_steps': diffusion_steps,
})
model, diffusion = create_model_and_diffusion(**model_config)
model.set_state_dict(
paddle.load(os.path.join(os.path.dirname(__file__), 'pre_trained', f'{diffusion_model}.pdparams')))
model.eval()
for name, param in model.named_parameters():
param.stop_gradient = True
return model, diffusion
def parse_prompt(prompt):
if prompt.startswith('http://') or prompt.startswith('https://'):
vals = prompt.rsplit(':', 2)
vals = [vals[0] + ':' + vals[1], *vals[2:]]
else:
vals = prompt.rsplit(':', 1)
vals = vals + ['', '1'][len(vals):]
return vals[0], float(vals[1])
"""
Codebase for "Improved Denoising Diffusion Probabilistic Models" implemented by Paddle.
"""
"""
Helpers for various likelihood-based losses implemented by Paddle. These are ported from the original
Ho et al. diffusion models codebase:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
"""
import numpy as np
import paddle
import paddle.nn.functional as F
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, paddle.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for th.exp().
logvar1, logvar2 = [x if isinstance(x, paddle.Tensor) else paddle.to_tensor(x) for x in (logvar1, logvar2)]
return 0.5 * (-1.0 + logvar2 - logvar1 + paddle.exp(logvar1 - logvar2) +
((mean1 - mean2)**2) * paddle.exp(-logvar2))
def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
"""
return 0.5 * (1.0 + paddle.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * paddle.pow(x, 3))))
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
inv_stdv = paddle.exp(-log_scales)
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
cdf_plus = approx_standard_normal_cdf(plus_in)
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
cdf_min = approx_standard_normal_cdf(min_in)
log_cdf_plus = paddle.log(cdf_plus.clip(min=1e-12))
log_one_minus_cdf_min = paddle.log((1.0 - cdf_min).clip(min=1e-12))
cdf_delta = cdf_plus - cdf_min
log_probs = paddle.where(
x < -0.999,
log_cdf_plus,
paddle.where(x > 0.999, log_one_minus_cdf_min, paddle.log(cdf_delta.clip(min=1e-12))),
)
assert log_probs.shape == x.shape
return log_probs
def spherical_dist_loss(x, y):
x = F.normalize(x, axis=-1)
y = F.normalize(y, axis=-1)
return (x - y).norm(axis=-1).divide(paddle.to_tensor(2.0)).asin().pow(2).multiply(paddle.to_tensor(2.0))
def tv_loss(input):
"""L2 total variation loss, as in Mahendran et al."""
input = F.pad(input, (0, 1, 0, 1), 'replicate')
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
return (x_diff**2 + y_diff**2).mean([1, 2, 3])
def range_loss(input):
return (input - input.clip(-1, 1)).pow(2).mean([1, 2, 3])
'''
This code is rewritten by Paddle based on Jina-ai/discoart.
https://github.com/jina-ai/discoart/blob/main/discoart/nn/make_cutouts.py
'''
import math
import paddle
import paddle.nn as nn
from disco_diffusion_clip_vitb32.resize_right.resize_right import resize
from paddle.nn import functional as F
from . import transforms as T
skip_augs = False # @param{type: 'boolean'}
def sinc(x):
return paddle.where(x != 0, paddle.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
def lanczos(x, a):
cond = paddle.logical_and(-a < x, x < a)
out = paddle.where(cond, sinc(x) * sinc(x / a), x.new_zeros([]))
return out / out.sum()
def ramp(ratio, width):
n = math.ceil(width / ratio + 1)
out = paddle.empty([n])
cur = 0
for i in range(out.shape[0]):
out[i] = cur
cur += ratio
return paddle.concat([-out[1:].flip([0]), out])[1:-1]
class MakeCutouts(nn.Layer):
def __init__(self, cut_size, cutn, skip_augs=False):
super().__init__()
self.cut_size = cut_size
self.cutn = cutn
self.skip_augs = skip_augs
self.augs = nn.Sequential(*[
T.RandomHorizontalFlip(prob=0.5),
T.Lambda(lambda x: x + paddle.randn(x.shape) * 0.01),
T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
T.Lambda(lambda x: x + paddle.randn(x.shape) * 0.01),
T.RandomPerspective(distortion_scale=0.4, p=0.7),
T.Lambda(lambda x: x + paddle.randn(x.shape) * 0.01),
T.RandomGrayscale(p=0.15),
T.Lambda(lambda x: x + paddle.randn(x.shape) * 0.01),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
])
def forward(self, input):
input = T.Pad(input.shape[2] // 4, fill=0)(input)
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
cutouts = []
for ch in range(self.cutn):
if ch > self.cutn - self.cutn // 4:
cutout = input.clone()
else:
size = int(max_size *
paddle.zeros(1, ).normal_(mean=0.8, std=0.3).clip(float(self.cut_size / max_size), 1.0))
offsetx = paddle.randint(0, abs(sideX - size + 1), ())
offsety = paddle.randint(0, abs(sideY - size + 1), ())
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
if not self.skip_augs:
cutout = self.augs(cutout)
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
del cutout
cutouts = paddle.concat(cutouts, axis=0)
return cutouts
class MakeCutoutsDango(nn.Layer):
def __init__(self, cut_size, Overview=4, InnerCrop=0, IC_Size_Pow=0.5, IC_Grey_P=0.2):
super().__init__()
self.cut_size = cut_size
self.Overview = Overview
self.InnerCrop = InnerCrop
self.IC_Size_Pow = IC_Size_Pow
self.IC_Grey_P = IC_Grey_P
self.augs = nn.Sequential(*[
T.RandomHorizontalFlip(prob=0.5),
T.Lambda(lambda x: x + paddle.randn(x.shape) * 0.01),
T.RandomAffine(
degrees=10,
translate=(0.05, 0.05),
interpolation=T.InterpolationMode.BILINEAR,
),
T.Lambda(lambda x: x + paddle.randn(x.shape) * 0.01),
T.RandomGrayscale(p=0.1),
T.Lambda(lambda x: x + paddle.randn(x.shape) * 0.01),
T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
])
def forward(self, input):
cutouts = []
gray = T.Grayscale(3)
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
output_shape = [1, 3, self.cut_size, self.cut_size]
pad_input = F.pad(
input,
(
(sideY - max_size) // 2,
(sideY - max_size) // 2,
(sideX - max_size) // 2,
(sideX - max_size) // 2,
),
**padargs,
)
cutout = resize(pad_input, out_shape=output_shape)
if self.Overview > 0:
if self.Overview <= 4:
if self.Overview >= 1:
cutouts.append(cutout)
if self.Overview >= 2:
cutouts.append(gray(cutout))
if self.Overview >= 3:
cutouts.append(cutout[:, :, :, ::-1])
if self.Overview == 4:
cutouts.append(gray(cutout[:, :, :, ::-1]))
else:
cutout = resize(pad_input, out_shape=output_shape)
for _ in range(self.Overview):
cutouts.append(cutout)
if self.InnerCrop > 0:
for i in range(self.InnerCrop):
size = int(paddle.rand([1])**self.IC_Size_Pow * (max_size - min_size) + min_size)
offsetx = paddle.randint(0, sideX - size + 1)
offsety = paddle.randint(0, sideY - size + 1)
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
if i <= int(self.IC_Grey_P * self.InnerCrop):
cutout = gray(cutout)
cutout = resize(cutout, out_shape=output_shape)
cutouts.append(cutout)
cutouts = paddle.concat(cutouts)
if skip_augs is not True:
cutouts = self.augs(cutouts)
return cutouts
def resample(input, size, align_corners=True):
n, c, h, w = input.shape
dh, dw = size
input = input.reshape([n * c, 1, h, w])
if dh < h:
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
pad_h = (kernel_h.shape[0] - 1) // 2
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
input = F.conv2d(input, kernel_h[None, None, :, None])
if dw < w:
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
pad_w = (kernel_w.shape[0] - 1) // 2
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
input = F.conv2d(input, kernel_w[None, None, None, :])
input = input.reshape([n, c, h, w])
return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
padargs = {}
"""
Various utilities for neural networks implemented by Paddle. This code is rewritten based on:
https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py
"""
import math
import paddle
import paddle.nn as nn
class SiLU(nn.Layer):
def forward(self, x):
return x * nn.functional.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1D(*args, **kwargs)
elif dims == 2:
return nn.Conv2D(*args, **kwargs)
elif dims == 3:
return nn.Conv3D(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1D(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2D(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3D(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(axis=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = paddle.exp(-math.log(max_period) * paddle.arange(start=0, end=half, dtype=paddle.float32) / half)
args = paddle.cast(timesteps[:, None], 'float32') * freqs[None]
embedding = paddle.concat([paddle.cos(args), paddle.sin(args)], axis=-1)
if dim % 2:
embedding = paddle.concat([embedding, paddle.zeros_like(embedding[:, :1])], axis=-1)
return embedding
def checkpoint(func, inputs, params, flag):
"""
This function is disabled. And now just forward.
"""
return func(*inputs)
'''
Perlin noise implementation by Paddle.
This code is rewritten based on:
https://github.com/jina-ai/discoart/blob/main/discoart/nn/perlin_noises.py
'''
import numpy as np
import paddle
import paddle.vision.transforms as TF
from PIL import Image
from PIL import ImageOps
def interp(t):
return 3 * t**2 - 2 * t**3
def perlin(width, height, scale=10):
gx, gy = paddle.randn([2, width + 1, height + 1, 1, 1])
xs = paddle.linspace(0, 1, scale + 1)[:-1, None]
ys = paddle.linspace(0, 1, scale + 1)[None, :-1]
wx = 1 - interp(xs)
wy = 1 - interp(ys)
dots = 0
dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)
dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)
dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))
dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))
return dots.transpose([0, 2, 1, 3]).reshape([width * scale, height * scale])
def perlin_ms(octaves, width, height, grayscale):
out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]
# out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]
for i in range(1 if grayscale else 3):
scale = 2**len(octaves)
oct_width = width
oct_height = height
for oct in octaves:
p = perlin(oct_width, oct_height, scale)
out_array[i] += p * oct
scale //= 2
oct_width *= 2
oct_height *= 2
return paddle.concat(out_array)
def create_perlin_noise(octaves, width, height, grayscale, side_y, side_x):
out = perlin_ms(octaves, width, height, grayscale)
if grayscale:
out = TF.resize(size=(side_y, side_x), img=out.numpy())
out = np.uint8(out)
out = Image.fromarray(out).convert('RGB')
else:
out = out.reshape([-1, 3, out.shape[0] // 3, out.shape[1]])
out = out.squeeze().transpose([1, 2, 0]).numpy()
out = TF.resize(size=(side_y, side_x), img=out)
out = out.clip(0, 1) * 255
out = np.uint8(out)
out = Image.fromarray(out)
out = ImageOps.autocontrast(out)
return out
def regen_perlin(perlin_mode, side_y, side_x, batch_size):
if perlin_mode == 'color':
init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], 1, 1, False, side_y, side_x)
init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4, False, side_y, side_x)
elif perlin_mode == 'gray':
init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], 1, 1, True, side_y, side_x)
init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4, True, side_y, side_x)
else:
init = create_perlin_noise([1.5**-i * 0.5 for i in range(12)], 1, 1, False, side_y, side_x)
init2 = create_perlin_noise([1.5**-i * 0.5 for i in range(8)], 4, 4, True, side_y, side_x)
init = (TF.to_tensor(init).add(TF.to_tensor(init2)).divide(paddle.to_tensor(2.0)).unsqueeze(0) * 2 - 1)
del init2
return init.expand([batch_size, -1, -1, -1])
'''
This code is rewritten by Paddle based on
https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py
'''
import numpy as np
import paddle
from .gaussian_diffusion import GaussianDiffusion
def space_timesteps(num_timesteps, section_counts):
"""
Create a list of timesteps to use from an original diffusion process,
given the number of timesteps we want to take from equally-sized portions
of the original process.
For example, if there's 300 timesteps and the section counts are [10,15,20]
then the first 100 timesteps are strided to be 10 timesteps, the second 100
are strided to be 15 timesteps, and the final 100 are strided to be 20.
If the stride is a string starting with "ddim", then the fixed striding
from the DDIM paper is used, and only one section is allowed.
:param num_timesteps: the number of diffusion steps in the original
process to divide up.
:param section_counts: either a list of numbers, or a string containing
comma-separated numbers, indicating the step count
per section. As a special case, use "ddimN" where N
is a number of steps to use the striding from the
DDIM paper.
:return: a set of diffusion steps from the original process to use.
"""
if isinstance(section_counts, str):
if section_counts.startswith("ddim"):
desired_count = int(section_counts[len("ddim"):])
for i in range(1, num_timesteps):
if len(range(0, num_timesteps, i)) == desired_count:
return set(range(0, num_timesteps, i))
raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
section_counts = [int(x) for x in section_counts.split(",")]
size_per = num_timesteps // len(section_counts)
extra = num_timesteps % len(section_counts)
start_idx = 0
all_steps = []
for i, section_count in enumerate(section_counts):
size = size_per + (1 if i < extra else 0)
if size < section_count:
raise ValueError(f"cannot divide section of {size} steps into {section_count}")
if section_count <= 1:
frac_stride = 1
else:
frac_stride = (size - 1) / (section_count - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(section_count):
taken_steps.append(start_idx + round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
start_idx += size
return set(all_steps)
class SpacedDiffusion(GaussianDiffusion):
"""
A diffusion process which can skip steps in a base diffusion process.
:param use_timesteps: a collection (sequence or set) of timesteps from the
original diffusion process to retain.
:param kwargs: the kwargs to create the base diffusion process.
"""
def __init__(self, use_timesteps, **kwargs):
self.use_timesteps = set(use_timesteps)
self.timestep_map = []
self.original_num_steps = len(kwargs["betas"])
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
last_alpha_cumprod = 1.0
new_betas = []
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
if i in self.use_timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
kwargs["betas"] = np.array(new_betas)
super().__init__(**kwargs)
def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
return super().training_losses(self._wrap_model(model), *args, **kwargs)
def condition_mean(self, cond_fn, *args, **kwargs):
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
def condition_score(self, cond_fn, *args, **kwargs):
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
def _wrap_model(self, model):
if isinstance(model, _WrappedModel):
return model
return _WrappedModel(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps)
def _scale_timesteps(self, t):
# Scaling is done by the wrapped model.
return t
class _WrappedModel:
def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
self.model = model
self.timestep_map = timestep_map
self.rescale_timesteps = rescale_timesteps
self.original_num_steps = original_num_steps
def __call__(self, x, ts, **kwargs):
map_tensor = paddle.to_tensor(self.timestep_map, place=ts.place, dtype=ts.dtype)
new_ts = map_tensor[ts]
if self.rescale_timesteps:
new_ts = paddle.cast(new_ts, 'float32') * (1000.0 / self.original_num_steps)
return self.model(x, new_ts, **kwargs)
'''
This code is based on
https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/script_util.py
'''
import argparse
import inspect
from . import gaussian_diffusion as gd
from .respace import space_timesteps
from .respace import SpacedDiffusion
from .unet import EncoderUNetModel
from .unet import SuperResModel
from .unet import UNetModel
NUM_CLASSES = 1000
def diffusion_defaults():
"""
Defaults for image and classifier training.
"""
return dict(
learn_sigma=False,
diffusion_steps=1000,
noise_schedule="linear",
timestep_respacing="",
use_kl=False,
predict_xstart=False,
rescale_timesteps=False,
rescale_learned_sigmas=False,
)
def model_and_diffusion_defaults():
"""
Defaults for image training.
"""
res = dict(
image_size=64,
num_channels=128,
num_res_blocks=2,
num_heads=4,
num_heads_upsample=-1,
num_head_channels=-1,
attention_resolutions="16,8",
channel_mult="",
dropout=0.0,
class_cond=False,
use_checkpoint=False,
use_scale_shift_norm=True,
resblock_updown=False,
use_fp16=False,
use_new_attention_order=False,
)
res.update(diffusion_defaults())
return res
def create_model_and_diffusion(
image_size,
class_cond,
learn_sigma,
num_channels,
num_res_blocks,
channel_mult,
num_heads,
num_head_channels,
num_heads_upsample,
attention_resolutions,
dropout,
diffusion_steps,
noise_schedule,
timestep_respacing,
use_kl,
predict_xstart,
rescale_timesteps,
rescale_learned_sigmas,
use_checkpoint,
use_scale_shift_norm,
resblock_updown,
use_fp16,
use_new_attention_order,
):
model = create_model(
image_size,
num_channels,
num_res_blocks,
channel_mult=channel_mult,
learn_sigma=learn_sigma,
class_cond=class_cond,
use_checkpoint=use_checkpoint,
attention_resolutions=attention_resolutions,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
dropout=dropout,
resblock_updown=resblock_updown,
use_fp16=use_fp16,
use_new_attention_order=use_new_attention_order,
)
diffusion = create_gaussian_diffusion(
steps=diffusion_steps,
learn_sigma=learn_sigma,
noise_schedule=noise_schedule,
use_kl=use_kl,
predict_xstart=predict_xstart,
rescale_timesteps=rescale_timesteps,
rescale_learned_sigmas=rescale_learned_sigmas,
timestep_respacing=timestep_respacing,
)
return model, diffusion
def create_model(
image_size,
num_channels,
num_res_blocks,
channel_mult="",
learn_sigma=False,
class_cond=False,
use_checkpoint=False,
attention_resolutions="16",
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
dropout=0,
resblock_updown=False,
use_fp16=False,
use_new_attention_order=False,
):
if channel_mult == "":
if image_size == 512:
channel_mult = (0.5, 1, 1, 2, 2, 4, 4)
elif image_size == 256:
channel_mult = (1, 1, 2, 2, 4, 4)
elif image_size == 128:
channel_mult = (1, 1, 2, 3, 4)
elif image_size == 64:
channel_mult = (1, 2, 3, 4)
else:
raise ValueError(f"unsupported image size: {image_size}")
else:
channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(","))
attention_ds = []
for res in attention_resolutions.split(","):
attention_ds.append(image_size // int(res))
return UNetModel(
image_size=image_size,
in_channels=3,
model_channels=num_channels,
out_channels=(3 if not learn_sigma else 6),
num_res_blocks=num_res_blocks,
attention_resolutions=tuple(attention_ds),
dropout=dropout,
channel_mult=channel_mult,
num_classes=(NUM_CLASSES if class_cond else None),
use_checkpoint=use_checkpoint,
use_fp16=use_fp16,
num_heads=num_heads,
num_head_channels=num_head_channels,
num_heads_upsample=num_heads_upsample,
use_scale_shift_norm=use_scale_shift_norm,
resblock_updown=resblock_updown,
use_new_attention_order=use_new_attention_order,
)
def create_gaussian_diffusion(
*,
steps=1000,
learn_sigma=False,
sigma_small=False,
noise_schedule="linear",
use_kl=False,
predict_xstart=False,
rescale_timesteps=False,
rescale_learned_sigmas=False,
timestep_respacing="",
):
betas = gd.get_named_beta_schedule(noise_schedule, steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
elif rescale_learned_sigmas:
loss_type = gd.LossType.RESCALED_MSE
else:
loss_type = gd.LossType.MSE
if not timestep_respacing:
timestep_respacing = [steps]
return SpacedDiffusion(
use_timesteps=space_timesteps(steps, timestep_respacing),
betas=betas,
model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X),
model_var_type=((gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
if not learn_sigma else gd.ModelVarType.LEARNED_RANGE),
loss_type=loss_type,
rescale_timesteps=rescale_timesteps,
)
'''
This code is rewritten by Paddle based on
https://github.com/jina-ai/discoart/blob/main/discoart/nn/sec_diff.py
'''
import math
from dataclasses import dataclass
from functools import partial
import paddle
import paddle.nn as nn
@dataclass
class DiffusionOutput:
v: paddle.Tensor
pred: paddle.Tensor
eps: paddle.Tensor
class SkipBlock(nn.Layer):
def __init__(self, main, skip=None):
super().__init__()
self.main = nn.Sequential(*main)
self.skip = skip if skip else nn.Identity()
def forward(self, input):
return paddle.concat([self.main(input), self.skip(input)], axis=1)
def append_dims(x, n):
return x[(Ellipsis, *(None, ) * (n - x.ndim))]
def expand_to_planes(x, shape):
return paddle.tile(append_dims(x, len(shape)), [1, 1, *shape[2:]])
def alpha_sigma_to_t(alpha, sigma):
return paddle.atan2(sigma, alpha) * 2 / math.pi
def t_to_alpha_sigma(t):
return paddle.cos(t * math.pi / 2), paddle.sin(t * math.pi / 2)
class SecondaryDiffusionImageNet2(nn.Layer):
def __init__(self):
super().__init__()
c = 64 # The base channel count
cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]
self.timestep_embed = FourierFeatures(1, 16)
self.down = nn.AvgPool2D(2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.net = nn.Sequential(
ConvBlock(3 + 16, cs[0]),
ConvBlock(cs[0], cs[0]),
SkipBlock([
self.down,
ConvBlock(cs[0], cs[1]),
ConvBlock(cs[1], cs[1]),
SkipBlock([
self.down,
ConvBlock(cs[1], cs[2]),
ConvBlock(cs[2], cs[2]),
SkipBlock([
self.down,
ConvBlock(cs[2], cs[3]),
ConvBlock(cs[3], cs[3]),
SkipBlock([
self.down,
ConvBlock(cs[3], cs[4]),
ConvBlock(cs[4], cs[4]),
SkipBlock([
self.down,
ConvBlock(cs[4], cs[5]),
ConvBlock(cs[5], cs[5]),
ConvBlock(cs[5], cs[5]),
ConvBlock(cs[5], cs[4]),
self.up,
]),
ConvBlock(cs[4] * 2, cs[4]),
ConvBlock(cs[4], cs[3]),
self.up,
]),
ConvBlock(cs[3] * 2, cs[3]),
ConvBlock(cs[3], cs[2]),
self.up,
]),
ConvBlock(cs[2] * 2, cs[2]),
ConvBlock(cs[2], cs[1]),
self.up,
]),
ConvBlock(cs[1] * 2, cs[1]),
ConvBlock(cs[1], cs[0]),
self.up,
]),
ConvBlock(cs[0] * 2, cs[0]),
nn.Conv2D(cs[0], 3, 3, padding=1),
)
def forward(self, input, t):
timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)
v = self.net(paddle.concat([input, timestep_embed], axis=1))
alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))
pred = input * alphas - v * sigmas
eps = input * sigmas + v * alphas
return DiffusionOutput(v, pred, eps)
class FourierFeatures(nn.Layer):
def __init__(self, in_features, out_features, std=1.0):
super().__init__()
assert out_features % 2 == 0
# self.weight = nn.Parameter(paddle.randn([out_features // 2, in_features]) * std)
self.weight = paddle.create_parameter([out_features // 2, in_features],
dtype='float32',
default_initializer=nn.initializer.Normal(mean=0.0, std=std))
def forward(self, input):
f = 2 * math.pi * input @ self.weight.T
return paddle.concat([f.cos(), f.sin()], axis=-1)
class ConvBlock(nn.Sequential):
def __init__(self, c_in, c_out):
super().__init__(
nn.Conv2D(c_in, c_out, 3, padding=1),
nn.ReLU(),
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册