提交 1921935e 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix xception_deeplab

上级 0e1789d4
...@@ -33,5 +33,6 @@ from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, Mobil ...@@ -33,5 +33,6 @@ from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, Mobil
from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2_swish from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2_swish
from .alexnet import AlexNet from .alexnet import AlexNet
from .inception_v4 import InceptionV4 from .inception_v4 import InceptionV4
from .xception_deeplab import Xception41_deeplab, Xception65_deeplab, Xception71_deeplab
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0 from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import contextlib
bn_regularizer = fluid.regularizer.L2DecayRegularizer(regularization_coeff=0.0)
name_scope = ""
@contextlib.contextmanager
def scope(name):
global name_scope
bk = name_scope
name_scope = name_scope + name + '/'
yield
name_scope = bk
def max_pool(input, kernel, stride, padding):
data = fluid.layers.pool2d(
input,
pool_size=kernel,
pool_type='max',
pool_stride=stride,
pool_padding=padding)
return data
def group_norm(input, G, eps=1e-5, param_attr=None, bias_attr=None):
N, C, H, W = input.shape
if C % G != 0:
# print "group can not divide channle:", C, G
for d in range(10):
for t in [d, -d]:
if G + t <= 0: continue
if C % (G + t) == 0:
G = G + t
break
if C % G == 0:
# print "use group size:", G
break
assert C % G == 0
x = fluid.layers.group_norm(
input,
groups=G,
param_attr=param_attr,
bias_attr=bias_attr,
name=name_scope + 'group_norm')
return x
def bn(*args, **kargs):
with scope('BatchNorm'):
return fluid.layers.batch_norm(
*args,
epsilon=1e-3,
momentum=0.99,
param_attr=fluid.ParamAttr(
name=name_scope + 'gamma', regularizer=bn_regularizer),
bias_attr=fluid.ParamAttr(
name=name_scope + 'beta', regularizer=bn_regularizer),
moving_mean_name=name_scope + 'moving_mean',
moving_variance_name=name_scope + 'moving_variance',
**kargs)
def bn_relu(data):
return fluid.layers.relu(bn(data))
def relu(data):
return fluid.layers.relu(data)
def conv(*args, **kargs):
kargs['param_attr'] = name_scope + 'weights'
if 'bias_attr' in kargs and kargs['bias_attr']:
kargs['bias_attr'] = fluid.ParamAttr(
name=name_scope + 'biases',
regularizer=None,
initializer=fluid.initializer.ConstantInitializer(value=0.0))
else:
kargs['bias_attr'] = False
return fluid.layers.conv2d(*args, **kargs)
def deconv(*args, **kargs):
kargs['param_attr'] = name_scope + 'weights'
if 'bias_attr' in kargs and kargs['bias_attr']:
kargs['bias_attr'] = name_scope + 'biases'
else:
kargs['bias_attr'] = False
return fluid.layers.conv2d_transpose(*args, **kargs)
def seperate_conv(input, channel, stride, filter, dilation=1, act=None):
param_attr = fluid.ParamAttr(
name=name_scope + 'weights',
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0),
initializer=fluid.initializer.TruncatedNormal(
loc=0.0, scale=0.33))
with scope('depthwise'):
input = conv(
input,
input.shape[1],
filter,
stride,
groups=input.shape[1],
padding=(filter // 2) * dilation,
dilation=dilation,
use_cudnn=False,
param_attr=param_attr)
input = bn(input)
if act: input = act(input)
param_attr = fluid.ParamAttr(
name=name_scope + 'weights',
regularizer=None,
initializer=fluid.initializer.TruncatedNormal(
loc=0.0, scale=0.06))
with scope('pointwise'):
input = conv(
input, channel, 1, 1, groups=1, padding=0, param_attr=param_attr)
input = bn(input)
if act: input = act(input)
return input
import paddle import paddle
import paddle.fluid as fluid from paddle import ParamAttr
from paddle.fluid.param_attr import ParamAttr import paddle.nn as nn
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout import paddle.nn.functional as F
from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, Dropout
__all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"] __all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"]
...@@ -56,7 +57,7 @@ def gen_bottleneck_params(backbone='xception_65'): ...@@ -56,7 +57,7 @@ def gen_bottleneck_params(backbone='xception_65'):
return bottleneck_params return bottleneck_params
class ConvBNLayer(fluid.dygraph.Layer): class ConvBNLayer(nn.Layer):
def __init__(self, def __init__(self,
input_channels, input_channels,
output_channels, output_channels,
...@@ -67,13 +68,13 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -67,13 +68,13 @@ class ConvBNLayer(fluid.dygraph.Layer):
name=None): name=None):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self._conv = Conv2D( self._conv = Conv2d(
num_channels=input_channels, in_channels=input_channels,
num_filters=output_channels, out_channels=output_channels,
filter_size=filter_size, kernel_size=filter_size,
stride=stride, stride=stride,
padding=padding, padding=padding,
param_attr=ParamAttr(name=name + "/weights"), weight_attr=ParamAttr(name=name + "/weights"),
bias_attr=False) bias_attr=False)
self._bn = BatchNorm( self._bn = BatchNorm(
num_channels=output_channels, num_channels=output_channels,
...@@ -89,7 +90,7 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -89,7 +90,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
return self._bn(self._conv(inputs)) return self._bn(self._conv(inputs))
class Seperate_Conv(fluid.dygraph.Layer): class Seperate_Conv(nn.Layer):
def __init__(self, def __init__(self,
input_channels, input_channels,
output_channels, output_channels,
...@@ -100,15 +101,15 @@ class Seperate_Conv(fluid.dygraph.Layer): ...@@ -100,15 +101,15 @@ class Seperate_Conv(fluid.dygraph.Layer):
name=None): name=None):
super(Seperate_Conv, self).__init__() super(Seperate_Conv, self).__init__()
self._conv1 = Conv2D( self._conv1 = Conv2d(
num_channels=input_channels, in_channels=input_channels,
num_filters=input_channels, out_channels=input_channels,
filter_size=filter, kernel_size=filter,
stride=stride, stride=stride,
groups=input_channels, groups=input_channels,
padding=(filter) // 2 * dilation, padding=(filter) // 2 * dilation,
dilation=dilation, dilation=dilation,
param_attr=ParamAttr(name=name + "/depthwise/weights"), weight_attr=ParamAttr(name=name + "/depthwise/weights"),
bias_attr=False) bias_attr=False)
self._bn1 = BatchNorm( self._bn1 = BatchNorm(
input_channels, input_channels,
...@@ -119,14 +120,14 @@ class Seperate_Conv(fluid.dygraph.Layer): ...@@ -119,14 +120,14 @@ class Seperate_Conv(fluid.dygraph.Layer):
bias_attr=ParamAttr(name=name + "/depthwise/BatchNorm/beta"), bias_attr=ParamAttr(name=name + "/depthwise/BatchNorm/beta"),
moving_mean_name=name + "/depthwise/BatchNorm/moving_mean", moving_mean_name=name + "/depthwise/BatchNorm/moving_mean",
moving_variance_name=name + "/depthwise/BatchNorm/moving_variance") moving_variance_name=name + "/depthwise/BatchNorm/moving_variance")
self._conv2 = Conv2D( self._conv2 = Conv2d(
input_channels, input_channels,
output_channels, output_channels,
1, 1,
stride=1, stride=1,
groups=1, groups=1,
padding=0, padding=0,
param_attr=ParamAttr(name=name + "/pointwise/weights"), weight_attr=ParamAttr(name=name + "/pointwise/weights"),
bias_attr=False) bias_attr=False)
self._bn2 = BatchNorm( self._bn2 = BatchNorm(
output_channels, output_channels,
...@@ -146,7 +147,7 @@ class Seperate_Conv(fluid.dygraph.Layer): ...@@ -146,7 +147,7 @@ class Seperate_Conv(fluid.dygraph.Layer):
return x return x
class Xception_Block(fluid.dygraph.Layer): class Xception_Block(nn.Layer):
def __init__(self, def __init__(self,
input_channels, input_channels,
output_channels, output_channels,
...@@ -226,11 +227,11 @@ class Xception_Block(fluid.dygraph.Layer): ...@@ -226,11 +227,11 @@ class Xception_Block(fluid.dygraph.Layer):
def forward(self, inputs): def forward(self, inputs):
if not self.activation_fn_in_separable_conv: if not self.activation_fn_in_separable_conv:
x = fluid.layers.relu(inputs) x = F.relu(inputs)
x = self._conv1(x) x = self._conv1(x)
x = fluid.layers.relu(x) x = F.relu(x)
x = self._conv2(x) x = self._conv2(x)
x = fluid.layers.relu(x) x = F.relu(x)
x = self._conv3(x) x = self._conv3(x)
else: else:
x = self._conv1(inputs) x = self._conv1(inputs)
...@@ -242,10 +243,10 @@ class Xception_Block(fluid.dygraph.Layer): ...@@ -242,10 +243,10 @@ class Xception_Block(fluid.dygraph.Layer):
skip = self._short(inputs) skip = self._short(inputs)
else: else:
skip = inputs skip = inputs
return fluid.layers.elementwise_add(x, skip) return paddle.elementwise_add(x, skip)
class XceptionDeeplab(fluid.dygraph.Layer): class XceptionDeeplab(nn.Layer):
def __init__(self, backbone, class_dim=1000): def __init__(self, backbone, class_dim=1000):
super(XceptionDeeplab, self).__init__() super(XceptionDeeplab, self).__init__()
...@@ -349,7 +350,7 @@ class XceptionDeeplab(fluid.dygraph.Layer): ...@@ -349,7 +350,7 @@ class XceptionDeeplab(fluid.dygraph.Layer):
self._fc = Linear( self._fc = Linear(
self.chns[1][-1], self.chns[1][-1],
class_dim, class_dim,
param_attr=ParamAttr(name="fc_weights"), weight_attr=ParamAttr(name="fc_weights"),
bias_attr=ParamAttr(name="fc_bias")) bias_attr=ParamAttr(name="fc_bias"))
def forward(self, inputs): def forward(self, inputs):
...@@ -363,7 +364,7 @@ class XceptionDeeplab(fluid.dygraph.Layer): ...@@ -363,7 +364,7 @@ class XceptionDeeplab(fluid.dygraph.Layer):
x = self._exit_flow_2(x) x = self._exit_flow_2(x)
x = self._drop(x) x = self._drop(x)
x = self._pool(x) x = self._pool(x)
x = fluid.layers.squeeze(x, axes=[2, 3]) x = paddle.squeeze(x, axis=[2, 3])
x = self._fc(x) x = self._fc(x)
return x return x
...@@ -380,4 +381,4 @@ def Xception65_deeplab(**args): ...@@ -380,4 +381,4 @@ def Xception65_deeplab(**args):
def Xception71_deeplab(**args): def Xception71_deeplab(**args):
model = XceptionDeeplab("xception_71", **args) model = XceptionDeeplab("xception_71", **args)
return model return model
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册