diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 7541b234ceaa69ffb42bc153e56911cdf64561af..228da9f77739d7f83abdc8cdeab8b829cea4b6d5 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -216,7 +216,7 @@ function(op_library TARGET) "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op" -"fused_bn_add_activation_op") +"fused_bn_add_activation_op" "resnet_unit_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 2630c12db2fc9a0ba7f2a718ba89ca738a02d3a3..2286aaaf85969fba754b30e140ddc35cbdb4e156 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -16,7 +16,8 @@ register_operators(EXCLUDES fusion_gru_op fusion_lstm_op fused_bn_add_activation_op - fused_transformer_op) + fused_transformer_op + resnet_unit_op) # fusion_gru_op does not have CUDA kernel op_library(fusion_gru_op) @@ -78,7 +79,10 @@ if (WITH_GPU OR WITH_ROCM) nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) endif() + # resnet_unit needs cudnn 8.0 above if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000)) + op_library(resnet_unit_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(resnet_unit);\n") cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory) cc_test(test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cc DEPS batch_norm_op fused_bn_add_activation_op tensor op_registry device_context generator memory) endif() diff --git a/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc b/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc index 709d69214c603f1b1420d8d26d7c63c21ebac7fe..c5995fe3554b4efda49971e9f2429a58677b1919 100644 --- a/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc +++ b/paddle/fluid/operators/fused/cudnn_bn_add_relu_test.cc @@ -631,8 +631,8 @@ class CudnnBNAddReluTester { op::CudnnScaleBiasAddRelu sbar_op(ctx, act_type_, fuse_add_, has_shortcut_, data_shape, param_shape, bitmask_shape); - sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, z, equiv_scale_z, - equiv_bias_z, &y, &bitmask); + sbar_op.Forward(ctx, x, equiv_scale_x, equiv_bias_x, &z, &equiv_scale_z, + &equiv_bias_z, &y, &bitmask); TensorCopySync(mean_x, platform::CPUPlace(), cpu_mean_x); TensorCopySync(var_x, platform::CPUPlace(), cpu_var_x); @@ -690,7 +690,7 @@ class CudnnBNAddReluTester { op::CudnnScaleBiasAddRelu sbar_op(ctx, act_type, true, false, data_shape, param_shape, bitmask_shape); sbar_op.Backward(ctx, dy, x, bn_scale, bn_bias, saved_mean, saved_var, - bitmask, &dx, &dz, &dscale, &dbias, eps_); + &bitmask, &dx, &dz, &dscale, &dbias, eps_); TensorCopySync(dx, platform::CPUPlace(), cpu_dx); TensorCopySync(dz, platform::CPUPlace(), cpu_dz); diff --git a/paddle/fluid/operators/fused/cudnn_fusion_helper.h b/paddle/fluid/operators/fused/cudnn_fusion_helper.h index fcd354df938ace35ebce577e1f77607b47e064f1..1de64cf5ad947d9b5e4185fcf79eedb5a612eca9 100644 --- a/paddle/fluid/operators/fused/cudnn_fusion_helper.h +++ b/paddle/fluid/operators/fused/cudnn_fusion_helper.h @@ -38,10 +38,12 @@ class CudnnFusionOp { &op_variant_params_, op_id)); } - ~CudnnFusionOp() { - dynload::cudnnDestroyFusedOpsVariantParamPack(op_variant_params_); - dynload::cudnnDestroyFusedOpsConstParamPack(op_const_params_); - dynload::cudnnDestroyFusedOpsPlan(op_); + ~CudnnFusionOp() PADDLE_MAY_THROW { + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnDestroyFusedOpsVariantParamPack(op_variant_params_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cudnnDestroyFusedOpsConstParamPack(op_const_params_)); + PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyFusedOpsPlan(op_)); } // Execute fused op diff --git a/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h index b48c964d264add5a8e6b8f6c303fa31866fc95ea..5166ff27234f237f74136862cb6a29c79860ea70 100644 --- a/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h +++ b/paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h @@ -94,13 +94,13 @@ template class CudnnScaleBiasAddRelu { public: CudnnScaleBiasAddRelu(const platform::CUDADeviceContext &ctx, - const std::string &act_type, bool fused_add, + const std::string &act_type, bool fuse_add, bool has_shortcut, const std::vector &data_shape, const std::vector ¶m_shape, const std::vector &bitmask_shape) : fwd_op_(CUDNN_FUSED_SCALE_BIAS_ADD_ACTIVATION_GEN_BITMASK), bwd_op_(CUDNN_FUSED_DACTIVATION_FORK_DBATCHNORM) { - fused_add_ = fused_add; + fuse_add_ = fuse_add; has_shortcut_ = has_shortcut; args_.Set(act_type, data_shape, param_shape, bitmask_shape); } @@ -108,8 +108,8 @@ class CudnnScaleBiasAddRelu { ~CudnnScaleBiasAddRelu() {} void Forward(const platform::CUDADeviceContext &ctx, const Tensor &x, - const Tensor &x_scale, const Tensor &x_bias, const Tensor &z, - const Tensor &z_scale, const Tensor &z_bias, Tensor *out, + const Tensor &x_scale, const Tensor &x_bias, const Tensor *z, + const Tensor *z_scale, const Tensor *z_bias, Tensor *out, Tensor *bitmask) { ForwardInit(ctx); auto handle = ctx.cudnn_handle(); @@ -125,15 +125,15 @@ class CudnnScaleBiasAddRelu { fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQSCALE, x_scale_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_EQBIAS, x_bias_ptr); if (has_shortcut_) { - T *z_ptr = const_cast(z.data()); - T *z_scale_ptr = const_cast(z_scale.data()); - T *z_bias_ptr = const_cast(z_bias.data()); + T *z_ptr = const_cast(z->data()); + T *z_scale_ptr = const_cast(z_scale->data()); + T *z_bias_ptr = const_cast(z_bias->data()); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQSCALE, z_scale_ptr); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_Z_EQBIAS, z_bias_ptr); } else { - if (fused_add_) { - T *z_ptr = const_cast(z.data()); + if (fuse_add_) { + T *z_ptr = const_cast(z->data()); fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_ZDATA, z_ptr); } } @@ -160,7 +160,7 @@ class CudnnScaleBiasAddRelu { void Backward(const platform::CUDADeviceContext &ctx, const Tensor &dy, const Tensor &x, const Tensor &scale, const Tensor &bias, const Tensor &saved_mean, const Tensor &saved_invstd, - const Tensor &bitmask, Tensor *dx, Tensor *dz, Tensor *dscale, + const Tensor *bitmask, Tensor *dx, Tensor *dz, Tensor *dscale, Tensor *dbias, double eps) { BackwardInit(ctx); auto handle = ctx.cudnn_handle(); @@ -175,7 +175,8 @@ class CudnnScaleBiasAddRelu { float *bias_ptr = const_cast(bias.data()); float *saved_mean_ptr = const_cast(saved_mean.data()); float *saved_invstd_ptr = const_cast(saved_invstd.data()); - int32_t *bitmask_ptr = const_cast(bitmask.data()); + int32_t *bitmask_ptr = + bitmask ? const_cast(bitmask->data()) : nullptr; T *dx_ptr = dx->mutable_data(place); T *dz_ptr = dz ? dz->mutable_data(place) : nullptr; float *dscale_ptr = dscale ? dscale->mutable_data(place) : nullptr; @@ -199,7 +200,7 @@ class CudnnScaleBiasAddRelu { bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_BN_DBIAS, dbias_ptr); bwd_op_.SetOpVariantParamAttrPtr(CUDNN_SCALAR_DOUBLE_BN_EPSILON, &eps); - if (has_shortcut_ || fused_add_) { + if (has_shortcut_ || fuse_add_) { bwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_DZDATA, dz_ptr); } @@ -226,14 +227,14 @@ class CudnnScaleBiasAddRelu { {CUDNN_PARAM_ZDATA_PLACEHOLDER, CUDNN_PARAM_BN_Z_EQSCALE_PLACEHOLDER, CUDNN_PARAM_BN_Z_EQBIAS_PLACEHOLDER}, CUDNN_PTR_16B_ALIGNED); - } else if (fused_add_) { + } else if (fuse_add_) { fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_ZDATA_PLACEHOLDER, CUDNN_PTR_16B_ALIGNED); } // input desc fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc()); - if (has_shortcut_ || fused_add_) { + if (has_shortcut_ || fuse_add_) { fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_ZDESC, args_.in_desc.desc()); } @@ -271,7 +272,7 @@ class CudnnScaleBiasAddRelu { CUDNN_PARAM_BN_DSCALE_PLACEHOLDER, CUDNN_PARAM_BN_DBIAS_PLACEHOLDER, CUDNN_PARAM_ACTIVATION_BITMASK_PLACEHOLDER}, CUDNN_PTR_16B_ALIGNED); - if (has_shortcut_ || fused_add_) { + if (has_shortcut_ || fuse_add_) { bwd_op_.SetOpConstParamAttr(CUDNN_PARAM_DZDATA_PLACEHOLDER, CUDNN_PTR_16B_ALIGNED); } @@ -279,7 +280,7 @@ class CudnnScaleBiasAddRelu { // input desc bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc()); bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DXDESC, args_.in_desc.desc()); - if (has_shortcut_ || fused_add_) { + if (has_shortcut_ || fuse_add_) { bwd_op_.SetOpConstParamDesc(CUDNN_PARAM_DZDESC, args_.in_desc.desc()); } @@ -303,7 +304,7 @@ class CudnnScaleBiasAddRelu { CUDNN_BATCHNORM_SPATIAL_PERSISTENT); } - bool fused_add_ = false; + bool fuse_add_ = false; bool has_shortcut_ = false; size_t fwd_workspace_byte_; size_t bwd_workspace_byte_; diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cc b/paddle/fluid/operators/fused/resnet_unit_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..062fd3f1cf40884701fb9e05f191491a1e963164 --- /dev/null +++ b/paddle/fluid/operators/fused/resnet_unit_op.cc @@ -0,0 +1,410 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +// Shape of bitmask +static framework::DDim GetBitmaskDims(std::vector out_shape) { + int c = out_shape.back(); + int64_t nhw = std::accumulate(out_shape.begin(), out_shape.end(), 1, + std::multiplies()) / + c; + int32_t c_int32_elems = ((c + 63) & ~63) / 32; + int32_t nhw_int32_elems = ((nhw + 31) & ~31); + std::vector bitmask_shape = {nhw_int32_elems, c_int32_elems, 1}; + return framework::make_ddim(bitmask_shape); +} + +class ResNetUnitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const { + // Check input + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("FilterX"), "Input", "FilterX", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("ScaleX"), "Input", "ScaleX", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("BiasX"), "Input", "BiasX", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("MeanX"), "Input", "MeanX", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("VarX"), "Input", "VarX", "ResNetUnitOp"); + + bool fuse_add = ctx->Attrs().Get("fuse_add"); + bool has_shortcut = ctx->Attrs().Get("has_shortcut"); + if (fuse_add || has_shortcut) { + OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", "ResNetUnitOp"); + } + if (has_shortcut) { + OP_INOUT_CHECK(ctx->HasInput("FilterZ"), "Input", "FilterZ", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("ScaleZ"), "Input", "ScaleZ", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("BiasZ"), "Input", "BiasZ", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("MeanZ"), "Input", "MeanZ", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasInput("VarZ"), "Input", "VarZ", "ResNetUnitOp"); + } + + // Check output + OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("BitMask"), "Output", "BitMask", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("ConvX"), "Output", "ConvX", "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("SavedMeanX"), "Output", "SavedMeanX", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("SavedInvstdX"), "Output", "SavedInvstdX", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("RunningMeanX"), "Output", "RunningMeanX", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("RunningVarX"), "Output", "RunningVarX", + "ResNetUnitOp"); + if (has_shortcut) { + OP_INOUT_CHECK(ctx->HasOutput("ConvZ"), "Output", "ConvZ", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("SavedMeanZ"), "Output", "SavedMeanZ", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("SavedInvstdZ"), "Output", "SavedInvstdZ", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("RunningMeanZ"), "Output", "RunningMeanZ", + "ResNetUnitOp"); + OP_INOUT_CHECK(ctx->HasOutput("RunningVarZ"), "Output", "RunningVarZ", + "ResNetUnitOp"); + } + + // make sure Mean/RunningMean and Var/RunningVar share memory + PADDLE_ENFORCE_EQ( + ctx->Inputs("MeanX")[0], ctx->Outputs("RunningMeanX")[0], + platform::errors::InvalidArgument( + "MeanX and RunningMeanX should share the same memory")); + PADDLE_ENFORCE_EQ(ctx->Inputs("VarX")[0], ctx->Outputs("RunningVarX")[0], + platform::errors::InvalidArgument( + "VarX and RunningVarX should share the same memory")); + if (has_shortcut) { + PADDLE_ENFORCE_EQ( + ctx->Inputs("MeanZ")[0], ctx->Outputs("RunningMeanZ")[0], + platform::errors::InvalidArgument( + "MeanZ and RunningMeanZ should share the same memory")); + PADDLE_ENFORCE_EQ( + ctx->Inputs("VarZ")[0], ctx->Outputs("RunningVarZ")[0], + platform::errors::InvalidArgument( + "VarZ and RunningVarZ should share the same memory")); + } + + // Check dims of inputs + const auto x_dims = ctx->GetInputDim("X"); + const auto w_dims = ctx->GetInputDim("FilterX"); + const auto bn_param_dims = ctx->GetInputDim("ScaleX"); + PADDLE_ENFORCE_EQ(x_dims.size(), 4, platform::errors::InvalidArgument( + "The dimensions of input " + "must equal to 4." + "But received: the shape of input " + "= [%s], the dimension of input = " + "[%d]", + x_dims, x_dims.size())); + PADDLE_ENFORCE_EQ(w_dims.size(), 4, + platform::errors::InvalidArgument( + "The dimensions of filter " + "must equal to 4." + "But received: the shape of filter " + "= [%s], the dimension of filter = [%d] ", + w_dims, w_dims.size())); + PADDLE_ENFORCE_EQ(bn_param_dims.size(), 4, + platform::errors::InvalidArgument( + "The dimensions of bn param " + "must equal to 4." + "But received: the shape of bn param " + "= [%s], the dimension of bn param = [%d] ", + bn_param_dims, bn_param_dims.size())); + auto data_format = ctx->Attrs().Get("data_format"); + PADDLE_ENFORCE_EQ( + data_format, "NHWC", + platform::errors::InvalidArgument("The data format must equal to NHWC. " + "But received: the data format " + "= [%s]", + data_format)); + // Calculate the dims of outputs + int batch = x_dims[0]; + int output_channel = w_dims[0]; + int filter_size = w_dims[2]; + int stride = ctx->Attrs().Get("stride"); + int padding = ctx->Attrs().Get("padding"); + int out_h = (x_dims[1] + padding * 2 - filter_size) / stride + 1; + int out_w = (x_dims[2] + padding * 2 - filter_size) / stride + 1; + std::vector out_shape = {batch, out_h, out_w, output_channel}; + + auto y_dims = framework::make_ddim(out_shape); + auto bitmask_dims = GetBitmaskDims(out_shape); + // Set dims of outputs + ctx->SetOutputDim("Y", y_dims); + ctx->SetOutputDim("BitMask", bitmask_dims); + ctx->SetOutputDim("ConvX", y_dims); + ctx->SetOutputDim("SavedMeanX", bn_param_dims); + ctx->SetOutputDim("SavedInvstdX", bn_param_dims); + ctx->SetOutputDim("RunningMeanX", bn_param_dims); + ctx->SetOutputDim("RunningVarX", bn_param_dims); + if (has_shortcut) { + ctx->SetOutputDim("ConvZ", y_dims); + ctx->SetOutputDim("SavedMeanZ", bn_param_dims); + ctx->SetOutputDim("SavedInvstdZ", bn_param_dims); + ctx->SetOutputDim("RunningMeanZ", bn_param_dims); + ctx->SetOutputDim("RunningVarZ", bn_param_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + // By default, the type of the scale, bias, mean, + // and var tensors should be float when input tensor's dtype is float16. + auto bn_param_type = framework::proto::VarType::FP32; + + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("ScaleX")->type(), + platform::errors::InvalidArgument( + "Scale input should be of float type")); + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("BiasX")->type(), + platform::errors::InvalidArgument( + "Bias input should be of float type")); + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library); + } +}; + +class ResNetUnitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "The input 1 tensor"); + AddInput("FilterX", "Filter tensor of input 1"); + AddInput("ScaleX", "Scale tensor of input 1 used in batchnorm"); + AddInput("BiasX", "Bias tensor of input 1 used in batchnorm"); + AddInput("MeanX", "Mean tensor of input 1 used in batchnorm"); + AddInput("VarX", "Variance tensor of input 1 used in batchnorm"); + AddInput("Z", "The input 2 tensor").AsDispensable(); + AddInput("FilterZ", "Filter tensor of input 2").AsDispensable(); + AddInput("ScaleZ", "Scale tensor of input 2").AsDispensable(); + AddInput("BiasZ", "Bias tensor of input 2").AsDispensable(); + AddInput("MeanZ", "Mean tensor of input 2").AsDispensable(); + AddInput("VarZ", "Variance tensor of input 2").AsDispensable(); + AddOutput("Y", "The result of the resnet unit"); + AddOutput("BitMask", "The bitmask generated after relu"); + AddOutput("ConvX", "The output of input 1 after conv"); + AddOutput("SavedMeanX", "Mean of input 1 in the current batch"); + AddOutput("SavedInvstdX", "Invstd of input 1 in the current batch"); + AddOutput("RunningMeanX", "Shared memory with MeanX"); + AddOutput("RunningVarX", "Shared memory with VarX"); + AddOutput("ConvZ", "The output of input 2 after conv").AsDispensable(); + AddOutput("SavedMeanZ", "Mean of input 1 in the current batch") + .AsDispensable(); + AddOutput("SavedInvstdZ", "Invstd of input 1 in the current batch") + .AsDispensable(); + AddOutput("RunningMeanZ", "Shared memory with MeanZ").AsDispensable(); + AddOutput("RunningVarZ", "Shared memory with VarZ").AsDispensable(); + AddAttr("stride", "").SetDefault(1); + AddAttr("stride_z", "").SetDefault(1); + AddAttr("padding", "").SetDefault(0); + AddAttr("dilation", "").SetDefault(1); + AddAttr("group", "").SetDefault(1); + AddAttr("momentum", "").SetDefault(0.9); + AddAttr("epsilon", "").SetDefault(1e-5); + AddAttr("data_format", "").SetDefault("NHWC"); + AddAttr("fuse_add", "").SetDefault(false); + AddAttr("has_shortcut", "").SetDefault(false); + AddAttr("use_global_stats", "").SetDefault(false); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("act_type", "The activation type to be fused.") + .SetDefault("relu"); + AddComment(R"DOC( +Fusion op of the basic unit of resnet block. + +The implementation is based on the latest fusion op interface in cuDNN v8.0. +For more details: +https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnFusedOps_t + +)DOC"); + } +}; + +class ResNetUnitGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const { + // check input + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("FilterX"), "Input", "FilterX", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("ConvX"), "Input", "ConvX", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("ScaleX"), "Input", "ScaleX", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("BiasX"), "Input", "BiasX", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("SavedMeanX"), "Input", "SavedMeanX", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("SavedInvstdX"), "Input", "SavedInvstdX", + "ResNetUnitGradOp"); + + bool fuse_add = ctx->Attrs().Get("fuse_add"); + bool has_shortcut = ctx->Attrs().Get("has_shortcut"); + if (fuse_add || has_shortcut) { + OP_INOUT_CHECK(ctx->HasInput("Z"), "Input", "Z", "ResNetUnitGradOp"); + } + if (has_shortcut) { + OP_INOUT_CHECK(ctx->HasInput("FilterZ"), "Input", "FilterZ", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("ConvZ"), "Input", "ConvZ", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("ScaleZ"), "Input", "ScaleZ", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("BiasZ"), "Input", "BiasZ", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("SavedMeanZ"), "Input", "SavedMeanZ", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("SavedInvstdZ"), "Input", "SavedInvstdZ", + "ResNetUnitGradOp"); + } + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput("BitMask"), "Input", "BitMask", + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input", + framework::GradVarName("Y"), "ResNetUnitGradOp"); + + // check output + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + framework::GradVarName("X"), "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("FilterX")), "Output", + framework::GradVarName("FilterX"), "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("ScaleX")), "Output", + framework::GradVarName("ScaleX"), "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BiasX")), "Output", + framework::GradVarName("BiasX"), "ResNetUnitGradOp"); + if (fuse_add) { + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Z")), "Output", + framework::GradVarName("Z"), "ResNetUnitGradOp"); + } + if (has_shortcut) { + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("FilterZ")), + "Output", framework::GradVarName("FilterZ"), + "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("ScaleZ")), "Output", + framework::GradVarName("ScaleZ"), "ResNetUnitGradOp"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("BiasZ")), "Output", + framework::GradVarName("BiasZ"), "ResNetUnitGradOp"); + } + const auto x_dims = ctx->GetInputDim("X"); + const auto filter_x_dims = ctx->GetInputDim("FilterX"); + const auto param_dims = ctx->GetInputDim("ScaleX"); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->SetOutputDim(framework::GradVarName("FilterX"), filter_x_dims); + ctx->SetOutputDim(framework::GradVarName("ScaleX"), param_dims); + ctx->SetOutputDim(framework::GradVarName("BiasX"), param_dims); + if (fuse_add || has_shortcut) { + const auto z_dims = ctx->GetInputDim("Z"); + ctx->SetOutputDim(framework::GradVarName("Z"), z_dims); + } + if (has_shortcut) { + const auto filter_z_dims = ctx->GetInputDim("FilterZ"); + ctx->SetOutputDim(framework::GradVarName("FilterZ"), filter_z_dims); + ctx->SetOutputDim(framework::GradVarName("ScaleZ"), param_dims); + ctx->SetOutputDim(framework::GradVarName("BiasZ"), param_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + PADDLE_ENFORCE_NOT_NULL( + ctx.InputVar(framework::GradVarName("Y")), + platform::errors::NotFound( + "Can not find Y@GRAD in the execution context.")); + + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), + layout, library); + } +}; + +template +class ResNetUnitGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("resnet_unit_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("FilterX", this->Input("FilterX")); + op->SetInput("ConvX", this->Output("ConvX")); + op->SetInput("ScaleX", this->Input("ScaleX")); + op->SetInput("BiasX", this->Input("BiasX")); + op->SetInput("SavedMeanX", this->Output("SavedMeanX")); + op->SetInput("SavedInvstdX", this->Output("SavedInvstdX")); + op->SetInput("Z", this->Input("Z")); + op->SetInput("FilterZ", this->Input("FilterZ")); + op->SetInput("ConvZ", this->Output("ConvZ")); + op->SetInput("ScaleZ", this->Input("ScaleZ")); + op->SetInput("BiasZ", this->Input("BiasZ")); + op->SetInput("SavedMeanZ", this->Output("SavedMeanZ")); + op->SetInput("SavedInvstdZ", this->Output("SavedInvstdZ")); + op->SetInput("Y", this->Output("Y")); + op->SetInput("BitMask", this->Output("BitMask")); + op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + + op->SetAttrMap(this->Attrs()); + + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("FilterX"), + this->InputGrad("FilterX")); + op->SetOutput(framework::GradVarName("ScaleX"), this->InputGrad("ScaleX")); + op->SetOutput(framework::GradVarName("BiasX"), this->InputGrad("BiasX")); + op->SetOutput(framework::GradVarName("Z"), this->InputGrad("Z")); + op->SetOutput(framework::GradVarName("FilterZ"), + this->InputGrad("FilterZ")); + op->SetOutput(framework::GradVarName("ScaleZ"), this->InputGrad("ScaleZ")); + op->SetOutput(framework::GradVarName("BiasZ"), this->InputGrad("BiasZ")); + } +}; + +class ResNetUnitOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{{"X", /*->*/ "Y"}}; + return m; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(resnet_unit, ops::ResNetUnitOp, ops::ResNetUnitOpMaker, + ops::ResNetUnitOpInferVarType, + ops::ResNetUnitGradOpMaker, + ops::ResNetUnitGradOpMaker); +REGISTER_OPERATOR(resnet_unit_grad, ops::ResNetUnitGradOp); diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cu b/paddle/fluid/operators/fused/resnet_unit_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..a0126e5a9d4283a718741e18754b15e26e56a28c --- /dev/null +++ b/paddle/fluid/operators/fused/resnet_unit_op.cu @@ -0,0 +1,298 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/fused/cudnn_bn_stats_finalize.cu.h" +#include "paddle/fluid/operators/fused/cudnn_norm_conv.cu.h" +#include "paddle/fluid/operators/fused/cudnn_scale_bias_add_relu.cu.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class ResNetUnitKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet("It must use CUDAPlace.")); + PADDLE_ENFORCE_EQ(platform::CudnnDataType::type, CUDNN_DATA_HALF, + platform::errors::Unavailable( + "ResNetUnitOp only supports float16 for now.")); + + // input x + const Tensor *input_x = ctx.Input("X"); + const Tensor *filter_x = ctx.Input("FilterX"); + const Tensor *scale_x = ctx.Input("ScaleX"); + const Tensor *bias_x = ctx.Input("BiasX"); + // norm conv + Tensor *conv_out_x = ctx.Output("ConvX"); + // bn finalize + Tensor *saved_mean_x = ctx.Output("SavedMeanX"); + Tensor *saved_invstd_x = ctx.Output("SavedInvstdX"); + Tensor *running_mean_x = ctx.Output("RunningMeanX"); + Tensor *running_var_x = ctx.Output("RunningVarX"); + // sbar + Tensor *output = ctx.Output("Y"); + Tensor *bitmask = ctx.Output("BitMask"); + // attrs + int padding = ctx.Attr("padding"); + int stride = ctx.Attr("stride"); + int stride_z = ctx.Attr("stride_z"); + int dilate = ctx.Attr("dilate"); + int group = ctx.Attr("group"); + double eps = static_cast(ctx.Attr("epsilon")); + double momentum = static_cast(ctx.Attr("momentum")); + bool has_shortcut = ctx.Attr("has_shortcut"); + bool fuse_add = ctx.Attr("fuse_add"); + bool use_global_stats = ctx.Attr("use_global_stats"); + bool is_test = ctx.Attr("is_test"); + bool is_train = !is_test && !use_global_stats; + std::string act_type = ctx.Attr("act_type"); + + auto input_x_shape = framework::vectorize(input_x->dims()); + auto filter_x_shape = framework::vectorize(filter_x->dims()); + auto param_dims = scale_x->dims(); + auto param_shape = framework::vectorize(scale_x->dims()); + auto output_shape = framework::vectorize(output->dims()); + auto bitmask_shape = framework::vectorize(bitmask->dims()); + int output_channel = filter_x_shape[0]; + int64_t ele_count = + std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()) / + output_channel; + + auto place = ctx.GetPlace(); + auto &dev_ctx = ctx.template device_context(); + + // 1. Conv + Tensor sum_x; + Tensor sum_of_squares_x; + sum_x.Resize(param_dims); + sum_of_squares_x.Resize(param_dims); + CudnnNormConvolution conv_x_op(dev_ctx, input_x_shape, filter_x_shape, + output_shape, padding, stride, dilate, + group); + conv_x_op.Forward(dev_ctx, *input_x, *filter_x, conv_out_x, &sum_x, + &sum_of_squares_x); + + // 2. BN + Tensor equiv_scale_x; + Tensor equiv_bias_x; + equiv_scale_x.Resize(param_dims); + equiv_bias_x.Resize(param_dims); + CudnnBNStatsFinalize bn_x_op(dev_ctx, param_shape); + bn_x_op.Forward(dev_ctx, sum_x, sum_of_squares_x, *scale_x, *bias_x, + saved_mean_x, saved_invstd_x, running_mean_x, running_var_x, + &equiv_scale_x, &equiv_bias_x, eps, momentum, ele_count, + is_train); + + // 3. scale + bias + add + relu + CudnnScaleBiasAddRelu sbar_op(dev_ctx, act_type, fuse_add, has_shortcut, + output_shape, param_shape, bitmask_shape); + if (has_shortcut) { + // input z + const Tensor *input_z = ctx.Input("Z"); + const Tensor *filter_z = ctx.Input("FilterZ"); + const Tensor *scale_z = ctx.Input("ScaleZ"); + const Tensor *bias_z = ctx.Input("BiasZ"); + // norm conv + Tensor *conv_out_z = ctx.Output("ConvZ"); + // bn finalize + Tensor *saved_mean_z = ctx.Output("SavedMeanZ"); + Tensor *saved_invstd_z = ctx.Output("SavedInvstdZ"); + Tensor *running_mean_z = ctx.Output("RunningMeanZ"); + Tensor *running_var_z = ctx.Output("RunningVarZ"); + + auto input_z_shape = framework::vectorize(input_z->dims()); + auto filter_z_shape = framework::vectorize(filter_z->dims()); + + // 3.1 Conv for second input + Tensor sum_z; + Tensor sum_of_squares_z; + sum_z.Resize(param_dims); + sum_of_squares_z.Resize(param_dims); + CudnnNormConvolution conv_z_op(dev_ctx, input_z_shape, filter_z_shape, + output_shape, padding, stride_z, dilate, + group); + conv_z_op.Forward(dev_ctx, *input_z, *filter_z, conv_out_z, &sum_z, + &sum_of_squares_z); + + // 3.2 BN for second input + Tensor equiv_scale_z; + Tensor equiv_bias_z; + equiv_scale_z.Resize(param_dims); + equiv_bias_z.Resize(param_dims); + CudnnBNStatsFinalize bn_z_op(dev_ctx, param_shape); + bn_z_op.Forward(dev_ctx, sum_z, sum_of_squares_z, *scale_z, *bias_z, + saved_mean_z, saved_invstd_z, running_mean_z, + running_var_z, &equiv_scale_z, &equiv_bias_z, eps, + momentum, ele_count, is_train); + // 3.3 sbar + sbar_op.Forward(dev_ctx, *conv_out_x, equiv_scale_x, equiv_bias_x, + conv_out_z, &equiv_scale_z, &equiv_bias_z, output, + bitmask); + } else { + const Tensor *input_z = fuse_add ? ctx.Input("Z") : nullptr; + sbar_op.Forward(dev_ctx, *conv_out_x, equiv_scale_x, equiv_bias_x, + input_z, nullptr, nullptr, output, bitmask); + } + } +}; + +template +class ResNetUnitGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::PreconditionNotMet("It must use CUDAPlace.")); + PADDLE_ENFORCE_EQ(platform::CudnnDataType::type, CUDNN_DATA_HALF, + platform::errors::Unavailable( + "ResNetUnitOp only supports float16 for now.")); + + const Tensor *y_grad = ctx.Input(framework::GradVarName("Y")); + + const Tensor *x = ctx.Input("X"); + const Tensor *filter_x = ctx.Input("FilterX"); + const Tensor *scale_x = ctx.Input("ScaleX"); + const Tensor *bias_x = ctx.Input("BiasX"); + const Tensor *saved_mean_x = ctx.Input("SavedMeanX"); + const Tensor *saved_invstd_x = ctx.Input("SavedInvstdX"); + + const Tensor *conv_out_x = ctx.Input("ConvX"); + const Tensor *output = ctx.Input("Y"); + const Tensor *bitmask = ctx.Input("BitMask"); + + Tensor *x_grad = ctx.Output(framework::GradVarName("X")); + Tensor *filter_x_grad = + ctx.Output(framework::GradVarName("FilterX")); + Tensor *scale_x_grad = ctx.Output(framework::GradVarName("ScaleX")); + Tensor *bias_x_grad = ctx.Output(framework::GradVarName("BiasX")); + + int padding = ctx.Attr("padding"); + int stride = ctx.Attr("stride"); + int stride_z = ctx.Attr("stride_z"); + int dilate = ctx.Attr("dilate"); + int group = ctx.Attr("group"); + double eps = static_cast(ctx.Attr("epsilon")); + double momentum = static_cast(ctx.Attr("momentum")); + bool has_shortcut = ctx.Attr("has_shortcut"); + bool fuse_add = ctx.Attr("fuse_add"); + bool use_global_stats = ctx.Attr("use_global_stats"); + std::string act_type = ctx.Attr("act_type"); + + auto x_shape = framework::vectorize(x->dims()); + auto filter_x_shape = framework::vectorize(filter_x->dims()); + auto param_shape = framework::vectorize(scale_x->dims()); + auto output_shape = framework::vectorize(output->dims()); + auto bitmask_shape = framework::vectorize(bitmask->dims()); + + auto place = ctx.GetPlace(); + auto &dev_ctx = ctx.template device_context(); + + // 1. Backward of BN (+ Add + Relu) for x, get conv_out_x_grad, + // scale_x_grad, bias_x_grad + Tensor conv_out_x_grad; + conv_out_x_grad.Resize(conv_out_x->dims()); + CudnnScaleBiasAddRelu sbar_x_op(dev_ctx, act_type, fuse_add, + has_shortcut, output_shape, param_shape, + bitmask_shape); + if (has_shortcut) { + // X Z + // | | + // NormConv NormConv + // | | + // BNStatsFinalize BNStatsFinalize + // \ / + // ScaleBiasAddRelu + // | + // Y + const Tensor *z = ctx.Input("Z"); + const Tensor *filter_z = ctx.Input("FilterZ"); + const Tensor *scale_z = ctx.Input("ScaleZ"); + const Tensor *bias_z = ctx.Input("BiasZ"); + const Tensor *saved_mean_z = ctx.Input("SavedMeanZ"); + const Tensor *saved_invstd_z = ctx.Input("SavedInvstdZ"); + const Tensor *conv_out_z = ctx.Input("ConvZ"); + + Tensor *z_grad = ctx.Output(framework::GradVarName("Z")); + Tensor *filter_z_grad = + ctx.Output(framework::GradVarName("FilterZ")); + Tensor *scale_z_grad = + ctx.Output(framework::GradVarName("ScaleZ")); + Tensor *bias_z_grad = ctx.Output(framework::GradVarName("BiasZ")); + + // 1.1 Backward of BN + Add (+ Relu) for x, get conv_out_x_grad, + // scale_x_grad, bias_x_grad and z_grad_temp + Tensor z_grad_temp; + z_grad_temp.Resize(conv_out_z->dims()); + sbar_x_op.Backward(dev_ctx, *y_grad, *conv_out_x, *scale_x, *bias_x, + *saved_mean_x, *saved_invstd_x, bitmask, + &conv_out_x_grad, &z_grad_temp, scale_x_grad, + bias_x_grad, eps); + + // 1.2 bn backward for z, get conv_out_z_grad, dscale_z, dbias_z + Tensor conv_out_z_grad; + conv_out_z_grad.Resize(conv_out_z->dims()); + CudnnScaleBiasAddRelu sbar_z_op( + dev_ctx, "", false, false, output_shape, param_shape, bitmask_shape); + sbar_z_op.Backward(dev_ctx, z_grad_temp, *conv_out_z, *scale_z, *bias_z, + *saved_mean_z, *saved_invstd_z, nullptr, + &conv_out_z_grad, nullptr, scale_z_grad, bias_z_grad, + eps); + + // 1.3 Backward of Conv for z, get z_grad and filter_z_grad + auto z_shape = framework::vectorize(z->dims()); + auto filter_z_shape = framework::vectorize(filter_z->dims()); + CudnnNormConvolutionGrad conv_z_op(dev_ctx, z_shape, filter_z_shape, + output_shape, padding, stride_z, + dilate, group); + conv_z_op.Backward(dev_ctx, *z, *filter_z, conv_out_z_grad, z_grad, + filter_z_grad); + } else { + // 1.1 Backward of BN (+ Add + Relu) for x, get conv_out_x_grad, + // scale_x_grad, bias_x_grad (and z_grad) + Tensor *z_grad = + fuse_add ? ctx.Output(framework::GradVarName("Z")) : nullptr; + sbar_x_op.Backward(dev_ctx, *y_grad, *conv_out_x, *scale_x, *bias_x, + *saved_mean_x, *saved_invstd_x, bitmask, + &conv_out_x_grad, z_grad, scale_x_grad, bias_x_grad, + eps); + } + + // 2. Backward of Conv for x, get x_grad and filter_x_grad + CudnnNormConvolutionGrad conv_x_op(dev_ctx, x_shape, filter_x_shape, + output_shape, padding, stride, dilate, + group); + conv_x_op.Backward(dev_ctx, *x, *filter_x, conv_out_x_grad, x_grad, + filter_x_grad); + } +}; + +} // namespace operators +} // namespace paddle + +#if CUDNN_VERSION >= 8000 +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(resnet_unit, ops::ResNetUnitKernel); +REGISTER_OP_CUDA_KERNEL(resnet_unit_grad, + ops::ResNetUnitGradKernel); +#endif diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index 5978d3829aecaed912c86677ae64956466ef1532..6317be9a2e2e2051accf0d10d2b7faa30a4d307d 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -80,6 +80,27 @@ def _dtype_to_str(dtype): return 'fp32' +def _keep_fp32_input(op, in_name): + op_type = op.type + if op_type in ['batch_norm', 'layer_norm']: + # Scale, Bias, Mean, Variance should be float32. + return in_name != 'X' + if op_type == 'fused_bn_add_activation': + return in_name not in {'X', 'Z'} + if op_type == 'resnet_unit': + return in_name not in {'X', 'FilterX', 'Z', 'FilterZ'} + return False + + +def _keep_fp32_output(op, out_name): + op_type = op.type + if op_type in ['batch_norm', 'fused_bn_add_activation', 'layer_norm']: + return out_name != 'Y' + if op_type == 'resnet_unit': + return out_name not in {'Y', 'ConvX', 'ConvZ'} + return False + + def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): """ Insert cast op and rename args of input and output. @@ -97,11 +118,9 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): num_cast_ops = 0 for in_name in op.input_names: - if src_dtype == core.VarDesc.VarType.FP32 and op.type in [ - 'batch_norm', 'fused_bn_add_activation', 'layer_norm' - ]: - if in_name not in {'X', 'Z'}: - continue + if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(op, + in_name): + continue for in_var_name in op.input(in_name): in_var = block._find_var_recursive(in_var_name) if in_var.type not in _valid_types or in_var.dtype == dest_dtype: @@ -154,9 +173,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): op._set_attr('in_dtype', dest_dtype) if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16: for out_name in op.output_names: - if op.type in [ - 'batch_norm', 'fused_bn_add_activation', 'layer_norm' - ] and out_name != 'Y': + if _keep_fp32_output(op, out_name): continue for out_var_name in op.output(out_name): out_var = block.var(out_var_name) @@ -371,9 +388,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): keep_fp32_ops.add(op) continue # processed below for in_name in op.input_names: - if op.type in { - 'batch_norm', 'fused_bn_add_activation', 'layer_norm' - } and in_name not in {'X', 'Z'}: + if _keep_fp32_input(op, in_name): continue for in_var_name in op.input(in_name): in_var = None @@ -401,9 +416,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True): format(op.type, in_var_name, in_var.dtype)) for out_name in op.output_names: - if op.type in { - 'batch_norm', 'fused_bn_add_activation', 'layer_norm' - } and out_name != 'Y': + if _keep_fp32_output(op, out_name): continue for out_var_name in op.output(out_name): out_var = None