提交 c9a2c16c 编写于 作者: C chenguowei01

update models

上级 9148fa5c
......@@ -20,12 +20,7 @@ from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.initializer import Normal
try:
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
print('using sync batch norm')
except:
from paddle.fluid.dygraph import BatchNorm
print('using batch norm')
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
__all__ = [
"HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
......
......@@ -14,15 +14,12 @@
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D, Pool2D
try:
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
except:
from paddle.fluid.dygraph import BatchNorm
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
class UNet(fluid.dygraph.Layer):
def __init__(self, num_classes, ignore_index=255):
super().__init__()
super(UNet, self).__init__()
self.encode = UnetEncoder()
self.decode = UnetDecode()
self.get_logit = GetLogit(64, num_classes)
......@@ -65,7 +62,7 @@ class UNet(fluid.dygraph.Layer):
class UnetEncoder(fluid.dygraph.Layer):
def __init__(self):
super().__init__()
super(UnetEncoder, self).__init__()
self.double_conv = DoubleConv(3, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
......@@ -88,7 +85,7 @@ class UnetEncoder(fluid.dygraph.Layer):
class UnetDecode(fluid.dygraph.Layer):
def __init__(self):
super().__init__()
super(UnetDecode, self).__init__()
self.up1 = Up(512, 256)
self.up2 = Up(256, 128)
self.up3 = Up(128, 64)
......@@ -104,7 +101,7 @@ class UnetDecode(fluid.dygraph.Layer):
class DoubleConv(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters):
super().__init__()
super(DoubleConv, self).__init__()
self.conv0 = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
......@@ -132,7 +129,7 @@ class DoubleConv(fluid.dygraph.Layer):
class Down(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters):
super().__init__()
super(Down, self).__init__()
self.max_pool = Pool2D(
pool_size=2, pool_type='max', pool_stride=2, pool_padding=0)
self.double_conv = DoubleConv(num_channels, num_filters)
......@@ -145,7 +142,7 @@ class Down(fluid.dygraph.Layer):
class Up(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters):
super().__init__()
super(Up, self).__init__()
self.double_conv = DoubleConv(2 * num_channels, num_filters)
def forward(self, x, short_cut):
......@@ -158,7 +155,7 @@ class Up(fluid.dygraph.Layer):
class GetLogit(fluid.dygraph.Layer):
def __init__(self, num_channels, num_classes):
super().__init__()
super(GetLogit, self).__init__()
self.conv = Conv2D(
num_channels=num_channels,
num_filters=num_classes,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册