提交 0ab81b37 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix densenet

上级 3962b385
......@@ -18,9 +18,10 @@ from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout
from paddle import ParamAttr
import paddle.nn as nn
from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, Dropout
from paddle.nn.initializer import Uniform
import math
......@@ -29,7 +30,7 @@ __all__ = [
]
class BNACConvLayer(fluid.dygraph.Layer):
class BNACConvLayer(nn.Layer):
def __init__(self,
num_channels,
num_filters,
......@@ -49,15 +50,14 @@ class BNACConvLayer(fluid.dygraph.Layer):
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance')
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
self._conv = Conv2d(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=pad,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
def forward(self, input):
......@@ -66,7 +66,7 @@ class BNACConvLayer(fluid.dygraph.Layer):
return y
class DenseLayer(fluid.dygraph.Layer):
class DenseLayer(nn.Layer):
def __init__(self, num_channels, growth_rate, bn_size, dropout, name=None):
super(DenseLayer, self).__init__()
self.dropout = dropout
......@@ -95,11 +95,11 @@ class DenseLayer(fluid.dygraph.Layer):
conv = self.bn_ac_func2(conv)
if self.dropout:
conv = self.dropout_func(conv)
conv = fluid.layers.concat([input, conv], axis=1)
conv = paddle.concat([input, conv], axis=1)
return conv
class DenseBlock(fluid.dygraph.Layer):
class DenseBlock(nn.Layer):
def __init__(self,
num_channels,
num_layers,
......@@ -132,7 +132,7 @@ class DenseBlock(fluid.dygraph.Layer):
return conv
class TransitionLayer(fluid.dygraph.Layer):
class TransitionLayer(nn.Layer):
def __init__(self, num_channels, num_output_features, name=None):
super(TransitionLayer, self).__init__()
......@@ -152,7 +152,7 @@ class TransitionLayer(fluid.dygraph.Layer):
return y
class ConvBNLayer(fluid.dygraph.Layer):
class ConvBNLayer(nn.Layer):
def __init__(self,
num_channels,
num_filters,
......@@ -164,15 +164,14 @@ class ConvBNLayer(fluid.dygraph.Layer):
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
self._conv = Conv2d(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=pad,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
self._batch_norm = BatchNorm(
num_filters,
......@@ -188,7 +187,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
return y
class DenseNet(fluid.dygraph.Layer):
class DenseNet(nn.Layer):
def __init__(self, layers=60, bn_size=4, dropout=0, class_dim=1000):
super(DenseNet, self).__init__()
......@@ -264,9 +263,8 @@ class DenseNet(fluid.dygraph.Layer):
self.out = Linear(
num_features,
class_dim,
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name="fc_weights"),
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv), name="fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
def forward(self, input):
......@@ -280,7 +278,7 @@ class DenseNet(fluid.dygraph.Layer):
conv = self.batch_norm(conv)
y = self.pool2d_avg(conv)
y = fluid.layers.reshape(y, shape=[0, -1])
y = paddle.reshape(y, shape=[0, -1])
y = self.out(y)
return y
......
......@@ -21,6 +21,7 @@ import paddle
from paddle import ParamAttr
import paddle.nn as nn
from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, Dropout
from paddle.nn.initializer import Uniform
import math
......@@ -248,8 +249,7 @@ class ResNet(nn.Layer):
self.pool2d_avg_channels,
class_dim,
weight_attr=ParamAttr(
initializer=paddle.nn.initializer.Uniform(-stdv, stdv),
name="fc_0.w_0"),
initializer=Uniform(-stdv, stdv), name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0"))
def forward(self, inputs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册