未验证 提交 818fe6b7 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

Add U2Net_Portrait Module

上级 b58fa8dc
## 概述
* ![](http://latex.codecogs.com/svg.latex?U^2Net) 的网络结构如下图,其类似于编码-解码(Encoder-Decoder)结构的 U-Net
* 每个 stage 由新提出的 RSU模块(residual U-block) 组成. 例如,En_1 即为基于 RSU 构建的
* ![](https://latex.codecogs.com/svg.latex?U^2Net_{Portrait}) 是基于![](http://latex.codecogs.com/svg.latex?U^2Net) 的人脸画像生成模型
![](https://ai-studio-static-online.cdn.bcebos.com/999d37b4ffdd49dc9e3315b7cec7b2c6918fdd57c8594ced9dded758a497913d)
## 效果展示
![](https://ai-studio-static-online.cdn.bcebos.com/07f73466f3294373965e06c141c4781992f447104a94471dadfabc1c3d920861)
![](https://ai-studio-static-online.cdn.bcebos.com/c6ab02cf27414a5ba5921d9e6b079b487f6cda6026dc4d6dbca8f0167ad7cae3)
## API
```python
def Portrait_GEN(
images=None,
paths=None,
scale=1,
batch_size=1,
output_dir='output',
face_detection=True,
visualization=False):
```
人脸画像生成 API
**参数**
* images (list[np.ndarray]) : 输入图像数据列表(BGR)
* paths (list[str]) : 输入图像路径列表
* scale (float) : 缩放因子(与face_detection相关联)
* batch_size (int) : 数据批大小
* output_dir (str) : 可视化图像输出目录
* face_detection (bool) : 是否开启人脸检测,开启后会检测人脸并使用人脸中心点进行图像缩放裁切
* visualization (bool) : 是否可视化
**返回**
* results (list[np.ndarray]): 输出图像数据列表
**代码示例**
```python
import cv2
import paddlehub as hub
model = hub.Module(name='U2Net_Portrait')
result = model.Portrait_GEN(
images=[cv2.imread('/PATH/TO/IMAGE')],
paths=None,
scale=1,
batch_size=1,
output_dir='output',
face_detection=True,
visualization=True)
```
## 查看代码
https://github.com/NathanUA/U-2-Net
## 依赖
paddlepaddle >= 2.0.0rc0
paddlehub >= 2.0.0b1
import os
import paddle
import paddle.nn as nn
import numpy as np
from U2Net_Portrait.u2net import U2NET
from U2Net_Portrait.processor import Processor
from paddlehub.module.module import moduleinfo
@moduleinfo(
name="U2Net_Portrait", # 模型名称
type="CV", # 模型类型
author="jm12138", # 作者名称
author_email="jm12138@qq.com", # 作者邮箱
summary="U2Net_Portrait", # 模型介绍
version="1.0.0" # 版本号
)
class U2Net_Portrait(nn.Layer):
def __init__(self):
super(U2Net_Portrait, self).__init__()
self.model = U2NET(3,1)
state_dict = paddle.load(os.path.join(self.directory, 'u2net_portrait.pdparams'))
self.model.set_dict(state_dict)
self.model.eval()
def predict(self, input_datas):
outputs = []
for data in input_datas:
data = paddle.to_tensor(data, 'float32')
d1,d2,d3,d4,d5,d6,d7= self.model(data)
outputs.append(d1.numpy())
outputs = np.concatenate(outputs, 0)
return outputs
def Portrait_GEN(
self,
images=None,
paths=None,
scale=1,
batch_size=1,
output_dir='output',
face_detection=True,
visualization=False):
# 初始化数据处理器
processor = Processor(paths, images, batch_size, face_detection, scale)
# 模型预测
outputs = self.predict(processor.input_datas)
# 预测结果后处理
results = processor.postprocess(outputs, visualization=visualization, output_dir=output_dir)
return results
import os
import cv2
import numpy as np
import paddlehub as hub
__all__ = ['Processor']
class Processor():
def __init__(self, paths, images, batch_size, face_detection=True, scale=1):
# 图像列表
self.imgs = self.load_datas(paths, images)
# 输入数据
self.input_datas = self.preprocess(self.imgs, batch_size, face_detection, scale)
# 读取数据函数
def load_datas(self, paths, images):
datas = []
# 读取数据列表
if paths is not None:
for im_path in paths:
assert os.path.isfile(im_path), "The {} isn't a valid file path.".format(im_path)
im = cv2.imread(im_path)
datas.append(im)
if images is not None:
datas = images
# 返回数据列表
return datas
# 预处理
def preprocess(self, imgs, batch_size=1, face_detection=True, scale=1):
if face_detection:
# face detection
face_detector = hub.Module(name="pyramidbox_lite_mobile")
results = face_detector.face_detection(images=imgs,
use_gpu=False,
visualization=False,
confs_threshold=0.5)
im_faces = []
for datas, img in zip(results, imgs):
for face in datas['data']:
# get detection result
l, r, t, b = [face['left'], face['right'], face['top'], face['bottom']]
# square crop
pad = max(int(scale*(r-l)), int(scale*(b-t)))
c_w, c_h = (r-l)//2+l, (b-t)//2+t
top = 0 if c_h-pad<0 else c_h-pad
bottom = pad + c_h
left = 0 if c_w-pad<0 else c_w-pad
right = pad + c_w
crop = img[top:bottom, left:right]
# resize
im_face = cv2.resize(crop, (512,512), interpolation = cv2.INTER_AREA)
im_faces.append(im_face)
else:
im_faces = []
for img in imgs:
h, w = img.shape[:2]
if h>w:
if (h-w)%2==0:
img = np.pad(img,((0,0),((h-w)//2,(h-w)//2),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255)))
else:
img = np.pad(img,((0,0),((h-w)//2,(h-w)//2+1),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255)))
else:
if (w-h)%2==0:
img = np.pad(img,(((w-h)//2,(w-h)//2),(0,0),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255)))
else:
img = np.pad(img,(((w-h)//2,(w-h)//2+1),(0,0),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255)))
im_face = cv2.resize(img, (512,512), interpolation = cv2.INTER_AREA)
im_faces.append(im_face)
input_datas = []
for im_face in im_faces:
tmpImg = np.zeros((im_face.shape[0],im_face.shape[1],3))
im_face = im_face/np.max(im_face)
tmpImg[:,:,0] = (im_face[:,:,2]-0.406)/0.225
tmpImg[:,:,1] = (im_face[:,:,1]-0.456)/0.224
tmpImg[:,:,2] = (im_face[:,:,0]-0.485)/0.229
# convert BGR to RGB
tmpImg = tmpImg.transpose((2, 0, 1))
tmpImg = tmpImg[np.newaxis,:,:,:]
input_datas.append(tmpImg)
input_datas = np.concatenate(input_datas, 0)
datas_num = input_datas.shape[0]
split_num = datas_num//batch_size+1 if datas_num%batch_size!=0 else datas_num//batch_size
input_datas = np.array_split(input_datas, split_num)
return input_datas
def normPRED(self, d):
ma = np.max(d)
mi = np.min(d)
dn = (d-mi)/(ma-mi)
return dn
# 后处理
def postprocess(self, outputs, visualization=False, output_dir='output'):
results = []
if visualization and not os.path.exists(output_dir):
os.mkdir(output_dir)
for i in range(outputs.shape[0]):
# normalization
pred = 1.0 - outputs[i,0,:,:]
pred = self.normPRED(pred)
# convert torch tensor to numpy array
pred = pred.squeeze()
pred = (pred*255).astype(np.uint8)
results.append(pred)
if visualization:
cv2.imwrite(os.path.join(output_dir, 'result_%d.png' % i), pred)
return results
\ No newline at end of file
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
__all__ = ['U2NETP', 'U2NET']
class REBNCONV(nn.Layer):
def __init__(self,in_ch=3,out_ch=3,dirate=1):
super(REBNCONV,self).__init__()
self.conv_s1 = nn.Conv2D(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
self.bn_s1 = nn.BatchNorm2D(out_ch)
self.relu_s1 = nn.ReLU()
def forward(self,x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):
src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
return src
### RSU-7 ###
class RSU7(nn.Layer):#UNet07DRES(nn.Layer):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU7,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool5 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(paddle.concat((hx7,hx6),1))
hx6dup = _upsample_like(hx6d,hx5)
hx5d = self.rebnconv5d(paddle.concat((hx6dup,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(paddle.concat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(paddle.concat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(paddle.concat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(paddle.concat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-6 ###
class RSU6(nn.Layer):#UNet06DRES(nn.Layer):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(paddle.concat((hx6,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.rebnconv4d(paddle.concat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(paddle.concat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(paddle.concat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(paddle.concat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-5 ###
class RSU5(nn.Layer):#UNet05DRES(nn.Layer):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(paddle.concat((hx5,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.rebnconv3d(paddle.concat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(paddle.concat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(paddle.concat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4 ###
class RSU4(nn.Layer):#UNet04DRES(nn.Layer):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(paddle.concat((hx4,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.rebnconv2d(paddle.concat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.rebnconv1d(paddle.concat((hx2dup,hx1),1))
return hx1d + hxin
### RSU-4F ###
class RSU4F(nn.Layer):#UNet04FRES(nn.Layer):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F,self).__init__()
self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
def forward(self,x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(paddle.concat((hx4,hx3),1))
hx2d = self.rebnconv2d(paddle.concat((hx3d,hx2),1))
hx1d = self.rebnconv1d(paddle.concat((hx2d,hx1),1))
return hx1d + hxin
##### U^2-Net ####
class U2NET(nn.Layer):
def __init__(self,in_ch=3,out_ch=1):
super(U2NET,self).__init__()
self.stage1 = RSU7(in_ch,32,64)
self.pool12 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,32,128)
self.pool23 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(128,64,256)
self.pool34 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(256,128,512)
self.pool45 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(512,256,512)
self.pool56 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(512,256,512)
# decoder
self.stage5d = RSU4F(1024,256,512)
self.stage4d = RSU4(1024,128,256)
self.stage3d = RSU5(512,64,128)
self.stage2d = RSU6(256,32,64)
self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2D(64,out_ch,3,padding=1)
self.side2 = nn.Conv2D(64,out_ch,3,padding=1)
self.side3 = nn.Conv2D(128,out_ch,3,padding=1)
self.side4 = nn.Conv2D(256,out_ch,3,padding=1)
self.side5 = nn.Conv2D(512,out_ch,3,padding=1)
self.side6 = nn.Conv2D(512,out_ch,3,padding=1)
self.outconv = nn.Conv2D(6,out_ch,1)
def forward(self,x):
hx = x
#stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6,hx5)
#-------------------- decoder --------------------
hx5d = self.stage5d(paddle.concat((hx6up,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(paddle.concat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(paddle.concat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(paddle.concat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(paddle.concat((hx2dup,hx1),1))
#side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2,d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3,d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4,d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5,d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6,d1)
d0 = self.outconv(paddle.concat((d1,d2,d3,d4,d5,d6),1))
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
### U^2-Net small ###
class U2NETP(nn.Layer):
def __init__(self,in_ch=3,out_ch=1):
super(U2NETP,self).__init__()
self.stage1 = RSU7(in_ch,16,64)
self.pool12 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,16,64)
self.pool23 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(64,16,64)
self.pool34 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(64,16,64)
self.pool45 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(64,16,64)
self.pool56 = nn.MaxPool2D(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(64,16,64)
# decoder
self.stage5d = RSU4F(128,16,64)
self.stage4d = RSU4(128,16,64)
self.stage3d = RSU5(128,16,64)
self.stage2d = RSU6(128,16,64)
self.stage1d = RSU7(128,16,64)
self.side1 = nn.Conv2D(64,out_ch,3,padding=1)
self.side2 = nn.Conv2D(64,out_ch,3,padding=1)
self.side3 = nn.Conv2D(64,out_ch,3,padding=1)
self.side4 = nn.Conv2D(64,out_ch,3,padding=1)
self.side5 = nn.Conv2D(64,out_ch,3,padding=1)
self.side6 = nn.Conv2D(64,out_ch,3,padding=1)
self.outconv = nn.Conv2D(6,out_ch,1)
def forward(self,x):
hx = x
#stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6,hx5)
#decoder
hx5d = self.stage5d(paddle.concat((hx6up,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(paddle.concat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(paddle.concat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(paddle.concat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(paddle.concat((hx2dup,hx1),1))
#side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2,d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3,d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4,d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5,d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6,d1)
d0 = self.outconv(paddle.concat((d1,d2,d3,d4,d5,d6),1))
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册