From 29796efee75e902b358110a26eb3a740773ba082 Mon Sep 17 00:00:00 2001 From: fwenguang <95677191+fwenguang@users.noreply.github.com> Date: Fri, 21 Jan 2022 10:46:33 +0800 Subject: [PATCH] [MLU]add batch_norm mlu kernel (#39070) --- paddle/fluid/operators/batch_norm_op_mlu.cc | 275 ++++++++++++++++++++ 1 file changed, 275 insertions(+) create mode 100644 paddle/fluid/operators/batch_norm_op_mlu.cc diff --git a/paddle/fluid/operators/batch_norm_op_mlu.cc b/paddle/fluid/operators/batch_norm_op_mlu.cc new file mode 100644 index 00000000000..534af63d2a0 --- /dev/null +++ b/paddle/fluid/operators/batch_norm_op_mlu.cc @@ -0,0 +1,275 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/batch_norm_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class MLUBatchNormOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto &place = ctx.GetPlace(); + const float epsilon = ctx.Attr("epsilon"); + float momentum = ctx.Attr("momentum"); + const bool is_test = ctx.Attr("is_test"); + const bool use_global_stats = ctx.Attr("use_global_stats"); + const bool trainable_stats = ctx.Attr("trainable_statistics"); + bool test_mode = is_test && (!trainable_stats); + + bool global_stats = test_mode || use_global_stats; + + const std::string data_layout_str = ctx.Attr("data_layout"); + DataLayout data_layout = framework::StringToDataLayout(data_layout_str); + + const auto *x = ctx.Input("X"); + const auto &x_dims = x->dims(); + PADDLE_ENFORCE_GE( + x_dims.size(), 2, + platform::errors::InvalidArgument( + "The size of input X's dimensions should be larger than 1." + "But received: the size of input X's dimensions is [%d]", + x_dims.size())); + PADDLE_ENFORCE_LE( + x_dims.size(), 5, + platform::errors::InvalidArgument( + "The size of input X's dimensions should be less than 6." + "But received: the size of input X's dimensions is [%d]", + x_dims.size())); + const int N = x_dims[0]; + const int C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + const int sample_size = x->numel() / N / C; + + const auto *running_mean = ctx.Input("Mean"); + const auto *running_var = ctx.Input("Variance"); + const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + + auto *y = ctx.Output("Y"); + auto *mean_out = ctx.Output("MeanOut"); + auto *variance_out = ctx.Output("VarianceOut"); + auto *saved_mean = ctx.Output("SavedMean"); + auto *saved_variance = ctx.Output("SavedVariance"); + + // alloc memory + y->mutable_data(place); + mean_out->mutable_data(place); + variance_out->mutable_data(place); + saved_mean->mutable_data(place); + saved_variance->mutable_data(place); + + Tensor transformed_x; + Tensor transformed_y; + const int transformed_dim_size = 4; + const int transformed_shape[transformed_dim_size] = {N, sample_size, 1, C}; + MLUCnnlTensorDesc transformed_desc(transformed_dim_size, transformed_shape, + ToCnnlDataType(), CNNL_LAYOUT_NHWC); + MLUCnnlTensorDesc others_input_desc(*scale); + // input dimension is 2 and the format is NCHW. The input can be regarded as + // NHWC format. Don't need to transpose. + bool need_transpose = + (data_layout == DataLayout::kNCHW && x_dims.size() != 2); + if (need_transpose) { + auto &dev_ctx = ctx.template device_context(); + transformed_x = ctx.AllocateTmpTensor( + framework::DDim(transformed_shape, transformed_dim_size), dev_ctx); + transformed_y = ctx.AllocateTmpTensor( + framework::DDim(transformed_shape, transformed_dim_size), dev_ctx); + + const int x_reshaped[] = {N, C, sample_size, 1}; + MLUCnnlTensorDesc x_reshaped_desc(transformed_dim_size, x_reshaped, + ToCnnlDataType()); + const std::vector perm = {0, 2, 3, 1}; + MLUCnnl::Transpose(ctx, perm, transformed_dim_size, x_reshaped_desc.get(), + GetBasePtr(x), transformed_desc.get(), + GetBasePtr(&transformed_x)); + } else { + transformed_x = *x; + transformed_y = *y; + } + + if (ctx.HasInput("MomentumTensor")) { + const auto *mom_tensor = ctx.Input("MomentumTensor"); + Tensor mom_cpu; + TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu); + momentum = mom_cpu.data()[0]; + } + + MLUCnnl::FusedBatchNorm( + ctx, !global_stats, transformed_desc.get(), GetBasePtr(&transformed_x), + others_input_desc.get(), GetBasePtr(scale), GetBasePtr(bias), + GetBasePtr(running_mean), GetBasePtr(running_var), epsilon, momentum, + transformed_desc.get(), GetBasePtr(&transformed_y), + GetBasePtr(mean_out), GetBasePtr(variance_out), GetBasePtr(saved_mean), + GetBasePtr(saved_variance)); + + if (need_transpose) { + const int y_reshaped[] = {N, C, sample_size, 1}; + MLUCnnlTensorDesc y_reshaped_desc(transformed_dim_size, y_reshaped, + ToCnnlDataType()); + const std::vector perm = {0, 3, 1, 2}; + MLUCnnl::Transpose(ctx, perm, transformed_y.dims().size(), + transformed_desc.get(), GetBasePtr(&transformed_y), + y_reshaped_desc.get(), GetBasePtr(y)); + } + } +}; + +template +class MLUBatchNormGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *x = ctx.Input("X"); + const auto *d_y = ctx.Input(framework::GradVarName("Y")); + const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + const auto *saved_mean = ctx.Input("SavedMean"); + // SavedVariance have been reverted in forward operator + const auto *saved_inv_variance = ctx.Input("SavedVariance"); + const std::string data_layout_str = ctx.Attr("data_layout"); + bool use_global_stats = ctx.Attr("use_global_stats"); + const bool is_test = ctx.Attr("is_test"); + const float epsilon = ctx.Attr("epsilon"); + DataLayout data_layout = framework::StringToDataLayout(data_layout_str); + + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + + auto &dev_ctx = ctx.template device_context(); + auto d_x_tmp = + ctx.AllocateTmpTensor(x->dims(), dev_ctx); + auto scale_grad_tmp = + ctx.AllocateTmpTensor(scale->dims(), dev_ctx); + auto bias_grad_tmp = + ctx.AllocateTmpTensor(bias->dims(), dev_ctx); + + if (d_x == nullptr) { + d_x = &d_x_tmp; + } + if (d_scale == nullptr) { + d_scale = &scale_grad_tmp; + } + if (d_bias == nullptr) { + d_bias = &bias_grad_tmp; + } + + const auto &place = ctx.GetPlace(); + d_x->mutable_data(place); + d_scale->mutable_data(place); + d_bias->mutable_data(place); + + use_global_stats = is_test || use_global_stats; + + const auto &x_dims = x->dims(); + PADDLE_ENFORCE_GE( + x_dims.size(), 2, + platform::errors::InvalidArgument( + "The size of input X's dimensions should be larger than 1." + "But received: the size of input X's dimensions is [%d]", + x_dims.size())); + PADDLE_ENFORCE_LE( + x_dims.size(), 5, + platform::errors::InvalidArgument( + "The size of input X's dimensions should be less than 6." + "But received: the size of input X's dimensions is [%d]", + x_dims.size())); + const int N = x_dims[0]; + const int C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + const int sample_size = x->numel() / N / C; + + Tensor transformed_d_y; + Tensor transformed_x; + Tensor transformed_d_x; + const int transformed_dim_size = 4; + const int transformed_shape[transformed_dim_size] = {N, sample_size, 1, C}; + + MLUCnnlTensorDesc transformed_desc(transformed_dim_size, transformed_shape, + ToCnnlDataType(), CNNL_LAYOUT_NHWC); + MLUCnnlTensorDesc others_input_desc(*scale); + + bool need_transpose = + (data_layout == DataLayout::kNCHW && x_dims.size() != 2); + if (need_transpose) { + transformed_d_y = ctx.AllocateTmpTensor( + framework::DDim(transformed_shape, transformed_dim_size), dev_ctx); + transformed_x = ctx.AllocateTmpTensor( + framework::DDim(transformed_shape, transformed_dim_size), dev_ctx); + transformed_d_x = ctx.AllocateTmpTensor( + framework::DDim(transformed_shape, transformed_dim_size), dev_ctx); + const int org_reshaped[] = {N, C, sample_size, 1}; + MLUCnnlTensorDesc org_reshaped_desc(transformed_dim_size, org_reshaped, + ToCnnlDataType()); + const std::vector perm = {0, 2, 3, 1}; + MLUCnnl::Transpose(ctx, perm, transformed_dim_size, + org_reshaped_desc.get(), GetBasePtr(d_y), + transformed_desc.get(), GetBasePtr(&transformed_d_y)); + MLUCnnl::Transpose(ctx, perm, transformed_dim_size, + org_reshaped_desc.get(), GetBasePtr(x), + transformed_desc.get(), GetBasePtr(&transformed_x)); + } else { + transformed_d_y = *d_y; + transformed_x = *x; + transformed_d_x = *d_x; + } + + if (use_global_stats) { + const auto *running_mean = ctx.Input("Mean"); + const auto *running_variance = ctx.Input("Variance"); + MLUCnnl::FusedBatchNormGrad( + ctx, true /*is_training*/, transformed_desc.get(), + GetBasePtr(&transformed_d_y), transformed_desc.get(), + GetBasePtr(&transformed_x), others_input_desc.get(), + GetBasePtr(scale), GetBasePtr(running_mean), + GetBasePtr(running_variance), epsilon, transformed_desc.get(), + GetBasePtr(&transformed_d_x), GetBasePtr(d_scale), + GetBasePtr(d_bias)); + } else { + MLUCnnl::FusedBatchNormGrad( + ctx, true /*is_training*/, transformed_desc.get(), + GetBasePtr(&transformed_d_y), transformed_desc.get(), + GetBasePtr(&transformed_x), others_input_desc.get(), + GetBasePtr(scale), GetBasePtr(saved_mean), + GetBasePtr(saved_inv_variance), epsilon, transformed_desc.get(), + GetBasePtr(&transformed_d_x), GetBasePtr(d_scale), + GetBasePtr(d_bias)); + } + + if (need_transpose) { + const int d_x_reshaped[] = {N, C, sample_size, 1}; + MLUCnnlTensorDesc d_x_reshaped_desc(transformed_dim_size, d_x_reshaped, + ToCnnlDataType()); + const std::vector perm = {0, 3, 1, 2}; + MLUCnnl::Transpose(ctx, perm, transformed_dim_size, + transformed_desc.get(), GetBasePtr(&transformed_d_x), + d_x_reshaped_desc.get(), GetBasePtr(d_x)); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_MLU_KERNEL(batch_norm, ops::MLUBatchNormOpKernel, + ops::MLUBatchNormOpKernel); +REGISTER_OP_MLU_KERNEL(batch_norm_grad, ops::MLUBatchNormGradOpKernel, + ops::MLUBatchNormGradOpKernel); -- GitLab