提交 30e44c5e 编写于 作者: C chenguowei01

update unet.py to 2.0beta

上级 a3bfb074
...@@ -14,15 +14,19 @@ ...@@ -14,15 +14,19 @@
import os import os
import paddle.fluid as fluid import paddle
from paddle.fluid.dygraph import Conv2D, Pool2D import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2d
from paddle.nn import SyncBatchNorm as BatchNorm from paddle.nn import SyncBatchNorm as BatchNorm
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg import utils from paddleseg import utils
from paddleseg.models.common import layer_libs
class UNet(fluid.dygraph.Layer): @manager.MODELS.add_component
class UNet(nn.Layer):
""" """
U-Net: Convolutional Networks for Biomedical Image Segmentation. U-Net: Convolutional Networks for Biomedical Image Segmentation.
https://arxiv.org/abs/1505.04597 https://arxiv.org/abs/1505.04597
...@@ -35,62 +39,35 @@ class UNet(fluid.dygraph.Layer): ...@@ -35,62 +39,35 @@ class UNet(fluid.dygraph.Layer):
def __init__(self, num_classes, model_pretrained=None, ignore_index=255): def __init__(self, num_classes, model_pretrained=None, ignore_index=255):
super(UNet, self).__init__() super(UNet, self).__init__()
self.model_pretrained = model_pretrained
self.ignore_index = ignore_index
self.encode = UnetEncoder() self.encode = UnetEncoder()
self.decode = UnetDecode() self.decode = UnetDecode()
self.get_logit = GetLogit(64, num_classes) self.get_logit = GetLogit(64, num_classes)
self.ignore_index = ignore_index
self.EPS = 1e-5 self.EPS = 1e-5
self.init_weight(model_pretrained) self.init_weight()
def forward(self, x, label=None): def forward(self, x, label=None):
encode_data, short_cuts = self.encode(x) encode_data, short_cuts = self.encode(x)
decode_data = self.decode(encode_data, short_cuts) decode_data = self.decode(encode_data, short_cuts)
logit = self.get_logit(decode_data) logit = self.get_logit(decode_data)
if self.training: return [logit]
return self._get_loss(logit, label)
else: def init_weight(self):
score_map = fluid.layers.softmax(logit, axis=1)
score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
pred = fluid.layers.argmax(score_map, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
return pred, score_map
def init_weight(self, pretrained_model=None):
""" """
Initialize the parameters of model parts. Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
""" """
if pretrained_model is not None: if self.model_pretrained is not None:
if os.path.exists(pretrained_model): if os.path.exists(self.model_pretrained):
utils.load_pretrained_model(self, pretrained_model) utils.load_pretrained_model(self, self.model_pretrained)
else: else:
raise Exception('Pretrained model is not found: {}'.format( raise Exception('Pretrained model is not found: {}'.format(
pretrained_model)) self.model_pretrained))
def _get_loss(self, logit, label):
logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) class UnetEncoder(nn.Layer):
label = fluid.layers.transpose(label, [0, 2, 3, 1])
mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=self.ignore_index,
return_softmax=True,
axis=-1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
class UnetEncoder(fluid.dygraph.Layer):
def __init__(self): def __init__(self):
super(UnetEncoder, self).__init__() super(UnetEncoder, self).__init__()
self.double_conv = DoubleConv(3, 64) self.double_conv = DoubleConv(3, 64)
...@@ -113,7 +90,7 @@ class UnetEncoder(fluid.dygraph.Layer): ...@@ -113,7 +90,7 @@ class UnetEncoder(fluid.dygraph.Layer):
return x, short_cuts return x, short_cuts
class UnetDecode(fluid.dygraph.Layer): class UnetDecode(nn.Layer):
def __init__(self): def __init__(self):
super(UnetDecode, self).__init__() super(UnetDecode, self).__init__()
self.up1 = Up(512, 256) self.up1 = Up(512, 256)
...@@ -129,20 +106,20 @@ class UnetDecode(fluid.dygraph.Layer): ...@@ -129,20 +106,20 @@ class UnetDecode(fluid.dygraph.Layer):
return x return x
class DoubleConv(fluid.dygraph.Layer): class DoubleConv(nn.Layer):
def __init__(self, num_channels, num_filters): def __init__(self, num_channels, num_filters):
super(DoubleConv, self).__init__() super(DoubleConv, self).__init__()
self.conv0 = Conv2D( self.conv0 = Conv2d(
num_channels=num_channels, in_channels=num_channels,
num_filters=num_filters, out_channels=num_filters,
filter_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
self.bn0 = BatchNorm(num_filters) self.bn0 = BatchNorm(num_filters)
self.conv1 = Conv2D( self.conv1 = Conv2d(
num_channels=num_filters, in_channels=num_filters,
num_filters=num_filters, out_channels=num_filters,
filter_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
self.bn1 = BatchNorm(num_filters) self.bn1 = BatchNorm(num_filters)
...@@ -150,18 +127,17 @@ class DoubleConv(fluid.dygraph.Layer): ...@@ -150,18 +127,17 @@ class DoubleConv(fluid.dygraph.Layer):
def forward(self, x): def forward(self, x):
x = self.conv0(x) x = self.conv0(x)
x = self.bn0(x) x = self.bn0(x)
x = fluid.layers.relu(x) x = F.relu(x)
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
x = fluid.layers.relu(x) x = F.relu(x)
return x return x
class Down(fluid.dygraph.Layer): class Down(nn.Layer):
def __init__(self, num_channels, num_filters): def __init__(self, num_channels, num_filters):
super(Down, self).__init__() super(Down, self).__init__()
self.max_pool = Pool2D( self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
pool_size=2, pool_type='max', pool_stride=2, pool_padding=0)
self.double_conv = DoubleConv(num_channels, num_filters) self.double_conv = DoubleConv(num_channels, num_filters)
def forward(self, x): def forward(self, x):
...@@ -170,34 +146,28 @@ class Down(fluid.dygraph.Layer): ...@@ -170,34 +146,28 @@ class Down(fluid.dygraph.Layer):
return x return x
class Up(fluid.dygraph.Layer): class Up(nn.Layer):
def __init__(self, num_channels, num_filters): def __init__(self, num_channels, num_filters):
super(Up, self).__init__() super(Up, self).__init__()
self.double_conv = DoubleConv(2 * num_channels, num_filters) self.double_conv = DoubleConv(2 * num_channels, num_filters)
def forward(self, x, short_cut): def forward(self, x, short_cut):
short_cut_shape = fluid.layers.shape(short_cut) x = F.resize_bilinear(x, short_cut.shape[2:])
x = fluid.layers.resize_bilinear(x, short_cut_shape[2:]) x = paddle.concat([x, short_cut], axis=1)
x = fluid.layers.concat([x, short_cut], axis=1)
x = self.double_conv(x) x = self.double_conv(x)
return x return x
class GetLogit(fluid.dygraph.Layer): class GetLogit(nn.Layer):
def __init__(self, num_channels, num_classes): def __init__(self, num_channels, num_classes):
super(GetLogit, self).__init__() super(GetLogit, self).__init__()
self.conv = Conv2D( self.conv = Conv2d(
num_channels=num_channels, in_channels=num_channels,
num_filters=num_classes, out_channels=num_classes,
filter_size=3, kernel_size=3,
stride=1, stride=1,
padding=1) padding=1)
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
return x return x
@manager.MODELS.add_component
def unet(*args, **kwargs):
return UNet(*args, **kwargs)
...@@ -87,7 +87,7 @@ def parse_args(): ...@@ -87,7 +87,7 @@ def parse_args():
def main(args): def main(args):
env_info = get_environ_info() env_info = get_environ_info()
info = ['{}: {}'.format(k, v) for k, v in env_info.items()] info = ['{}: {}'.format(k, v) for k, v in env_info.items()]
info = '\n'.join(['\n', format('Environment Information', '-^48s')] + info + info = '\n'.join(['', format('Environment Information', '-^48s')] + info +
['-' * 48]) ['-' * 48])
logger.info(info) logger.info(info)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册