From 8ad672a287eef9df865d4499e86762d3069e7d30 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Fri, 15 Mar 2019 11:22:48 +0800 Subject: [PATCH] Support sync batch norm. (#16121) * Support Sync Batch Norm. * Note, do not enable it in one device. Usage: build_strategy = fluid.BuildStrategy() build_strategy.sync_batch_norm = True binary = fluid.compiler.CompiledProgram(tp).with_data_parallel( loss_name=loss_mean.name, build_strategy=build_strategy) --- cmake/operators.cmake | 2 +- paddle/fluid/API.spec | 2 +- .../fluid/framework/details/build_strategy.cc | 7 + .../fluid/framework/details/build_strategy.h | 2 + paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../framework/ir/sync_batch_norm_pass.cc | 45 ++ .../fluid/framework/ir/sync_batch_norm_pass.h | 32 ++ .../ir/sync_batch_norm_pass_tester.cc | 80 +++ paddle/fluid/framework/parallel_executor.cc | 16 + paddle/fluid/operators/CMakeLists.txt | 6 +- paddle/fluid/operators/batch_norm_op.cc | 474 +++++++++--------- paddle/fluid/operators/batch_norm_op.cu | 58 +-- paddle/fluid/operators/batch_norm_op.h | 74 ++- paddle/fluid/operators/sync_batch_norm_op.cc | 20 + paddle/fluid/operators/sync_batch_norm_op.cu | 452 +++++++++++++++++ paddle/fluid/platform/device_context.cc | 2 +- paddle/fluid/platform/device_context.h | 13 + paddle/fluid/platform/init.cc | 3 + paddle/fluid/platform/nccl_helper.h | 4 + paddle/fluid/pybind/pybind.cc | 15 + python/paddle/fluid/compiler.py | 3 + python/paddle/fluid/layers/nn.py | 21 +- .../unittests/test_sync_batch_norm_op.py | 159 ++++++ 23 files changed, 1189 insertions(+), 303 deletions(-) create mode 100644 paddle/fluid/framework/ir/sync_batch_norm_pass.cc create mode 100644 paddle/fluid/framework/ir/sync_batch_norm_pass.h create mode 100644 paddle/fluid/framework/ir/sync_batch_norm_pass_tester.cc create mode 100644 paddle/fluid/operators/sync_batch_norm_op.cc create mode 100644 paddle/fluid/operators/sync_batch_norm_op.cu create mode 100644 python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 11a5b1b455..34c6cbd73d 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -110,7 +110,7 @@ function(op_library TARGET) # Define operators that don't need pybind here. foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" -"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op") +"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 68c6c8fd67..057ce9341b 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -91,7 +91,7 @@ paddle.fluid.layers.pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'po paddle.fluid.layers.pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)), ('document', '043de7333b79ee0ac55053c14ed81625')) paddle.fluid.layers.adaptive_pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '859b887174d06f361658f69cb7c06d95')) paddle.fluid.layers.adaptive_pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '120f4323a3d7ed9c0916f15a59f0e497')) -paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)), ('document', 'c527b71b8a4c60dca8df8a745c2b598d')) +paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)), ('document', '320c6973b02ea179fa89fecc80796464')) paddle.fluid.layers.data_norm (ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, None, None, None, False)), ('document', 'e45e09e65a2658e07cad987222f0d9ab')) paddle.fluid.layers.beam_search_decode (ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b0b8d53821716cd50c42e09b593f3feb')) paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', '03993955ab1e6d3044c44e6f17fc85e9')) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 2cfc76e47f..932d0b4538 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h" @@ -49,6 +50,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPass("sequential_execution_pass"); } + // Add op fusion. + if (strategy.sync_batch_norm_) { + AppendPass("sync_batch_norm_pass"); + } + // Add op fusion. if (strategy.fuse_relu_depthwise_conv_) { AppendPass("fuse_relu_depthwise_conv_pass"); @@ -227,6 +233,7 @@ std::unique_ptr BuildStrategy::Apply( } // namespace framework } // namespace paddle +USE_PASS(sync_batch_norm_pass); USE_PASS(fuse_relu_depthwise_conv_pass); USE_PASS(fuse_elewise_add_act_pass); USE_PASS(graph_viz_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index d755a2505a..122411641d 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -77,6 +77,8 @@ struct BuildStrategy { bool fuse_relu_depthwise_conv_{false}; + bool sync_batch_norm_{false}; + bool memory_optimize_{true}; // TODO(dzhwinter): // make enable_inplace, memory_optimize_ diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 57cb3911c2..faf7768a7b 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -67,6 +67,7 @@ pass_library(conv_elementwise_add_fuse_pass inference) pass_library(conv_affine_channel_fuse_pass inference) pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(identity_scale_op_clean_pass base) +pass_library(sync_batch_norm_pass base) # There may be many transpose-flatten structures in a model, and the output of # these structures will be used as inputs to the concat Op. This pattern will @@ -101,6 +102,7 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) +cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass) cc_test(test_cpu_quantize_squash_pass SRCS cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) if (WITH_MKLDNN) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) diff --git a/paddle/fluid/framework/ir/sync_batch_norm_pass.cc b/paddle/fluid/framework/ir/sync_batch_norm_pass.cc new file mode 100644 index 0000000000..b370039915 --- /dev/null +++ b/paddle/fluid/framework/ir/sync_batch_norm_pass.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h" +#include +#include +#include + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr SyncBatchNormPass::ApplyImpl( + std::unique_ptr graph) const { + VLOG(3) << "Use synchronous batch norm"; + for (const Node* n : graph->Nodes()) { + if (n->IsOp()) { + auto* op = n->Op(); + if (op->Type() == "batch_norm") { + op->SetType("sync_batch_norm"); + } + if (op->Type() == "batch_norm_grad") { + op->SetType("sync_batch_norm_grad"); + } + } + } + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(sync_batch_norm_pass, paddle::framework::ir::SyncBatchNormPass); diff --git a/paddle/fluid/framework/ir/sync_batch_norm_pass.h b/paddle/fluid/framework/ir/sync_batch_norm_pass.h new file mode 100644 index 0000000000..51cce3dca6 --- /dev/null +++ b/paddle/fluid/framework/ir/sync_batch_norm_pass.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class SyncBatchNormPass : public Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/sync_batch_norm_pass_tester.cc b/paddle/fluid/framework/ir/sync_batch_norm_pass_tester.cc new file mode 100644 index 0000000000..9c94c1746a --- /dev/null +++ b/paddle/fluid/framework/ir/sync_batch_norm_pass_tester.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h" +#include + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, + const std::vector& inputs, + const std::vector& outputs) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + op->SetAttr("name", name); + op->SetInput("X", inputs); + op->SetOutput("Out", outputs); +} + +// (a, conv_w)->conv2d->b +// (b, bn_scale, bn_bias, mean, var)->batch_norm +// ->(c, mean, var, save_mean, save_inv_var) +ProgramDesc BuildProgramDesc() { + ProgramDesc prog; + for (auto& v : std::vector({"a", "conv_w", "b", "bn_scale", + "bn_bias", "mean", "var", "c", + "save_mean", "save_inv_var"})) { + auto* var = prog.MutableBlock(0)->Var(v); + if (v == "conv_w" || v == "bn_scale" || v == "bn_bias" || v == "mean" || + v == "var") { + var->SetPersistable(true); + } + } + + SetOp(&prog, "conv2d", "conv", std::vector({"a", "conv_w"}), + std::vector({"b"})); + SetOp(&prog, "batch_norm", "bn", + std::vector({"b", "bn_scale", "bn_bias", "mean", "var"}), + std::vector( + {"c", "mean", "var", "save_mean", "save_inv_var"})); + return prog; +} + +TEST(IsTestPass, basic) { + auto prog = BuildProgramDesc(); + + std::unique_ptr graph(new ir::Graph(prog)); + + auto pass = PassRegistry::Instance().Get("sync_batch_norm_pass"); + + graph = pass->Apply(std::move(graph)); + + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + auto op_name = boost::get(op->GetAttr("name")); + if (op_name == "bn") { + ASSERT_EQ(op->Type(), "sync_batch_norm"); + } + } + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(sync_batch_norm_pass); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index acc71396b4..56f108cea2 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -14,8 +14,10 @@ limitations under the License. */ #include "paddle/fluid/framework/parallel_executor.h" #include +#include #include #include +#include #include #include "paddle/fluid/framework/ir/graph_helper.h" @@ -251,6 +253,20 @@ ParallelExecutor::ParallelExecutor(const std::vector &places, member_->nccl_ctxs_.reset(new platform::NCCLContextMap( member_->places_, nccl_id, build_strategy.num_trainers_, build_strategy.trainer_id_)); + + std::unique_ptr dev_nccl_ctxs; + dev_nccl_ctxs.reset(new platform::NCCLContextMap(member_->places_)); + // Initialize device context's nccl comm + // Note, more than one ParallelExecutor with same place, the nccl comm will + // be rewrite and there will be some problem. + for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) { + auto &nccl_ctx = dev_nccl_ctxs->at(dev_id); + platform::DeviceContextPool &pool = + platform::DeviceContextPool::Instance(); + auto *dev_ctx = static_cast( + pool.Get(member_->places_[dev_id])); + dev_ctx->set_nccl_comm(nccl_ctx.comm()); + } #else PADDLE_THROW("Not compiled with CUDA"); #endif diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index a3f2a69aef..2f8e0b3a30 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -44,10 +44,10 @@ if (WITH_DISTRIBUTE) SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch) endif() -register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) +register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op sync_batch_norm_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) -# warpctc_op needs cudnn 7 above if (WITH_GPU) + # warpctc_op needs cudnn 7 above if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc) else() @@ -58,6 +58,8 @@ if (WITH_GPU) op_library(conv_fusion_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n") endif() + op_library(sync_batch_norm_op) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n") else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index feac412538..c0ad959309 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/batch_norm_op.h" +#include #include +#include #include "paddle/fluid/framework/data_layout.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" @@ -22,147 +24,150 @@ limitations under the License. */ namespace paddle { namespace operators { -class BatchNormOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), ""); - PADDLE_ENFORCE(ctx->HasInput("Scale"), ""); - PADDLE_ENFORCE(ctx->HasInput("Bias"), ""); - PADDLE_ENFORCE(ctx->HasInput("Mean"), ""); - PADDLE_ENFORCE(ctx->HasInput("Variance"), ""); - PADDLE_ENFORCE(ctx->HasOutput("Y"), ""); - PADDLE_ENFORCE(ctx->HasOutput("MeanOut"), ""); - PADDLE_ENFORCE(ctx->HasOutput("VarianceOut"), ""); - PADDLE_ENFORCE(ctx->HasOutput("SavedMean"), ""); - PADDLE_ENFORCE(ctx->HasOutput("SavedVariance"), ""); - - // make sure Mean/MeanOut and Variance/VarianceOut share memory in Python - PADDLE_ENFORCE_EQ(ctx->Inputs("Mean")[0], ctx->Outputs("MeanOut")[0], - "Mean and MeanOut should share the same memory"); - PADDLE_ENFORCE_EQ(ctx->Inputs("Variance")[0], - ctx->Outputs("VarianceOut")[0], - "Variance and VarianceOut should share the same memory"); - - const auto x_dims = ctx->GetInputDim("X"); - const DataLayout data_layout = framework::StringToDataLayout( - ctx->Attrs().Get("data_layout")); - - PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, - "Input X must have 2 to 5 dimensions."); - - const int64_t C = - (data_layout == DataLayout::kNCHW ? x_dims[1] - : x_dims[x_dims.size() - 1]); - - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], C); - - ctx->SetOutputDim("Y", x_dims); - ctx->SetOutputDim("MeanOut", {C}); - ctx->SetOutputDim("VarianceOut", {C}); - ctx->SetOutputDim("SavedMean", {C}); - ctx->SetOutputDim("SavedVariance", {C}); - ctx->ShareLoD("X", "Y"); +void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ConvOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Scale"), + "Input(Scale) of ConvOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Bias"), + "Input(Bias) of ConvOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Mean"), + "Input(Mean) of ConvOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Variance"), + "Input(Variance) of ConvOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), + "Output(Y) of ConvOp should not be null."); + bool is_test = ctx->Attrs().Get("is_test"); + if (!is_test) { + PADDLE_ENFORCE(ctx->HasOutput("MeanOut"), + "Output(MeanOut) of ConvOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("VarianceOut"), + "Output(VarianceOut) of ConvOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("SavedMean"), + "Output(SavedMean) of ConvOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("SavedVariance"), + "Output(SavedVariance) of ConvOp should not be null."); } - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - auto input_data_type = ctx.Input("X")->type(); - // By default, the type of the scale, bias, mean, - // and var tensors should both be float. (For float or float16 input tensor) - // or double (For double input tensor). - auto bn_param_type = framework::proto::VarType::FP32; - if (input_data_type == framework::proto::VarType::FP64) { - bn_param_type = framework::proto::VarType::FP64; - } - PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Scale")->type(), - "Scale input should be of float type"); - PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Bias")->type(), - "Bias input should be of float type"); - PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Mean")->type(), - "Mean input should be of float type"); - PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Variance")->type(), - "Variance input should be of float type"); - - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; + // make sure Mean/MeanOut and Variance/VarianceOut share memory in Python + PADDLE_ENFORCE_EQ(ctx->Inputs("Mean")[0], ctx->Outputs("MeanOut")[0], + "Mean and MeanOut should share the same memory"); + PADDLE_ENFORCE_EQ(ctx->Inputs("Variance")[0], ctx->Outputs("VarianceOut")[0], + "Variance and VarianceOut should share the same memory"); + + const auto x_dims = ctx->GetInputDim("X"); + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + + PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, + "Input X must have 2 to 5 dimensions."); + + const int64_t C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], C); + + ctx->SetOutputDim("Y", x_dims); + ctx->SetOutputDim("MeanOut", {C}); + ctx->SetOutputDim("VarianceOut", {C}); + ctx->SetOutputDim("SavedMean", {C}); + ctx->SetOutputDim("SavedVariance", {C}); + ctx->ShareLoD("X", "Y"); +} + +framework::OpKernelType BatchNormOp::GetExpectedKernelType( + const framework::ExecutionContext &ctx) const { + auto input_data_type = ctx.Input("X")->type(); + // By default, the type of the scale, bias, mean, + // and var tensors should both be float. (For float or float16 input tensor) + // or double (For double input tensor). + auto bn_param_type = framework::proto::VarType::FP32; + if (input_data_type == framework::proto::VarType::FP64) { + bn_param_type = framework::proto::VarType::FP64; + } + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Scale")->type(), + "Scale input should be of float type"); + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Bias")->type(), + "Bias input should be of float type"); + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Mean")->type(), + "Mean input should be of float type"); + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Variance")->type(), + "Variance input should be of float type"); + + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; - } -#endif - - return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, - library); + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; } -}; +#endif -class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddAttr("is_test", - "(bool, default false) Set to true for inference only, false " - "for training. Some layers may run faster when this is true.") - .SetDefault(false); - AddAttr("momentum", "").SetDefault(0.9); - AddAttr("epsilon", "") - .SetDefault(1e-5) - .AddCustomChecker([](const float &epsilon) { - PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, - "'epsilon' should be between 0.0 and 0.001."); - }); - AddAttr("data_layout", "").SetDefault("NCHW"); - AddInput("X", "The input tensor"); - AddInput("Scale", - "Scale is a 1-dimensional tensor of size C " - "that is applied to the output"); - AddInput("Bias", - "Bias is a 1-dimensional tensor of size C " - "that is applied to the output"); - AddInput("Mean", - "The global mean (for training) or " - "estimated mean (for testing)"); - AddInput("Variance", - "The global variance (for training) " - "or estimated Variance (for testing)"); - AddOutput("Y", "result after normalization"); - AddOutput("MeanOut", - "Share memory with Mean. " - "Store the global mean when training"); - AddOutput("VarianceOut", - "Share memory with Variance. " - "Store the global Variance when training"); - AddOutput("SavedMean", - "Mean of the current mini batch, " - "will apply to output when training") - .AsIntermediate(); - AddOutput("SavedVariance", - "Variance of the current mini batch, " - "will apply to output when training") - .AsIntermediate(); - AddAttr("use_mkldnn", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); - AddAttr("fuse_with_relu", - "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); - AddAttr("use_global_stats", - "(bool, default false) Whether to use global mean and " - "variance. In inference or test mode, set use_global_stats " - "to true or is_test true. the behavior is equivalent. " - "In train mode, when setting use_global_stats True, the " - "global mean and variance are also used during train time, " - "the BN acts as scaling and shiffting.") - .SetDefault(false); - AddComment(R"DOC( + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library); +} + +void BatchNormOpMaker::Make() { + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); + AddAttr("momentum", "").SetDefault(0.9); + AddAttr("epsilon", "") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, + "'epsilon' should be between 0.0 and 0.001."); + }); + AddAttr("data_layout", "").SetDefault("NCHW"); + AddInput("X", "The input tensor"); + AddInput("Scale", + "Scale is a 1-dimensional tensor of size C " + "that is applied to the output"); + AddInput("Bias", + "Bias is a 1-dimensional tensor of size C " + "that is applied to the output"); + AddInput("Mean", + "The global mean (for training) or " + "estimated mean (for testing)"); + AddInput("Variance", + "The global variance (for training) " + "or estimated Variance (for testing)"); + AddOutput("Y", "result after normalization"); + AddOutput("MeanOut", + "Share memory with Mean. " + "Store the global mean when training"); + AddOutput("VarianceOut", + "Share memory with Variance. " + "Store the global Variance when training"); + AddOutput("SavedMean", + "Mean of the current mini batch, " + "will apply to output when training") + .AsIntermediate(); + AddOutput("SavedVariance", + "Variance of the current mini batch, " + "will apply to output when training") + .AsIntermediate(); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("fuse_with_relu", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("use_global_stats", + "(bool, default false) Whether to use global mean and " + "variance. In inference or test mode, set use_global_stats " + "to true or is_test true. the behavior is equivalent. " + "In train mode, when setting use_global_stats True, the " + "global mean and variance are also used during train time, " + "the BN acts as scaling and shiffting.") + .SetDefault(false); + AddComment(R"DOC( Batch Normalization. Batch Norm has been implemented as discussed in the paper: @@ -173,17 +178,7 @@ The required data format for this layer is one of the following: 2. NCHW `[batch, in_channels, in_height, in_width]` )DOC"); - } -}; - -class BatchNormOpInferVarType - : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map GetInputOutputWithSameType() - const override { - return std::unordered_map{{"X", /*->*/ "Y"}}; - } -}; +} template class BatchNormKernel @@ -336,82 +331,75 @@ class BatchNormKernel } }; -class BatchNormGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - // check input - PADDLE_ENFORCE(ctx->HasInput("X")); - PADDLE_ENFORCE(ctx->HasInput("Scale"), "Input(scale) should not be null."); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), - "Input(Y@GRAD) should not be null."); - PADDLE_ENFORCE(ctx->HasInput("SavedMean"), - "Input(SavedMean) should not be null."); - PADDLE_ENFORCE(ctx->HasInput("SavedVariance"), - "Input(SavedVariance) should not be null"); - - // check output - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), ""); - if (ctx->HasOutput(framework::GradVarName("Scale"))) { - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), - "Output(Scale@GRAD) and Output(Bias@GRAD) should not be " - "null at same time"); - } - const bool use_global_stats = ctx->Attrs().Get("use_global_stats"); - if (use_global_stats) { - PADDLE_ENFORCE(!ctx->Attrs().Get("use_mkldnn"), - "Using global stats during training is not supported " - "in gradient op kernel of batch_norm_mkldnn_op now."); - } +void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { + // check input + PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasInput("Scale"), "Input(scale) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("SavedMean"), + "Input(SavedMean) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("SavedVariance"), + "Input(SavedVariance) should not be null"); + + // check output + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), ""); + if (ctx->HasOutput(framework::GradVarName("Scale"))) { + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), + "Output(Scale@GRAD) and Output(Bias@GRAD) should not be " + "null at same time"); + } + const bool use_global_stats = ctx->Attrs().Get("use_global_stats"); + if (use_global_stats) { + PADDLE_ENFORCE(!ctx->Attrs().Get("use_mkldnn"), + "Using global stats during training is not supported " + "in gradient op kernel of batch_norm_mkldnn_op now."); + } - const auto x_dims = ctx->GetInputDim("X"); - const DataLayout data_layout = framework::StringToDataLayout( - ctx->Attrs().Get("data_layout")); - const int C = - (data_layout == DataLayout::kNCHW ? x_dims[1] - : x_dims[x_dims.size() - 1]); + const auto x_dims = ctx->GetInputDim("X"); + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - if (ctx->HasOutput(framework::GradVarName("Scale"))) { - ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); - ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); - } + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + if (ctx->HasOutput(framework::GradVarName("Scale"))) { + ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); + ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); } +} - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - const auto *var = ctx.InputVar(framework::GradVarName("Y")); - if (var == nullptr) { - PADDLE_THROW("can't find Y@GRAD"); - } - const Tensor *t = nullptr; - if (var->IsType()) { - t = &var->Get(); - } else if (var->IsType()) { - t = &var->Get(); - } - if (t == nullptr) { - PADDLE_THROW("can't find Y@GRAD"); - } +framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( + const framework::ExecutionContext &ctx) const { + const auto *var = ctx.InputVar(framework::GradVarName("Y")); + if (var == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } + const Tensor *t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } + if (t == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; - } + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } #endif - return framework::OpKernelType(ctx.Input("X")->type(), - ctx.GetPlace(), layout, library); - } -}; + return framework::OpKernelType(ctx.Input("X")->type(), ctx.GetPlace(), + layout, library); +} template class BatchNormGradKernel @@ -572,37 +560,31 @@ class BatchNormGradKernel } }; -class BatchNormGradMaker : public framework::SingleGradOpDescMaker { - public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; - - protected: - std::unique_ptr Apply() const override { - auto *op = new framework::OpDesc(); - op->SetType("batch_norm_grad"); - op->SetInput("X", Input("X")); - op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); - - op->SetInput("Scale", Input("Scale")); - op->SetInput("Bias", Input("Bias")); - op->SetInput("SavedMean", Output("SavedMean")); - op->SetInput("SavedVariance", Output("SavedVariance")); - - // used when setting use_global_stats True during training - if (boost::get(GetAttr("use_global_stats"))) { - op->SetInput("Mean", Output("MeanOut")); - op->SetInput("Variance", Output("VarianceOut")); - } +std::unique_ptr BatchNormGradMaker::Apply() const { + auto *op = new framework::OpDesc(); + op->SetType(GradOpType()); + op->SetInput("X", Input("X")); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + + op->SetInput("Scale", Input("Scale")); + op->SetInput("Bias", Input("Bias")); + op->SetInput("SavedMean", Output("SavedMean")); + op->SetInput("SavedVariance", Output("SavedVariance")); + + // used when setting use_global_stats True during training + if (boost::get(GetAttr("use_global_stats"))) { + op->SetInput("Mean", Output("MeanOut")); + op->SetInput("Variance", Output("VarianceOut")); + } - op->SetAttrMap(Attrs()); + op->SetAttrMap(Attrs()); - op->SetOutput(framework::GradVarName("X"), InputGrad("X")); - op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale")); - op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale")); + op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias")); - return std::unique_ptr(op); - } -}; + return std::unique_ptr(op); +} class BatchNormInplaceInToOut : public framework::InplaceInToOut { public: @@ -642,10 +624,10 @@ class BatchNormGradInplaceInToOut : public framework::InplaceInToOut { namespace ops = paddle::operators; REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, - ops::BatchNormOpInferVarType, ops::BatchNormGradMaker, - ops::BatchNormInplaceInToOut); -REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp, - ops::BatchNormGradInplaceInToOut); + ops::BatchNormOpInferVarType, ops::BatchNormGradMaker) +// ops::BatchNormInplaceInToOut); +REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp) +// ops::BatchNormGradInplaceInToOut); REGISTER_OP_CPU_KERNEL( batch_norm, ops::BatchNormKernel, diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index 1c45746a92..36d297ec55 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -33,26 +33,6 @@ using CudnnDataType = platform::CudnnDataType; template using BatchNormParamType = typename CudnnDataType::BatchNormParamType; -void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout, - int *N, int *C, int *H, int *W, int *D) { - *N = dims[0]; - if (dims.size() == 2) { - *C = dims[1]; - *H = 1; - *W = 1; - *D = 1; - } else { - *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; - *H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1]; - *W = dims.size() > 3 - ? (data_layout == DataLayout::kNCHW ? dims[3] : dims[2]) - : 1; - *D = dims.size() > 4 - ? (data_layout == DataLayout::kNCHW ? dims[4] : dims[3]) - : 1; - } -} - template class BatchNormKernel : public framework::OpKernel { @@ -196,22 +176,6 @@ class BatchNormKernel } }; -template -static __global__ void KeBNBackwardData(const T *dy, - const BatchNormParamType *scale, - const BatchNormParamType *variance, - const double epsilon, const int C, - const int HxW, const int num, T *dx) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - for (int i = gid; i < num; i += stride) { - const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C; - BatchNormParamType inv_var = 1.0 / sqrt(variance[c] + epsilon); - dx[i] = static_cast(static_cast>(dy[i]) * - scale[c] * inv_var); - } -} - template static __global__ void KeBNBackwardScaleBias( const T *dy, const T *x, const BatchNormParamType *mean, @@ -248,6 +212,22 @@ static __global__ void KeBNBackwardScaleBias( } } +template +static __global__ void KeBNBackwardData(const T *dy, + const BatchNormParamType *scale, + const BatchNormParamType *variance, + const double epsilon, const int C, + const int HxW, const int num, T *dx) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = gid; i < num; i += stride) { + const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C; + BatchNormParamType inv_var = 1.0 / sqrt(variance[c] + epsilon); + dx[i] = static_cast(static_cast>(dy[i]) * + scale[c] * inv_var); + } +} + template class BatchNormGradKernel : public framework::OpKernel { @@ -383,7 +363,7 @@ class BatchNormGradKernel KeBNBackwardScaleBias<<< grid2, block, 0, dev_ctx.stream()>>>( d_y->data(), x->data(), running_mean_data, running_var_data, - epsilon, C, H * W, num, d_scale->data>(), + epsilon, N, C, H * W * D, d_scale->data>(), d_bias->data>()); } } else { @@ -394,10 +374,10 @@ class BatchNormGradKernel running_var_data, epsilon, C, H * W, num, d_x->data()); } if (d_scale && d_bias) { - KeBNBackwardScaleBias<<< + KeBNBackwardScaleBias<<< grid2, block, 0, dev_ctx.stream()>>>( d_y->data(), x->data(), running_mean_data, running_var_data, - epsilon, C, H * W, num, d_scale->data>(), + epsilon, N, C, H * W * D, d_scale->data>(), d_bias->data>()); } } diff --git a/paddle/fluid/operators/batch_norm_op.h b/paddle/fluid/operators/batch_norm_op.h index 5e3d630d68..6e89d73eb2 100644 --- a/paddle/fluid/operators/batch_norm_op.h +++ b/paddle/fluid/operators/batch_norm_op.h @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include +#include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -35,17 +38,84 @@ template using ConstEigenVectorArrayMap = Eigen::Map>; +class BatchNormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override; +}; + +class BatchNormGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override; +}; + +class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override; +}; + +class BatchNormGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override; + + virtual std::string GradOpType() const { + return this->ForwardOpType() + "_grad"; + } +}; + +class BatchNormOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Y"}}; + } +}; + template class BatchNormKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override; + void Compute(const framework::ExecutionContext &ctx) const override; }; template class BatchNormGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override; + void Compute(const framework::ExecutionContext &ctx) const override; }; +inline void ExtractNCWHD(const framework::DDim &dims, + const DataLayout &data_layout, int *N, int *C, int *H, + int *W, int *D) { + *N = dims[0]; + if (dims.size() == 2) { + *C = dims[1]; + *H = 1; + *W = 1; + *D = 1; + } else { + *C = data_layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1]; + *H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1]; + *W = dims.size() > 3 + ? (data_layout == DataLayout::kNCHW ? dims[3] : dims[2]) + : 1; + *D = dims.size() > 4 + ? (data_layout == DataLayout::kNCHW ? dims[4] : dims[3]) + : 1; + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/sync_batch_norm_op.cc b/paddle/fluid/operators/sync_batch_norm_op.cc new file mode 100644 index 0000000000..d6cf27fd77 --- /dev/null +++ b/paddle/fluid/operators/sync_batch_norm_op.cc @@ -0,0 +1,20 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/batch_norm_op.h" + +namespace ops = paddle::operators; +REGISTER_OPERATOR(sync_batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, + ops::BatchNormOpInferVarType, ops::BatchNormGradMaker); +REGISTER_OPERATOR(sync_batch_norm_grad, ops::BatchNormGradOp); diff --git a/paddle/fluid/operators/sync_batch_norm_op.cu b/paddle/fluid/operators/sync_batch_norm_op.cu new file mode 100644 index 0000000000..a5984bfaaa --- /dev/null +++ b/paddle/fluid/operators/sync_batch_norm_op.cu @@ -0,0 +1,452 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include "cub/cub.cuh" +#include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/operators/batch_norm_op.h" +#include "paddle/fluid/platform/cudnn_helper.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/nccl_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DataLayout = framework::DataLayout; +template +using CudnnDataType = platform::CudnnDataType; + +template +__global__ void KeLocalStats(const T *x, int N, int M, int C, T *mean_var) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + for (int k = blockIdx.x; k < C; k += gridDim.x) { + T x_sum = 0; + T x2_sum = 0; + for (int i = threadIdx.x; i < N * M; i += BlockDim) { + int id = layout == framework::DataLayout::kNCHW + ? (i / M) * C * M + k * M + i % M + : i * C + k; + T x_in = x[id]; + x_sum += x_in; + x2_sum += x_in * x_in; + } + __syncthreads(); + T out = BlockReduce(temp_storage).Reduce(x_sum, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + mean_var[k] = out / (N * M); + } + out = BlockReduce(temp_storage).Reduce(x2_sum, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + mean_var[k + C] = out / (N * M); + } + } + if (blockIdx.x == 0 && threadIdx.x == 0) { + mean_var[2 * C] = static_cast(1.0); + } +} + +template +__global__ void KeSyncAndMovingStats(T *means, T *variances, T *num_dev, + const int C, const T momentum, + const double epsilon, T *sv_mean_data, + T *sv_inv_var_data, T *moving_means, + T *moving_variances) { + // sync stats across multi-devices + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = gid; i < C; i += stride) { + T mean = means[i] / (*num_dev); + T var = variances[i] / (*num_dev); + var = var - mean * mean; + + // sync stats + sv_mean_data[i] = mean; + sv_inv_var_data[i] = 1.0 / sqrt(var + epsilon); + variances[i] = var; + + // moving stats + moving_means[i] = moving_means[i] * momentum + mean * (1. - momentum); + moving_variances[i] = + moving_variances[i] * momentum + var * (1. - momentum); + } +} + +template +static __global__ void KeNormAffine(const T *x, const T *scale, const T *bias, + const T *mean, const T *variance, + const double epsilon, const int C, + const int M, const int num, T *y) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = gid; i < num; i += stride) { + const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C; + y[i] = (x[i] - mean[c]) / sqrt(variance[c] + epsilon) * scale[c] + bias[c]; + } +} + +template +class SyncBatchNormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + double epsilon = static_cast(ctx.Attr("epsilon")); + const float momentum = ctx.Attr("momentum"); + const bool is_test = ctx.Attr("is_test"); + const std::string layout_str = ctx.Attr("data_layout"); + const DataLayout layout = framework::StringToDataLayout(layout_str); + const bool use_global_stats = ctx.Attr("use_global_stats"); + PADDLE_ENFORCE( + !use_global_stats, + "sync_batch_norm doesn't support to set use_global_stats True. ", + "Please use batch_norm in this case."); + + const auto *x = ctx.Input("X"); + const auto &x_dims = x->dims(); + PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, + "The Input dim size should be between 2 and 5"); + int N, C, H, W, D; + ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); + int x_numel = x->numel(); + + const T *x_d = x->data(); + const T *s_d = ctx.Input("Scale")->data(); + const T *b_d = ctx.Input("Bias")->data(); + + auto *y = ctx.Output("Y"); + T *y_d = y->mutable_data(ctx.GetPlace()); + + const T *mean_data = nullptr; + const T *var_data = nullptr; + + auto &dev_ctx = ctx.cuda_device_context(); + auto stream = dev_ctx.stream(); + auto *comm = dev_ctx.nccl_comm(); + const int block = 512; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + + paddle::memory::AllocationPtr alloc_ptr{nullptr}; + + if (is_test) { + const auto *est_mean = ctx.Input("Mean"); + const auto *est_var = ctx.Input("Variance"); + mean_data = est_mean->data(); + var_data = est_var->data(); + } else { + auto &allocator = + platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); + // x, x^2, 1, here 1 is used to calc device num + // device num also can be got from platform::DeviceContextPool + const int bytes = (C * 2 + 1) * sizeof(T); + alloc_ptr = allocator.Allocate(bytes); + + T *stats = reinterpret_cast(alloc_ptr->ptr()); + const int threads = 256; + int grid = std::min(C, (max_threads + threads - 1) / threads); + if (layout == framework::DataLayout::kNCHW) { + KeLocalStats< + T, threads, + framework::DataLayout::kNCHW><<>>( + x_d, N, H * W * D, C, stats); + } else { + KeLocalStats< + T, threads, + framework::DataLayout::kNHWC><<>>( + x_d, N, H * W * D, C, stats); + } + + Tensor c_g_st; + T *c_g_st_d = c_g_st.mutable_data({2 * C + 1}, platform::CPUPlace()); + auto gplace = boost::get(ctx.GetPlace()); + memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0); + + int dtype = platform::ToNCCLDataType(x->type()); + // In-place operation + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + stats, stats, 2 * C + 1, static_cast(dtype), ncclSum, + comm, stream)); + + // moving mean/variance + auto *mean_out = ctx.Output("MeanOut"); + auto *variance_out = ctx.Output("VarianceOut"); + T *est_mean_data = mean_out->mutable_data(ctx.GetPlace()); + T *est_var_data = variance_out->mutable_data(ctx.GetPlace()); + + auto *saved_mean = ctx.Output("SavedMean"); + auto *saved_inv_variance = ctx.Output("SavedVariance"); + T *sv_mean_data = saved_mean->mutable_data(ctx.GetPlace()); + T *sv_inv_var_data = saved_inv_variance->mutable_data(ctx.GetPlace()); + + // Note, Input('Mean')/Input('Variance') share variable with + // Output('MeanOut')/Output('VarianceOut') + KeSyncAndMovingStats<<<(C + block - 1) / block, block, 0, stream>>>( + stats, stats + C, stats + 2 * C, C, momentum, epsilon, sv_mean_data, + sv_inv_var_data, est_mean_data, est_var_data); + + mean_data = sv_mean_data; + var_data = stats + C; + } + + int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; + if (layout == framework::DataLayout::kNCHW) { + KeNormAffine<<>>( + x_d, s_d, b_d, mean_data, var_data, epsilon, C, H * W * D, x_numel, + y_d); + } else { + KeNormAffine<<>>( + x_d, s_d, b_d, mean_data, var_data, epsilon, C, H * W * D, x_numel, + y_d); + } + } +}; + +template +__global__ void KeBackwardLocalStats(const T *dy, const T *x, const T *means, + int N, int M, int C, T *sum_dy_prod) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + for (int k = blockIdx.x; k < C; k += gridDim.x) { + T sum1 = 0; + T sum2 = 0; + T mean = means[k]; + for (int i = threadIdx.x; i < N * M; i += blockDim.x) { + int id = layout == framework::DataLayout::kNCHW + ? (i / M) * C * M + k * M + i % M + : i * C + k; + T g = dy[id]; + sum1 += g; + sum2 += g * (x[id] - mean); + } + + __syncthreads(); + T out = BlockReduce(temp_storage).Reduce(sum1, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + sum_dy_prod[k] = out; + } + out = BlockReduce(temp_storage).Reduce(sum2, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + sum_dy_prod[k + C] = out; + } + } + if (blockIdx.x == 0 && threadIdx.x == 0) { + sum_dy_prod[2 * C] = static_cast(1.0); + } +} + +template +static __global__ void KeBNBackwardScaleBias(const T *dy, const T *x, + const T *mean, + const T *inv_variance, + const double epsilon, const int N, + const int C, const int HxW, + T *dscale, T *dbias) { + const int outer_size = C; + const int inner_size = N * HxW; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { + T ds_sum = static_cast(0); + T db_sum = static_cast(0); + + T inv_var_i = inv_variance[i]; + T mean_i = mean[i]; + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int id = layout == framework::DataLayout::kNCHW + ? ((j / HxW) * C + i) * HxW + (j % HxW) + : j * outer_size + i; + ds_sum += dy[id] * (x[id] - mean_i); + db_sum += dy[id]; + } + __syncthreads(); + double os = BlockReduce(temp_storage) + .Reduce(static_cast(ds_sum), cub::Sum()); + __syncthreads(); + double ob = BlockReduce(temp_storage) + .Reduce(static_cast(db_sum), cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + dscale[i] = static_cast(os * inv_var_i); + dbias[i] = static_cast(ob); + } + __syncthreads(); + } +} + +template +static __global__ void KeBNBackwardData(const T *dy, const T *x, const T *beta, + const T *mean, const T *inv_variance, + const T *g_sum_dy, + const T *g_sum_dy_prod, + const T *num_dev, const double epsilon, + const int C, const int HxW, + const int num, T *dx) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + T scale = static_cast(C) / num; + T dev_num = num_dev[0]; + for (int i = gid; i < num; i += stride) { + const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C; + T inv_var = inv_variance[c]; + T s_d = beta[c]; + T gvar = -1.0 * (g_sum_dy_prod[c] / dev_num) * s_d * inv_var * + (inv_var * inv_var); + T gmean = -1.0 * (g_sum_dy[c] / dev_num) * s_d * inv_var; + + dx[i] = + dy[i] * s_d * inv_var + gmean * scale + gvar * scale * (x[i] - mean[c]); + } +} + +// Deriving the Gradient for the Backward Pass of Batch Normalization +// https://kevinzakka.github.io/2016/09/14/batch_normalization/ +template +class SyncBatchNormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use CUDAPlace."); + double epsilon = static_cast(ctx.Attr("epsilon")); + const std::string layout_str = ctx.Attr("data_layout"); + + const DataLayout layout = framework::StringToDataLayout(layout_str); + const auto *x = ctx.Input("X"); + const auto *d_y = ctx.Input(framework::GradVarName("Y")); + const auto *scale = ctx.Input("Scale"); + + const auto &x_dims = x->dims(); + + PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, + "The Input dim size should be between 2 and 5"); + int N, C, H, W, D; + ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); + + // init output + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + + d_x->mutable_data(ctx.GetPlace()); + if (d_scale && d_bias) { + d_scale->mutable_data(ctx.GetPlace()); + d_bias->mutable_data(ctx.GetPlace()); + } + PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL); + PADDLE_ENFORCE_EQ(scale->dims()[0], C); + + std::vector dims; + std::vector strides; + if (layout == DataLayout::kNCHW) { + dims = {N, C, H, W, D}; + strides = {C * H * W * D, H * W * D, W * D, D, 1}; + } else { + dims = {N, C, H, W, D}; + strides = {H * W * C * D, 1, W * D * C, D * C, C}; + } + + const T *x_d = x->data(); + const T *dy_d = d_y->data(); + + auto &dev_ctx = ctx.cuda_device_context(); + auto stream = dev_ctx.stream(); + auto *comm = dev_ctx.nccl_comm(); + + const T *saved_mean = ctx.Input("SavedMean")->data(); + const T *saved_inv_var = ctx.Input("SavedVariance")->data(); + auto &allocator = + platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx); + const int bytes = (C * 2 + 1) * sizeof(T); + auto alloc_ptr = allocator.Allocate(bytes); + T *stats = reinterpret_cast(alloc_ptr->ptr()); + + const int threads = 256; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + int grid = std::min(C, (max_threads + threads - 1) / threads); + int x_numel = x->numel(); + int fsize = H * W * D; + + if (layout == framework::DataLayout::kNCHW) { + KeBackwardLocalStats< + T, threads, + framework::DataLayout::kNCHW><<>>( + dy_d, x_d, saved_mean, N, fsize, C, stats); + } else { + KeBackwardLocalStats< + T, threads, + framework::DataLayout::kNHWC><<>>( + dy_d, x_d, saved_mean, N, fsize, C, stats); + } + int dtype = platform::ToNCCLDataType(x->type()); + // In-place operation + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + stats, stats, 2 * C + 1, static_cast(dtype), ncclSum, + comm, stream)); + + const int block = 512; + int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; + if (layout == framework::DataLayout::kNCHW) { + if (d_scale && d_bias) { + KeBNBackwardScaleBias< + T, threads, + framework::DataLayout::kNCHW><<>>( + dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize, + d_scale->data(), d_bias->data()); + } + if (d_x) { + KeBNBackwardData< + T, framework::DataLayout::kNCHW><<>>( + dy_d, x_d, scale->data(), saved_mean, saved_inv_var, stats, + stats + C, stats + 2 * C, epsilon, C, fsize, x->numel(), + d_x->data()); + } + } else { + if (d_scale && d_bias) { + KeBNBackwardScaleBias< + T, threads, + framework::DataLayout::kNHWC><<>>( + dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize, + d_scale->data(), d_bias->data()); + } + if (d_x) { + KeBNBackwardData< + T, framework::DataLayout::kNHWC><<>>( + dy_d, x_d, scale->data(), saved_mean, saved_inv_var, stats, + stats + C, stats + 2 * C, epsilon, C, fsize, x->numel(), + d_x->data()); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + sync_batch_norm, ops::SyncBatchNormKernel, + ops::SyncBatchNormKernel); +REGISTER_OP_CUDA_KERNEL( + sync_batch_norm_grad, + ops::SyncBatchNormGradKernel, + ops::SyncBatchNormGradKernel); diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 920b43b2b1..ada9a19736 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -57,7 +57,6 @@ DeviceContextPool::DeviceContextPool( for (auto& p : places) { set.insert(p); } - for (auto& p : set) { if (platform::is_cpu_place(p)) { #ifdef PADDLE_WITH_MKLDNN @@ -317,6 +316,7 @@ CUDADeviceContext::~CUDADeviceContext() { eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); + PADDLE_ENFORCE(dynload::ncclCommDestroy(nccl_comm_)); } Place CUDADeviceContext::GetPlace() const { return place_; } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index d376f90ad5..3f7ce3d944 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -265,6 +265,12 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cuda stream in the device context. */ cudaStream_t stream() const; + /*! \brief Return nccl communicators. */ + ncclComm_t nccl_comm() const { return nccl_comm_; } + + /*! \brief Set nccl communicators. */ + void set_nccl_comm(ncclComm_t comm) { nccl_comm_ = comm; } + template void RecordEvent(cudaEvent_t ev, Callback callback) { callback(); @@ -289,6 +295,13 @@ class CUDADeviceContext : public DeviceContext { std::unique_ptr cublas_handle_; std::unique_ptr cublas_tensor_core_handle_; + // NCCL communicator (single process version) for NCCL collective operations. + // NCCL collective operations provides fast collectives over multiple GPUs + // both within and across nodes. + // But, this collectives is used for collectives over multiple GPUs within + // nodes. + ncclComm_t nccl_comm_{nullptr}; + int compute_capability_; int runtime_version_; int driver_version_; diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 4dcf7e7904..d53a4029e1 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include // for strdup #include +#include +#include #include #include @@ -140,6 +142,7 @@ void InitDevices(bool init_p2p, const std::vector devices) { places.emplace_back(platform::CPUPlace()); platform::DeviceContextPool::Init(places); platform::DeviceTemporaryAllocator::Init(); + #ifndef PADDLE_WITH_MKLDNN platform::SetNumThreads(FLAGS_paddle_num_threads); #endif diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h index 6ae21ee829..0428c40f98 100644 --- a/paddle/fluid/platform/nccl_helper.h +++ b/paddle/fluid/platform/nccl_helper.h @@ -16,9 +16,11 @@ #pragma once #include +#include #include #include // NOLINT #include +#include #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/platform/dynload/nccl.h" @@ -78,6 +80,8 @@ struct NCCLContext { cudaStream_t stream() const { return ctx_->stream(); } + ncclComm_t comm() const { return comm_; } + int device_id() const { return boost::get(ctx_->GetPlace()).device; } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 552a5e0c32..5a753d0a78 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1230,6 +1230,21 @@ All parameter, weight, gradient are variables in Paddle. it will save GPU memory and may make the execution faster. This options is only available in GPU devices. Default False)DOC") + .def_property( + "sync_batch_norm", + [](const BuildStrategy &self) { return self.sync_batch_norm_; }, + [](BuildStrategy &self, bool b) { + PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized."); + self.sync_batch_norm_ = b; + }, + R"DOC(The type is BOOL, sync_batch_norm indicates whether to use + synchronous batch normalization which synchronizes the mean + and variance through multi-devices in training phase. + + Current implementation doesn't support FP16 training and CPU. + And only synchronous on one machine, not all machines. + + Default False)DOC") .def_property( "memory_optimize", [](const BuildStrategy &self) { return self.memory_optimize_; }, diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index 2bf30e39d3..5732377bd6 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -223,6 +223,9 @@ class CompiledProgram(object): tps), "num_trainers == len(end_points)" self._build_strategy.trainers_endpoints = tps + if self._build_strategy.sync_batch_norm: + self._build_strategy.enable_sequential_execution = True + self._persistable_vars = [] for node in self._graph.nodes(): if node.is_var() and node.var() is not None and node.var().persistable() and \ diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9886f4e84c..ea028b0566 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2922,11 +2922,17 @@ def batch_norm(input, y_i &\\gets \\gamma \\hat{x_i} + \\beta Args: - input(variable): The input variable which is a LoDTensor. + input(variable): The rank of input variable can be 2, 3, 4, 5. act(string, Default None): Activation type, linear|relu|prelu|... - is_test(bool, Default False): Used for training or training. - momentum(float, Default 0.9): - epsilon(float, Default 1e-05): + is_test (bool, Default False): A flag indicating whether it is in + test phrase or not. + momentum(float, Default 0.9): The value used for the moving_mean and + moving_var computation. The updated formula is: + :math:`moving\_mean = moving\_mean * momentum + new\_mean * (1. - momentum)` + :math:`moving\_var = moving\_var * momentum + new\_var * (1. - momentum)` + Default is 0.9. + epsilon(float, Default 1e-05): A value added to the denominator for + numerical stability. Default is 1e-5. param_attr(ParamAttr|None): The parameter attribute for Parameter `scale` of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm will create ParamAttr as param_attr. If the Initializer of the param_attr @@ -2984,15 +2990,8 @@ def batch_norm(input, shape=param_shape, dtype=dtype, default_initializer=Constant(1.0)) - # setting stop_gradient=True to reduce computation - if use_global_stats and helper.param_attr.learning_rate == 0.: - scale.stop_gradient = True - bias = helper.create_parameter( attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True) - # setting stop_gradient=True to reduce computation - if use_global_stats and helper.bias_attr.learning_rate == 0.: - bias.stop_gradient = True mean = helper.create_parameter( attr=ParamAttr( diff --git a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py new file mode 100644 index 0000000000..f6a658cb1b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py @@ -0,0 +1,159 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import os +import six +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import compiler + + +class TestSyncBatchNormOpTraining(unittest.TestCase): + def setUp(self): + #self.dtype = np.float32 + self.dtype = np.float64 + self.N = 32 + self.C = 16 + self.H = 64 + self.W = 32 + self.dshape = [self.N, self.C, self.H, self.W] + + def build_program(self, + place, + layout, + seed, + sync_bn=False, + only_forward=False): + main = fluid.Program() + startup = fluid.Program() + main.random_seed = seed + startup.random_seed = seed + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + data = fluid.layers.data( + name='input', + shape=self.dshape, + dtype=self.dtype, + append_batch_size=False) + conv = fluid.layers.conv2d( + input=data, + num_filters=32, + filter_size=1, + param_attr=fluid.ParamAttr(name='conv2d_weight'), + bias_attr=False, + use_cudnn=False) + bn = fluid.layers.batch_norm( + conv, + param_attr=fluid.ParamAttr(name='bn_scale'), + bias_attr=fluid.ParamAttr(name='bn_bias'), + moving_mean_name='bn_moving_mean', + moving_variance_name='bn_moving_variance', + data_layout=layout, + is_test=only_forward) + sigmoid = fluid.layers.sigmoid(bn) + out = fluid.layers.reduce_sum(sigmoid) + if not sync_bn: + out = out / core.get_cuda_device_count() + if not only_forward: + sgd_opt = fluid.optimizer.SGD(learning_rate=0.0) + sgd_opt.backward(out) + return main, startup, [out, conv, bn] + + def compare(self, place, layout, only_forward): + seed = 10 + os.environ['FLAGS_cudnn_deterministic'] = "1" + data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2 + # Single-GPU, N = 32 per GPU + main, startup, outs = self.build_program(place, layout, seed, False, + only_forward) + exe = fluid.Executor(place) + exe.run(startup) + fetch_names = [v.name for v in outs] + [ + 'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias' + ] + if not only_forward: + others = [ + 'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD', + 'bn_bias@GRAD', 'batch_norm_0.tmp_2@GRAD', 'conv2d_0.tmp_0@GRAD' + ] + fetch_names += others + bn_fetches = exe.run(program=main, + feed={'input': data}, + fetch_list=fetch_names) + + ##################################################################### + # Multi-GPUs, self.N / core.get_cuda_device_count() per GPU + main, startup, outs = self.build_program(place, layout, seed, True, + only_forward) + exe = fluid.Executor(place) + exe.run(startup) + fetch_names = [v.name for v in outs] + [ + 'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias' + ] + if not only_forward: + others = [ + 'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD', + 'bn_bias@GRAD', 'batch_norm_0.tmp_2@GRAD', 'conv2d_0.tmp_0@GRAD' + ] + fetch_names += others + for nm in fetch_names: + fv = fluid.framework._get_var(str(nm), program=main) + fv.persistable = True + build_strategy = fluid.BuildStrategy() + build_strategy.sync_batch_norm = True + build_strategy.enable_inplace = False + build_strategy.memory_optimize = False + comp_prog = compiler.CompiledProgram(main).with_data_parallel( + outs[0].name if not only_forward else None, + build_strategy=build_strategy) + sync_bn_fetches = exe.run(program=comp_prog, + feed={'input': data}, + fetch_list=fetch_names) + + for i in six.moves.xrange(1, len(sync_bn_fetches)): + bn_val = bn_fetches[i] + sync_bn_val = sync_bn_fetches[i] + if sync_bn_val.shape != bn_val.shape: + sync_bn_val = sync_bn_val[:bn_val.shape[0]] + self.assertTrue( + np.allclose( + bn_val, sync_bn_val, atol=1e-3), + "Output (" + fetch_names[i] + ") has diff. \n" + "\nBN " + + str(bn_val) + "\n" + "Sync BN " + str(sync_bn_val)) + + def test_train(self): + if not core.is_compiled_with_cuda(): + return + + places = [core.CUDAPlace(0)] + for place in places: + for layout in ["NCHW", "NHWC"]: + self.compare(place, layout, False) + + def test_infer(self): + if not core.is_compiled_with_cuda(): + return + + places = [core.CUDAPlace(0)] + for place in places: + for layout in ["NCHW", "NHWC"]: + self.compare(place, layout, True) + + +if __name__ == '__main__': + unittest.main() -- GitLab