未验证 提交 917235be 编写于 作者: Z zhangyikun02 提交者: GitHub

add ResNetBasicBlock python api for kunlun, test=kunlun (#44171)

上级 cb4eea92
......@@ -959,12 +959,8 @@ class ResNetBasicBlockGradXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(
resnet_basic_block,
ops::ResNetBasicBlockXPUKernel<float>,
ops::ResNetBasicBlockXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
resnet_basic_block_grad,
ops::ResNetBasicBlockGradXPUKernel<float>,
ops::ResNetBasicBlockGradXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(resnet_basic_block,
ops::ResNetBasicBlockXPUKernel<float>);
REGISTER_OP_XPU_KERNEL(resnet_basic_block_grad,
ops::ResNetBasicBlockGradXPUKernel<float>);
#endif
......@@ -519,11 +519,9 @@ XPUOpMap& get_kl2_ops() {
// Fused op
{"resnet_basic_block_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"resnet_basic_block",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
};
return s_xpu2_kernels;
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.nn as nn
from paddle.fluid import core
from paddle.incubate.xpu.resnet_block import ResNetBasicBlock
from paddle.fluid.framework import default_main_program
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static()
class XPUTestResNetBasicBlockOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = "resnet_basic_block"
self.use_dynamic_create_class = False
class TestResNetBasicBlockOp(OpTest):
def setUp(self):
paddle.disable_static()
self.dtype = self.in_type
self.place = paddle.XPUPlace(0)
self.__class__.op_type = "resnet_basic_block"
self.__class__.no_need_check_grad = True
self.getShape()
self.getDiff()
self.getShortcut()
paddle.set_default_dtype(self.dtype)
self.src = np.random.random(self.input_size).astype(self.dtype)
self.dout = np.random.random(self.output_size).astype(self.dtype)
def getShape(self):
self.in_channels = 8
self.out_channels = 8
self.stride = 1
self.input_size = [2, 8, 32, 32] # NCHW
self.output_size = [2, 8, 32, 32] # NCHW
def getDiff(self):
self.rtol = 1e-3
self.atol = 1e-3
def getShortcut(self):
self.has_shortcut = False
def Base(self):
paddle.disable_static()
conv1_weight = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
conv2_weight = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
conv3_weight = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
bn1_weight = fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=1.0))
bn1_bias = fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.0))
bn2_weight = fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=1.0))
bn2_bias = fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.0))
bn3_weight = fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=1.0))
bn3_bias = fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.0))
self.conv1 = nn.Conv2D(in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=self.stride,
padding=1,
weight_attr=conv1_weight,
bias_attr=None,
data_format='NCHW')
self.bn1 = nn.BatchNorm(self.out_channels,
act='relu',
param_attr=bn1_weight,
bias_attr=bn1_bias,
data_layout='NCHW')
self.conv2 = nn.Conv2D(in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
weight_attr=conv2_weight,
bias_attr=None,
data_format='NCHW')
self.bn2 = nn.BatchNorm(self.out_channels,
act=None,
param_attr=bn2_weight,
bias_attr=bn2_bias,
data_layout='NCHW')
self.conv3 = nn.Conv2D(in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=1,
stride=self.stride,
padding=0,
weight_attr=conv3_weight,
bias_attr=None,
data_format='NCHW')
self.bn3 = nn.BatchNorm(self.out_channels,
act=None,
param_attr=bn3_weight,
bias_attr=bn3_bias,
data_layout='NCHW')
self.relu = nn.ReLU()
tensor_src = paddle.to_tensor(self.src, stop_gradient=False)
if self.has_shortcut:
z_out = self.bn3(self.conv3(tensor_src))
else:
z_out = tensor_src
bn1_out = self.bn1(self.conv1(tensor_src))
bn2_out = self.bn2(self.conv2(bn1_out))
result = self.relu(bn2_out + z_out)
paddle.autograd.backward([result], [paddle.to_tensor(self.dout)],
True)
return result, tensor_src.grad
def FusedResNetBasicBlock(self):
paddle.disable_static()
fused_conv1_weight = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
fused_conv2_weight = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
fused_conv3_weight = fluid.ParamAttr(
initializer=fluid.initializer.Xavier(uniform=False),
learning_rate=0.001)
fused_bn1_weight = fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0))
fused_bn1_bias = fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.0))
fused_bn2_weight = fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0))
fused_bn2_bias = fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.0))
fused_bn3_weight = fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0))
fused_bn3_bias = fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.0))
if self.has_shortcut:
self.resnet_basic_block = ResNetBasicBlock(
num_channels1=self.in_channels,
num_filter1=self.out_channels,
filter1_size=3,
num_channels2=self.out_channels,
num_filter2=self.out_channels,
filter2_size=3,
num_channels3=self.in_channels,
num_filter3=self.out_channels,
filter3_size=1,
filter1_attr=fused_conv1_weight,
scale1_attr=fused_bn1_weight,
bias1_attr=fused_bn1_bias,
filter2_attr=fused_conv2_weight,
scale2_attr=fused_bn2_weight,
bias2_attr=fused_bn2_bias,
filter3_attr=fused_conv3_weight,
scale3_attr=fused_bn3_weight,
bias3_attr=fused_bn3_bias,
stride1=self.stride,
stride2=1,
stride3=self.stride,
act='relu',
padding1=1,
padding2=1,
padding3=0,
has_shortcut=True)
else:
self.resnet_basic_block = ResNetBasicBlock(
num_channels1=self.in_channels,
num_filter1=self.out_channels,
filter1_size=3,
num_channels2=self.out_channels,
num_filter2=self.out_channels,
filter2_size=3,
num_channels3=self.in_channels,
num_filter3=self.out_channels,
filter3_size=1,
filter1_attr=fused_conv1_weight,
scale1_attr=fused_bn1_weight,
bias1_attr=fused_bn1_bias,
filter2_attr=fused_conv2_weight,
scale2_attr=fused_bn2_weight,
bias2_attr=fused_bn2_bias,
filter3_attr=fused_conv3_weight,
scale3_attr=fused_bn3_weight,
bias3_attr=fused_bn3_bias,
stride1=self.stride,
stride2=1,
stride3=self.stride,
act='relu',
padding1=1,
padding2=1,
padding3=1,
has_shortcut=False)
x = paddle.to_tensor(self.src, stop_gradient=False)
out = self.resnet_basic_block.forward(x)
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)])
return out, x.grad
def test_out_and_grad_has_shortcut(self):
self.has_shortcut = True
default_main_program().random_seed = 1
base_out, base_grad = self.Base()
fused_out, fused_grad = self.FusedResNetBasicBlock()
np.testing.assert_allclose(base_out.numpy(),
fused_out.numpy(),
rtol=self.rtol,
atol=self.atol)
np.testing.assert_allclose(base_grad.numpy(),
fused_grad.numpy(),
rtol=self.rtol,
atol=self.atol)
def test_out_and_grad(self):
self.has_shortcut = False
default_main_program().random_seed = 1
base_out, base_grad = self.Base()
fused_out, fused_grad = self.FusedResNetBasicBlock()
np.testing.assert_allclose(base_out.numpy(),
fused_out.numpy(),
rtol=self.rtol,
atol=self.atol)
np.testing.assert_allclose(base_grad.numpy(),
fused_grad.numpy(),
rtol=self.rtol,
atol=self.atol)
support_types = get_xpu_op_support_types('resnet_basic_block')
for stype in support_types:
create_test_class(globals(),
XPUTestResNetBasicBlockOp,
stype,
ignore_deivce_version=[core.XPUVersion.XPU1])
if __name__ == '__main__':
unittest.main()
......@@ -38,6 +38,7 @@ from . import asp #noqa: F401
from ..fluid.layers.loss import identity_loss
from ..fluid.incubate import fleet
from . import xpu
__all__ = [
'LookAhead',
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .resnet_block import ResNetBasicBlock
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import collections
import itertools
import six
import math
import sys
import warnings
from functools import partial, reduce
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle import framework
from paddle.nn import initializer as I
from paddle.nn import Layer, LayerList
from paddle.fluid.layers import utils
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.param_attr import ParamAttr
from paddle import _C_ops
__all__ = ['resnet_basic_block', 'ResNetBasicBlock']
def resnet_basic_block(x,
filter1,
scale1,
bias1,
mean1,
var1,
filter2,
scale2,
bias2,
mean2,
var2,
filter3,
scale3,
bias3,
mean3,
var3,
stride1,
stride2,
stride3,
padding1,
padding2,
padding3,
dilation1,
dilation2,
dilation3,
groups,
momentum,
eps,
data_format,
has_shortcut,
use_global_stats=None,
training=False,
trainable_statistics=False,
find_conv_max=True):
if fluid.framework.in_dygraph_mode():
attrs = ('stride1', stride1, 'stride2', stride2, 'stride3', stride3,
'padding1', padding1, 'padding2', padding2, 'padding3',
padding3, 'dilation1', dilation1, 'dilation2', dilation2,
'dilation3', dilation3, 'group', groups, 'momentum', momentum,
'epsilon', eps, 'data_format', data_format, 'has_shortcut',
has_shortcut, 'use_global_stats', use_global_stats,
"trainable_statistics", trainable_statistics, 'is_test',
not training, 'act_type', "relu", 'find_conv_input_max',
find_conv_max)
out, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _ = \
getattr(_C_ops, "resnet_basic_block")(x, filter1, scale1, bias1, mean1, var1, filter2, scale2, bias2, mean2, var2, \
filter3, scale3, bias3, mean3, var3, mean1, var1, mean2, var2, mean3, var3, *attrs)
return out
helper = LayerHelper('resnet_basic_block', **locals())
bn_param_dtype = fluid.core.VarDesc.VarType.FP32
max_dtype = fluid.core.VarDesc.VarType.FP32
out = helper.create_variable_for_type_inference(dtype=x.dtype,
stop_gradient=True)
conv1 = helper.create_variable_for_type_inference(dtype=x.dtype,
stop_gradient=True)
saved_mean1 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
saved_invstd1 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
running_mean1 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if mean1 is None else mean1
running_var1 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if var1 is None else var1
conv2 = helper.create_variable_for_type_inference(dtype=x.dtype,
stop_gradient=True)
conv2_input = helper.create_variable_for_type_inference(dtype=x.dtype,
stop_gradient=True)
saved_mean2 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
saved_invstd2 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
running_mean2 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if mean2 is None else mean2
running_var2 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if var2 is None else var2
conv3 = helper.create_variable_for_type_inference(dtype=x.dtype,
stop_gradient=True)
saved_mean3 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
saved_invstd3 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
running_mean3 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if mean3 is None else mean3
running_var3 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if var3 is None else var3
conv1_input_max = helper.create_variable_for_type_inference(
dtype=max_dtype, stop_gradient=True)
conv1_filter_max = helper.create_variable_for_type_inference(
dtype=max_dtype, stop_gradient=True)
conv2_input_max = helper.create_variable_for_type_inference(
dtype=max_dtype, stop_gradient=True)
conv2_filter_max = helper.create_variable_for_type_inference(
dtype=max_dtype, stop_gradient=True)
conv3_input_max = helper.create_variable_for_type_inference(
dtype=max_dtype, stop_gradient=True)
conv3_filter_max = helper.create_variable_for_type_inference(
dtype=max_dtype, stop_gradient=True)
inputs = {
'X': x,
'Filter1': filter1,
'Scale1': scale1,
'Bias1': bias1,
'Mean1': mean1,
'Var1': var1,
'Filter2': filter2,
'Scale2': scale2,
'Bias2': bias2,
'Mean2': mean2,
'Var2': var2,
'Filter3': filter3,
'Scale3': scale3,
'Bias3': bias3,
'Mean3': mean3,
'Var3': var3,
}
attrs = {
'stride1': stride1,
'stride2': stride2,
'stride3': stride3,
'padding1': padding1,
'padding2': padding2,
'padding3': padding3,
'dilation1': dilation1,
'dilation2': dilation2,
'dilation3': dilation3,
'group': groups,
'momentum': momentum,
'epsilon': eps,
'data_format': data_format,
'has_shortcut': has_shortcut,
'use_global_stats': use_global_stats,
"trainable_statistics": trainable_statistics,
'is_test': not training,
'act_type': "relu",
'find_conv_input_max': find_conv_max
}
outputs = {
'Y': out,
'Conv1': conv1,
'SavedMean1': saved_mean1,
'SavedInvstd1': saved_invstd1,
'Mean1Out': running_mean1,
'Var1Out': running_var1,
'Conv2': conv2,
'SavedMean2': saved_mean2,
'SavedInvstd2': saved_invstd2,
'Mean2Out': running_mean2,
'Var2Out': running_var2,
'Conv2Input': conv2_input,
'Conv3': conv3,
'SavedMean3': saved_mean3,
'SavedInvstd3': saved_invstd3,
'Mean3Out': running_mean3,
'Var3Out': running_var3,
'MaxInput1': conv1_input_max,
'MaxFilter1': conv1_filter_max,
'MaxInput2': conv2_input_max,
'MaxFilter2': conv2_filter_max,
'MaxInput3': conv3_input_max,
'MaxFilter3': conv3_filter_max,
}
helper.append_op(type='resnet_basic_block',
inputs=inputs,
outputs=outputs,
attrs=attrs)
return out
class ResNetBasicBlock(Layer):
"""
ResNetBasicBlock is designed for optimize the performence of the basic unit of ssd resnet block.
The fusion op architecture like this:
has_shortcut = True: else:
X X
/ /
| | | |
CONV1 | CONV1 |
| | | |
BN1 | BN1 |
| | | |
RELU1 | RELU1 |
| | | |
CONV2 CONV3 CONV2 |
| | | |
BN2 BN3 BN2 |
\ / \ /
ADD ADD
| |
RELU RELU
| |
Y Y
"""
def __init__(self,
num_channels1,
num_filter1,
filter1_size,
num_channels2,
num_filter2,
filter2_size,
num_channels3,
num_filter3,
filter3_size,
stride1=1,
stride2=1,
stride3=1,
act='relu',
momentum=0.9,
eps=1e-5,
data_format='NCHW',
has_shortcut=False,
use_global_stats=False,
is_test=False,
filter1_attr=None,
scale1_attr=None,
bias1_attr=None,
moving_mean1_name=None,
moving_var1_name=None,
filter2_attr=None,
scale2_attr=None,
bias2_attr=None,
moving_mean2_name=None,
moving_var2_name=None,
filter3_attr=None,
scale3_attr=None,
bias3_attr=None,
moving_mean3_name=None,
moving_var3_name=None,
padding1=0,
padding2=0,
padding3=0,
dilation1=1,
dilation2=1,
dilation3=1,
trainable_statistics=False,
find_conv_max=True):
super(ResNetBasicBlock, self).__init__()
self._stride1 = stride1
self._stride2 = stride2
self._kernel1_size = utils.convert_to_list(filter1_size, 2,
'filter1_size')
self._kernel2_size = utils.convert_to_list(filter2_size, 2,
'filter2_size')
self._dilation1 = dilation1
self._dilation2 = dilation2
self._padding1 = padding1
self._padding2 = padding2
self._groups = 1
self._momentum = momentum
self._eps = eps
self._data_format = data_format
self._act = act
self._has_shortcut = has_shortcut
self._use_global_stats = use_global_stats
self._is_test = is_test
self._trainable_statistics = trainable_statistics
self._find_conv_max = find_conv_max
if has_shortcut:
self._kernel3_size = utils.convert_to_list(filter3_size, 2,
'filter3_size')
self._padding3 = padding3
self._stride3 = stride3
self._dilation3 = dilation3
else:
self._kernel3_size = None
self._padding3 = 1
self._stride3 = 1
self._dilation3 = 1
# check format
valid_format = {'NCHW'}
if data_format not in valid_format:
raise ValueError(
"conv_format must be one of {}, but got conv_format={}".format(
valid_format, data_format))
def _get_default_param_initializer(channels, kernel_size):
filter_elem_num = np.prod(kernel_size) * channels
std = (2.0 / filter_elem_num)**0.5
return I.Normal(0.0, std)
# init filter
bn_param_dtype = fluid.core.VarDesc.VarType.FP32
bn1_param_shape = [1, 1, num_filter1]
bn2_param_shape = [1, 1, num_filter2]
filter1_shape = [num_filter1, num_channels1, filter1_size, filter1_size]
filter2_shape = [num_filter2, num_channels2, filter2_size, filter2_size]
self.filter_1 = self.create_parameter(
shape=filter1_shape,
attr=filter1_attr,
default_initializer=_get_default_param_initializer(
num_channels1, self._kernel1_size))
self.scale_1 = self.create_parameter(
shape=bn1_param_shape,
attr=scale1_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0))
self.bias_1 = self.create_parameter(shape=bn1_param_shape,
attr=bias1_attr,
dtype=bn_param_dtype,
is_bias=True)
self.mean_1 = self.create_parameter(attr=ParamAttr(
name=moving_mean1_name,
initializer=I.Constant(0.0),
trainable=False),
shape=bn1_param_shape,
dtype=bn_param_dtype)
self.mean_1.stop_gradient = True
self.var_1 = self.create_parameter(
attr=ParamAttr(name=moving_var1_name,
initializer=I.Constant(1.0),
trainable=False),
shape=bn1_param_shape,
dtype=bn_param_dtype)
self.var_1.stop_gradient = True
self.filter_2 = self.create_parameter(
shape=filter2_shape,
attr=filter2_attr,
default_initializer=_get_default_param_initializer(
num_channels2, self._kernel2_size))
self.scale_2 = self.create_parameter(
shape=bn2_param_shape,
attr=scale2_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0))
self.bias_2 = self.create_parameter(shape=bn2_param_shape,
attr=bias2_attr,
dtype=bn_param_dtype,
is_bias=True)
self.mean_2 = self.create_parameter(attr=ParamAttr(
name=moving_mean2_name,
initializer=I.Constant(0.0),
trainable=False),
shape=bn2_param_shape,
dtype=bn_param_dtype)
self.mean_2.stop_gradient = True
self.var_2 = self.create_parameter(
attr=ParamAttr(name=moving_var2_name,
initializer=I.Constant(1.0),
trainable=False),
shape=bn2_param_shape,
dtype=bn_param_dtype)
self.var_2.stop_gradient = True
if has_shortcut:
bn3_param_shape = [1, 1, num_filter3]
filter3_shape = [
num_filter3, num_channels3, filter3_size, filter3_size
]
self.filter_3 = self.create_parameter(
shape=filter3_shape,
attr=filter3_attr,
default_initializer=_get_default_param_initializer(
num_channels3, self._kernel3_size))
self.scale_3 = self.create_parameter(
shape=bn3_param_shape,
attr=scale3_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0))
self.bias_3 = self.create_parameter(shape=bn3_param_shape,
attr=bias3_attr,
dtype=bn_param_dtype,
is_bias=True)
self.mean_3 = self.create_parameter(attr=ParamAttr(
name=moving_mean3_name,
initializer=I.Constant(0.0),
trainable=False),
shape=bn3_param_shape,
dtype=bn_param_dtype)
self.mean_3.stop_gradient = True
self.var_3 = self.create_parameter(attr=ParamAttr(
name=moving_var3_name,
initializer=I.Constant(1.0),
trainable=False),
shape=bn3_param_shape,
dtype=bn_param_dtype)
self.var_3.stop_gradient = True
else:
self.filter_3 = None
self.scale_3 = None
self.bias_3 = None
self.mean_3 = None
self.var_3 = None
def forward(self, x):
out = resnet_basic_block(
x,
self.filter_1,
self.scale_1,
self.bias_1,
self.mean_1,
self.var_1,
self.filter_2,
self.scale_2,
self.bias_2,
self.mean_2,
self.var_2,
self.filter_3,
self.scale_3,
self.bias_3,
self.mean_3,
self.var_3,
self._stride1,
self._stride2,
self._stride3,
self._padding1,
self._padding2,
self._padding3,
self._dilation1,
self._dilation2,
self._dilation3,
self._groups,
self._momentum,
self._eps,
self._data_format,
self._has_shortcut,
use_global_stats=self._use_global_stats,
training=self.training,
trainable_statistics=self._trainable_statistics,
find_conv_max=self._find_conv_max)
return out
......@@ -379,6 +379,7 @@ packages=['paddle',
'paddle.incubate.sparse.nn',
'paddle.incubate.sparse.nn.layer',
'paddle.incubate.sparse.nn.functional',
'paddle.incubate.xpu',
'paddle.io',
'paddle.optimizer',
'paddle.nn',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册