提交 03bbd585 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix hrnet

上级 2bdf8a9b
...@@ -32,5 +32,6 @@ from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75 ...@@ -32,5 +32,6 @@ from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75
from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25 from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25
from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2_swish from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2_swish
from .alexnet import AlexNet from .alexnet import AlexNet
from .inception_v4 import InceptionV4
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0 from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
import paddle.fluid as fluid from paddle import ParamAttr
from paddle.fluid.param_attr import ParamAttr import paddle.nn as nn
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout import paddle.nn.functional as F
from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, Dropout
from paddle.nn.initializer import Uniform
import math import math
__all__ = ["InceptionV4"] __all__ = ["InceptionV4"]
class ConvBNLayer(fluid.dygraph.Layer):
class ConvBNLayer(nn.Layer):
def __init__(self, def __init__(self,
num_channels, num_channels,
num_filters, num_filters,
...@@ -18,15 +35,14 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -18,15 +35,14 @@ class ConvBNLayer(fluid.dygraph.Layer):
name=None): name=None):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self._conv = Conv2D( self._conv = Conv2d(
num_channels=num_channels, in_channels=num_channels,
num_filters=num_filters, out_channels=num_filters,
filter_size=filter_size, kernel_size=filter_size,
stride=stride, stride=stride,
padding=padding, padding=padding,
groups=groups, groups=groups,
act=None, weight_attr=ParamAttr(name=name + "_weights"),
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False) bias_attr=False)
bn_name = name + "_bn" bn_name = name + "_bn"
self._batch_norm = BatchNorm( self._batch_norm = BatchNorm(
...@@ -43,7 +59,7 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -43,7 +59,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
return y return y
class InceptionStem(fluid.dygraph.Layer): class InceptionStem(nn.Layer):
def __init__(self): def __init__(self):
super(InceptionStem, self).__init__() super(InceptionStem, self).__init__()
self._conv_1 = ConvBNLayer( self._conv_1 = ConvBNLayer(
...@@ -84,7 +100,7 @@ class InceptionStem(fluid.dygraph.Layer): ...@@ -84,7 +100,7 @@ class InceptionStem(fluid.dygraph.Layer):
pool1 = self._pool(conv) pool1 = self._pool(conv)
conv2 = self._conv2(conv) conv2 = self._conv2(conv)
concat = fluid.layers.concat([pool1, conv2], axis=1) concat = paddle.concat([pool1, conv2], axis=1)
conv1 = self._conv1_1(concat) conv1 = self._conv1_1(concat)
conv1 = self._conv1_2(conv1) conv1 = self._conv1_2(conv1)
...@@ -94,16 +110,16 @@ class InceptionStem(fluid.dygraph.Layer): ...@@ -94,16 +110,16 @@ class InceptionStem(fluid.dygraph.Layer):
conv2 = self._conv2_3(conv2) conv2 = self._conv2_3(conv2)
conv2 = self._conv2_4(conv2) conv2 = self._conv2_4(conv2)
concat = fluid.layers.concat([conv1, conv2], axis=1) concat = paddle.concat([conv1, conv2], axis=1)
conv1 = self._conv3(concat) conv1 = self._conv3(concat)
pool1 = self._pool(concat) pool1 = self._pool(concat)
concat = fluid.layers.concat([conv1, pool1], axis=1) concat = paddle.concat([conv1, pool1], axis=1)
return concat return concat
class InceptionA(fluid.dygraph.Layer): class InceptionA(nn.Layer):
def __init__(self, name): def __init__(self, name):
super(InceptionA, self).__init__() super(InceptionA, self).__init__()
self._pool = Pool2D(pool_size=3, pool_type="avg", pool_padding=1) self._pool = Pool2D(pool_size=3, pool_type="avg", pool_padding=1)
...@@ -154,11 +170,11 @@ class InceptionA(fluid.dygraph.Layer): ...@@ -154,11 +170,11 @@ class InceptionA(fluid.dygraph.Layer):
conv4 = self._conv4_2(conv4) conv4 = self._conv4_2(conv4)
conv4 = self._conv4_3(conv4) conv4 = self._conv4_3(conv4)
concat = fluid.layers.concat([conv1, conv2, conv3, conv4], axis=1) concat = paddle.concat([conv1, conv2, conv3, conv4], axis=1)
return concat return concat
class ReductionA(fluid.dygraph.Layer): class ReductionA(nn.Layer):
def __init__(self): def __init__(self):
super(ReductionA, self).__init__() super(ReductionA, self).__init__()
self._pool = Pool2D(pool_size=3, pool_type="max", pool_stride=2) self._pool = Pool2D(pool_size=3, pool_type="max", pool_stride=2)
...@@ -177,11 +193,11 @@ class ReductionA(fluid.dygraph.Layer): ...@@ -177,11 +193,11 @@ class ReductionA(fluid.dygraph.Layer):
conv3 = self._conv3_1(inputs) conv3 = self._conv3_1(inputs)
conv3 = self._conv3_2(conv3) conv3 = self._conv3_2(conv3)
conv3 = self._conv3_3(conv3) conv3 = self._conv3_3(conv3)
concat = fluid.layers.concat([pool1, conv2, conv3], axis=1) concat = paddle.concat([pool1, conv2, conv3], axis=1)
return concat return concat
class InceptionB(fluid.dygraph.Layer): class InceptionB(nn.Layer):
def __init__(self, name=None): def __init__(self, name=None):
super(InceptionB, self).__init__() super(InceptionB, self).__init__()
self._pool = Pool2D(pool_size=3, pool_type="avg", pool_padding=1) self._pool = Pool2D(pool_size=3, pool_type="avg", pool_padding=1)
...@@ -254,11 +270,11 @@ class InceptionB(fluid.dygraph.Layer): ...@@ -254,11 +270,11 @@ class InceptionB(fluid.dygraph.Layer):
conv4 = self._conv4_4(conv4) conv4 = self._conv4_4(conv4)
conv4 = self._conv4_5(conv4) conv4 = self._conv4_5(conv4)
concat = fluid.layers.concat([conv1, conv2, conv3, conv4], axis=1) concat = paddle.concat([conv1, conv2, conv3, conv4], axis=1)
return concat return concat
class ReductionB(fluid.dygraph.Layer): class ReductionB(nn.Layer):
def __init__(self): def __init__(self):
super(ReductionB, self).__init__() super(ReductionB, self).__init__()
self._pool = Pool2D(pool_size=3, pool_type="max", pool_stride=2) self._pool = Pool2D(pool_size=3, pool_type="max", pool_stride=2)
...@@ -294,12 +310,12 @@ class ReductionB(fluid.dygraph.Layer): ...@@ -294,12 +310,12 @@ class ReductionB(fluid.dygraph.Layer):
conv3 = self._conv3_3(conv3) conv3 = self._conv3_3(conv3)
conv3 = self._conv3_4(conv3) conv3 = self._conv3_4(conv3)
concat = fluid.layers.concat([pool1, conv2, conv3], axis=1) concat = paddle.concat([pool1, conv2, conv3], axis=1)
return concat return concat
class InceptionC(fluid.dygraph.Layer): class InceptionC(nn.Layer):
def __init__(self, name=None): def __init__(self, name=None):
super(InceptionC, self).__init__() super(InceptionC, self).__init__()
self._pool = Pool2D(pool_size=3, pool_type="avg", pool_padding=1) self._pool = Pool2D(pool_size=3, pool_type="avg", pool_padding=1)
...@@ -364,13 +380,13 @@ class InceptionC(fluid.dygraph.Layer): ...@@ -364,13 +380,13 @@ class InceptionC(fluid.dygraph.Layer):
conv4_1 = self._conv4_1(conv4) conv4_1 = self._conv4_1(conv4)
conv4_2 = self._conv4_2(conv4) conv4_2 = self._conv4_2(conv4)
concat = fluid.layers.concat( concat = paddle.concat(
[conv1, conv2, conv3_1, conv3_2, conv4_1, conv4_2], axis=1) [conv1, conv2, conv3_1, conv3_2, conv4_1, conv4_2], axis=1)
return concat return concat
class InceptionV4DY(fluid.dygraph.Layer): class InceptionV4DY(nn.Layer):
def __init__(self, class_dim=1000): def __init__(self, class_dim=1000):
super(InceptionV4DY, self).__init__() super(InceptionV4DY, self).__init__()
self._inception_stem = InceptionStem() self._inception_stem = InceptionStem()
...@@ -400,9 +416,8 @@ class InceptionV4DY(fluid.dygraph.Layer): ...@@ -400,9 +416,8 @@ class InceptionV4DY(fluid.dygraph.Layer):
self.out = Linear( self.out = Linear(
1536, 1536,
class_dim, class_dim,
param_attr=ParamAttr( weight_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv), initializer=Uniform(-stdv, stdv), name="final_fc_weights"),
name="final_fc_weights"),
bias_attr=ParamAttr(name="final_fc_offset")) bias_attr=ParamAttr(name="final_fc_offset"))
def forward(self, inputs): def forward(self, inputs):
...@@ -428,7 +443,7 @@ class InceptionV4DY(fluid.dygraph.Layer): ...@@ -428,7 +443,7 @@ class InceptionV4DY(fluid.dygraph.Layer):
x = self._inceptionC_3(x) x = self._inceptionC_3(x)
x = self.avg_pool(x) x = self.avg_pool(x)
x = fluid.layers.squeeze(x, axes=[2, 3]) x = paddle.squeeze(x, axis=[2, 3])
x = self._drop(x) x = self._drop(x)
x = self.out(x) x = self.out(x)
return x return x
...@@ -436,4 +451,4 @@ class InceptionV4DY(fluid.dygraph.Layer): ...@@ -436,4 +451,4 @@ class InceptionV4DY(fluid.dygraph.Layer):
def InceptionV4(**args): def InceptionV4(**args):
model = InceptionV4DY(**args) model = InceptionV4DY(**args)
return model return model
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册