From 0ca2807c7f0d33bca8e42d04157c39d83061e5c7 Mon Sep 17 00:00:00 2001 From: furnace <34057289+windstamp@users.noreply.github.com> Date: Thu, 21 Oct 2021 18:42:50 +0800 Subject: [PATCH] [NPU] Add sync_batch_norm and sync_batch_norm_grad NPU Kernel (#36320) * add sync_batch_norm (support train, infer, and fp32, fp16, and NCHW, NHWC) * [NPU] Delete debug codes * [NPU] Remove FP16 --- paddle/fluid/operators/CMakeLists.txt | 5 + paddle/fluid/operators/batch_norm_op_npu.cc | 3 + .../fluid/operators/sync_batch_norm_op_npu.cc | 995 ++++++++++++++++++ .../unittests/npu/sync_batch_norm_op_npu.py | 104 ++ .../npu/test_sync_batch_norm_base_npu.py | 481 +++++++++ .../test_sync_batch_norm_op_npu_baseline.py | 42 + .../npu/test_sync_batch_norm_op_npu_extra.py | 105 ++ 7 files changed, 1735 insertions(+) create mode 100644 paddle/fluid/operators/sync_batch_norm_op_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/sync_batch_norm_op_npu.py create mode 100644 python/paddle/fluid/tests/unittests/npu/test_sync_batch_norm_base_npu.py create mode 100644 python/paddle/fluid/tests/unittests/npu/test_sync_batch_norm_op_npu_baseline.py create mode 100644 python/paddle/fluid/tests/unittests/npu/test_sync_batch_norm_op_npu_extra.py diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 937bfea3a5..dcf492dc6d 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -121,6 +121,11 @@ else() endif() endif() +if (WITH_ASCEND_CL) + op_library(sync_batch_norm_op) + file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(sync_batch_norm);\n") +endif() + op_library(lstm_op DEPS ${OP_HEADER_DEPS} lstm_compute) op_library(eye_op DEPS ${OP_HEADER_DEPS}) op_library(recurrent_op DEPS ${OP_HEADER_DEPS}) diff --git a/paddle/fluid/operators/batch_norm_op_npu.cc b/paddle/fluid/operators/batch_norm_op_npu.cc index 791c365679..3bcd0ac37b 100644 --- a/paddle/fluid/operators/batch_norm_op_npu.cc +++ b/paddle/fluid/operators/batch_norm_op_npu.cc @@ -192,6 +192,9 @@ class NPUBatchNormGradOpKernel : public framework::OpKernel { auto dx_tensor = ctx.AllocateTmpTensor(d_x->dims(), dev_ctx); dx_tensor.ShareDataWith(*d_x); + if (data_layout == DataLayout::kNHWC) { + dx_tensor.set_layout(DataLayout::kNHWC); + } if (use_global_stats) { if (x->dims().size() == 3) { // BNInferGrad only support x rank = 4, diff --git a/paddle/fluid/operators/sync_batch_norm_op_npu.cc b/paddle/fluid/operators/sync_batch_norm_op_npu.cc new file mode 100644 index 0000000000..31289b1c23 --- /dev/null +++ b/paddle/fluid/operators/sync_batch_norm_op_npu.cc @@ -0,0 +1,995 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the Licnse. */ + +#include "paddle/fluid/operators/batch_norm_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/hccl_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +void training_or_inference( + const framework::ExecutionContext &ctx, const aclrtStream &stream, + const platform::Place &place, const DataLayout &layout, + const bool &test_mode, const int &N, const int &C, const int &H, + const int &W, const float epsilon, const float &momentum, + const Tensor *common_mean, const Tensor *common_var, const Tensor *x, + const Tensor *scale, const Tensor *bias, const Tensor *mean, + const Tensor *variance, Tensor *mean_out, Tensor *variance_out, + Tensor *saved_mean, Tensor *saved_variance, Tensor *y) { + std::vector axes; + if (layout == framework::DataLayout::kNCHW) { + axes = {0, 2, 3}; + } else if (layout == framework::DataLayout::kNHWC) { + axes = {0, 1, 2}; + } + + std::vector multiples; + if (layout == framework::DataLayout::kNCHW) + multiples = {N, 1, H, W}; + else if (layout == framework::DataLayout::kNHWC) + multiples = {N, H, W, 1}; + + Tensor common_mean_tile_1; + { + common_mean_tile_1.Resize({C}); + common_mean_tile_1.mutable_data(place); + TensorCopySync(*common_mean, place, &common_mean_tile_1); + if (layout == framework::DataLayout::kNCHW) + common_mean_tile_1.Resize({1, C, 1, 1}); + else if (layout == framework::DataLayout::kNHWC) + common_mean_tile_1.Resize({1, 1, 1, C}); + } + + Tensor common_mean_tile; + { + framework::NPUAttributeMap attr_input = {{"multiples", multiples}}; + common_mean_tile.Resize(x->dims()); + common_mean_tile.mutable_data(place); + const auto &runner = NpuOpRunner("TileD", {common_mean_tile_1}, + {common_mean_tile}, attr_input); + runner.Run(stream); + } + + Tensor common_var_tile_1; + { + common_var_tile_1.Resize({C}); + common_var_tile_1.mutable_data(place); + TensorCopySync(*common_var, place, &common_var_tile_1); + if (layout == framework::DataLayout::kNCHW) + common_var_tile_1.Resize({1, C, 1, 1}); + else if (layout == framework::DataLayout::kNHWC) + common_var_tile_1.Resize({1, 1, 1, C}); + } + + Tensor common_var_tile; + { + framework::NPUAttributeMap attr_input = {{"multiples", multiples}}; + common_var_tile.Resize(x->dims()); + common_var_tile.mutable_data(place); + const auto &runner = NpuOpRunner("TileD", {common_var_tile_1}, + {common_var_tile}, attr_input); + runner.Run(stream); + } + + Tensor common_var_tile_add_epsilon; + { + framework::NPUAttributeMap attr_input = {{"value", epsilon}}; + common_var_tile_add_epsilon.Resize(x->dims()); + common_var_tile_add_epsilon.mutable_data(place); + const auto &runner = NpuOpRunner("Adds", {common_var_tile}, + {common_var_tile_add_epsilon}, attr_input); + runner.Run(stream); + } + + Tensor common_var_tile_add_epsilon_sqrt; + { + common_var_tile_add_epsilon_sqrt.Resize(x->dims()); + common_var_tile_add_epsilon_sqrt.mutable_data(place); + const auto &runner = NpuOpRunner("Sqrt", {common_var_tile_add_epsilon}, + {common_var_tile_add_epsilon_sqrt}, {}); + runner.Run(stream); + } + + Tensor x_sub_common_mean; + { + x_sub_common_mean.Resize(x->dims()); + x_sub_common_mean.mutable_data(place); + const auto &runner = + NpuOpRunner("Sub", {*x, common_mean_tile}, {x_sub_common_mean}, {}); + runner.Run(stream); + } + + Tensor normalized; + { + normalized.Resize(x->dims()); + normalized.mutable_data(place); + const auto &runner = NpuOpRunner( + "Div", {x_sub_common_mean, common_var_tile_add_epsilon_sqrt}, + {normalized}, {}); + runner.Run(stream); + } + + Tensor scale_tile_1; + { + scale_tile_1.Resize({C}); + scale_tile_1.mutable_data(place); + TensorCopySync(*scale, place, &scale_tile_1); + if (layout == framework::DataLayout::kNCHW) + scale_tile_1.Resize({1, C, 1, 1}); + else if (layout == framework::DataLayout::kNHWC) + scale_tile_1.Resize({1, 1, 1, C}); + } + + Tensor scale_tile; + { + framework::NPUAttributeMap attr_input = {{"multiples", multiples}}; + scale_tile.Resize(x->dims()); + scale_tile.mutable_data(place); + const auto &runner = + NpuOpRunner("TileD", {scale_tile_1}, {scale_tile}, attr_input); + runner.Run(stream); + } + + Tensor normalized_mul_scale; + { + normalized_mul_scale.Resize(x->dims()); + normalized_mul_scale.mutable_data(place); + const auto &runner = NpuOpRunner("Mul", {normalized, scale_tile}, + {normalized_mul_scale}, {}); + runner.Run(stream); + } + + Tensor bias_tile_1; + { + bias_tile_1.Resize({C}); + bias_tile_1.mutable_data(place); + TensorCopySync(*bias, place, &bias_tile_1); + if (layout == framework::DataLayout::kNCHW) + bias_tile_1.Resize({1, C, 1, 1}); + else if (layout == framework::DataLayout::kNHWC) + bias_tile_1.Resize({1, 1, 1, C}); + } + + Tensor bias_tile; + { + framework::NPUAttributeMap attr_input = {{"multiples", multiples}}; + bias_tile.Resize(x->dims()); + bias_tile.mutable_data(place); + const auto &runner = + NpuOpRunner("TileD", {bias_tile_1}, {bias_tile}, attr_input); + runner.Run(stream); + } + + // calculate y + { + y->mutable_data(place); + const auto &runner = + NpuOpRunner("Add", {normalized_mul_scale, bias_tile}, {*y}, {}); + runner.Run(stream); + } + + if (!test_mode) { + Tensor ones; + { + ones.Resize({C}); + ones.mutable_data(place); + FillNpuTensorWithConstant(&ones, 1); + } + + // cacl mean_out + { + Tensor common_mean_mul_1_sub_momentum; + { + framework::NPUAttributeMap attr_input = {{"value", 1 - momentum}}; + common_mean_mul_1_sub_momentum.Resize({C}); + common_mean_mul_1_sub_momentum.mutable_data(place); + const auto &runner = + NpuOpRunner("Muls", {*common_mean}, + {common_mean_mul_1_sub_momentum}, attr_input); + runner.Run(stream); + } + + Tensor mean_mul_momentum; + { + framework::NPUAttributeMap attr_input = {{"value", momentum}}; + mean_mul_momentum.Resize({C}); + mean_mul_momentum.mutable_data(place); + const auto &runner = + NpuOpRunner("Muls", {*mean}, {mean_mul_momentum}, attr_input); + runner.Run(stream); + } + + mean_out->mutable_data(place); + + const auto &runner = NpuOpRunner( + "Add", {common_mean_mul_1_sub_momentum, mean_mul_momentum}, + {*mean_out}, {}); + runner.Run(stream); + } + + // cacl variance_out + { + Tensor momentum_mul_var; + { + framework::NPUAttributeMap attr_input = {{"value", momentum}}; + momentum_mul_var.Resize({C}); + momentum_mul_var.mutable_data(place); + const auto &runner = + NpuOpRunner("Muls", {*variance}, {momentum_mul_var}, attr_input); + runner.Run(stream); + } + + Tensor var_ref_mul_1_sub_momentum; + { + framework::NPUAttributeMap attr_input = {{"value", 1 - momentum}}; + var_ref_mul_1_sub_momentum.Resize({C}); + var_ref_mul_1_sub_momentum.mutable_data(place); + const auto &runner = NpuOpRunner( + "Muls", {*common_var}, {var_ref_mul_1_sub_momentum}, attr_input); + runner.Run(stream); + } + + variance_out->mutable_data(place); + + const auto &runner = + NpuOpRunner("Add", {var_ref_mul_1_sub_momentum, momentum_mul_var}, + {*variance_out}, {}); + runner.Run(stream); + } + + // cacl saved_variance + { + Tensor var_ref_add_epsilon; + { + framework::NPUAttributeMap attr_input = {{"value", epsilon}}; + var_ref_add_epsilon.Resize({C}); + var_ref_add_epsilon.mutable_data(place); + const auto &runner = NpuOpRunner("Adds", {*common_var}, + {var_ref_add_epsilon}, attr_input); + runner.Run(stream); + } + + Tensor var_ref_add_epsilon_sqrt; + { + var_ref_add_epsilon_sqrt.Resize({C}); + var_ref_add_epsilon_sqrt.mutable_data(place); + const auto &runner = NpuOpRunner("Sqrt", {var_ref_add_epsilon}, + {var_ref_add_epsilon_sqrt}, {}); + runner.Run(stream); + } + + saved_variance->mutable_data(place); + + const auto &runner = NpuOpRunner("Div", {ones, var_ref_add_epsilon_sqrt}, + {*saved_variance}, {}); + runner.Run(stream); + } + } +} + +template +class SyncBatchNormNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const float epsilon = ctx.Attr("epsilon"); + float momentum = ctx.Attr("momentum"); + const bool is_test = ctx.Attr("is_test"); + const std::string layout_str = ctx.Attr("data_layout"); + const DataLayout layout = framework::StringToDataLayout(layout_str); + const bool use_global_stats = ctx.Attr("use_global_stats"); + const bool trainable_stats = ctx.Attr("trainable_statistics"); + + PADDLE_ENFORCE_EQ(use_global_stats, false, + platform::errors::InvalidArgument( + "sync_batch_norm doesn't support " + "to set use_global_stats True. Please use batch_norm " + "in this case.")); + + const auto *x = ctx.Input("X"); + auto *y = ctx.Output("Y"); + const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + const auto *mean = ctx.Input("Mean"); + const auto *variance = ctx.Input("Variance"); + auto *mean_out = ctx.Output("MeanOut"); + auto *variance_out = ctx.Output("VarianceOut"); + auto *saved_mean = ctx.Output("SavedMean"); + auto *saved_variance = ctx.Output("SavedVariance"); + + const auto &x_dims = x->dims(); + PADDLE_ENFORCE_EQ(x_dims.size(), 4, + platform::errors::InvalidArgument( + "The input tensor X's dimension must equal to 4. But " + "received X's shape = [%s], X's dimension = [%d].", + x_dims, x_dims.size())); + + int N, C, H, W, D; + ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); + + int x_numel = x->numel(); + auto place = ctx.GetPlace(); + auto stream = + ctx.template device_context() + .stream(); + + std::vector axes; + if (layout == framework::DataLayout::kNCHW) { + axes = {0, 2, 3}; + } else if (layout == framework::DataLayout::kNHWC) { + axes = {0, 1, 2}; + } + + bool test_mode = is_test && (!trainable_stats); + if (test_mode) { // inference + // cacl saved_mean + saved_mean->mutable_data(place); + TensorCopySync(*mean, place, saved_mean); + + // cacl saved_variance + saved_variance->mutable_data(place); + TensorCopySync(*variance, place, saved_variance); + + // cacl y + training_or_inference(ctx, stream, place, layout, test_mode, N, C, H, + W, epsilon, momentum, mean, variance, x, scale, + bias, mean, variance, NULL, NULL, NULL, NULL, y); + + } else { // training + if (ctx.HasInput("MomentumTensor")) { + const auto *mom_tensor = ctx.Input("MomentumTensor"); + Tensor mom_cpu; + TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu); + momentum = mom_cpu.data()[0]; + } + + // cacl saved_mean and var_ref + Tensor var_ref; + var_ref.Resize({C}); + var_ref.mutable_data(place); + { + Tensor x_sum; + { + framework::NPUAttributeMap attr_input = {{"keep_dims", false}, + {"axes", axes}}; + x_sum.Resize({C}); + x_sum.mutable_data(place); + const auto &runner = + NpuOpRunner("ReduceSumD", {*x}, {x_sum}, attr_input); + runner.Run(stream); + } + + Tensor x_square; + { + x_square.Resize(x->dims()); + x_square.mutable_data(place); + const auto &runner = NpuOpRunner("Square", {*x}, {x_square}, {}); + runner.Run(stream); + } + + Tensor x_square_sum; + { + framework::NPUAttributeMap attr_input = {{"keep_dims", false}, + {"axes", axes}}; + x_square_sum.Resize({C}); + x_square_sum.mutable_data(place); + const auto &runner = + NpuOpRunner("ReduceSumD", {x_square}, {x_square_sum}, attr_input); + runner.Run(stream); + } + + auto comm = paddle::platform::HCCLCommContext::Instance().Get(0, place); + + float device_counts = 0.0; + if (comm) { + HcclDataType dtype = platform::ToHCCLDataType(mean_out->type()); + + Tensor device_count_tensor; + { + device_count_tensor.Resize({1}); + device_count_tensor.mutable_data(place); + FillNpuTensorWithConstant(&device_count_tensor, 1); + } + + // HcclAllReduce device_count_tensor + { + void *sendbuff = reinterpret_cast( + const_cast(device_count_tensor.data())); + void *recvbuff = sendbuff; + PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllReduce( + sendbuff, recvbuff, 1, dtype, HCCL_REDUCE_SUM, comm->comm(), + reinterpret_cast(stream))); + } + + std::vector device_count_vec(1); + TensorToVector(device_count_tensor, ctx.device_context(), + &device_count_vec); + device_counts = device_count_vec[0]; + + // HcclAllReduce x_sum + { + void *sendbuff = reinterpret_cast( + const_cast(x_sum.data())); + void *recvbuff = sendbuff; + PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllReduce( + sendbuff, recvbuff, C, dtype, HCCL_REDUCE_SUM, comm->comm(), + reinterpret_cast(stream))); + } + + // HcclAllReduce x_square_sum + { + void *sendbuff = reinterpret_cast( + const_cast(x_square_sum.data())); + void *recvbuff = sendbuff; + PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllReduce( + sendbuff, recvbuff, C, dtype, HCCL_REDUCE_SUM, comm->comm(), + reinterpret_cast(stream))); + } + } + + // cacl saved_mean + { + framework::NPUAttributeMap attr_input = { + {"value", 1.0f * C / x_numel / device_counts}}; + saved_mean->mutable_data(place); + const auto &runner = + NpuOpRunner("Muls", {x_sum}, {*saved_mean}, attr_input); + runner.Run(stream); + } + + // cacl var_ref + { + Tensor saved_mean_square; + { + saved_mean_square.Resize({C}); + saved_mean_square.mutable_data(place); + const auto &runner = + NpuOpRunner("Square", {*saved_mean}, {saved_mean_square}, {}); + runner.Run(stream); + } + + Tensor var_ref_tmp; + var_ref_tmp.Resize({C}); + var_ref_tmp.mutable_data(place); + { + framework::NPUAttributeMap attr_input = { + {"value", 1.0f * C / x_numel / device_counts}}; + const auto &runner = + NpuOpRunner("Muls", {x_square_sum}, {var_ref_tmp}, attr_input); + runner.Run(stream); + } + + // cacl var_ref + { + const auto &runner = NpuOpRunner( + "Sub", {var_ref_tmp, saved_mean_square}, {var_ref}, {}); + runner.Run(stream); + } + } + } + + training_or_inference(ctx, stream, place, layout, test_mode, N, C, H, + W, epsilon, momentum, saved_mean, &var_ref, x, + scale, bias, mean, variance, mean_out, + variance_out, saved_mean, saved_variance, y); + } + } +}; + +template +class SyncBatchNormNPUGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + float epsilon = ctx.Attr("epsilon"); + const std::string layout_str = ctx.Attr("data_layout"); + const DataLayout layout = framework::StringToDataLayout(layout_str); + + const auto *d_y = ctx.Input(framework::GradVarName("Y")); + const auto *scale = ctx.Input("Scale"); + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + const auto *saved_mean = ctx.Input("SavedMean"); + + const Tensor *x; + if (ctx.HasInput("Y")) { + PADDLE_ENFORCE_EQ(true, false, + platform::errors::InvalidArgument( + "sync_batch_norm_grad doesn't support input Y")); + } else { + x = ctx.Input("X"); + } + + int N, C, H, W, D; + ExtractNCWHD(x->dims(), layout, &N, &C, &H, &W, &D); + + int x_numel = x->numel(); + auto place = ctx.GetPlace(); + auto stream = + ctx.template device_context() + .stream(); + + std::vector axes; + if (layout == framework::DataLayout::kNCHW) { + axes = {0, 2, 3}; + } else if (layout == framework::DataLayout::kNHWC) { + axes = {0, 1, 2}; + } + + std::vector multiples; + if (layout == framework::DataLayout::kNCHW) + multiples = {N, 1, H, W}; + else if (layout == framework::DataLayout::kNHWC) + multiples = {N, H, W, 1}; + + auto comm = paddle::platform::HCCLCommContext::Instance().Get(0, place); + HcclDataType dtype = platform::ToHCCLDataType(scale->type()); + + float device_counts = 0.0; + if (comm) { + Tensor device_count_tensor; + { + device_count_tensor.Resize({1}); + device_count_tensor.mutable_data(place); + FillNpuTensorWithConstant(&device_count_tensor, 1); + } + + // HcclAllReduce device_count_tensor + { + void *sendbuff = reinterpret_cast( + const_cast(device_count_tensor.data())); + void *recvbuff = sendbuff; + PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllReduce( + sendbuff, recvbuff, 1, dtype, HCCL_REDUCE_SUM, comm->comm(), + reinterpret_cast(stream))); + } + + std::vector device_count_vec(1); + TensorToVector(device_count_tensor, ctx.device_context(), + &device_count_vec); + device_counts = device_count_vec[0]; + PADDLE_ENFORCE_GE(device_counts, 2, platform::errors::PreconditionNotMet( + "device_counts should >= 2.")); + } + + // cacl var_ref + Tensor var_ref; + var_ref.Resize({C}); + var_ref.mutable_data(place); + { + // cacl var_ref + { + Tensor x_square; + { + x_square.Resize(x->dims()); + x_square.mutable_data(place); + const auto &runner = NpuOpRunner("Square", {*x}, {x_square}, {}); + runner.Run(stream); + } + + Tensor x_square_sum; + { + framework::NPUAttributeMap attr_input = {{"keep_dims", false}, + {"axes", axes}}; + x_square_sum.Resize({C}); + x_square_sum.mutable_data(place); + const auto &runner = + NpuOpRunner("ReduceSumD", {x_square}, {x_square_sum}, attr_input); + runner.Run(stream); + } + + Tensor x_square_sum_mean; + { + framework::NPUAttributeMap attr_input = { + {"value", 1.0f * C / x_numel}}; + x_square_sum_mean.Resize({C}); + x_square_sum_mean.mutable_data(place); + const auto &runner = NpuOpRunner("Muls", {x_square_sum}, + {x_square_sum_mean}, attr_input); + runner.Run(stream); + } + + Tensor mean_square; + { + mean_square.Resize({C}); + mean_square.mutable_data(place); + const auto &runner = + NpuOpRunner("Square", {*saved_mean}, {mean_square}, {}); + runner.Run(stream); + } + + // cacl var_ref + { + const auto &runner = NpuOpRunner( + "Sub", {x_square_sum_mean, mean_square}, {var_ref}, {}); + runner.Run(stream); + } + } + } + + Tensor saved_mean_tile_1; + { + saved_mean_tile_1.Resize({C}); + saved_mean_tile_1.mutable_data(place); + TensorCopySync(*saved_mean, place, &saved_mean_tile_1); + if (layout == framework::DataLayout::kNCHW) + saved_mean_tile_1.Resize({1, C, 1, 1}); + else if (layout == framework::DataLayout::kNHWC) + saved_mean_tile_1.Resize({1, 1, 1, C}); + } + + Tensor saved_mean_tile; + { + framework::NPUAttributeMap attr_input = {{"multiples", multiples}}; + saved_mean_tile.Resize(x->dims()); + saved_mean_tile.mutable_data(place); + const auto &runner = NpuOpRunner("TileD", {saved_mean_tile_1}, + {saved_mean_tile}, attr_input); + runner.Run(stream); + } + + Tensor x_sub_saved_mean; + { + x_sub_saved_mean.Resize(x->dims()); + x_sub_saved_mean.mutable_data(place); + const auto &runner = + NpuOpRunner("Sub", {*x, saved_mean_tile}, {x_sub_saved_mean}, {}); + runner.Run(stream); + } + + Tensor var_ref_tile_1; + { + var_ref_tile_1.Resize({C}); + var_ref_tile_1.mutable_data(place); + TensorCopySync(var_ref, place, &var_ref_tile_1); + if (layout == framework::DataLayout::kNCHW) + var_ref_tile_1.Resize({1, C, 1, 1}); + else if (layout == framework::DataLayout::kNHWC) + var_ref_tile_1.Resize({1, 1, 1, C}); + } + + Tensor var_ref_tile; + { + framework::NPUAttributeMap attr_input = {{"multiples", multiples}}; + var_ref_tile.Resize(x->dims()); + var_ref_tile.mutable_data(place); + const auto &runner = + NpuOpRunner("TileD", {var_ref_tile_1}, {var_ref_tile}, attr_input); + runner.Run(stream); + } + + Tensor var_ref_tile_add_epsilon; + { + framework::NPUAttributeMap attr_input = {{"value", epsilon}}; + var_ref_tile_add_epsilon.Resize(x->dims()); + var_ref_tile_add_epsilon.mutable_data(place); + const auto &runner = NpuOpRunner("Adds", {var_ref_tile}, + {var_ref_tile_add_epsilon}, attr_input); + runner.Run(stream); + } + + Tensor var_ref_tile_add_epsilon_sqrt; + { + var_ref_tile_add_epsilon_sqrt.Resize(x->dims()); + var_ref_tile_add_epsilon_sqrt.mutable_data(place); + const auto &runner = NpuOpRunner("Sqrt", {var_ref_tile_add_epsilon}, + {var_ref_tile_add_epsilon_sqrt}, {}); + runner.Run(stream); + } + + Tensor dy_mul_x_sub_mean_for_scale; + { + if (d_y->type() == framework::proto::VarType::FP16) { + dy_mul_x_sub_mean_for_scale.Resize(x->dims()); + dy_mul_x_sub_mean_for_scale.mutable_data(place); + const auto &runner = NpuOpRunner("Mul", {*d_y, x_sub_saved_mean}, + {dy_mul_x_sub_mean_for_scale}, {}); + runner.Run(stream); + } else { + dy_mul_x_sub_mean_for_scale.Resize(x->dims()); + dy_mul_x_sub_mean_for_scale.mutable_data(place); + const auto &runner = NpuOpRunner("Mul", {*d_y, x_sub_saved_mean}, + {dy_mul_x_sub_mean_for_scale}, {}); + runner.Run(stream); + } + } + + Tensor dy_mul_x_sub_mean; + { + if (d_y->type() == framework::proto::VarType::FP16) { + dy_mul_x_sub_mean.Resize(x->dims()); + dy_mul_x_sub_mean.mutable_data(place); + const auto &runner = NpuOpRunner("Mul", {*d_y, x_sub_saved_mean}, + {dy_mul_x_sub_mean}, {}); + runner.Run(stream); + } else { + dy_mul_x_sub_mean.Resize(x->dims()); + dy_mul_x_sub_mean.mutable_data(place); + const auto &runner = NpuOpRunner("Mul", {*d_y, x_sub_saved_mean}, + {dy_mul_x_sub_mean}, {}); + runner.Run(stream); + } + } + + // HcclAllReduce dy_mul_x_sub_mean + if (comm) { + { + void *sendbuff = reinterpret_cast( + const_cast(dy_mul_x_sub_mean.data())); + void *recvbuff = sendbuff; + PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllReduce( + sendbuff, recvbuff, C, dtype, HCCL_REDUCE_SUM, comm->comm(), + reinterpret_cast(stream))); + } + + { + framework::NPUAttributeMap attr_input = { + {"value", 1.0f / device_counts}}; + const auto &runner = NpuOpRunner("Muls", {dy_mul_x_sub_mean}, + {dy_mul_x_sub_mean}, attr_input); + runner.Run(stream); + } + } + + // cacl d_x + if (d_x) { + Tensor dy_mean; + { + if (d_y->type() == framework::proto::VarType::FP16) { + framework::NPUAttributeMap attr_input = {{"keep_dims", false}, + {"axes", axes}}; + dy_mean.Resize({C}); + dy_mean.mutable_data(place); + const auto &runner = + NpuOpRunner("ReduceMeanD", {*d_y}, {dy_mean}, attr_input); + runner.Run(stream); + } else { + framework::NPUAttributeMap attr_input = {{"keep_dims", false}, + {"axes", axes}}; + dy_mean.Resize({C}); + dy_mean.mutable_data(place); + const auto &runner = + NpuOpRunner("ReduceMeanD", {*d_y}, {dy_mean}, attr_input); + runner.Run(stream); + } + } + + // HcclAllReduce dy_mean + if (comm) { + { + void *sendbuff = reinterpret_cast( + const_cast(dy_mean.data())); + void *recvbuff = sendbuff; + PADDLE_ENFORCE_NPU_SUCCESS(platform::dynload::HcclAllReduce( + sendbuff, recvbuff, C, dtype, HCCL_REDUCE_SUM, comm->comm(), + reinterpret_cast(stream))); + } + + { + framework::NPUAttributeMap attr_input = { + {"value", 1.0f / device_counts}}; + const auto &runner = + NpuOpRunner("Muls", {dy_mean}, {dy_mean}, attr_input); + runner.Run(stream); + } + } + + Tensor dy_mean_tile_1; + { + dy_mean_tile_1.Resize({C}); + dy_mean_tile_1.mutable_data(place); + TensorCopySync(dy_mean, place, &dy_mean_tile_1); + if (layout == framework::DataLayout::kNCHW) + dy_mean_tile_1.Resize({1, C, 1, 1}); + else if (layout == framework::DataLayout::kNHWC) + dy_mean_tile_1.Resize({1, 1, 1, C}); + } + + Tensor dy_mean_tile; + { + framework::NPUAttributeMap attr_input = {{"multiples", multiples}}; + dy_mean_tile.Resize(x->dims()); + dy_mean_tile.mutable_data(place); + const auto &runner = + NpuOpRunner("TileD", {dy_mean_tile_1}, {dy_mean_tile}, attr_input); + runner.Run(stream); + } + + Tensor dy_sub_dy_mean; + { + if (d_y->type() == framework::proto::VarType::FP16) { + dy_sub_dy_mean.Resize(x->dims()); + dy_sub_dy_mean.mutable_data(place); + const auto &runner = + NpuOpRunner("Sub", {*d_y, dy_mean_tile}, {dy_sub_dy_mean}, {}); + runner.Run(stream); + } else { + dy_sub_dy_mean.Resize(x->dims()); + dy_sub_dy_mean.mutable_data(place); + const auto &runner = + NpuOpRunner("Sub", {*d_y, dy_mean_tile}, {dy_sub_dy_mean}, {}); + runner.Run(stream); + } + } + + Tensor dy_mul_x_sub_mean_mean; + { + framework::NPUAttributeMap attr_input = {{"keep_dims", false}, + {"axes", axes}}; + dy_mul_x_sub_mean_mean.Resize({C}); + dy_mul_x_sub_mean_mean.mutable_data(place); + const auto &runner = NpuOpRunner("ReduceMeanD", {dy_mul_x_sub_mean}, + {dy_mul_x_sub_mean_mean}, attr_input); + runner.Run(stream); + } + + Tensor dy_mul_x_sub_mean_mean_tile_1; + { + dy_mul_x_sub_mean_mean_tile_1.Resize({C}); + dy_mul_x_sub_mean_mean_tile_1.mutable_data(place); + TensorCopySync(dy_mul_x_sub_mean_mean, place, + &dy_mul_x_sub_mean_mean_tile_1); + if (layout == framework::DataLayout::kNCHW) + dy_mul_x_sub_mean_mean_tile_1.Resize({1, C, 1, 1}); + else if (layout == framework::DataLayout::kNHWC) + dy_mul_x_sub_mean_mean_tile_1.Resize({1, 1, 1, C}); + } + + Tensor dy_mul_x_sub_mean_mean_tile; + { + framework::NPUAttributeMap attr_input = {{"multiples", multiples}}; + dy_mul_x_sub_mean_mean_tile.Resize(x->dims()); + dy_mul_x_sub_mean_mean_tile.mutable_data(place); + const auto &runner = + NpuOpRunner("TileD", {dy_mul_x_sub_mean_mean_tile_1}, + {dy_mul_x_sub_mean_mean_tile}, attr_input); + runner.Run(stream); + } + + // (x - mean) * np.mean(dy * (x - mean), axis=axis) + // x_sub_saved_mean * dy_mul_x_sub_mean_mean_tile + Tensor tmp1; + { + tmp1.Resize(x->dims()); + tmp1.mutable_data(place); + const auto &runner = NpuOpRunner( + "Mul", {x_sub_saved_mean, dy_mul_x_sub_mean_mean_tile}, {tmp1}, {}); + runner.Run(stream); + } + + // (x - mean) * np.mean(dy * (x - mean), axis=axis) / (var + epsilon) + // tmp1 / (var + epsilon) + // tmp1 / var_ref_tile_add_epsilon + Tensor tmp2; + { + tmp2.Resize(x->dims()); + tmp2.mutable_data(place); + const auto &runner = + NpuOpRunner("Div", {tmp1, var_ref_tile_add_epsilon}, {tmp2}, {}); + runner.Run(stream); + } + + // dy - np.mean(dy, axis) - (x - mean) * np.mean(dy * (x - mean), axis) / + // (var + epsilon) + // dy_sub_dy_mean - tmp2 + Tensor tmp3; + { + tmp3.Resize(x->dims()); + tmp3.mutable_data(place); + const auto &runner = + NpuOpRunner("Sub", {dy_sub_dy_mean, tmp2}, {tmp3}, {}); + runner.Run(stream); + } + + Tensor scale_tile_1; + { + scale_tile_1.Resize({C}); + scale_tile_1.mutable_data(place); + TensorCopySync(*scale, place, &scale_tile_1); + if (layout == framework::DataLayout::kNCHW) + scale_tile_1.Resize({1, C, 1, 1}); + else if (layout == framework::DataLayout::kNHWC) + scale_tile_1.Resize({1, 1, 1, C}); + } + + Tensor scale_tile; + { + framework::NPUAttributeMap attr_input = {{"multiples", multiples}}; + scale_tile.Resize(x->dims()); + scale_tile.mutable_data(place); + const auto &runner = + NpuOpRunner("TileD", {scale_tile_1}, {scale_tile}, attr_input); + runner.Run(stream); + } + + // scale * (dy - np.mean(dy, axis) - (x - mean) * np.mean(dy * (x - mean), + // axis) / (var + epsilon)) + // scale * tmp3 + Tensor dx_1; + { + dx_1.Resize(x->dims()); + dx_1.mutable_data(place); + + const auto &runner = NpuOpRunner("Mul", {scale_tile, tmp3}, {dx_1}, {}); + runner.Run(stream); + } + + // dx_1 / var_ref_tile_add_epsilon_sqrt + { + d_x->Resize(x->dims()); + d_x->mutable_data(place); + const auto &runner = NpuOpRunner( + "Div", {dx_1, var_ref_tile_add_epsilon_sqrt}, {*d_x}, {}); + runner.Run(stream); + } + } + + // cacl d_scale + if (d_scale) { + Tensor d_scale_2; + { + d_scale_2.Resize(x->dims()); + d_scale_2.mutable_data(place); + const auto &runner = NpuOpRunner( + "Div", {dy_mul_x_sub_mean_for_scale, var_ref_tile_add_epsilon_sqrt}, + {d_scale_2}, {}); + runner.Run(stream); + } + + { + framework::NPUAttributeMap attr_input = {{"keep_dims", false}, + {"axes", axes}}; + d_scale->mutable_data(place); + const auto &runner = + NpuOpRunner("ReduceSumD", {d_scale_2}, {*d_scale}, attr_input); + runner.Run(stream); + } + } + + // cacl d_bias + if (d_bias) { + if (d_y->type() == framework::proto::VarType::FP16) { + framework::NPUAttributeMap attr_input = {{"keep_dims", false}, + {"axes", axes}}; + d_bias->mutable_data(place); + const auto &runner = + NpuOpRunner("ReduceSumD", {*d_y}, {*d_bias}, attr_input); + runner.Run(stream); + } else { + framework::NPUAttributeMap attr_input = {{"keep_dims", false}, + {"axes", axes}}; + d_bias->mutable_data(place); + const auto &runner = + NpuOpRunner("ReduceSumD", {*d_y}, {*d_bias}, attr_input); + runner.Run(stream); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_NPU_KERNEL( + sync_batch_norm, + ops::SyncBatchNormNPUKernel); +REGISTER_OP_NPU_KERNEL( + sync_batch_norm_grad, + ops::SyncBatchNormNPUGradKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/sync_batch_norm_op_npu.py b/python/paddle/fluid/tests/unittests/npu/sync_batch_norm_op_npu.py new file mode 100644 index 0000000000..361efebce9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/sync_batch_norm_op_npu.py @@ -0,0 +1,104 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import os +import sys +sys.path.append("..") +import signal +import time +from contextlib import closing +from six import string_types +import math +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +import unittest +from multiprocessing import Process +import paddle.fluid.layers as layers +from functools import reduce +from test_sync_batch_norm_base_npu import TestSyncBatchNormRunnerBase, runtime_main +from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator + +from paddle.fluid.tests.unittests.test_sync_batch_norm_op import create_or_get_tensor + +_set_use_system_allocator(False) +paddle.enable_static() + + +class TestSyncBatchNormOpTraining(TestSyncBatchNormRunnerBase): + def __init__(self): + self.global_ring_id = 0 + + self.dtype = np.float32 + self.N = 8 + self.C = 16 + self.H = 32 + self.W = 32 + self.dshape = [self.N, self.C, self.H, self.W] + self.atol = 1e-3 + + def get_model(self, + main, + startup, + place, + layout, + seed, + sync_bn=False, + only_forward=False): + """Build program.""" + use_cudnn = False + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + data = fluid.layers.data( + name='input', + shape=self.dshape, + dtype=self.dtype, + append_batch_size=False) + conv = fluid.layers.conv2d( + input=data, + num_filters=32, + filter_size=1, + param_attr=fluid.ParamAttr(name='conv2d_weight'), + bias_attr=False, + use_cudnn=use_cudnn) + bn = fluid.layers.batch_norm( + conv, + param_attr=fluid.ParamAttr(name='bn_scale'), + bias_attr=fluid.ParamAttr(name='bn_bias'), + moving_mean_name='bn_moving_mean', + moving_variance_name='bn_moving_variance', + data_layout=layout, + is_test=only_forward) + # if self.dtype == np.float16: + # bn = fluid.layers.cast(bn, 'float32') + sigmoid = fluid.layers.sigmoid(bn) + out = fluid.layers.reduce_sum(sigmoid) + # if not sync_bn: + # out = out / core.get_npu_device_count() + if not only_forward: + sgd_opt = fluid.optimizer.SGD(learning_rate=0.0) + sgd_opt.backward(out) + return [out, conv, bn] + + +if __name__ == "__main__": + # print('sync_batch_norm_op_npu.py __main__') + + runtime_main(TestSyncBatchNormOpTraining, "identity", 0) diff --git a/python/paddle/fluid/tests/unittests/npu/test_sync_batch_norm_base_npu.py b/python/paddle/fluid/tests/unittests/npu/test_sync_batch_norm_base_npu.py new file mode 100644 index 0000000000..9df216d973 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_sync_batch_norm_base_npu.py @@ -0,0 +1,481 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import numpy as np +import unittest +import time +import argparse +import os +import six +import sys +sys.path.append("..") +import subprocess +import traceback +import functools +import pickle +from contextlib import closing +import paddle.fluid as fluid +import paddle.fluid.unique_name as nameGen +from paddle.fluid import core +from six import string_types +import paddle + +from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator + +from paddle.fluid.tests.unittests.test_sync_batch_norm_op import create_or_get_tensor + +_set_use_system_allocator(False) +paddle.enable_static() + +SEED = 10 + + +class TestSyncBatchNormRunnerBase(object): + def get_model(self, + main, + startup, + place, + layout, + seed, + sync_bn=False, + only_forward=False): + raise NotImplementedError( + "get model should be implemented by child class.") + + def wait_server_ready(self, endpoints): + assert not isinstance(endpoints, string_types) + while True: + all_ok = True + not_ready_endpoints = [] + for ep in endpoints: + ip_port = ep.split(":") + with closing( + socket.socket(socket.AF_INET, + socket.SOCK_STREAM)) as sock: + sock.settimeout(2) + result = sock.connect_ex((ip_port[0], int(ip_port[1]))) + if result != 0: + all_ok = False + not_ready_endpoints.append(ep) + if not all_ok: + sys.stderr.write("server not ready, wait 3 sec to retry...\n") + sys.stderr.write("not ready endpoints:" + str( + not_ready_endpoints) + "\n") + sys.stderr.flush() + time.sleep(3) + else: + break + +#endpoints should be ["ip1:port1","ip2:port2"] + + def initCommunicator(self, program, rank, nranks, wait_port, + current_endpoint, endpoints): + other_endpoints = endpoints[:] + other_endpoints.remove(current_endpoint) + if rank == 0 and wait_port: + self.wait_server_ready(other_endpoints) + block = program.global_block() + hccl_id_var = block.create_var( + name=nameGen.generate('hccl_id'), + persistable=True, + type=core.VarDesc.VarType.RAW) + block.append_op( + type='c_gen_hccl_id', + inputs={}, + outputs={'Out': hccl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints + }) + block.append_op( + type='c_comm_init_hccl', + inputs={'X': hccl_id_var}, + outputs={}, + attrs={ + 'rank': rank, + 'ring_id': self.global_ring_id, + 'device_id': int(os.getenv("FLAGS_selected_npus")), + 'rank_ids': nranks + }) + + def run_trainer(self, args): + device_id = int(os.getenv("FLAGS_selected_npus", "0")) + place = fluid.NPUPlace(device_id) + places = [place] + + # Test training + for place in places: + for layout in ["NCHW", "NHWC"]: + self._compare(args, place, layout, False) + + # Test inference + for place in places: + for layout in ["NCHW", "NHWC"]: + self._compare(args, place, layout, True) + + # Test FP16 - @TODO + # self.dtype = np.float16 + # self.atol = 1e-2 + + # Test training + # for place in places: + # for layout in ["NCHW", "NHWC"]: + # self._compare(args, place, layout, False) + + # Test inference + # for place in places: + # for layout in ["NCHW", "NHWC"]: + # self._compare(args, place, layout, True) + + sys.stdout.buffer.write( + pickle.dumps( + 'training, inference, fp32, fp16, NCHW, NHWC all passed')) + + def _compare(self, args, place, layout, only_forward): + scope = core.Scope() + + np.random.seed(SEED) + data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2 + sys.stderr.write("data: " + str(data) + "\n") + data = create_or_get_tensor(scope, "input", + OpTest.np_dtype_to_fluid_dtype(data), place) + + bn_fetches = self._cal_single_card(args, data, place, layout, + only_forward) + fetch_names, sync_bn_fetches = self._cal_multiple_cards( + args, data, place, layout, only_forward) + + sys.stderr.write("len(sync_bn_fetches): " + str(len(sync_bn_fetches)) + + "\n") + for i in six.moves.xrange(0, len(sync_bn_fetches)): + sys.stderr.write("i: " + str(i) + "\n") + sys.stderr.write("fetch_names[i]): " + fetch_names[i] + "\n") + + bn_val = bn_fetches[i] + sync_bn_val = sync_bn_fetches[i] + if sync_bn_val.shape != bn_val.shape: + sync_bn_val = sync_bn_val[:bn_val.shape[0]] + + # i = 0 + if fetch_names[i] == 'reduce_sum_0.tmp_0': + # sys.stderr.write("skip reduce_sum_0.tmp_0 (Out of reduce_sum op)" + "\n") + sys.stderr.write("reduce_sum_0.tmp_0 (Out of reduce_sum op)" + + "\n") + sys.stderr.write("bn_val: " + str(bn_val) + "\n") + sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") + + # continue + + # i = 1 + if fetch_names[i] == 'conv2d_0.tmp_0': + # sys.stderr.write("skip conv2d_0.tmp_0 (X)" + "\n") + sys.stderr.write("conv2d_0.tmp_0 (X)" + "\n") + sys.stderr.write("bn_val: " + str(bn_val) + "\n") + sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") + + # continue + + # i = 2 + if fetch_names[i] == 'batch_norm_0.tmp_3': + # sys.stderr.write("skip batch_norm_0.tmp_3 (Y)" + "\n") + sys.stderr.write("batch_norm_0.tmp_3 (Y)" + "\n") + sys.stderr.write("bn_val: " + str(bn_val) + "\n") + sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") + + # continue + + # i = 2 + if fetch_names[i] == 'batch_norm_0.tmp_2': + # sys.stderr.write("skip batch_norm_0.tmp_2 (ReserveSpace of batch_norm)" + "\n") + sys.stderr.write( + "batch_norm_0.tmp_2 (ReserveSpace of batch_norm)" + "\n") + sys.stderr.write("bn_val: " + str(bn_val) + "\n") + sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") + + # continue + + # i = 3 + if fetch_names[i] == 'bn_moving_mean': + sys.stderr.write("skip bn_moving_mean (MeanOut)" + "\n") + sys.stderr.write("bn_val: " + str(bn_val) + "\n") + sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") + + continue + + # i = 4 + if fetch_names[i] == 'bn_moving_variance': + sys.stderr.write("skip bn_moving_variance (VarianceOut)" + "\n") + sys.stderr.write("bn_val: " + str(bn_val) + "\n") + sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") + + continue + + # i = 7 + if fetch_names[i] == 'batch_norm_0.tmp_0': + # sys.stderr.write("skip batch_norm_0.tmp_0 (SavedMean)" + "\n") + sys.stderr.write("batch_norm_0.tmp_0 (SavedMean)" + "\n") + sys.stderr.write("bn_val: " + str(bn_val) + "\n") + sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") + + # continue + + # i = 8 + if fetch_names[i] == 'batch_norm_0.tmp_1': + sys.stderr.write("skip batch_norm_0.tmp_1 (SavedVariance)" + + "\n") + sys.stderr.write("bn_val: " + str(bn_val) + "\n") + sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") + + continue + + # i = 9 + if fetch_names[i] == 'bn_scale@GRAD': + # sys.stderr.write("skip bn_scale@GRAD (Scale@GRAD)" + "\n") + sys.stderr.write("bn_scale@GRAD (Scale@GRAD)" + "\n") + sys.stderr.write("bn_val: " + str(bn_val) + "\n") + sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") + + # continue + + # i = 10 + if fetch_names[i] == 'bn_bias@GRAD': + # sys.stderr.write("skip bn_bias@GRAD (Bias@GRAD)" + "\n") + sys.stderr.write("bn_bias@GRAD (Bias@GRAD)" + "\n") + sys.stderr.write("bn_val: " + str(bn_val) + "\n") + sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") + + # continue + + # i = 11 + if fetch_names[i] == 'batch_norm_0.tmp_3@GRAD': + # sys.stderr.write("skip batch_norm_0.tmp_3@GRAD (Y@GRAD)" + "\n") + sys.stderr.write("batch_norm_0.tmp_3@GRAD (Y@GRAD)" + "\n") + sys.stderr.write("bn_val: " + str(bn_val) + "\n") + sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") + + # continue + + # i = 12 + if fetch_names[i] == 'conv2d_0.tmp_0@GRAD': + # sys.stderr.write("skip conv2d_0.tmp_0@GRAD (X@GRAD)" + "\n") + sys.stderr.write("conv2d_0.tmp_0@GRAD (X@GRAD)" + "\n") + sys.stderr.write("bn_val: " + str(bn_val) + "\n") + sys.stderr.write("sync_bn_val: " + str(sync_bn_val) + "\n") + + # continue + + atol = self.atol + if fetch_names[i] == 'conv2d_0.tmp_0@GRAD': + atol = 1e-2 + + assert np.allclose( + bn_val, sync_bn_val, atol=atol), "Output (" + fetch_names[ + i] + ") has diff. \n" + "\nBN " + str( + bn_val) + "\n" + "Sync BN " + str(sync_bn_val) + + def _cal_single_card(self, args, data, place, layout, only_forward): + # Single-NPU, N = 32 per NPU + train_prog = fluid.Program() + startup_prog = fluid.Program() + train_prog.global_seed(SEED) + startup_prog.global_seed(SEED) + paddle.seed(SEED) + + outs = self.get_model(train_prog, startup_prog, place, layout, SEED, + False, only_forward) + + exe = fluid.Executor(place) + exe.run(startup_prog) + fetch_names = [v.name for v in outs] + [ + 'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias' + ] + if not only_forward: + others = [ + 'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD', + 'bn_bias@GRAD', 'batch_norm_0.tmp_3@GRAD', 'conv2d_0.tmp_0@GRAD' + ] + fetch_names += others + bn_fetches = exe.run(program=train_prog, + feed={'input': data}, + fetch_list=fetch_names) + + return bn_fetches + + def _cal_multiple_cards(self, args, data, place, layout, only_forward): + # Multi-NPUs, self.N per NPU + # return + assert core.get_npu_device_count() > 1 + + train_prog = fluid.Program() + startup_prog = fluid.Program() + train_prog.global_seed(SEED) + startup_prog.global_seed(SEED) + paddle.seed(SEED) + sys.stderr.write("train_prog: " + train_prog.to_string(True) + "\n") + sys.stderr.write("startup_prog: " + startup_prog.to_string(True) + "\n") + + endpoints = args["endpoints"].split(",") + rank = args["trainerid"] + current_endpoint = args["currentendpoint"] + nranks = 2 + + self.initCommunicator(startup_prog, rank, nranks, True, + current_endpoint, endpoints) + sys.stderr.write("after init, startup_prog: " + startup_prog.to_string( + True) + "\n") + train_prog.global_seed(SEED) + train_prog._sync_with_cpp() + startup_prog.global_seed(SEED) + startup_prog._sync_with_cpp() + paddle.seed(SEED) + + self.rank = rank + outs = self.get_model(train_prog, startup_prog, place, layout, SEED, + True, only_forward) + sys.stderr.write("after get_model, train_prog: " + train_prog.to_string( + True) + "\n") + sys.stderr.write("after get_model, startup_prog: " + + startup_prog.to_string(True) + "\n") + + ops = train_prog.blocks[0].ops + for i, op in enumerate(ops): + if op.type == 'batch_norm': + sys.stderr.write("i: " + str(i) + "\n") + sys.stderr.write("op type: " + op.type + "\n") + op.desc.set_type('sync_batch_norm') + if op.type == 'batch_norm_grad': + sys.stderr.write("i: " + str(i) + "\n") + sys.stderr.write("op type: " + op.type + "\n") + op.desc.set_type('sync_batch_norm_grad') + + sys.stderr.write("after update sync_batch_norm, train_prog: " + + train_prog.to_string(True) + "\n") + + exe = fluid.Executor(place) + exe.run(startup_prog) + fetch_names = [v.name for v in outs] + [ + 'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias' + ] + if not only_forward: + others = [ + 'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD', + 'bn_bias@GRAD', 'batch_norm_0.tmp_3@GRAD', 'conv2d_0.tmp_0@GRAD' + ] + fetch_names += others + sync_bn_fetches = exe.run(program=train_prog, + feed={'input': data}, + fetch_list=fetch_names) + + return fetch_names, sync_bn_fetches + + +def runtime_main(test_class, col_type, sub_type): + args = {} + model = test_class() + args["deviceid"] = os.getenv("FLAGS_selected_npus") + args["trainerid"] = int(os.getenv("PADDLE_TRAINER_ID")) + args["trainernum"] = int(os.getenv("PADDLE_TRAINERS_NUM")) + args["endpoints"] = os.getenv('PADDLE_TRAINER_ENDPOINTS') + args["currentendpoint"] = os.getenv("PADDLE_CURRENT_ENDPOINT") + args["col_type"] = col_type + model.run_trainer(args) + + +import paddle.compat as cpt +import socket +from contextlib import closing + + +class TestDistBase(unittest.TestCase): + def setUp(self): + self._port_set = set() + self._trainers = 2 + self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( + self._find_free_port(), self._find_free_port()) + self._python_interp = sys.executable + + def _find_free_port(self): + def __free_port(): + with closing(socket.socket(socket.AF_INET, + socket.SOCK_STREAM)) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + while True: + port = __free_port() + if port not in self._port_set: + self._port_set.add(port) + return port + + def _run_cluster(self, model_file, envs): + worker_endpoints = self._ps_endpoints.split(",") + w0_ep, w1_ep = worker_endpoints + # print("w0_ep:", w0_ep, " w1_ep:", w1_ep) + env0 = { + "FLAGS_selected_npus": "0", + "PADDLE_TRAINER_ID": "0", + "PADDLE_TRAINERS_NUM": "2", + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": w0_ep, + } + + env1 = { + "FLAGS_selected_npus": "1", + "PADDLE_TRAINER_ID": "1", + "PADDLE_TRAINERS_NUM": "2", + "PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints, + "PADDLE_CURRENT_ENDPOINT": w1_ep, + } + #update environment + env0.update(envs) + env1.update(envs) + + tr_cmd = "%s %s" + tr0_cmd = tr_cmd % (self._python_interp, model_file) + tr1_cmd = tr_cmd % (self._python_interp, model_file) + tr0_pipe = open("/tmp/tr0_err.log", "wb") + tr1_pipe = open("/tmp/tr1_err.log", "wb") + # print(tr0_cmd) + # print(tr1_cmd) + tr0_proc = subprocess.Popen( + tr0_cmd.strip().split(), + stdout=subprocess.PIPE, + stderr=tr0_pipe, + env=env0) + + tr1_proc = subprocess.Popen( + tr0_cmd.strip().split(), + stdout=subprocess.PIPE, + stderr=tr1_pipe, + env=env1) + + tr0_out, tr0_err = tr0_proc.communicate() + tr1_out, tr1_err = tr1_proc.communicate() + + sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err) + sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err) + # close trainer file + tr0_pipe.close() + tr1_pipe.close() + return pickle.loads(tr0_out), pickle.loads( + tr1_out), tr0_proc.pid, tr1_proc.pid + + def check_with_place(self, model_file, col_type, need_envs={}): + tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file, need_envs) + self.assertEqual( + tr0_out, 'training, inference, fp32, fp16, NCHW, NHWC all passed') + self.assertEqual( + tr1_out, 'training, inference, fp32, fp16, NCHW, NHWC all passed') diff --git a/python/paddle/fluid/tests/unittests/npu/test_sync_batch_norm_op_npu_baseline.py b/python/paddle/fluid/tests/unittests/npu/test_sync_batch_norm_op_npu_baseline.py new file mode 100644 index 0000000000..54a78ea2d5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_sync_batch_norm_op_npu_baseline.py @@ -0,0 +1,42 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle +import os +import sys +sys.path.append("..") + +from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator + +from test_sync_batch_norm_base_npu import TestDistBase + +_set_use_system_allocator(False) +paddle.enable_static() + + +class TestSyncBatchNormOp(TestDistBase): + def _setup_config(self): + pass + + def test_identity(self, col_type="identity"): + dist_env = os.environ + self.check_with_place( + "sync_batch_norm_op_npu.py", col_type, need_envs=dist_env) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_sync_batch_norm_op_npu_extra.py b/python/paddle/fluid/tests/unittests/npu/test_sync_batch_norm_op_npu_extra.py new file mode 100644 index 0000000000..bafe45b77d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_sync_batch_norm_op_npu_extra.py @@ -0,0 +1,105 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle +import os +import sys +sys.path.append("..") + +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +import paddle.nn as nn +from paddle.fluid import Program, program_guard + +from paddle.fluid.tests.unittests.op_test import OpTest, _set_use_system_allocator + +# _set_use_system_allocator(False) +paddle.enable_static() + + +class TestDygraphSyncBatchNormAPIError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + my_sync_batch_norm = paddle.nn.SyncBatchNorm(10) + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.NPUPlace(0)) + self.assertRaises(TypeError, my_sync_batch_norm, x1) + + # the input dtype of SyncBatchNorm must be float16 or float32 + # float16 only can be set on GPU place and NPU place + x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="int32") + self.assertRaises(TypeError, my_sync_batch_norm, x2) + + +class TestConvertSyncBatchNorm(unittest.TestCase): + def test_convert(self): + with program_guard(Program(), Program()): + compare_model = paddle.nn.Sequential( + paddle.nn.Conv2D(3, 5, 3), + paddle.nn.BatchNorm2D(5), paddle.nn.BatchNorm2D(5)) + model = paddle.nn.Sequential( + paddle.nn.Conv2D(3, 5, 3), + paddle.nn.BatchNorm2D(5), + paddle.nn.BatchNorm2D( + 5, + weight_attr=fluid.ParamAttr(name='bn.scale'), + bias_attr=fluid.ParamAttr(name='bn.bias'))) + model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) + for idx, sublayer in enumerate(compare_model.sublayers()): + if isinstance(sublayer, paddle.nn.BatchNorm2D): + self.assertEqual( + isinstance(model[idx], paddle.nn.SyncBatchNorm), True) + + +class TestConvertSyncBatchNormCast1(unittest.TestCase): + def test_convert(self): + class Net(nn.Layer): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2D(3, 5, 3) + self.bn = [] + bn = self.add_sublayer('bn', nn.BatchNorm2D(5)) + self.bn.append(bn) + + def forward(self, x): + x = self.conv1(x) + for bn in self.bn: + x = bn(x) + return x + + model = nn.Sequential() + model.add_sublayer('net1', Net()) + model.add_sublayer('net2', Net()) + compare_model = nn.Sequential() + compare_model.add_sublayer('net1', Net()) + compare_model.add_sublayer('net2', Net()) + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + self.assertEqual(len(compare_model.sublayers()), len(model.sublayers())) + + +class TestDygraphSyncBatchNormDataFormatError(unittest.TestCase): + def test_errors(self): + with fluid.dygraph.guard(fluid.NPUPlace(0)): + my_sync_batch_norm = paddle.nn.SyncBatchNorm(10, data_format='CN') + data = np.random.random([3, 3, 3]).astype('float32') + x = paddle.to_tensor(data) + self.assertRaises(ValueError, my_sync_batch_norm, x) + + +if __name__ == '__main__': + unittest.main() -- GitLab