提交 ba6a4f95 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix resnet

上级 a52efec3
......@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
import paddle
from paddle import ParamAttr
# from paddle.fluid.param_attr import ParamAttr
import paddle.nn as nn
from paddle.nn import Conv2d, Pool2D, BatchNorm, Linear, Dropout
......@@ -39,14 +40,13 @@ class ConvBNLayer(nn.Layer):
super(ConvBNLayer, self).__init__()
self._conv = Conv2d(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
weight_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
......@@ -248,8 +248,8 @@ class ResNet(nn.Layer):
self.out = Linear(
self.pool2d_avg_channels,
class_dim,
param_attr=ParamAttr(
initializer=paddle.distribution.Uniform(-stdv, stdv),
weight_attr=ParamAttr(
initializer=paddle.nn.initializer.Uniform(-stdv, stdv),
name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册