提交 0e1789d4 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix mv2 and mv3

上级 515c9c99
......@@ -18,9 +18,10 @@ from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, Dropout
import math
......@@ -30,7 +31,7 @@ __all__ = [
]
class ConvBNLayer(fluid.dygraph.Layer):
class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
filter_size,
......@@ -43,16 +44,14 @@ class ConvBNLayer(fluid.dygraph.Layer):
use_cudnn=True):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
self._conv = Conv2d(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + "_weights"),
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(
......@@ -66,11 +65,11 @@ class ConvBNLayer(fluid.dygraph.Layer):
y = self._conv(inputs)
y = self._batch_norm(y)
if if_act:
y = fluid.layers.relu6(y)
y = F.relu6(y)
return y
class InvertedResidualUnit(fluid.dygraph.Layer):
class InvertedResidualUnit(nn.Layer):
def __init__(self, num_channels, num_in_filter, num_filters, stride,
filter_size, padding, expansion_factor, name):
super(InvertedResidualUnit, self).__init__()
......@@ -108,11 +107,11 @@ class InvertedResidualUnit(fluid.dygraph.Layer):
y = self._bottleneck_conv(y, if_act=True)
y = self._linear_conv(y, if_act=False)
if ifshortcut:
y = fluid.layers.elementwise_add(inputs, y)
y = paddle.elementwise_add(inputs, y)
return y
class InvresiBlocks(fluid.dygraph.Layer):
class InvresiBlocks(nn.Layer):
def __init__(self, in_c, t, c, n, s, name):
super(InvresiBlocks, self).__init__()
......@@ -148,7 +147,7 @@ class InvresiBlocks(fluid.dygraph.Layer):
return y
class MobileNet(fluid.dygraph.Layer):
class MobileNet(nn.Layer):
def __init__(self, class_dim=1000, scale=1.0):
super(MobileNet, self).__init__()
self.scale = scale
......@@ -204,7 +203,7 @@ class MobileNet(fluid.dygraph.Layer):
self.out = Linear(
self.out_c,
class_dim,
param_attr=ParamAttr(name="fc10_weights"),
weight_attr=ParamAttr(name="fc10_weights"),
bias_attr=ParamAttr(name="fc10_offset"))
def forward(self, inputs):
......@@ -213,7 +212,7 @@ class MobileNet(fluid.dygraph.Layer):
y = block(y)
y = self.conv9(y, if_act=True)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self.out_c])
y = paddle.reshape(y, shape=[-1, self.out_c])
y = self.out(y)
return y
......
......@@ -18,9 +18,12 @@ from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, Dropout
# TODO: need to be removed later!
from paddle.fluid.regularizer import L2Decay
import math
......@@ -42,7 +45,7 @@ def make_divisible(v, divisor=8, min_value=None):
return new_v
class MobileNetV3(fluid.dygraph.Layer):
class MobileNetV3(nn.Layer):
def __init__(self, scale=1.0, model_name="small", class_dim=1000):
super(MobileNetV3, self).__init__()
......@@ -133,20 +136,19 @@ class MobileNetV3(fluid.dygraph.Layer):
self.pool = Pool2D(
pool_type="avg", global_pooling=True, use_cudnn=False)
self.last_conv = Conv2D(
num_channels=make_divisible(scale * self.cls_ch_squeeze),
num_filters=self.cls_ch_expand,
filter_size=1,
self.last_conv = Conv2d(
in_channels=make_divisible(scale * self.cls_ch_squeeze),
out_channels=self.cls_ch_expand,
kernel_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name="last_1x1_conv_weights"),
weight_attr=ParamAttr(name="last_1x1_conv_weights"),
bias_attr=False)
self.out = Linear(
input_dim=self.cls_ch_expand,
output_dim=class_dim,
param_attr=ParamAttr("fc_weights"),
self.cls_ch_expand,
class_dim,
weight_attr=ParamAttr("fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
def forward(self, inputs, label=None, dropout_prob=0.2):
......@@ -156,15 +158,15 @@ class MobileNetV3(fluid.dygraph.Layer):
x = self.last_second_conv(x)
x = self.pool(x)
x = self.last_conv(x)
x = fluid.layers.hard_swish(x)
x = fluid.layers.dropout(x=x, dropout_prob=dropout_prob)
x = fluid.layers.reshape(x, shape=[x.shape[0], x.shape[1]])
x = F.hard_swish(x)
x = F.dropout(x=x, p=dropout_prob)
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
x = self.out(x)
return x
class ConvBNLayer(fluid.dygraph.Layer):
class ConvBNLayer(nn.Layer):
def __init__(self,
in_c,
out_c,
......@@ -179,28 +181,24 @@ class ConvBNLayer(fluid.dygraph.Layer):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = fluid.dygraph.Conv2D(
num_channels=in_c,
num_filters=out_c,
filter_size=filter_size,
self.conv = Conv2d(
in_channels=in_c,
out_channels=out_c,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
use_cudnn=use_cudnn,
act=None)
self.bn = fluid.dygraph.BatchNorm(
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
self.bn = BatchNorm(
num_channels=out_c,
act=None,
param_attr=ParamAttr(
name=name + "_bn_scale",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
regularizer=L2Decay(regularization_coeff=0.0)),
bias_attr=ParamAttr(
name=name + "_bn_offset",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
regularizer=L2Decay(regularization_coeff=0.0)),
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance")
......@@ -209,16 +207,16 @@ class ConvBNLayer(fluid.dygraph.Layer):
x = self.bn(x)
if self.if_act:
if self.act == "relu":
x = fluid.layers.relu(x)
x = F.relu(x)
elif self.act == "hard_swish":
x = fluid.layers.hard_swish(x)
x = F.hard_swish(x)
else:
print("The activation function is selected incorrectly.")
exit()
return x
class ResidualUnit(fluid.dygraph.Layer):
class ResidualUnit(nn.Layer):
def __init__(self,
in_c,
mid_c,
......@@ -270,40 +268,38 @@ class ResidualUnit(fluid.dygraph.Layer):
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = fluid.layers.elementwise_add(inputs, x)
x = paddle.elementwise_add(inputs, x)
return x
class SEModule(fluid.dygraph.Layer):
class SEModule(nn.Layer):
def __init__(self, channel, reduction=4, name=""):
super(SEModule, self).__init__()
self.avg_pool = fluid.dygraph.Pool2D(
pool_type="avg", global_pooling=True, use_cudnn=False)
self.conv1 = fluid.dygraph.Conv2D(
num_channels=channel,
num_filters=channel // reduction,
filter_size=1,
self.avg_pool = Pool2D(pool_type="avg", global_pooling=True)
self.conv1 = Conv2d(
in_channels=channel,
out_channels=channel // reduction,
kernel_size=1,
stride=1,
padding=0,
act="relu",
param_attr=ParamAttr(name=name + "_1_weights"),
weight_attr=ParamAttr(name=name + "_1_weights"),
bias_attr=ParamAttr(name=name + "_1_offset"))
self.conv2 = fluid.dygraph.Conv2D(
num_channels=channel // reduction,
num_filters=channel,
filter_size=1,
self.conv2 = Conv2d(
in_channels=channel // reduction,
out_channels=channel,
kernel_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name + "_2_weights"),
weight_attr=ParamAttr(name + "_2_weights"),
bias_attr=ParamAttr(name=name + "_2_offset"))
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = fluid.layers.hard_sigmoid(outputs)
return fluid.layers.elementwise_mul(x=inputs, y=outputs, axis=0)
outputs = F.hard_sigmoid(outputs)
return paddle.multiply(x=inputs, y=outputs, axis=0)
def MobileNetV3_small_x0_35(**args):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册