未验证 提交 e5336bb5 编写于 作者: W wuzewu 提交者: GitHub

Merge pull request #351 from michaelowenliu/develop

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dygraph.models
\ No newline at end of file
......@@ -12,36 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .architectures import *
from .unet import UNet
from .hrnet import *
from .deeplab import *
MODELS = {
"UNet": UNet,
"HRNet_W18_Small_V1": HRNet_W18_Small_V1,
"HRNet_W18_Small_V2": HRNet_W18_Small_V2,
"HRNet_W18": HRNet_W18,
"HRNet_W30": HRNet_W30,
"HRNet_W32": HRNet_W32,
"HRNet_W40": HRNet_W40,
"HRNet_W44": HRNet_W44,
"HRNet_W48": HRNet_W48,
"HRNet_W60": HRNet_W48,
"HRNet_W64": HRNet_W64,
"SE_HRNet_W18_Small_V1": SE_HRNet_W18_Small_V1,
"SE_HRNet_W18_Small_V2": SE_HRNet_W18_Small_V2,
"SE_HRNet_W18": SE_HRNet_W18,
"SE_HRNet_W30": SE_HRNet_W30,
"SE_HRNet_W32": SE_HRNet_W30,
"SE_HRNet_W40": SE_HRNet_W40,
"SE_HRNet_W44": SE_HRNet_W44,
"SE_HRNet_W48": SE_HRNet_W48,
"SE_HRNet_W60": SE_HRNet_W60,
"SE_HRNet_W64": SE_HRNet_W64,
"DeepLabV3P": DeepLabV3P,
"deeplabv3p_resnet101_vd": deeplabv3p_resnet101_vd,
"deeplabv3p_resnet101_vd_os8": deeplabv3p_resnet101_vd_os8,
"deeplabv3p_resnet50_vd": deeplabv3p_resnet50_vd,
"deeplabv3p_resnet50_vd_os8": deeplabv3p_resnet50_vd_os8,
"deeplabv3p_xception65_deeplab": deeplabv3p_xception65_deeplab
}
# MODELS = {
# "UNet": UNet,
# "HRNet_W18_Small_V1": HRNet_W18_Small_V1,
# "HRNet_W18_Small_V2": HRNet_W18_Small_V2,
# "HRNet_W18": HRNet_W18,
# "HRNet_W30": HRNet_W30,
# "HRNet_W32": HRNet_W32,
# "HRNet_W40": HRNet_W40,
# "HRNet_W44": HRNet_W44,
# "HRNet_W48": HRNet_W48,
# "HRNet_W60": HRNet_W48,
# "HRNet_W64": HRNet_W64,
# "SE_HRNet_W18_Small_V1": SE_HRNet_W18_Small_V1,
# "SE_HRNet_W18_Small_V2": SE_HRNet_W18_Small_V2,
# "SE_HRNet_W18": SE_HRNet_W18,
# "SE_HRNet_W30": SE_HRNet_W30,
# "SE_HRNet_W32": SE_HRNet_W30,
# "SE_HRNet_W40": SE_HRNet_W40,
# "SE_HRNet_W44": SE_HRNet_W44,
# "SE_HRNet_W48": SE_HRNet_W48,
# "SE_HRNet_W60": SE_HRNet_W60,
# "SE_HRNet_W64": SE_HRNet_W64,
# "DeepLabV3P": DeepLabV3P,
# "deeplabv3p_resnet101_vd": deeplabv3p_resnet101_vd,
# "deeplabv3p_resnet101_vd_os8": deeplabv3p_resnet101_vd_os8,
# "deeplabv3p_resnet50_vd": deeplabv3p_resnet50_vd,
# "deeplabv3p_resnet50_vd_os8": deeplabv3p_resnet50_vd_os8,
# "deeplabv3p_xception65_deeplab": deeplabv3p_xception65_deeplab,
# "deeplabv3p_mobilenetv3_large": deeplabv3p_mobilenetv3_large,
# "deeplabv3p_mobilenetv3_small": deeplabv3p_mobilenetv3_small
# }
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
import math
from dygraph.cvlibs import manager
__all__ = [
"MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
"MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0",
"MobileNetV3_small_x1_25", "MobileNetV3_large_x0_35",
"MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75",
"MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25"
]
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
def get_padding_same(kernel_size, dilation_rate):
"""
SAME padding implementation given kernel_size and dilation_rate.
The calculation formula as following:
(F-(k+(k -1)*(r-1))+2*p)/s + 1 = F_new
where F: a feature map
k: kernel size, r: dilation rate, p: padding value, s: stride
F_new: new feature map
Args:
kernel_size (int)
dilation_rate (int)
Returns:
padding_same (int): padding value
"""
k = kernel_size
r = dilation_rate
padding_same = (k + (k - 1) * (r - 1) - 1)//2
return padding_same
class MobileNetV3(fluid.dygraph.Layer):
def __init__(self, scale=1.0, model_name="small", class_dim=1000, output_stride=None, **kwargs):
super(MobileNetV3, self).__init__()
inplanes = 16
if model_name == "large":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, "relu", 1],
[3, 64, 24, False, "relu", 2],
[3, 72, 24, False, "relu", 1], # output 1 -> out_index=2
[5, 72, 40, True, "relu", 2],
[5, 120, 40, True, "relu", 1],
[5, 120, 40, True, "relu", 1], # output 2 -> out_index=5
[3, 240, 80, False, "hard_swish", 2],
[3, 200, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 480, 112, True, "hard_swish", 1],
[3, 672, 112, True, "hard_swish", 1], # output 3 -> out_index=11
[5, 672, 160, True, "hard_swish", 2],
[5, 960, 160, True, "hard_swish", 1],
[5, 960, 160, True, "hard_swish", 1], # output 3 -> out_index=14
]
self.out_indices = [2, 5, 11, 14]
self.cls_ch_squeeze = 960
self.cls_ch_expand = 1280
elif model_name == "small":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, "relu", 2], # output 1 -> out_index=0
[3, 72, 24, False, "relu", 2],
[3, 88, 24, False, "relu", 1], # output 2 -> out_index=3
[5, 96, 40, True, "hard_swish", 2],
[5, 240, 40, True, "hard_swish", 1],
[5, 240, 40, True, "hard_swish", 1],
[5, 120, 48, True, "hard_swish", 1],
[5, 144, 48, True, "hard_swish", 1], # output 3 -> out_index=7
[5, 288, 96, True, "hard_swish", 2],
[5, 576, 96, True, "hard_swish", 1],
[5, 576, 96, True, "hard_swish", 1], # output 4 -> out_index=10
]
self.out_indices = [0, 3, 7, 10]
self.cls_ch_squeeze = 576
self.cls_ch_expand = 1280
else:
raise NotImplementedError(
"mode[{}_model] is not implemented!".format(model_name))
###################################################
# modify stride and dilation based on output_stride
self.dilation_cfg = [1] * len(self.cfg)
self.modify_bottle_params(output_stride=output_stride)
###################################################
self.conv1 = ConvBNLayer(
in_c=3,
out_c=make_divisible(inplanes * scale),
filter_size=3,
stride=2,
padding=1,
num_groups=1,
if_act=True,
act="hard_swish",
name="conv1")
self.block_list = []
inplanes = make_divisible(inplanes * scale)
for i, (k, exp, c, se, nl, s) in enumerate(self.cfg):
######################################
# add dilation rate
dilation_rate = self.dilation_cfg[i]
######################################
self.block_list.append(
ResidualUnit(
in_c=inplanes,
mid_c=make_divisible(scale * exp),
out_c=make_divisible(scale * c),
filter_size=k,
stride=s,
dilation=dilation_rate,
use_se=se,
act=nl,
name="conv" + str(i + 2)))
self.add_sublayer(
sublayer=self.block_list[-1], name="conv" + str(i + 2))
inplanes = make_divisible(scale * c)
self.last_second_conv = ConvBNLayer(
in_c=inplanes,
out_c=make_divisible(scale * self.cls_ch_squeeze),
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act="hard_swish",
name="conv_last")
self.pool = Pool2D(
pool_type="avg", global_pooling=True, use_cudnn=False)
self.last_conv = Conv2D(
num_channels=make_divisible(scale * self.cls_ch_squeeze),
num_filters=self.cls_ch_expand,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name="last_1x1_conv_weights"),
bias_attr=False)
self.out = Linear(
input_dim=self.cls_ch_expand,
output_dim=class_dim,
param_attr=ParamAttr("fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
def modify_bottle_params(self, output_stride=None):
if output_stride is not None and output_stride % 2 != 0:
raise Exception("output stride must to be even number")
if output_stride is not None:
stride = 2
rate = 1
for i, _cfg in enumerate(self.cfg):
stride = stride * _cfg[-1]
if stride > output_stride:
rate = rate * _cfg[-1]
self.cfg[i][-1] = 1
self.dilation_cfg[i] = rate
def forward(self, inputs, label=None, dropout_prob=0.2):
x = self.conv1(inputs)
# A feature list saves each downsampling feature.
feat_list = []
for i, block in enumerate(self.block_list):
x = block(x)
if i in self.out_indices:
feat_list.append(x)
#print("block {}:".format(i),x.shape, self.dilation_cfg[i])
x = self.last_second_conv(x)
x = self.pool(x)
x = self.last_conv(x)
x = fluid.layers.hard_swish(x)
x = fluid.layers.dropout(x=x, dropout_prob=dropout_prob)
x = fluid.layers.reshape(x, shape=[x.shape[0], x.shape[1]])
x = self.out(x)
return x, feat_list
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
in_c,
out_c,
filter_size,
stride,
padding,
dilation=1,
num_groups=1,
if_act=True,
act=None,
use_cudnn=True,
name=""):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = fluid.dygraph.Conv2D(
num_channels=in_c,
num_filters=out_c,
filter_size=filter_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=num_groups,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
use_cudnn=use_cudnn,
act=None)
self.bn = fluid.dygraph.BatchNorm(
num_channels=out_c,
act=None,
param_attr=ParamAttr(
name=name + "_bn_scale",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
bias_attr=ParamAttr(
name=name + "_bn_offset",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance")
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.if_act:
if self.act == "relu":
x = fluid.layers.relu(x)
elif self.act == "hard_swish":
x = fluid.layers.hard_swish(x)
else:
print("The activation function is selected incorrectly.")
exit()
return x
class ResidualUnit(fluid.dygraph.Layer):
def __init__(self,
in_c,
mid_c,
out_c,
filter_size,
stride,
use_se,
dilation=1,
act=None,
name=''):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_c == out_c
self.if_se = use_se
self.expand_conv = ConvBNLayer(
in_c=in_c,
out_c=mid_c,
filter_size=1,
stride=1,
padding=0,
if_act=True,
act=act,
name=name + "_expand")
self.bottleneck_conv = ConvBNLayer(
in_c=mid_c,
out_c=mid_c,
filter_size=filter_size,
stride=stride,
padding= get_padding_same(filter_size, dilation), #int((filter_size - 1) // 2) + (dilation - 1),
dilation=dilation,
num_groups=mid_c,
if_act=True,
act=act,
name=name + "_depthwise")
if self.if_se:
self.mid_se = SEModule(mid_c, name=name + "_se")
self.linear_conv = ConvBNLayer(
in_c=mid_c,
out_c=out_c,
filter_size=1,
stride=1,
padding=0,
if_act=False,
act=None,
name=name + "_linear")
self.dilation = dilation
def forward(self, inputs):
x = self.expand_conv(inputs)
x = self.bottleneck_conv(x)
if self.if_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = fluid.layers.elementwise_add(inputs, x)
return x
class SEModule(fluid.dygraph.Layer):
def __init__(self, channel, reduction=4, name=""):
super(SEModule, self).__init__()
self.avg_pool = fluid.dygraph.Pool2D(
pool_type="avg", global_pooling=True, use_cudnn=False)
self.conv1 = fluid.dygraph.Conv2D(
num_channels=channel,
num_filters=channel // reduction,
filter_size=1,
stride=1,
padding=0,
act="relu",
param_attr=ParamAttr(name=name + "_1_weights"),
bias_attr=ParamAttr(name=name + "_1_offset"))
self.conv2 = fluid.dygraph.Conv2D(
num_channels=channel // reduction,
num_filters=channel,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name + "_2_weights"),
bias_attr=ParamAttr(name=name + "_2_offset"))
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = self.conv2(outputs)
outputs = fluid.layers.hard_sigmoid(outputs)
return fluid.layers.elementwise_mul(x=inputs, y=outputs, axis=0)
def MobileNetV3_small_x0_35(**kwargs):
model = MobileNetV3(model_name="small", scale=0.35, **kwargs)
return model
def MobileNetV3_small_x0_5(**kwargs):
model = MobileNetV3(model_name="small", scale=0.5, **kwargs)
return model
def MobileNetV3_small_x0_75(**kwargs):
model = MobileNetV3(model_name="small", scale=0.75, **kwargs)
return model
@manager.BACKBONES.add_component
def MobileNetV3_small_x1_0(**kwargs):
model = MobileNetV3(model_name="small", scale=1.0, **kwargs)
return model
def MobileNetV3_small_x1_25(**kwargs):
model = MobileNetV3(model_name="small", scale=1.25, **kwargs)
return model
def MobileNetV3_large_x0_35(**kwargs):
model = MobileNetV3(model_name="large", scale=0.35, **kwargs)
return model
def MobileNetV3_large_x0_5(**kwargs):
model = MobileNetV3(model_name="large", scale=0.5, **kwargs)
return model
def MobileNetV3_large_x0_75(**kwargs):
model = MobileNetV3(model_name="large", scale=0.75, **kwargs)
return model
@manager.BACKBONES.add_component
def MobileNetV3_large_x1_0(**kwargs):
model = MobileNetV3(model_name="large", scale=1.0, **kwargs)
return model
def MobileNetV3_large_x1_25(**kwargs):
model = MobileNetV3(model_name="large", scale=1.25, **kwargs)
return model
......@@ -28,6 +28,8 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
from dygraph.utils import utils
from dygraph.cvlibs import manager
__all__ = [
"ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd"
]
......@@ -199,7 +201,7 @@ class BasicBlock(fluid.dygraph.Layer):
class ResNet_vd(fluid.dygraph.Layer):
def __init__(self, layers=50, class_dim=1000, dilation_dict=None, multi_grid=(1, 2, 4), **kwargs):
def __init__(self, layers=50, class_dim=1000, output_stride=None, multi_grid=(1, 2, 4), **kwargs):
super(ResNet_vd, self).__init__()
self.layers = layers
......@@ -222,6 +224,12 @@ class ResNet_vd(fluid.dygraph.Layer):
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
dilation_dict=None
if output_stride == 8:
dilation_dict = {2: 2, 3: 4}
elif output_stride == 16:
dilation_dict = {3: 2}
self.conv1_1 = ConvBNLayer(
num_channels=3,
num_filters=32,
......@@ -359,12 +367,12 @@ def ResNet34_vd(**args):
model = ResNet_vd(layers=34, **args)
return model
@manager.BACKBONES.add_component
def ResNet50_vd(**args):
model = ResNet_vd(layers=50, **args)
return model
@manager.BACKBONES.add_component
def ResNet101_vd(**args):
model = ResNet_vd(layers=101, **args)
return model
......
......@@ -4,6 +4,8 @@ from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
from dygraph.cvlibs import manager
__all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"]
......@@ -400,7 +402,7 @@ def Xception41_deeplab(**args):
model = XceptionDeeplab('xception_41', **args)
return model
@manager.BACKBONES.add_component
def Xception65_deeplab(**args):
model = XceptionDeeplab("xception_65", **args)
return model
......
......@@ -13,22 +13,21 @@
# limitations under the License.
import os
import numpy as np
import paddle
from dygraph.cvlibs import manager
from dygraph.models.architectures import layer_utils
from paddle import fluid
from paddle.fluid import dygraph
from paddle.fluid.dygraph import Conv2D
from .architectures import layer_utils, xception_deeplab, resnet_vd
from dygraph.utils import utils
__all__ = ['DeepLabV3P', "deeplabv3p_resnet101_vd", "deeplabv3p_resnet101_vd_os8",
"deeplabv3p_resnet50_vd", "deeplabv3p_resnet50_vd_os8",
"deeplabv3p_xception65_deeplab"]
"deeplabv3p_xception65_deeplab",
"deeplabv3p_mobilenetv3_large", "deeplabv3p_mobilenetv3_small"]
class ImageAverage(dygraph.Layer):
"""
......@@ -104,7 +103,6 @@ class ASPP(dygraph.Layer):
dilation=aspp_ratios[2],
padding=aspp_ratios[2])
# After concat op, using 1*1 conv
self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels=1280,
num_filters=256,
......@@ -197,11 +195,12 @@ class DeepLabV3P(dygraph.Layer):
in ASPP and Decoder components. Default True.
pretrained_model (str): the pretrained_model path of backbone.
"""
def __init__(self,
backbone,
num_classes=2,
output_stride=16,
backbone_indices=(0,3),
backbone_indices=(0, 3),
backbone_channels=(256, 2048),
ignore_index=255,
using_sep_conv=True,
......@@ -209,7 +208,7 @@ class DeepLabV3P(dygraph.Layer):
super(DeepLabV3P, self).__init__()
self.backbone = build_backbone(backbone, output_stride)
self.backbone = manager.BACKBONES[backbone](output_stride=output_stride)
self.aspp = ASPP(output_stride, backbone_channels[1], using_sep_conv)
self.decoder = Decoder(num_classes, backbone_channels[0], using_sep_conv)
self.ignore_index = ignore_index
......@@ -217,7 +216,8 @@ class DeepLabV3P(dygraph.Layer):
self.backbone_indices = backbone_indices
self.init_weight(pretrained_model)
def forward(self, input, label=None, mode='train'):
def forward(self, input, label=None):
_, feat_list = self.backbone(input)
low_level_feat = feat_list[self.backbone_indices[0]]
x = feat_list[self.backbone_indices[1]]
......@@ -279,50 +279,63 @@ class DeepLabV3P(dygraph.Layer):
return avg_loss
def build_backbone(backbone, output_stride):
if output_stride == 8:
dilation_dict = {2: 2, 3: 4}
elif output_stride == 16:
dilation_dict = {3: 2}
else:
raise Exception("deeplab only support stride 8 or 16")
model_dict = {"ResNet50_vd":resnet_vd.ResNet50_vd,
"ResNet101_vd":resnet_vd.ResNet101_vd,
"Xception65_deeplab": xception_deeplab.Xception65_deeplab}
model = model_dict[backbone]
return model(dilation_dict=dilation_dict)
def build_aspp(output_stride, using_sep_conv):
return ASPP(output_stride=output_stride, using_sep_conv=using_sep_conv)
def build_decoder(num_classes, using_sep_conv):
return Decoder(num_classes, using_sep_conv=using_sep_conv)
@manager.MODELS.add_component
def deeplabv3p_resnet101_vd(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component
def deeplabv3p_resnet101_vd_os8(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet101_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component
def deeplabv3p_resnet50_vd(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component
def deeplabv3p_resnet50_vd_os8(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='ResNet50_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs)
@manager.MODELS.add_component
def deeplabv3p_xception65_deeplab(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='Xception65_deeplab',
pretrained_model=pretrained_model,
backbone_indices=(0,1),
backbone_indices=(0, 1),
backbone_channels=(128, 2048),
**kwargs)
@manager.MODELS.add_component
def deeplabv3p_mobilenetv3_large(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='MobileNetV3_large_x1_0',
pretrained_model=pretrained_model,
backbone_indices=(0, 3),
backbone_channels=(24, 160),
**kwargs)
@manager.MODELS.add_component
def deeplabv3p_mobilenetv3_small(*args, **kwargs):
pretrained_model = None
return DeepLabV3P(backbone='MobileNetV3_small_x1_0',
pretrained_model=pretrained_model,
backbone_indices=(0, 3),
backbone_channels=(16, 96),
**kwargs)
......@@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from dygraph.datasets import DATASETS
import dygraph.transforms as T
from dygraph.models import MODELS
#from dygraph.models import MODELS
from dygraph.cvlibs import manager
from dygraph.utils import get_environ_info
from dygraph.core import train
......@@ -32,7 +33,7 @@ def parse_args():
'--model_name',
dest='model_name',
help='Model type for training, which is one of {}'.format(
str(list(MODELS.keys()))),
str(list(manager.MODELS.components_dict.keys()))),
type=str,
default='UNet')
......@@ -160,11 +161,8 @@ def main(args):
transforms=eval_transforms,
mode='val')
if args.model_name not in MODELS:
raise Exception(
'`--model_name` is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
model = manager.MODELS[args.model_name](num_classes=train_dataset.num_classes)
# Creat optimizer
# todo, may less one than len(loader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册