diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 17aabc25b3fa41270e71051c4bf77f9415f68587..7fb00504ee2db05d5e9ecaab61dd4bded7d2ca2a 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -149,6 +149,10 @@ if (WITH_ASCEND_CL) op_library(sync_batch_norm_op) endif() +if (WITH_MLU) + op_library(sync_batch_norm_op) +endif() + op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute) op_library(eye_op DEPS ${OP_HEADER_DEPS}) op_library(recurrent_op DEPS ${OP_HEADER_DEPS}) diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 5531250f363b526ae45ca309a8043fb63f62a0c4..175fa9f94470f86aa75ef63dde043edb0e705b20 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -259,15 +259,16 @@ MLUCnnlTensorDesc::~MLUCnnlTensorDesc() { MLUCnnlActivationDesc::MLUCnnlActivationDesc( const cnnlActivationMode_t act_mode, const float ceof) { PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_)); - PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetActivationDescriptor_v4( - active_desc_, - act_mode, - CNNL_ACTIVATION_HIGH_PRECISION, - CNNL_NOT_PROPAGATE_NAN, - ceof, - 1.0f /*sliced_dim*/, - 1.67326319217681884765625 /*selu_alpha*/, - 1.05070102214813232421875 /*selu_lambda*/)); + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlSetActivationDescriptor_v5(active_desc_, + act_mode, + CNNL_ACTIVATION_HIGH_PRECISION, + CNNL_NOT_PROPAGATE_NAN, + ceof, + 1.0f /*sliced_dim*/, + 1.67326319217681884765625 /*selu_alpha*/, + 1.05070102214813232421875 /*selu_lambda*/, + false /*is_elu_mode*/)); } MLUCnnlActivationDesc::MLUCnnlActivationDesc( @@ -278,14 +279,15 @@ MLUCnnlActivationDesc::MLUCnnlActivationDesc( const float selu_lambda) { PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateActivationDescriptor(&active_desc_)); PADDLE_ENFORCE_MLU_SUCCESS( - cnnlSetActivationDescriptor_v4(active_desc_, + cnnlSetActivationDescriptor_v5(active_desc_, act_mode, CNNL_ACTIVATION_HIGH_PRECISION, CNNL_NOT_PROPAGATE_NAN, ceof, sliced_dim, selu_alpha, - selu_lambda)); + selu_lambda, + false /*is_elu_mode*/)); } const cnnlActivationDescriptor_t MLUCnnlActivationDesc::get() const { @@ -2350,6 +2352,36 @@ MLURNNDesc::~MLURNNDesc() { workspace_size)); } +/* static */ void MLUCnnl::Pow(const ExecutionContext& ctx, + cnnlComputationPreference_t prefer, + const cnnlTensorDescriptor_t input1_desc, + const void* input1, + const cnnlTensorDescriptor_t input2_desc, + const void* input2, + const cnnlTensorDescriptor_t output_desc, + void* output) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + size_t workspace_size; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetPowWorkspaceSize( + handle, input1_desc, input2_desc, output_desc, &workspace_size)); + + auto& dev_ctx = GetDevCtxFromCTX(ctx); + Tensor workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlPow(handle, + prefer, + input1_desc, + input1, + input2_desc, + input2, + workspace_ptr, + workspace_size, + output_desc, + output)); +} + /* static */ void MLUCnnl::PowR(const ExecutionContext& ctx, cnnlComputationPreference_t prefer, const cnnlTensorDescriptor_t input1_desc, @@ -4895,5 +4927,180 @@ MLURNNDesc::~MLURNNDesc() { grads_image)); } +/* static */ void MLUCnnl::SyncBatchNormStats( + const ExecutionContext& ctx, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const float eps, + const cnnlTensorDescriptor_t mean_desc, + void* mean, + const cnnlTensorDescriptor_t invstd_desc, + void* invstd) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSyncBatchNormStats( + handle, x_desc, x, eps, mean_desc, mean, invstd_desc, invstd)); +} + +/* static */ void MLUCnnl::SyncBatchNormGatherStatsWithCounts( + const ExecutionContext& ctx, + float momentum, + float eps, + const cnnlTensorDescriptor_t mean_all_desc, + const void* mean_all, + const cnnlTensorDescriptor_t invstd_all_desc, + const void* invstd_all, + const cnnlTensorDescriptor_t moving_mean_desc, + void* moving_mean, + const cnnlTensorDescriptor_t moving_var_desc, + void* moving_var, + const cnnlTensorDescriptor_t count_all_desc, + const void* count_all, + const cnnlTensorDescriptor_t mean_desc, + void* mean, + const cnnlTensorDescriptor_t invstd_desc, + void* invstd) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlSyncBatchNormGatherStatsWithCounts(handle, + mean_all_desc, + mean_all, + invstd_all_desc, + invstd_all, + moving_mean_desc, + moving_mean, + moving_var_desc, + moving_var, + momentum, + eps, + count_all_desc, + count_all, + mean_desc, + mean, + invstd_desc, + invstd)); +} + +/* static */ void MLUCnnl::SyncBatchNormElemt( + const ExecutionContext& ctx, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t mean_desc, + const void* mean, + const cnnlTensorDescriptor_t invstd_desc, + const void* invstd, + const cnnlTensorDescriptor_t weight_desc, + const void* weight, + const cnnlTensorDescriptor_t bias_desc, + const void* bias, + const cnnlTensorDescriptor_t y_desc, + void* y) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSyncBatchNormElemt(handle, + x_desc, + x, + mean_desc, + mean, + invstd_desc, + invstd, + weight_desc, + weight, + bias_desc, + bias, + y_desc, + y)); +} + +/* static */ void MLUCnnl::SyncBatchnormBackwardReduce( + const ExecutionContext& ctx, + const cnnlTensorDescriptor_t desc_dz, + const void* dz, + const cnnlTensorDescriptor_t desc_x, + const void* x, + const cnnlTensorDescriptor_t desc_mean, + const void* mean, + const cnnlTensorDescriptor_t desc_invstd, + const void* invstd, + const cnnlTensorDescriptor_t desc_dweight, + void* dweight, + const cnnlTensorDescriptor_t desc_dbias, + void* dbias, + const cnnlTensorDescriptor_t desc_sum_dy, + void* sum_dy, + const cnnlTensorDescriptor_t desc_sum_dy_xmu, + void* sum_dy_xmu, + const bool needs_input_grad0, + const bool needs_input_grad1, + const bool needs_input_grad2) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlSyncBatchnormBackwardReduce(handle, + desc_dz, + dz, + desc_x, + x, + desc_mean, + mean, + desc_invstd, + invstd, + desc_dweight, + dweight, + desc_dbias, + dbias, + desc_sum_dy, + sum_dy, + desc_sum_dy_xmu, + sum_dy_xmu, + needs_input_grad0, + needs_input_grad1, + needs_input_grad2)); +} + +/* static */ void MLUCnnl::SyncBatchNormBackwardElemt( + const ExecutionContext& ctx, + const cnnlTensorDescriptor_t diff_y_desc, + const void* diff_y, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t mean_desc, + const void* mean, + const cnnlTensorDescriptor_t invstd_desc, + const void* invstd, + const cnnlTensorDescriptor_t weight_desc, + const void* weight, + const cnnlTensorDescriptor_t sum_dy_desc, + const void* sum_dy, + const cnnlTensorDescriptor_t sum_dy_xmu_desc, + const void* sum_dy_xmu, + const cnnlTensorDescriptor_t count_desc, + const void* count, + const cnnlTensorDescriptor_t diff_x_desc, + void* diff_x) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSyncBatchNormBackwardElemtV2(handle, + diff_y_desc, + diff_y, + x_desc, + x, + mean_desc, + mean, + invstd_desc, + invstd, + weight_desc, + weight, + sum_dy_desc, + sum_dy, + sum_dy_xmu_desc, + sum_dy_xmu, + count_desc, + count, + diff_x_desc, + diff_x)); +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 07c5031ee2eb19781ae1d75dee35e42e3d41164a..0d4c7d2e5a3297ec0b17ac67ba55ef52c62cac84 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -1276,6 +1276,15 @@ class MLUCnnl { const cnnlTensorDescriptor_t output_desc, void* output); + static void Pow(const ExecutionContext& ctx, + cnnlComputationPreference_t prefer, + const cnnlTensorDescriptor_t input1_desc, + const void* input1, + const cnnlTensorDescriptor_t input2_desc, + const void* input2, + const cnnlTensorDescriptor_t output_desc, + void* output); + static void PowR(const ExecutionContext& ctx, cnnlComputationPreference_t prefer, const cnnlTensorDescriptor_t input1_desc, @@ -2030,8 +2039,152 @@ class MLUCnnl { const void* boxes, const cnnlTensorDescriptor_t grads_image_desc, void* grads_image); + + static void SyncBatchNormStats(const ExecutionContext& ctx, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const float eps, + const cnnlTensorDescriptor_t mean_desc, + void* mean, + const cnnlTensorDescriptor_t invstd_desc, + void* invstd); + + static void SyncBatchNormGatherStatsWithCounts( + const ExecutionContext& ctx, + float momentum, + float eps, + const cnnlTensorDescriptor_t mean_all_desc, + const void* mean_all, + const cnnlTensorDescriptor_t invstd_all_desc, + const void* invstd_all, + const cnnlTensorDescriptor_t moving_mean_desc, + void* moving_mean, + const cnnlTensorDescriptor_t moving_var_desc, + void* moving_var, + const cnnlTensorDescriptor_t count_all_desc, + const void* count_all, + const cnnlTensorDescriptor_t mean_desc, + void* mean, + const cnnlTensorDescriptor_t invstd_desc, + void* invstd); + + static void SyncBatchNormElemt(const ExecutionContext& ctx, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t mean_desc, + const void* mean, + const cnnlTensorDescriptor_t invstd_desc, + const void* invstd, + const cnnlTensorDescriptor_t weight_desc, + const void* weight, + const cnnlTensorDescriptor_t bias_desc, + const void* bias, + const cnnlTensorDescriptor_t y_desc, + void* y); + + static void SyncBatchnormBackwardReduce( + const ExecutionContext& ctx, + const cnnlTensorDescriptor_t desc_dz, + const void* dz, + const cnnlTensorDescriptor_t desc_x, + const void* x, + const cnnlTensorDescriptor_t desc_mean, + const void* mean, + const cnnlTensorDescriptor_t desc_invstd, + const void* invstd, + const cnnlTensorDescriptor_t desc_dweight, + void* dweight, + const cnnlTensorDescriptor_t desc_dbias, + void* dbias, + const cnnlTensorDescriptor_t desc_sum_dy, + void* sum_dy, + const cnnlTensorDescriptor_t desc_sum_dy_xmu, + void* sum_dy_xmu, + const bool needs_input_grad0, + const bool needs_input_grad1, + const bool needs_input_grad2); + + static void SyncBatchNormBackwardElemt( + const ExecutionContext& ctx, + const cnnlTensorDescriptor_t diff_y_desc, + const void* diff_y, + const cnnlTensorDescriptor_t x_desc, + const void* x, + const cnnlTensorDescriptor_t mean_desc, + const void* mean, + const cnnlTensorDescriptor_t invstd_desc, + const void* invstd, + const cnnlTensorDescriptor_t weight_desc, + const void* weight, + const cnnlTensorDescriptor_t sum_dy_desc, + const void* sum_dy, + const cnnlTensorDescriptor_t sum_dy_xmu_desc, + const void* sum_dy_xmu, + const cnnlTensorDescriptor_t count_desc, + const void* count, + const cnnlTensorDescriptor_t diff_x_desc, + void* diff_x); }; +const std::map, std::vector>> + TransPermMap = { + // trans_mode, (forward_perm, backward_perm) + {"3D_NCHW2NHWC", {{0, 2, 1}, {0, 2, 1}}}, + {"4D_NCHW2NHWC", {{0, 2, 3, 1}, {0, 3, 1, 2}}}, + {"5D_NCHWD2NDHWC", {{0, 4, 2, 3, 1}, {0, 4, 2, 3, 1}}}, + {"5D_NHWDC2NDHWC", {{0, 3, 1, 2, 4}, {0, 2, 3, 4, 1}}}}; + +inline void SetMLUTransposePerm(const framework::DDim& dims, + const DataLayout& data_layout, + std::vector* forward_perm, + std::vector* backward_perm, + std::vector* out_shape) { + const int dim_size = dims.size(); + PADDLE_ENFORCE_EQ((dim_size >= 3) && (dim_size <= 5), + true, + platform::errors::InvalidArgument( + "MLUTransposePerm func only support (dim_size >= 3) && " + "(dim_size <= 5), but now dim_size is %d.", + dim_size)); + + PADDLE_ENFORCE_EQ( + (data_layout == DataLayout::kNCHW) || (data_layout == DataLayout::kNHWC), + true, + platform::errors::InvalidArgument( + "MLUTransposePerm func only support DataLayout: kNCHW or kNHWC, but " + "now data_layout is %s.", + data_layout)); + + // case 1: NCHW of Paddle != NHWC of MLU when dims==3,4 + // case 2: NHWDC and NCHWD of Paddle != NDHWC of MLU when dims==5 + std::string map_key = ""; + if (data_layout == DataLayout::kNCHW) { + switch (dim_size) { + case 3: + map_key = "3D_NCHW2NHWC"; + break; + case 4: + map_key = "4D_NCHW2NHWC"; + break; + case 5: + map_key = "5D_NCHWD2NDHWC"; + break; + } + } else if (data_layout == DataLayout::kNHWC && dim_size == 5) { + map_key = "5D_NHWDC2NDHWC"; + } + assert(map_key != ""); + forward_perm->assign(TransPermMap.at(map_key).first.begin(), + TransPermMap.at(map_key).first.end()); + backward_perm->assign(TransPermMap.at(map_key).second.begin(), + TransPermMap.at(map_key).second.end()); + + auto in_dims = phi::vectorize(dims); + for (size_t i = 0; i < in_dims.size(); i++) { + out_shape->push_back(in_dims[forward_perm->at(i)]); + } +} + template inline void TransposeFromMLUTensor(const ExecutionContext& ctx, const std::vector perm, diff --git a/paddle/fluid/operators/sync_batch_norm_op_mlu.cc b/paddle/fluid/operators/sync_batch_norm_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..ce511a12bbfdb2a685150d2a8a4980c599480ccd --- /dev/null +++ b/paddle/fluid/operators/sync_batch_norm_op_mlu.cc @@ -0,0 +1,492 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the Licnse. */ + +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/batch_norm_op.h" +#include "paddle/fluid/platform/collective_helper.h" +#if defined(PADDLE_WITH_CNCL) +#include "paddle/fluid/platform/device/mlu/cncl_helper.h" +#endif +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +#define GET_LAYOUT_OFFSET 2 +using Tensor = framework::Tensor; +static std::vector supported_input_layout = { + CNNL_LAYOUT_NC, CNNL_LAYOUT_NLC, CNNL_LAYOUT_NHWC, CNNL_LAYOUT_NDHWC}; + +template +class SyncBatchNormMLUKernel : public framework::OpKernel { + using MPDType = typename details::MPTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext &ctx) const override { + float epsilon = ctx.Attr("epsilon"); + float momentum = ctx.Attr("momentum"); + const bool is_test = ctx.Attr("is_test"); + const bool use_global_stats = ctx.Attr("use_global_stats"); + const bool trainable_stats = ctx.Attr("trainable_statistics"); + const std::string layout_str = ctx.Attr("data_layout"); + const DataLayout layout = framework::StringToDataLayout(layout_str); + + PADDLE_ENFORCE_EQ(use_global_stats, + false, + platform::errors::InvalidArgument( + "sync_batch_norm doesn't support " + "to set use_global_stats True. Please use batch_norm " + "in this case.")); + + const auto *x = ctx.Input("X"); + const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + const auto *mean = ctx.Input("Mean"); + const auto *variance = ctx.Input("Variance"); + auto *mean_out = ctx.Output("MeanOut"); + auto *variance_out = ctx.Output("VarianceOut"); + auto *saved_mean = ctx.Output("SavedMean"); + auto *saved_variance = ctx.Output("SavedVariance"); + auto *y = ctx.Output("Y"); + + const auto &x_dims = x->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), + 2, + platform::errors::InvalidArgument( + "The Input dim size should be larger than 1.")); + PADDLE_ENFORCE_LE(x_dims.size(), + 5, + platform::errors::InvalidArgument( + "The Input dim size should be less than 6.")); + + int N, C, H, W, D; + ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); + + y->mutable_data(ctx.GetPlace()); + mean_out->mutable_data(ctx.GetPlace()); + variance_out->mutable_data(ctx.GetPlace()); + saved_mean->mutable_data(ctx.GetPlace()); + saved_variance->mutable_data(ctx.GetPlace()); + + Tensor trans_x; + Tensor trans_y; + std::vector forward_perm; + std::vector backward_perm; + std::vector trans_shape; + const bool need_transpose = + ((layout == DataLayout::kNCHW && x_dims.size() != 2) || + x_dims.size() == 5); + if (need_transpose) { + SetMLUTransposePerm( + x_dims, layout, &forward_perm, &backward_perm, &trans_shape); + trans_x.mutable_data(phi::make_ddim(trans_shape), ctx.GetPlace()); + trans_y.mutable_data(phi::make_ddim(trans_shape), ctx.GetPlace()); + MLUCnnlTensorDesc desc_x(*x); + MLUCnnlTensorDesc desc_trans_x( + trans_shape.size(), trans_shape.data(), ToCnnlDataType(x->dtype())); + MLUCnnl::Transpose(ctx, + forward_perm, + x_dims.size(), + desc_x.get(), + GetBasePtr(x), + desc_trans_x.get(), + GetBasePtr(&trans_x)); + } else { + trans_x = *x; + trans_y = *y; + } + + MLUCnnlTensorDesc desc_trans( + trans_x, + supported_input_layout[x_dims.size() - GET_LAYOUT_OFFSET], + ToCnnlDataType()); + + bool test_mode = is_test && (!trainable_stats); + if (test_mode) { // inference + MLUCnnlTensorDesc desc_weight_bias_mean_var(*bias); + MLUCnnl::FusedBatchNorm(ctx, + false /*is_training*/, + desc_trans.get(), + GetBasePtr(&trans_x), + desc_weight_bias_mean_var.get(), + GetBasePtr(scale), + GetBasePtr(bias), + GetBasePtr(mean), + GetBasePtr(variance), + epsilon, + momentum, + desc_trans.get(), + GetBasePtr(&trans_y), + nullptr, + nullptr, + nullptr, + nullptr); + } else { // training + if (ctx.HasInput("MomentumTensor")) { + const auto *mom_tensor = ctx.Input("MomentumTensor"); + Tensor mom_cpu; + paddle::framework::TensorCopySync( + *mom_tensor, platform::CPUPlace(), &mom_cpu); + momentum = mom_cpu.data()[0]; + } + + Tensor local_mean, local_var; + local_mean.mutable_data(mean->dims(), ctx.GetPlace()); + local_var.mutable_data(variance->dims(), ctx.GetPlace()); + MLUCnnlTensorDesc desc_mean_var(*mean_out); + + // cacl local_mean and local_var + MLUCnnl::SyncBatchNormStats(ctx, + desc_trans.get(), + GetBasePtr(&trans_x), + epsilon, + desc_mean_var.get(), + GetBasePtr(&local_mean), + desc_mean_var.get(), + GetBasePtr(&local_var)); + + Tensor input_count; + input_count.mutable_data(phi::make_ddim({1}), ctx.GetPlace()); + FillMLUTensorWithHostValue( + ctx, static_cast(x->numel() / C), &input_count); + + Tensor count_all; + Tensor mean_all(mean->dtype()); + Tensor invstd_all(variance->dtype()); + + auto &dev_ctx = + ctx.template device_context(); + auto stream = dev_ctx.stream(); + auto *comm = dev_ctx.cncl_comm(); + if (comm) { + auto *comm = paddle::platform::CNCLCommContext::Instance() + .Get(0, ctx.GetPlace()) + ->comm(); + int count; + PADDLE_ENFORCE_MLU_SUCCESS(cnclGetCommCount(&count, comm)); + count_all.mutable_data(phi::make_ddim({count}), ctx.GetPlace()); + cnclDataType_t dtype = platform::ToCNCLDataType( + framework::TransToProtoVarType(count_all.dtype())); + PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&input_count), + GetBasePtr(&count_all), + 1, + dtype, + comm, + stream)); + + mean_all.mutable_data(phi::make_ddim({count, mean->numel()}), + ctx.GetPlace()); + invstd_all.mutable_data( + phi::make_ddim({count, variance->numel()}), ctx.GetPlace()); + + auto cncl_dtype = platform::ToCNCLDataType( + framework::TransToProtoVarType(mean_all.dtype())); + PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&local_mean), + GetBasePtr(&mean_all), + local_mean.numel(), + cncl_dtype, + comm, + stream)); + + PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&local_var), + GetBasePtr(&invstd_all), + local_var.numel(), + cncl_dtype, + comm, + stream)); + + } else { + count_all = input_count; + mean_all.ShareDataWith(local_mean); + invstd_all.ShareDataWith(local_var); + mean_all.Resize(phi::make_ddim({1, local_mean.numel()})); + invstd_all.Resize(phi::make_ddim({1, local_var.numel()})); + } + + MLUCnnlTensorDesc desc_all_mean_invstd( + invstd_all, CNNL_LAYOUT_NC, ToCnnlDataType()); + MLUCnnlTensorDesc desc_moving_mean_var(*mean_out); + MLUCnnlTensorDesc desc_saved_mean_var(*saved_mean); + MLUCnnlTensorDesc desc_count_all(count_all); + + MLUCnnl::SyncBatchNormGatherStatsWithCounts(ctx, + momentum, + epsilon, + desc_all_mean_invstd.get(), + GetBasePtr(&mean_all), + desc_all_mean_invstd.get(), + GetBasePtr(&invstd_all), + desc_moving_mean_var.get(), + GetBasePtr(mean_out), + desc_moving_mean_var.get(), + GetBasePtr(variance_out), + desc_count_all.get(), + GetBasePtr(&count_all), + desc_saved_mean_var.get(), + GetBasePtr(saved_mean), + desc_saved_mean_var.get(), + GetBasePtr(saved_variance)); + + MLUCnnlTensorDesc desc_other_param(*saved_mean); + MLUCnnl::SyncBatchNormElemt(ctx, + desc_trans.get(), + GetBasePtr(&trans_x), + desc_other_param.get(), + GetBasePtr(saved_mean), + desc_other_param.get(), + GetBasePtr(saved_variance), + desc_other_param.get(), + GetBasePtr(scale), + desc_other_param.get(), + GetBasePtr(bias), + desc_trans.get(), + GetBasePtr(&trans_y)); + } + if (need_transpose) { + MLUCnnlTensorDesc desc_y(*y); + MLUCnnlTensorDesc desc_trans_y(trans_y); + MLUCnnl::Transpose(ctx, + backward_perm, + trans_y.dims().size(), + desc_trans_y.get(), + GetBasePtr(&trans_y), + desc_y.get(), + GetBasePtr(y)); + } + } +}; + +template +class SyncBatchNormMLUGradKernel : public framework::OpKernel { + using MPDType = typename details::MPTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const std::string layout_str = ctx.Attr("data_layout"); + const DataLayout layout = framework::StringToDataLayout(layout_str); + + const auto *d_y = ctx.Input(framework::GradVarName("Y")); + const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + + // init output + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + + const auto *saved_mean = ctx.Input("SavedMean"); + const auto *saved_inv_var = ctx.Input("SavedVariance"); + + const Tensor *x; + if (ctx.HasInput("Y")) { + PADDLE_ENFORCE_EQ(true, + false, + platform::errors::InvalidArgument( + "sync_batch_norm_grad doesn't support input Y")); + } else { + x = ctx.Input("X"); + } + + const auto &x_dims = x->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), + 2, + platform::errors::InvalidArgument( + "The Input X dim size should be larger than 1.")); + PADDLE_ENFORCE_LE(x_dims.size(), + 5, + platform::errors::InvalidArgument( + "The Input X dim size should be less than 6.")); + + int N, C, H, W, D; + ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); + PADDLE_ENFORCE_EQ(scale->dims()[0], + C, + platform::errors::InvalidArgument( + "Expected first dim for input parameter(scale) of " + "OP(sync_batch_norm) be (%d), but given (%d).", + C, + scale->dims()[0])); + + d_x->mutable_data(ctx.GetPlace()); + if (d_scale && d_bias) { + d_scale->mutable_data(ctx.GetPlace()); + d_bias->mutable_data(ctx.GetPlace()); + } + PADDLE_ENFORCE_EQ(scale->dims().size(), + 1UL, + platform::errors::InvalidArgument( + "Expected rank for input parameter(scale) of " + "OP(sync_batch_norm) be (1), but given (%d).", + scale->dims().size())); + + Tensor trans_x; + Tensor trans_dy; + Tensor trans_dx; + std::vector forward_perm; + std::vector backward_perm; + std::vector trans_shape; + const bool need_transpose = + ((layout == DataLayout::kNCHW && x_dims.size() != 2) || + x_dims.size() == 5); + if (need_transpose) { + SetMLUTransposePerm( + x_dims, layout, &forward_perm, &backward_perm, &trans_shape); + trans_x.mutable_data(phi::make_ddim(trans_shape), ctx.GetPlace()); + trans_dy.mutable_data(phi::make_ddim(trans_shape), ctx.GetPlace()); + trans_dx.mutable_data(phi::make_ddim(trans_shape), ctx.GetPlace()); + MLUCnnlTensorDesc desc_x(*x); + MLUCnnlTensorDesc desc_trans_x( + trans_shape.size(), trans_shape.data(), ToCnnlDataType(x->dtype())); + MLUCnnl::Transpose(ctx, + forward_perm, + x_dims.size(), + desc_x.get(), + GetBasePtr(x), + desc_trans_x.get(), + GetBasePtr(&trans_x)); + MLUCnnl::Transpose(ctx, + forward_perm, + x_dims.size(), + desc_x.get(), + GetBasePtr(d_y), + desc_trans_x.get(), + GetBasePtr(&trans_dy)); + } else { + trans_x = *x; + trans_dy = *d_y; + trans_dx = *d_x; + } + MLUCnnlTensorDesc desc_trans( + trans_x, + supported_input_layout[x_dims.size() - GET_LAYOUT_OFFSET], + ToCnnlDataType()); + + Tensor sum_dy, sum_dy_xmu; + sum_dy.mutable_data(bias->dims(), ctx.GetPlace()); + sum_dy_xmu.mutable_data(bias->dims(), ctx.GetPlace()); + MLUCnnlTensorDesc desc_other_param(*bias); + + MLUCnnl::SyncBatchnormBackwardReduce( + ctx, + desc_trans.get(), + GetBasePtr(&trans_dy), + desc_trans.get(), + GetBasePtr(&trans_x), + desc_other_param.get(), + GetBasePtr(saved_mean), + desc_other_param.get(), + GetBasePtr(saved_inv_var), + d_scale ? desc_other_param.get() : nullptr, + d_scale ? GetBasePtr(d_scale) : nullptr, + d_bias ? desc_other_param.get() : nullptr, + d_bias ? GetBasePtr(d_bias) : nullptr, + desc_other_param.get(), + GetBasePtr(&sum_dy), + desc_other_param.get(), + GetBasePtr(&sum_dy_xmu), + true /*compute sum_dy, sum_dy_xmu*/, + d_scale ? true : false /*compute d_scale*/, + d_bias ? true : false /*compute d_bias*/); + + Tensor numel_count; + numel_count.mutable_data(phi::make_ddim({1}), ctx.GetPlace()); + FillMLUTensorWithHostValue( + ctx, static_cast(x->numel() / C), &numel_count); + + auto &dev_ctx = + ctx.template device_context(); + auto stream = dev_ctx.stream(); + auto *comm = dev_ctx.cncl_comm(); + if (comm) { + auto *comm = paddle::platform::CNCLCommContext::Instance() + .Get(0, ctx.GetPlace()) + ->comm(); + cnclDataType_t dtype = platform::ToCNCLDataType( + framework::TransToProtoVarType(numel_count.dtype())); + PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&numel_count), + GetBasePtr(&numel_count), + 1, + dtype, + cnclSum, + comm, + stream)); + + auto cncl_dtype = platform::ToCNCLDataType( + framework::TransToProtoVarType(sum_dy.dtype())); + PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&sum_dy), + GetBasePtr(&sum_dy), + sum_dy.numel(), + cncl_dtype, + cnclSum, + comm, + stream)); + + PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&sum_dy_xmu), + GetBasePtr(&sum_dy_xmu), + sum_dy_xmu.numel(), + cncl_dtype, + cnclSum, + comm, + stream)); + } + + if (d_x) { + MLUCnnlTensorDesc desc_count(numel_count); + MLUCnnl::SyncBatchNormBackwardElemt(ctx, + desc_trans.get(), + GetBasePtr(&trans_dy), + desc_trans.get(), + GetBasePtr(&trans_x), + desc_other_param.get(), + GetBasePtr(saved_mean), + desc_other_param.get(), + GetBasePtr(saved_inv_var), + desc_other_param.get(), + GetBasePtr(scale), + desc_other_param.get(), + GetBasePtr(&sum_dy), + desc_other_param.get(), + GetBasePtr(&sum_dy_xmu), + desc_count.get(), + GetBasePtr(&numel_count), + desc_trans.get(), + GetBasePtr(&trans_dx)); + + if (need_transpose) { + MLUCnnlTensorDesc desc_dx(*d_x); + MLUCnnlTensorDesc desc_trans_dx(trans_dx); + MLUCnnl::Transpose(ctx, + backward_perm, + trans_dx.dims().size(), + desc_trans_dx.get(), + GetBasePtr(&trans_dx), + desc_dx.get(), + GetBasePtr(d_x)); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_MLU_KERNEL(sync_batch_norm, + ops::SyncBatchNormMLUKernel, + ops::SyncBatchNormMLUKernel); + +REGISTER_OP_MLU_KERNEL(sync_batch_norm_grad, + ops::SyncBatchNormMLUGradKernel, + ops::SyncBatchNormMLUGradKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/CMakeLists.txt b/python/paddle/fluid/tests/unittests/mlu/CMakeLists.txt index cac8e95521d319c2c84a89d979d1328d98a28daf..385879c08a72f524d07492d0a5ec75f38474fc74 100644 --- a/python/paddle/fluid/tests/unittests/mlu/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/mlu/CMakeLists.txt @@ -50,5 +50,7 @@ if(WITH_MLU) set_tests_properties(test_collective_allgather_api_mlu PROPERTIES TIMEOUT 120) set_tests_properties(test_c_comm_init_op_mlu PROPERTIES TIMEOUT 120) + set_tests_properties(test_sync_batch_norm_op_mlu_baseline PROPERTIES TIMEOUT + 120) endif() endif() diff --git a/python/paddle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..4f80523a18254a9e5b618e7ed227714b06599621 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/sync_batch_norm_op_mlu.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys + +sys.path.append("..") +import signal +import time +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_sync_batch_norm_base_mlu import TestSyncBatchNormRunnerBase, runtime_main +from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator + +from paddle.fluid.tests.unittests.test_sync_batch_norm_op import create_or_get_tensor + +_set_use_system_allocator(False) +paddle.enable_static() + + +class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): + + def __init__(self): + self.global_ring_id = 0 + + self.dtype = np.float32 + self.N = 8 + self.C = 16 + self.H = 32 + self.W = 32 + self.dshape = [self.N, self.C, self.H, self.W] + self.atol = 1e-3 + + def get_model(self, + main, + startup, + place, + layout, + seed, + sync_bn=False, + only_forward=False): + """Build program.""" + use_cudnn = False + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + data = fluid.layers.data(name='input', + shape=self.dshape, + dtype=self.dtype, + append_batch_size=False) + conv = fluid.layers.conv2d( + input=data, + num_filters=32, + filter_size=1, + param_attr=fluid.ParamAttr(name='conv2d_weight'), + bias_attr=False, + use_cudnn=use_cudnn) + bn = fluid.layers.batch_norm( + conv, + param_attr=fluid.ParamAttr(name='bn_scale'), + bias_attr=fluid.ParamAttr(name='bn_bias'), + moving_mean_name='bn_moving_mean', + moving_variance_name='bn_moving_variance', + data_layout=layout, + is_test=only_forward) + # if self.dtype == np.float16: + # bn = fluid.layers.cast(bn, 'float32') + sigmoid = fluid.layers.sigmoid(bn) + out = fluid.layers.reduce_sum(sigmoid) + # if not sync_bn: + # out = out / core.get_mlu_device_count() + if not only_forward: + sgd_opt = fluid.optimizer.SGD(learning_rate=0.0) + sgd_opt.backward(out) + return [out, conv, bn] + + +if __name__ == "__main__": + # print('sync_batch_norm_op_mlu.py __main__') + + runtime_main(TestSyncBatchNormOpTraining, "identity", 0)