未验证 提交 12882b2f 编写于 作者: Z Zhang Zheng 提交者: GitHub

Add ResNetUnit Python API (#35426)

上级 2de0b58e
......@@ -179,7 +179,8 @@ void InplaceAddToOpPass::Run(Graph *graph) const {
out_var_ptr->GeneratedOp());
// NOTE(zhiqiu): currently, only conv2d_grad supports addto strategy
if (right_generated_op->Name() != "conv2d_grad") {
if (right_generated_op->Name() != "conv2d_grad" &&
right_generated_op->Name() != "resnet_unit_grad") {
continue;
}
......@@ -224,11 +225,13 @@ static bool IsValidConv2DGradDataGradNode(const Node &node) {
if (node.inputs.empty()) return false;
auto *generated_op = node.inputs[0];
auto *op_desc = generated_op->Op();
if (op_desc == nullptr || op_desc->Type() != "conv2d_grad") {
if (op_desc == nullptr || (op_desc->Type() != "conv2d_grad" &&
op_desc->Type() != "resnet_unit_grad")) {
return false;
}
const auto &outputs = op_desc->Outputs();
auto iter = outputs.find(GradVarName("Input"));
std::string grad_var_name = op_desc->Type() == "conv2d_grad" ? "Input" : "X";
auto iter = outputs.find(GradVarName(grad_var_name));
return iter != outputs.end() && !iter->second.empty() &&
iter->second[0] == node.Name() &&
!op_desc->GetAttrIfExists<bool>("use_addto");
......
......@@ -232,6 +232,7 @@ class ResNetUnitOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<bool>("use_addto", "").SetDefault(false);
AddAttr<std::string>("act_type", "The activation type to be fused.")
.SetDefault("relu");
AddComment(R"DOC(
......
......@@ -55,7 +55,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
int padding = ctx.Attr<int>("padding");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilate = ctx.Attr<int>("dilate");
int dilation = ctx.Attr<int>("dilation");
int group = ctx.Attr<int>("group");
double eps = static_cast<double>(ctx.Attr<float>("epsilon"));
double momentum = static_cast<double>(ctx.Attr<float>("momentum"));
......@@ -87,7 +87,7 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
sum_x.Resize(param_dims);
sum_of_squares_x.Resize(param_dims);
CudnnNormConvolution<T> conv_x_op(dev_ctx, input_x_shape, filter_x_shape,
output_shape, padding, stride, dilate,
output_shape, padding, stride, dilation,
group);
conv_x_op.Forward(dev_ctx, *input_x, *filter_x, conv_out_x, &sum_x,
&sum_of_squares_x);
......@@ -129,8 +129,8 @@ class ResNetUnitKernel : public framework::OpKernel<T> {
sum_z.Resize(param_dims);
sum_of_squares_z.Resize(param_dims);
CudnnNormConvolution<T> conv_z_op(dev_ctx, input_z_shape, filter_z_shape,
output_shape, padding, stride_z, dilate,
group);
output_shape, padding, stride_z,
dilation, group);
conv_z_op.Forward(dev_ctx, *input_z, *filter_z, conv_out_z, &sum_z,
&sum_of_squares_z);
......@@ -189,7 +189,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
int padding = ctx.Attr<int>("padding");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilate = ctx.Attr<int>("dilate");
int dilation = ctx.Attr<int>("dilation");
int group = ctx.Attr<int>("group");
double eps = static_cast<double>(ctx.Attr<float>("epsilon"));
double momentum = static_cast<double>(ctx.Attr<float>("momentum"));
......@@ -263,7 +263,7 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
auto filter_z_shape = framework::vectorize<int>(filter_z->dims());
CudnnNormConvolutionGrad<T> conv_z_op(dev_ctx, z_shape, filter_z_shape,
output_shape, padding, stride_z,
dilate, group);
dilation, group);
conv_z_op.Backward(dev_ctx, *z, *filter_z, conv_out_z_grad, z_grad,
filter_z_grad);
} else {
......@@ -278,11 +278,12 @@ class ResNetUnitGradKernel : public framework::OpKernel<T> {
}
// 2. Backward of Conv for x, get x_grad and filter_x_grad
bool use_addto = ctx.Attr<bool>("use_addto");
CudnnNormConvolutionGrad<T> conv_x_op(dev_ctx, x_shape, filter_x_shape,
output_shape, padding, stride, dilate,
group);
output_shape, padding, stride,
dilation, group);
conv_x_op.Backward(dev_ctx, *x, *filter_x, conv_out_x_grad, x_grad,
filter_x_grad);
filter_x_grad, use_addto);
}
};
......
......@@ -14,3 +14,4 @@
from .softmax_mask_fuse_upper_triangle import softmax_mask_fuse_upper_triangle # noqa: F401
from .softmax_mask_fuse import softmax_mask_fuse # noqa: F401
from .resnet_unit import ResNetUnit #noqa: F401
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import collections
import itertools
import six
import math
import sys
import warnings
from functools import partial, reduce
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle import framework
from paddle.device import get_device, get_cudnn_version
from paddle.nn import initializer as I
from paddle.nn import Layer, LayerList
from paddle.fluid.layers import utils
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.param_attr import ParamAttr
from paddle import _C_ops
__all__ = ['resnet_unit', 'ResNetUnit']
def resnet_unit(x, filter_x, scale_x, bias_x, mean_x, var_x, z, filter_z,
scale_z, bias_z, mean_z, var_z, stride, stride_z, padding,
dilation, groups, momentum, eps, data_format, fuse_add,
has_shortcut, use_global_stats, is_test, act):
helper = LayerHelper('resnet_unit', **locals())
bn_param_dtype = fluid.core.VarDesc.VarType.FP32
bit_mask_dtype = fluid.core.VarDesc.VarType.INT32
out = helper.create_variable_for_type_inference(x.dtype)
bit_mask = helper.create_variable_for_type_inference(
dtype=bit_mask_dtype, stop_gradient=True)
# intermediate_out for x
conv_x = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
saved_mean_x = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
saved_invstd_x = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
running_mean_x = mean_x
running_var_x = var_x
# intermediate_out for z
conv_z = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
saved_mean_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
saved_invstd_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True)
running_mean_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if mean_z is None else mean_z
running_var_z = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True) if var_z is None else var_z
inputs = {
'X': x,
'FilterX': filter_x,
'ScaleX': scale_x,
'BiasX': bias_x,
'MeanX': mean_x,
'VarX': var_x,
'Z': z,
'FilterZ': filter_z,
'ScaleZ': scale_z,
'BiasZ': bias_z,
'MeanZ': mean_z,
'VarZ': var_z
}
attrs = {
'stride': stride,
'stride_z': stride_z,
'padding': padding,
'dilation': dilation,
'group': groups,
'momentum': momentum,
'epsilon': eps,
'data_format': data_format,
'fuse_add': fuse_add,
'has_shortcut': has_shortcut,
'use_global_stats': use_global_stats,
'is_test': is_test,
'act_type': act
}
outputs = {
'Y': out,
'BitMask': bit_mask,
'ConvX': conv_x,
'SavedMeanX': saved_mean_x,
'SavedInvstdX': saved_invstd_x,
'RunningMeanX': running_mean_x,
'RunningVarX': running_var_x,
'ConvZ': conv_z,
'SavedMeanZ': saved_mean_z,
'SavedInvstdZ': saved_invstd_z,
'RunningMeanZ': running_mean_z,
'RunningVarZ': running_var_z,
}
helper.append_op(
type='resnet_unit', inputs=inputs, outputs=outputs, attrs=attrs)
return out
class ResNetUnit(Layer):
r"""
******Temporary version******.
ResNetUnit is designed for optimize the performence by using cudnnv8 API.
"""
def __init__(self,
num_channels_x,
num_filters,
filter_size,
stride=1,
momentum=0.9,
eps=1e-5,
data_format='NHWC',
act='relu',
fuse_add=False,
has_shortcut=False,
use_global_stats=False,
is_test=False,
filter_x_attr=None,
scale_x_attr=None,
bias_x_attr=None,
moving_mean_x_name=None,
moving_var_x_name=None,
num_channels_z=1,
stride_z=1,
filter_z_attr=None,
scale_z_attr=None,
bias_z_attr=None,
moving_mean_z_name=None,
moving_var_z_name=None):
super(ResNetUnit, self).__init__()
self._stride = stride
self._stride_z = stride_z
self._dilation = 1
self._kernel_size = utils.convert_to_list(filter_size, 2, 'kernel_size')
self._padding = (filter_size - 1) // 2
self._groups = 1
self._momentum = momentum
self._eps = eps
self._data_format = data_format
self._act = act
self._fuse_add = fuse_add
self._has_shortcut = has_shortcut
self._use_global_stats = use_global_stats
self._is_test = is_test
# check format
valid_format = {'NHWC'}
if data_format not in valid_format:
raise ValueError(
"conv_format must be one of {}, but got conv_format='{}'".
format(valid_format, data_format))
def _get_default_param_initializer(channels):
filter_elem_num = np.prod(self._kernel_size) * channels
std = (2.0 / filter_elem_num)**0.5
return I.Normal(0.0, std)
# initial filter
bn_param_dtype = fluid.core.VarDesc.VarType.FP32
bn_param_shape = [1, 1, 1, num_filters]
filter_x_shape = [num_filters, filter_size, filter_size, num_channels_x]
filter_z_shape = [num_filters, filter_size, filter_size, num_channels_z]
self.filter_x = self.create_parameter(
shape=filter_x_shape,
attr=filter_x_attr,
default_initializer=_get_default_param_initializer(num_channels_x))
self.scale_x = self.create_parameter(
shape=bn_param_shape,
attr=scale_x_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0))
self.bias_x = self.create_parameter(
shape=bn_param_shape,
attr=bias_x_attr,
dtype=bn_param_dtype,
is_bias=True)
self.mean_x = self.create_parameter(
attr=ParamAttr(
name=moving_mean_x_name,
initializer=I.Constant(0.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.mean_x.stop_gradient = True
self.var_x = self.create_parameter(
attr=ParamAttr(
name=moving_var_x_name,
initializer=I.Constant(1.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.var_x.stop_gradient = True
if has_shortcut:
self.filter_z = self.create_parameter(
shape=filter_z_shape,
attr=filter_z_attr,
default_initializer=_get_default_param_initializer(
num_channels_z))
self.scale_z = self.create_parameter(
shape=bn_param_shape,
attr=scale_z_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0))
self.bias_z = self.create_parameter(
shape=bn_param_shape,
attr=bias_z_attr,
dtype=bn_param_dtype,
is_bias=True)
self.mean_z = self.create_parameter(
attr=ParamAttr(
name=moving_mean_z_name,
initializer=I.Constant(0.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.mean_z.stop_gradient = True
self.var_z = self.create_parameter(
attr=ParamAttr(
name=moving_var_z_name,
initializer=I.Constant(1.0),
trainable=False),
shape=bn_param_shape,
dtype=bn_param_dtype)
self.var_z.stop_gradient = True
else:
self.filter_z = None
self.scale_z = None
self.bias_z = None
self.mean_z = None
self.var_z = None
def forward(self, x, z=None):
if self._fuse_add and z is None:
raise ValueError("z can not be None")
out = resnet_unit(
x, self.filter_x, self.scale_x, self.bias_x, self.mean_x,
self.var_x, z, self.filter_z, self.scale_z, self.bias_z,
self.mean_z, self.var_z, self._stride, self._stride_z,
self._padding, self._dilation, self._groups, self._momentum,
self._eps, self._data_format, self._fuse_add, self._has_shortcut,
self._use_global_stats, self._is_test, self._act)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册