未验证 提交 12e6dbbc 编写于 作者: Z Zhang Zheng 提交者: GitHub

Add the complete code and related files of resnet_unit_op (#36366)

上级 bed4fb27
......@@ -216,7 +216,7 @@ function(op_library TARGET)
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op"
"fused_bn_add_activation_op")
"fused_bn_add_activation_op" "resnet_unit_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......
......@@ -16,7 +16,8 @@ register_operators(EXCLUDES
fusion_gru_op
fusion_lstm_op
fused_bn_add_activation_op
fused_transformer_op)
fused_transformer_op
resnet_unit_op)
# fusion_gru_op does not have CUDA kernel
op_library(fusion_gru_op)
......@@ -78,7 +79,10 @@ if (WITH_GPU OR WITH_ROCM)
nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
endif()
# resnet_unit needs cudnn 8.0 above
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000))
op_library(resnet_unit_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(resnet_unit);\n")
cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory)
cc_test(test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cc DEPS batch_norm_op fused_bn_add_activation_op tensor op_registry device_context generator memory)
endif()
......
......@@ -631,8 +631,8 @@ class CudnnBNAddReluTester {
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type_, fuse_add_,
has_shortcut_, data_shape, param_shape,
bitmask_shape);
sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, z, equiv_scale_z,
equiv_bias_z, &y, &bitmask);
sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, &z, &equiv_scale_z,
&equiv_bias_z, &y, &bitmask);
TensorCopySync(mean_x, platform::CPUPlace(), cpu_mean_x);
TensorCopySync(var_x, platform::CPUPlace(), cpu_var_x);
......@@ -690,7 +690,7 @@ class CudnnBNAddReluTester {
op::CudnnScaleBiasAddRelu<T> sbar_op(ctx, act_type, true, false, data_shape,
param_shape, bitmask_shape);
sbar_op.Backward(ctx, dy, x, bn_scale, bn_bias, saved_mean, saved_var,
bitmask, &dx, &dz, &dscale, &dbias, eps_);
&bitmask, &dx, &dz, &dscale, &dbias, eps_);
TensorCopySync(dx, platform::CPUPlace(), cpu_dx);
TensorCopySync(dz, platform::CPUPlace(), cpu_dz);
......
......@@ -38,10 +38,12 @@ class CudnnFusionOp {
&op_variant_params_, op_id));
}
~CudnnFusionOp() {
dynload::cudnnDestroyFusedOpsVariantParamPack(op_variant_params_);
dynload::cudnnDestroyFusedOpsConstParamPack(op_const_params_);
dynload::cudnnDestroyFusedOpsPlan(op_);
~CudnnFusionOp() PADDLE_MAY_THROW {
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnDestroyFusedOpsVariantParamPack(op_variant_params_));
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnDestroyFusedOpsConstParamPack(op_const_params_));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyFusedOpsPlan(op_));
}
// Execute fused op
......
......@@ -94,13 +94,13 @@ template <typename T>
class CudnnScaleBiasAddRelu {
public:
CudnnScaleBiasAddRelu(const platform::CUDADeviceContext &ctx,
const std::string &act_type, bool fused_add,
const std::string &act_type, bool fuse_add,
bool has_shortcut, const std::vector<int> &data_shape,
const std::vector<int> &param_shape,
const std::vector<int> &bitmask_shape)
: fwd_op_(CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK),
bwd_op_(CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM) {
fused_add_ = fused_add;
fuse_add_ = fuse_add;
has_shortcut_ = has_shortcut;
args_.Set(act_type, data_shape, param_shape, bitmask_shape);
}
......@@ -108,8 +108,8 @@ class CudnnScaleBiasAddRelu {
~CudnnScaleBiasAddRelu() {}
void Forward(const platform::CUDADeviceContext &ctx, const Tensor &x,
const Tensor &x_scale, const Tensor &x_bias, const Tensor &z,
const Tensor &z_scale, const Tensor &z_bias, Tensor *out,
const Tensor &x_scale, const Tensor &x_bias, const Tensor *z,
const Tensor *z_scale, const Tensor *z_bias, Tensor *out,
Tensor *bitmask) {
ForwardInit(ctx);
auto handle = ctx.cudnn_handle();
......@@ -125,15 +125,15 @@ class CudnnScaleBiasAddRelu {
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, x_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, x_bias_ptr);
if (has_shortcut_) {
T *z_ptr = const_cast<T *>(z.data<T>());
T *z_scale_ptr = const_cast<T *>(z_scale.data<T>());
T *z_bias_ptr = const_cast<T *>(z_bias.data<T>());
T *z_ptr = const_cast<T *>(z->data<T>());
T *z_scale_ptr = const_cast<T *>(z_scale->data<T>());
T *z_bias_ptr = const_cast<T *>(z_bias->data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr);
} else {
if (fused_add_) {
T *z_ptr = const_cast<T *>(z.data<T>());
if (fuse_add_) {
T *z_ptr = const_cast<T *>(z->data<T>());
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr);
}
}
......@@ -160,7 +160,7 @@ class CudnnScaleBiasAddRelu {
void Backward(const platform::CUDADeviceContext &ctx, const Tensor &dy,
const Tensor &x, const Tensor &scale, const Tensor &bias,
const Tensor &saved_mean, const Tensor &saved_invstd,
const Tensor &bitmask, Tensor *dx, Tensor *dz, Tensor *dscale,
const Tensor *bitmask, Tensor *dx, Tensor *dz, Tensor *dscale,
Tensor *dbias, double eps) {
BackwardInit(ctx);
auto handle = ctx.cudnn_handle();
......@@ -175,7 +175,8 @@ class CudnnScaleBiasAddRelu {
float *bias_ptr = const_cast<float *>(bias.data<float>());
float *saved_mean_ptr = const_cast<float *>(saved_mean.data<float>());
float *saved_invstd_ptr = const_cast<float *>(saved_invstd.data<float>());
int32_t *bitmask_ptr = const_cast<int32_t *>(bitmask.data<int32_t>());
int32_t *bitmask_ptr =
bitmask ? const_cast<int32_t *>(bitmask->data<int32_t>()) : nullptr;
T *dx_ptr = dx->mutable_data<T>(place);
T *dz_ptr = dz ? dz->mutable_data<T>(place) : nullptr;
float *dscale_ptr = dscale ? dscale->mutable_data<float>(place) : nullptr;
......@@ -199,7 +200,7 @@ class CudnnScaleBiasAddRelu {
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_DBIAS, dbias_ptr);
bwd_op_.SetOpVariantParamAttrPtr<double>(CUDNN_SCALAR_DOUBLE_BN_EPSILON,
&eps);
if (has_shortcut_ || fused_add_) {
if (has_shortcut_ || fuse_add_) {
bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DZDATA, dz_ptr);
}
......@@ -226,14 +227,14 @@ class CudnnScaleBiasAddRelu {
{CUDNN_PARAM_ZDATA_PLACEHOLDER, CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER,
CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
} else if (fused_add_) {
} else if (fuse_add_) {
fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_ZDATA_PLACEHOLDER,
CUDNN_PTR_16B_ALIGNED);
}
// input desc
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
if (has_shortcut_ || fused_add_) {
if (has_shortcut_ || fuse_add_) {
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ZDESC, args_.in_desc.desc());
}
......@@ -271,7 +272,7 @@ class CudnnScaleBiasAddRelu {
CUDNN_PARAM_BN_DSCALE_PLACEHOLDER, CUDNN_PARAM_BN_DBIAS_PLACEHOLDER,
CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
if (has_shortcut_ || fused_add_) {
if (has_shortcut_ || fuse_add_) {
bwd_op_.SetOpConstParamAttr(CUDNN_PARAM_DZDATA_PLACEHOLDER,
CUDNN_PTR_16B_ALIGNED);
}
......@@ -279,7 +280,7 @@ class CudnnScaleBiasAddRelu {
// input desc
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DXDESC, args_.in_desc.desc());
if (has_shortcut_ || fused_add_) {
if (has_shortcut_ || fuse_add_) {
bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DZDESC, args_.in_desc.desc());
}
......@@ -303,7 +304,7 @@ class CudnnScaleBiasAddRelu {
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
}
bool fused_add_ = false;
bool fuse_add_ = false;
bool has_shortcut_ = false;
size_t fwd_workspace_byte_;
size_t bwd_workspace_byte_;
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
// Shape of bitmask
static framework::DDim GetBitmaskDims(std::vector<int> out_shape) {
int c = out_shape.back();
int64_t nhw = std::accumulate(out_shape.begin(), out_shape.end(), 1,
std::multiplies<int>()) /
c;
int32_t c_int32_elems = ((c + 63) & ~63) / 32;
int32_t nhw_int32_elems = ((nhw + 31) & ~31);
std::vector<int> bitmask_shape = {nhw_int32_elems, c_int32_elems, 1};
return framework::make_ddim(bitmask_shape);
}
class ResNetUnitOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const {
// Check input
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("FilterX"), "Input", "FilterX",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("ScaleX"), "Input", "ScaleX", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("BiasX"), "Input", "BiasX", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("MeanX"), "Input", "MeanX", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("VarX"), "Input", "VarX", "ResNetUnitOp");
bool fuse_add = ctx->Attrs().Get<bool>("fuse_add");
bool has_shortcut = ctx->Attrs().Get<bool>("has_shortcut");
if (fuse_add || has_shortcut) {
OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", "ResNetUnitOp");
}
if (has_shortcut) {
OP_INOUT_CHECK(ctx->HasInput("FilterZ"), "Input", "FilterZ",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("ScaleZ"), "Input", "ScaleZ",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("BiasZ"), "Input", "BiasZ", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("MeanZ"), "Input", "MeanZ", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasInput("VarZ"), "Input", "VarZ", "ResNetUnitOp");
}
// Check output
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("BitMask"), "Output", "BitMask",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("ConvX"), "Output", "ConvX", "ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("SavedMeanX"), "Output", "SavedMeanX",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("SavedInvstdX"), "Output", "SavedInvstdX",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("RunningMeanX"), "Output", "RunningMeanX",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("RunningVarX"), "Output", "RunningVarX",
"ResNetUnitOp");
if (has_shortcut) {
OP_INOUT_CHECK(ctx->HasOutput("ConvZ"), "Output", "ConvZ",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("SavedMeanZ"), "Output", "SavedMeanZ",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("SavedInvstdZ"), "Output", "SavedInvstdZ",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("RunningMeanZ"), "Output", "RunningMeanZ",
"ResNetUnitOp");
OP_INOUT_CHECK(ctx->HasOutput("RunningVarZ"), "Output", "RunningVarZ",
"ResNetUnitOp");
}
// make sure Mean/RunningMean and Var/RunningVar share memory
PADDLE_ENFORCE_EQ(
ctx->Inputs("MeanX")[0], ctx->Outputs("RunningMeanX")[0],
platform::errors::InvalidArgument(
"MeanX and RunningMeanX should share the same memory"));
PADDLE_ENFORCE_EQ(ctx->Inputs("VarX")[0], ctx->Outputs("RunningVarX")[0],
platform::errors::InvalidArgument(
"VarX and RunningVarX should share the same memory"));
if (has_shortcut) {
PADDLE_ENFORCE_EQ(
ctx->Inputs("MeanZ")[0], ctx->Outputs("RunningMeanZ")[0],
platform::errors::InvalidArgument(
"MeanZ and RunningMeanZ should share the same memory"));
PADDLE_ENFORCE_EQ(
ctx->Inputs("VarZ")[0], ctx->Outputs("RunningVarZ")[0],
platform::errors::InvalidArgument(
"VarZ and RunningVarZ should share the same memory"));
}
// Check dims of inputs
const auto x_dims = ctx->GetInputDim("X");
const auto w_dims = ctx->GetInputDim("FilterX");
const auto bn_param_dims = ctx->GetInputDim("ScaleX");
PADDLE_ENFORCE_EQ(x_dims.size(), 4, platform::errors::InvalidArgument(
"The dimensions of input "
"must equal to 4."
"But received: the shape of input "
"= [%s], the dimension of input = "
"[%d]",
x_dims, x_dims.size()));
PADDLE_ENFORCE_EQ(w_dims.size(), 4,
platform::errors::InvalidArgument(
"The dimensions of filter "
"must equal to 4."
"But received: the shape of filter "
"= [%s], the dimension of filter = [%d] ",
w_dims, w_dims.size()));
PADDLE_ENFORCE_EQ(bn_param_dims.size(), 4,
platform::errors::InvalidArgument(
"The dimensions of bn param "
"must equal to 4."
"But received: the shape of bn param "
"= [%s], the dimension of bn param = [%d] ",
bn_param_dims, bn_param_dims.size()));
auto data_format = ctx->Attrs().Get<std::string>("data_format");
PADDLE_ENFORCE_EQ(
data_format, "NHWC",
platform::errors::InvalidArgument("The data format must equal to NHWC. "
"But received: the data format "
"= [%s]",
data_format));
// Calculate the dims of outputs
int batch = x_dims[0];
int output_channel = w_dims[0];
int filter_size = w_dims[2];
int stride = ctx->Attrs().Get<int>("stride");
int padding = ctx->Attrs().Get<int>("padding");
int out_h = (x_dims[1] + padding * 2 - filter_size) / stride + 1;
int out_w = (x_dims[2] + padding * 2 - filter_size) / stride + 1;
std::vector<int> out_shape = {batch, out_h, out_w, output_channel};
auto y_dims = framework::make_ddim(out_shape);
auto bitmask_dims = GetBitmaskDims(out_shape);
// Set dims of outputs
ctx->SetOutputDim("Y", y_dims);
ctx->SetOutputDim("BitMask", bitmask_dims);
ctx->SetOutputDim("ConvX", y_dims);
ctx->SetOutputDim("SavedMeanX", bn_param_dims);
ctx->SetOutputDim("SavedInvstdX", bn_param_dims);
ctx->SetOutputDim("RunningMeanX", bn_param_dims);
ctx->SetOutputDim("RunningVarX", bn_param_dims);
if (has_shortcut) {
ctx->SetOutputDim("ConvZ", y_dims);
ctx->SetOutputDim("SavedMeanZ", bn_param_dims);
ctx->SetOutputDim("SavedInvstdZ", bn_param_dims);
ctx->SetOutputDim("RunningMeanZ", bn_param_dims);
ctx->SetOutputDim("RunningVarZ", bn_param_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should be float when input tensor's dtype is float16.
auto bn_param_type = framework::proto::VarType::FP32;
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("ScaleX")->type(),
platform::errors::InvalidArgument(
"Scale input should be of float type"));
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("BiasX")->type(),
platform::errors::InvalidArgument(
"Bias input should be of float type"));
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library);
}
};
class ResNetUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "The input 1 tensor");
AddInput("FilterX", "Filter tensor of input 1");
AddInput("ScaleX", "Scale tensor of input 1 used in batchnorm");
AddInput("BiasX", "Bias tensor of input 1 used in batchnorm");
AddInput("MeanX", "Mean tensor of input 1 used in batchnorm");
AddInput("VarX", "Variance tensor of input 1 used in batchnorm");
AddInput("Z", "The input 2 tensor").AsDispensable();
AddInput("FilterZ", "Filter tensor of input 2").AsDispensable();
AddInput("ScaleZ", "Scale tensor of input 2").AsDispensable();
AddInput("BiasZ", "Bias tensor of input 2").AsDispensable();
AddInput("MeanZ", "Mean tensor of input 2").AsDispensable();
AddInput("VarZ", "Variance tensor of input 2").AsDispensable();
AddOutput("Y", "The result of the resnet unit");
AddOutput("BitMask", "The bitmask generated after relu");
AddOutput("ConvX", "The output of input 1 after conv");
AddOutput("SavedMeanX", "Mean of input 1 in the current batch");
AddOutput("SavedInvstdX", "Invstd of input 1 in the current batch");
AddOutput("RunningMeanX", "Shared memory with MeanX");
AddOutput("RunningVarX", "Shared memory with VarX");
AddOutput("ConvZ", "The output of input 2 after conv").AsDispensable();
AddOutput("SavedMeanZ", "Mean of input 1 in the current batch")
.AsDispensable();
AddOutput("SavedInvstdZ", "Invstd of input 1 in the current batch")
.AsDispensable();
AddOutput("RunningMeanZ", "Shared memory with MeanZ").AsDispensable();
AddOutput("RunningVarZ", "Shared memory with VarZ").AsDispensable();
AddAttr<int>("stride", "").SetDefault(1);
AddAttr<int>("stride_z", "").SetDefault(1);
AddAttr<int>("padding", "").SetDefault(0);
AddAttr<int>("dilation", "").SetDefault(1);
AddAttr<int>("group", "").SetDefault(1);
AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "").SetDefault(1e-5);
AddAttr<std::string>("data_format", "").SetDefault("NHWC");
AddAttr<bool>("fuse_add", "").SetDefault(false);
AddAttr<bool>("has_shortcut", "").SetDefault(false);
AddAttr<bool>("use_global_stats", "").SetDefault(false);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddAttr<std::string>("act_type", "The activation type to be fused.")
.SetDefault("relu");
AddComment(R"DOC(
Fusion op of the basic unit of resnet block.
The implementation is based on the latest fusion op interface in cuDNN v8.0.
For more details:
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnFusedOps_t
)DOC");
}
};
class ResNetUnitGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const {
// check input
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("FilterX"), "Input", "FilterX",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("ConvX"), "Input", "ConvX",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("ScaleX"), "Input", "ScaleX",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("BiasX"), "Input", "BiasX",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("SavedMeanX"), "Input", "SavedMeanX",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("SavedInvstdX"), "Input", "SavedInvstdX",
"ResNetUnitGradOp");
bool fuse_add = ctx->Attrs().Get<bool>("fuse_add");
bool has_shortcut = ctx->Attrs().Get<bool>("has_shortcut");
if (fuse_add || has_shortcut) {
OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", "ResNetUnitGradOp");
}
if (has_shortcut) {
OP_INOUT_CHECK(ctx->HasInput("FilterZ"), "Input", "FilterZ",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("ConvZ"), "Input", "ConvZ",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("ScaleZ"), "Input", "ScaleZ",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("BiasZ"), "Input", "BiasZ",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("SavedMeanZ"), "Input", "SavedMeanZ",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("SavedInvstdZ"), "Input", "SavedInvstdZ",
"ResNetUnitGradOp");
}
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput("BitMask"), "Input", "BitMask",
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
framework::GradVarName("Y"), "ResNetUnitGradOp");
// check output
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("FilterX")), "Output",
framework::GradVarName("FilterX"), "ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("ScaleX")), "Output",
framework::GradVarName("ScaleX"), "ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BiasX")), "Output",
framework::GradVarName("BiasX"), "ResNetUnitGradOp");
if (fuse_add) {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Z")), "Output",
framework::GradVarName("Z"), "ResNetUnitGradOp");
}
if (has_shortcut) {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("FilterZ")),
"Output", framework::GradVarName("FilterZ"),
"ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("ScaleZ")), "Output",
framework::GradVarName("ScaleZ"), "ResNetUnitGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BiasZ")), "Output",
framework::GradVarName("BiasZ"), "ResNetUnitGradOp");
}
const auto x_dims = ctx->GetInputDim("X");
const auto filter_x_dims = ctx->GetInputDim("FilterX");
const auto param_dims = ctx->GetInputDim("ScaleX");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->SetOutputDim(framework::GradVarName("FilterX"), filter_x_dims);
ctx->SetOutputDim(framework::GradVarName("ScaleX"), param_dims);
ctx->SetOutputDim(framework::GradVarName("BiasX"), param_dims);
if (fuse_add || has_shortcut) {
const auto z_dims = ctx->GetInputDim("Z");
ctx->SetOutputDim(framework::GradVarName("Z"), z_dims);
}
if (has_shortcut) {
const auto filter_z_dims = ctx->GetInputDim("FilterZ");
ctx->SetOutputDim(framework::GradVarName("FilterZ"), filter_z_dims);
ctx->SetOutputDim(framework::GradVarName("ScaleZ"), param_dims);
ctx->SetOutputDim(framework::GradVarName("BiasZ"), param_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
PADDLE_ENFORCE_NOT_NULL(
ctx.InputVar(framework::GradVarName("Y")),
platform::errors::NotFound(
"Can not find Y@GRAD in the execution context."));
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
}
};
template <typename T>
class ResNetUnitGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("resnet_unit_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("FilterX", this->Input("FilterX"));
op->SetInput("ConvX", this->Output("ConvX"));
op->SetInput("ScaleX", this->Input("ScaleX"));
op->SetInput("BiasX", this->Input("BiasX"));
op->SetInput("SavedMeanX", this->Output("SavedMeanX"));
op->SetInput("SavedInvstdX", this->Output("SavedInvstdX"));
op->SetInput("Z", this->Input("Z"));
op->SetInput("FilterZ", this->Input("FilterZ"));
op->SetInput("ConvZ", this->Output("ConvZ"));
op->SetInput("ScaleZ", this->Input("ScaleZ"));
op->SetInput("BiasZ", this->Input("BiasZ"));
op->SetInput("SavedMeanZ", this->Output("SavedMeanZ"));
op->SetInput("SavedInvstdZ", this->Output("SavedInvstdZ"));
op->SetInput("Y", this->Output("Y"));
op->SetInput("BitMask", this->Output("BitMask"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("FilterX"),
this->InputGrad("FilterX"));
op->SetOutput(framework::GradVarName("ScaleX"), this->InputGrad("ScaleX"));
op->SetOutput(framework::GradVarName("BiasX"), this->InputGrad("BiasX"));
op->SetOutput(framework::GradVarName("Z"), this->InputGrad("Z"));
op->SetOutput(framework::GradVarName("FilterZ"),
this->InputGrad("FilterZ"));
op->SetOutput(framework::GradVarName("ScaleZ"), this->InputGrad("ScaleZ"));
op->SetOutput(framework::GradVarName("BiasZ"), this->InputGrad("BiasZ"));
}
};
class ResNetUnitOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string>& GetInputOutputWithSameType()
const override {
static std::unordered_map<std::string, std::string> m{{"X", /*->*/ "Y"}};
return m;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(resnet_unit, ops::ResNetUnitOp, ops::ResNetUnitOpMaker,
ops::ResNetUnitOpInferVarType,
ops::ResNetUnitGradOpMaker<paddle::framework::OpDesc>,
ops::ResNetUnitGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(resnet_unit_grad, ops::ResNetUnitGradOp);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h"
#include "paddle/fluid/operators/fused/cudnn_norm_conv.cu.h"
#include "paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class ResNetUnitKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("It must use CUDAPlace."));
PADDLE_ENFORCE_EQ(platform::CudnnDataType<T>::type, CUDNN_DATA_HALF,
platform::errors::Unavailable(
"ResNetUnitOp only supports float16 for now."));
// input x
const Tensor *input_x = ctx.Input<Tensor>("X");
const Tensor *filter_x = ctx.Input<Tensor>("FilterX");
const Tensor *scale_x = ctx.Input<Tensor>("ScaleX");
const Tensor *bias_x = ctx.Input<Tensor>("BiasX");
// norm conv
Tensor *conv_out_x = ctx.Output<Tensor>("ConvX");
// bn finalize
Tensor *saved_mean_x = ctx.Output<Tensor>("SavedMeanX");
Tensor *saved_invstd_x = ctx.Output<Tensor>("SavedInvstdX");
Tensor *running_mean_x = ctx.Output<Tensor>("RunningMeanX");
Tensor *running_var_x = ctx.Output<Tensor>("RunningVarX");
// sbar
Tensor *output = ctx.Output<Tensor>("Y");
Tensor *bitmask = ctx.Output<Tensor>("BitMask");
// attrs
int padding = ctx.Attr<int>("padding");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilate = ctx.Attr<int>("dilate");
int group = ctx.Attr<int>("group");
double eps = static_cast<double>(ctx.Attr<float>("epsilon"));
double momentum = static_cast<double>(ctx.Attr<float>("momentum"));
bool has_shortcut = ctx.Attr<bool>("has_shortcut");
bool fuse_add = ctx.Attr<bool>("fuse_add");
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
bool is_test = ctx.Attr<bool>("is_test");
bool is_train = !is_test && !use_global_stats;
std::string act_type = ctx.Attr<std::string>("act_type");
auto input_x_shape = framework::vectorize<int>(input_x->dims());
auto filter_x_shape = framework::vectorize<int>(filter_x->dims());
auto param_dims = scale_x->dims();
auto param_shape = framework::vectorize<int>(scale_x->dims());
auto output_shape = framework::vectorize<int>(output->dims());
auto bitmask_shape = framework::vectorize<int>(bitmask->dims());
int output_channel = filter_x_shape[0];
int64_t ele_count =
std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>()) /
output_channel;
auto place = ctx.GetPlace();
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
// 1. Conv
Tensor sum_x;
Tensor sum_of_squares_x;
sum_x.Resize(param_dims);
sum_of_squares_x.Resize(param_dims);
CudnnNormConvolution<T> conv_x_op(dev_ctx, input_x_shape, filter_x_shape,
output_shape, padding, stride, dilate,
group);
conv_x_op.Forward(dev_ctx, *input_x, *filter_x, conv_out_x, &sum_x,
&sum_of_squares_x);
// 2. BN
Tensor equiv_scale_x;
Tensor equiv_bias_x;
equiv_scale_x.Resize(param_dims);
equiv_bias_x.Resize(param_dims);
CudnnBNStatsFinalize<T> bn_x_op(dev_ctx, param_shape);
bn_x_op.Forward(dev_ctx, sum_x, sum_of_squares_x, *scale_x, *bias_x,
saved_mean_x, saved_invstd_x, running_mean_x, running_var_x,
&equiv_scale_x, &equiv_bias_x, eps, momentum, ele_count,
is_train);
// 3. scale + bias + add + relu
CudnnScaleBiasAddRelu<T> sbar_op(dev_ctx, act_type, fuse_add, has_shortcut,
output_shape, param_shape, bitmask_shape);
if (has_shortcut) {
// input z
const Tensor *input_z = ctx.Input<Tensor>("Z");
const Tensor *filter_z = ctx.Input<Tensor>("FilterZ");
const Tensor *scale_z = ctx.Input<Tensor>("ScaleZ");
const Tensor *bias_z = ctx.Input<Tensor>("BiasZ");
// norm conv
Tensor *conv_out_z = ctx.Output<Tensor>("ConvZ");
// bn finalize
Tensor *saved_mean_z = ctx.Output<Tensor>("SavedMeanZ");
Tensor *saved_invstd_z = ctx.Output<Tensor>("SavedInvstdZ");
Tensor *running_mean_z = ctx.Output<Tensor>("RunningMeanZ");
Tensor *running_var_z = ctx.Output<Tensor>("RunningVarZ");
auto input_z_shape = framework::vectorize<int>(input_z->dims());
auto filter_z_shape = framework::vectorize<int>(filter_z->dims());
// 3.1 Conv for second input
Tensor sum_z;
Tensor sum_of_squares_z;
sum_z.Resize(param_dims);
sum_of_squares_z.Resize(param_dims);
CudnnNormConvolution<T> conv_z_op(dev_ctx, input_z_shape, filter_z_shape,
output_shape, padding, stride_z, dilate,
group);
conv_z_op.Forward(dev_ctx, *input_z, *filter_z, conv_out_z, &sum_z,
&sum_of_squares_z);
// 3.2 BN for second input
Tensor equiv_scale_z;
Tensor equiv_bias_z;
equiv_scale_z.Resize(param_dims);
equiv_bias_z.Resize(param_dims);
CudnnBNStatsFinalize<T> bn_z_op(dev_ctx, param_shape);
bn_z_op.Forward(dev_ctx, sum_z, sum_of_squares_z, *scale_z, *bias_z,
saved_mean_z, saved_invstd_z, running_mean_z,
running_var_z, &equiv_scale_z, &equiv_bias_z, eps,
momentum, ele_count, is_train);
// 3.3 sbar
sbar_op.Forward(dev_ctx, *conv_out_x, equiv_scale_x, equiv_bias_x,
conv_out_z, &equiv_scale_z, &equiv_bias_z, output,
bitmask);
} else {
const Tensor *input_z = fuse_add ? ctx.Input<Tensor>("Z") : nullptr;
sbar_op.Forward(dev_ctx, *conv_out_x, equiv_scale_x, equiv_bias_x,
input_z, nullptr, nullptr, output, bitmask);
}
}
};
template <typename T>
class ResNetUnitGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("It must use CUDAPlace."));
PADDLE_ENFORCE_EQ(platform::CudnnDataType<T>::type, CUDNN_DATA_HALF,
platform::errors::Unavailable(
"ResNetUnitOp only supports float16 for now."));
const Tensor *y_grad = ctx.Input<Tensor>(framework::GradVarName("Y"));
const Tensor *x = ctx.Input<Tensor>("X");
const Tensor *filter_x = ctx.Input<Tensor>("FilterX");
const Tensor *scale_x = ctx.Input<Tensor>("ScaleX");
const Tensor *bias_x = ctx.Input<Tensor>("BiasX");
const Tensor *saved_mean_x = ctx.Input<Tensor>("SavedMeanX");
const Tensor *saved_invstd_x = ctx.Input<Tensor>("SavedInvstdX");
const Tensor *conv_out_x = ctx.Input<Tensor>("ConvX");
const Tensor *output = ctx.Input<Tensor>("Y");
const Tensor *bitmask = ctx.Input<Tensor>("BitMask");
Tensor *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor *filter_x_grad =
ctx.Output<Tensor>(framework::GradVarName("FilterX"));
Tensor *scale_x_grad = ctx.Output<Tensor>(framework::GradVarName("ScaleX"));
Tensor *bias_x_grad = ctx.Output<Tensor>(framework::GradVarName("BiasX"));
int padding = ctx.Attr<int>("padding");
int stride = ctx.Attr<int>("stride");
int stride_z = ctx.Attr<int>("stride_z");
int dilate = ctx.Attr<int>("dilate");
int group = ctx.Attr<int>("group");
double eps = static_cast<double>(ctx.Attr<float>("epsilon"));
double momentum = static_cast<double>(ctx.Attr<float>("momentum"));
bool has_shortcut = ctx.Attr<bool>("has_shortcut");
bool fuse_add = ctx.Attr<bool>("fuse_add");
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
std::string act_type = ctx.Attr<std::string>("act_type");
auto x_shape = framework::vectorize<int>(x->dims());
auto filter_x_shape = framework::vectorize<int>(filter_x->dims());
auto param_shape = framework::vectorize<int>(scale_x->dims());
auto output_shape = framework::vectorize<int>(output->dims());
auto bitmask_shape = framework::vectorize<int>(bitmask->dims());
auto place = ctx.GetPlace();
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
// 1. Backward of BN (+ Add + Relu) for x, get conv_out_x_grad,
// scale_x_grad, bias_x_grad
Tensor conv_out_x_grad;
conv_out_x_grad.Resize(conv_out_x->dims());
CudnnScaleBiasAddRelu<T> sbar_x_op(dev_ctx, act_type, fuse_add,
has_shortcut, output_shape, param_shape,
bitmask_shape);
if (has_shortcut) {
// X Z
// | |
// NormConv NormConv
// | |
// BNStatsFinalize BNStatsFinalize
// \ /
// ScaleBiasAddRelu
// |
// Y
const Tensor *z = ctx.Input<Tensor>("Z");
const Tensor *filter_z = ctx.Input<Tensor>("FilterZ");
const Tensor *scale_z = ctx.Input<Tensor>("ScaleZ");
const Tensor *bias_z = ctx.Input<Tensor>("BiasZ");
const Tensor *saved_mean_z = ctx.Input<Tensor>("SavedMeanZ");
const Tensor *saved_invstd_z = ctx.Input<Tensor>("SavedInvstdZ");
const Tensor *conv_out_z = ctx.Input<Tensor>("ConvZ");
Tensor *z_grad = ctx.Output<Tensor>(framework::GradVarName("Z"));
Tensor *filter_z_grad =
ctx.Output<Tensor>(framework::GradVarName("FilterZ"));
Tensor *scale_z_grad =
ctx.Output<Tensor>(framework::GradVarName("ScaleZ"));
Tensor *bias_z_grad = ctx.Output<Tensor>(framework::GradVarName("BiasZ"));
// 1.1 Backward of BN + Add (+ Relu) for x, get conv_out_x_grad,
// scale_x_grad, bias_x_grad and z_grad_temp
Tensor z_grad_temp;
z_grad_temp.Resize(conv_out_z->dims());
sbar_x_op.Backward(dev_ctx, *y_grad, *conv_out_x, *scale_x, *bias_x,
*saved_mean_x, *saved_invstd_x, bitmask,
&conv_out_x_grad, &z_grad_temp, scale_x_grad,
bias_x_grad, eps);
// 1.2 bn backward for z, get conv_out_z_grad, dscale_z, dbias_z
Tensor conv_out_z_grad;
conv_out_z_grad.Resize(conv_out_z->dims());
CudnnScaleBiasAddRelu<T> sbar_z_op(
dev_ctx, "", false, false, output_shape, param_shape, bitmask_shape);
sbar_z_op.Backward(dev_ctx, z_grad_temp, *conv_out_z, *scale_z, *bias_z,
*saved_mean_z, *saved_invstd_z, nullptr,
&conv_out_z_grad, nullptr, scale_z_grad, bias_z_grad,
eps);
// 1.3 Backward of Conv for z, get z_grad and filter_z_grad
auto z_shape = framework::vectorize<int>(z->dims());
auto filter_z_shape = framework::vectorize<int>(filter_z->dims());
CudnnNormConvolutionGrad<T> conv_z_op(dev_ctx, z_shape, filter_z_shape,
output_shape, padding, stride_z,
dilate, group);
conv_z_op.Backward(dev_ctx, *z, *filter_z, conv_out_z_grad, z_grad,
filter_z_grad);
} else {
// 1.1 Backward of BN (+ Add + Relu) for x, get conv_out_x_grad,
// scale_x_grad, bias_x_grad (and z_grad)
Tensor *z_grad =
fuse_add ? ctx.Output<Tensor>(framework::GradVarName("Z")) : nullptr;
sbar_x_op.Backward(dev_ctx, *y_grad, *conv_out_x, *scale_x, *bias_x,
*saved_mean_x, *saved_invstd_x, bitmask,
&conv_out_x_grad, z_grad, scale_x_grad, bias_x_grad,
eps);
}
// 2. Backward of Conv for x, get x_grad and filter_x_grad
CudnnNormConvolutionGrad<T> conv_x_op(dev_ctx, x_shape, filter_x_shape,
output_shape, padding, stride, dilate,
group);
conv_x_op.Backward(dev_ctx, *x, *filter_x, conv_out_x_grad, x_grad,
filter_x_grad);
}
};
} // namespace operators
} // namespace paddle
#if CUDNN_VERSION >= 8000
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(resnet_unit, ops::ResNetUnitKernel<plat::float16>);
REGISTER_OP_CUDA_KERNEL(resnet_unit_grad,
ops::ResNetUnitGradKernel<plat::float16>);
#endif
......@@ -80,6 +80,27 @@ def _dtype_to_str(dtype):
return 'fp32'
def _keep_fp32_input(op, in_name):
op_type = op.type
if op_type in ['batch_norm', 'layer_norm']:
# Scale, Bias, Mean, Variance should be float32.
return in_name != 'X'
if op_type == 'fused_bn_add_activation':
return in_name not in {'X', 'Z'}
if op_type == 'resnet_unit':
return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'}
return False
def _keep_fp32_output(op, out_name):
op_type = op.type
if op_type in ['batch_norm', 'fused_bn_add_activation', 'layer_norm']:
return out_name != 'Y'
if op_type == 'resnet_unit':
return out_name not in {'Y', 'ConvX', 'ConvZ'}
return False
def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
"""
Insert cast op and rename args of input and output.
......@@ -97,11 +118,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
num_cast_ops = 0
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
]:
if in_name not in {'X', 'Z'}:
continue
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(op,
in_name):
continue
for in_var_name in op.input(in_name):
in_var = block._find_var_recursive(in_var_name)
if in_var.type not in _valid_types or in_var.dtype == dest_dtype:
......@@ -154,9 +173,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16:
for out_name in op.output_names:
if op.type in [
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
] and out_name != 'Y':
if _keep_fp32_output(op, out_name):
continue
for out_var_name in op.output(out_name):
out_var = block.var(out_var_name)
......@@ -371,9 +388,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
keep_fp32_ops.add(op)
continue # processed below
for in_name in op.input_names:
if op.type in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
} and in_name not in {'X', 'Z'}:
if _keep_fp32_input(op, in_name):
continue
for in_var_name in op.input(in_name):
in_var = None
......@@ -401,9 +416,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
format(op.type, in_var_name, in_var.dtype))
for out_name in op.output_names:
if op.type in {
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
} and out_name != 'Y':
if _keep_fp32_output(op, out_name):
continue
for out_var_name in op.output(out_name):
out_var = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册