提交 7bcdf7ad 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix xception

上级 9ecc07df
......@@ -33,6 +33,7 @@ from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, Mobil
from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2_swish
from .alexnet import AlexNet
from .inception_v4 import InceptionV4
from .xception import Xception41, Xception65, Xception71
from .xception_deeplab import Xception41_deeplab, Xception65_deeplab, Xception71_deeplab
from .resnext101_wsl import ResNeXt101_32x8d_wsl, ResNeXt101_32x16d_wsl, ResNeXt101_32x32d_wsl, ResNeXt101_32x48d_wsl
from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2_swish
......
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2d, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2d, MaxPool2d, AvgPool2d
from paddle.nn.initializer import Uniform
import math
__all__ = ['Xception41', 'Xception65', 'Xception71']
class ConvBNLayer(fluid.dygraph.Layer):
class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
num_filters,
......@@ -18,15 +21,14 @@ class ConvBNLayer(fluid.dygraph.Layer):
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
self._conv = Conv2d(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
bn_name = "bn_" + name
self._batch_norm = BatchNorm(
......@@ -43,7 +45,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
return y
class SeparableConv(fluid.dygraph.Layer):
class SeparableConv(nn.Layer):
def __init__(self, input_channels, output_channels, stride=1, name=None):
super(SeparableConv, self).__init__()
......@@ -63,7 +65,7 @@ class SeparableConv(fluid.dygraph.Layer):
return x
class EntryFlowBottleneckBlock(fluid.dygraph.Layer):
class EntryFlowBottleneckBlock(nn.Layer):
def __init__(self,
input_channels,
output_channels,
......@@ -73,14 +75,13 @@ class EntryFlowBottleneckBlock(fluid.dygraph.Layer):
super(EntryFlowBottleneckBlock, self).__init__()
self.relu_first = relu_first
self._short = Conv2D(
num_channels=input_channels,
num_filters=output_channels,
filter_size=1,
self._short = Conv2d(
in_channels=input_channels,
out_channels=output_channels,
kernel_size=1,
stride=stride,
padding=0,
act=None,
param_attr=ParamAttr(name + "_branch1_weights"),
weight_attr=ParamAttr(name + "_branch1_weights"),
bias_attr=False)
self._conv1 = SeparableConv(
input_channels,
......@@ -92,22 +93,21 @@ class EntryFlowBottleneckBlock(fluid.dygraph.Layer):
output_channels,
stride=1,
name=name + "_branch2b_weights")
self._pool = Pool2D(
pool_size=3, pool_stride=stride, pool_padding=1, pool_type="max")
self._pool = MaxPool2d(kernel_size=3, stride=stride, padding=1)
def forward(self, inputs):
conv0 = inputs
short = self._short(inputs)
if self.relu_first:
conv0 = fluid.layers.relu(conv0)
conv0 = F.relu(conv0)
conv1 = self._conv1(conv0)
conv2 = fluid.layers.relu(conv1)
conv2 = F.relu(conv1)
conv2 = self._conv2(conv2)
pool = self._pool(conv2)
return fluid.layers.elementwise_add(x=short, y=pool)
return paddle.elementwise_add(x=short, y=pool)
class EntryFlow(fluid.dygraph.Layer):
class EntryFlow(nn.Layer):
def __init__(self, block_num=3):
super(EntryFlow, self).__init__()
......@@ -154,7 +154,7 @@ class EntryFlow(fluid.dygraph.Layer):
return x
class MiddleFlowBottleneckBlock(fluid.dygraph.Layer):
class MiddleFlowBottleneckBlock(nn.Layer):
def __init__(self, input_channels, output_channels, name):
super(MiddleFlowBottleneckBlock, self).__init__()
......@@ -175,16 +175,16 @@ class MiddleFlowBottleneckBlock(fluid.dygraph.Layer):
name=name + "_branch2c_weights")
def forward(self, inputs):
conv0 = fluid.layers.relu(inputs)
conv0 = F.relu(inputs)
conv0 = self._conv_0(conv0)
conv1 = fluid.layers.relu(conv0)
conv1 = F.relu(conv0)
conv1 = self._conv_1(conv1)
conv2 = fluid.layers.relu(conv1)
conv2 = F.relu(conv1)
conv2 = self._conv_2(conv2)
return fluid.layers.elementwise_add(x=inputs, y=conv2)
return paddle.elementwise_add(x=inputs, y=conv2)
class MiddleFlow(fluid.dygraph.Layer):
class MiddleFlow(nn.Layer):
def __init__(self, block_num=8):
super(MiddleFlow, self).__init__()
......@@ -244,19 +244,18 @@ class MiddleFlow(fluid.dygraph.Layer):
return x
class ExitFlowBottleneckBlock(fluid.dygraph.Layer):
class ExitFlowBottleneckBlock(nn.Layer):
def __init__(self, input_channels, output_channels1, output_channels2,
name):
super(ExitFlowBottleneckBlock, self).__init__()
self._short = Conv2D(
num_channels=input_channels,
num_filters=output_channels2,
filter_size=1,
self._short = Conv2d(
in_channels=input_channels,
out_channels=output_channels2,
kernel_size=1,
stride=2,
padding=0,
act=None,
param_attr=ParamAttr(name + "_branch1_weights"),
weight_attr=ParamAttr(name + "_branch1_weights"),
bias_attr=False)
self._conv_1 = SeparableConv(
input_channels,
......@@ -268,20 +267,19 @@ class ExitFlowBottleneckBlock(fluid.dygraph.Layer):
output_channels2,
stride=1,
name=name + "_branch2b_weights")
self._pool = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type="max")
self._pool = MaxPool2d(kernel_size=3, stride=2, padding=1)
def forward(self, inputs):
short = self._short(inputs)
conv0 = fluid.layers.relu(inputs)
conv0 = F.relu(inputs)
conv1 = self._conv_1(conv0)
conv2 = fluid.layers.relu(conv1)
conv2 = F.relu(conv1)
conv2 = self._conv_2(conv2)
pool = self._pool(conv2)
return fluid.layers.elementwise_add(x=short, y=pool)
return paddle.elementwise_add(x=short, y=pool)
class ExitFlow(fluid.dygraph.Layer):
class ExitFlow(nn.Layer):
def __init__(self, class_dim):
super(ExitFlow, self).__init__()
......@@ -291,29 +289,28 @@ class ExitFlow(fluid.dygraph.Layer):
728, 728, 1024, name=name + "_1")
self._conv_1 = SeparableConv(1024, 1536, stride=1, name=name + "_2")
self._conv_2 = SeparableConv(1536, 2048, stride=1, name=name + "_3")
self._pool = Pool2D(pool_type="avg", global_pooling=True)
self._pool = AdaptiveAvgPool2d(1)
stdv = 1.0 / math.sqrt(2048 * 1.0)
self._out = Linear(
2048,
class_dim,
param_attr=ParamAttr(
name="fc_weights",
initializer=fluid.initializer.Uniform(-stdv, stdv)),
weight_attr=ParamAttr(
name="fc_weights", initializer=Uniform(-stdv, stdv)),
bias_attr=ParamAttr(name="fc_offset"))
def forward(self, inputs):
conv0 = self._conv_0(inputs)
conv1 = self._conv_1(conv0)
conv1 = fluid.layers.relu(conv1)
conv1 = F.relu(conv1)
conv2 = self._conv_2(conv1)
conv2 = fluid.layers.relu(conv2)
conv2 = F.relu(conv2)
pool = self._pool(conv2)
pool = fluid.layers.reshape(pool, [0, -1])
pool = paddle.reshape(pool, [0, -1])
out = self._out(pool)
return out
class Xception(fluid.dygraph.Layer):
class Xception(nn.Layer):
def __init__(self,
entry_flow_block_num=3,
middle_flow_block_num=8,
......@@ -344,4 +341,4 @@ def Xception65(**args):
def Xception71(**args):
model = Xception(entry_flow_block_num=5, middle_flow_block_num=16, **args)
return model
\ No newline at end of file
return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册