未验证 提交 8ad672a2 编写于 作者: Q qingqing01 提交者: GitHub

Support sync batch norm. (#16121)

* Support Sync Batch Norm.
* Note, do not enable it in one device.

Usage:

build_strategy = fluid.BuildStrategy()
build_strategy.sync_batch_norm = True
binary = fluid.compiler.CompiledProgram(tp).with_data_parallel(
        loss_name=loss_mean.name,
        build_strategy=build_strategy)
上级 4ae23cc3
...@@ -110,7 +110,7 @@ function(op_library TARGET) ...@@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op") "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()
......
...@@ -91,7 +91,7 @@ paddle.fluid.layers.pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'po ...@@ -91,7 +91,7 @@ paddle.fluid.layers.pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'po
paddle.fluid.layers.pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)), ('document', '043de7333b79ee0ac55053c14ed81625')) paddle.fluid.layers.pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True)), ('document', '043de7333b79ee0ac55053c14ed81625'))
paddle.fluid.layers.adaptive_pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '859b887174d06f361658f69cb7c06d95')) paddle.fluid.layers.adaptive_pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '859b887174d06f361658f69cb7c06d95'))
paddle.fluid.layers.adaptive_pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '120f4323a3d7ed9c0916f15a59f0e497')) paddle.fluid.layers.adaptive_pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '120f4323a3d7ed9c0916f15a59f0e497'))
paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)), ('document', 'c527b71b8a4c60dca8df8a745c2b598d')) paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)), ('document', '320c6973b02ea179fa89fecc80796464'))
paddle.fluid.layers.data_norm (ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, None, None, None, False)), ('document', 'e45e09e65a2658e07cad987222f0d9ab')) paddle.fluid.layers.data_norm (ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, None, None, None, False)), ('document', 'e45e09e65a2658e07cad987222f0d9ab'))
paddle.fluid.layers.beam_search_decode (ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b0b8d53821716cd50c42e09b593f3feb')) paddle.fluid.layers.beam_search_decode (ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b0b8d53821716cd50c42e09b593f3feb'))
paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', '03993955ab1e6d3044c44e6f17fc85e9')) paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', '03993955ab1e6d3044c44e6f17fc85e9'))
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <memory> #include <memory>
#include <utility>
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
...@@ -49,6 +50,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -49,6 +50,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("sequential_execution_pass"); AppendPass("sequential_execution_pass");
} }
// Add op fusion.
if (strategy.sync_batch_norm_) {
AppendPass("sync_batch_norm_pass");
}
// Add op fusion. // Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) { if (strategy.fuse_relu_depthwise_conv_) {
AppendPass("fuse_relu_depthwise_conv_pass"); AppendPass("fuse_relu_depthwise_conv_pass");
...@@ -227,6 +233,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -227,6 +233,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_PASS(sync_batch_norm_pass);
USE_PASS(fuse_relu_depthwise_conv_pass); USE_PASS(fuse_relu_depthwise_conv_pass);
USE_PASS(fuse_elewise_add_act_pass); USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
......
...@@ -77,6 +77,8 @@ struct BuildStrategy { ...@@ -77,6 +77,8 @@ struct BuildStrategy {
bool fuse_relu_depthwise_conv_{false}; bool fuse_relu_depthwise_conv_{false};
bool sync_batch_norm_{false};
bool memory_optimize_{true}; bool memory_optimize_{true};
// TODO(dzhwinter): // TODO(dzhwinter):
// make enable_inplace, memory_optimize_ // make enable_inplace, memory_optimize_
......
...@@ -67,6 +67,7 @@ pass_library(conv_elementwise_add_fuse_pass inference) ...@@ -67,6 +67,7 @@ pass_library(conv_elementwise_add_fuse_pass inference)
pass_library(conv_affine_channel_fuse_pass inference) pass_library(conv_affine_channel_fuse_pass inference)
pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base) pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base)
# There may be many transpose-flatten structures in a model, and the output of # There may be many transpose-flatten structures in a model, and the output of
# these structures will be used as inputs to the concat Op. This pattern will # these structures will be used as inputs to the concat Op. This pattern will
...@@ -101,6 +102,7 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g ...@@ -101,6 +102,7 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass)
cc_test(test_cpu_quantize_squash_pass SRCS cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) cc_test(test_cpu_quantize_squash_pass SRCS cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
if (WITH_MKLDNN) if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
#include <memory>
#include <string>
#include <utility>
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> SyncBatchNormPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Use synchronous batch norm";
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
if (op->Type() == "batch_norm") {
op->SetType("sync_batch_norm");
}
if (op->Type() == "batch_norm_grad") {
op->SetType("sync_batch_norm_grad");
}
}
}
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(sync_batch_norm_pass, paddle::framework::ir::SyncBatchNormPass);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class SyncBatchNormPass : public Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/sync_batch_norm_pass.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("name", name);
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
}
// (a, conv_w)->conv2d->b
// (b, bn_scale, bn_bias, mean, var)->batch_norm
// ->(c, mean, var, save_mean, save_inv_var)
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "conv_w", "b", "bn_scale",
"bn_bias", "mean", "var", "c",
"save_mean", "save_inv_var"})) {
auto* var = prog.MutableBlock(0)->Var(v);
if (v == "conv_w" || v == "bn_scale" || v == "bn_bias" || v == "mean" ||
v == "var") {
var->SetPersistable(true);
}
}
SetOp(&prog, "conv2d", "conv", std::vector<std::string>({"a", "conv_w"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "batch_norm", "bn",
std::vector<std::string>({"b", "bn_scale", "bn_bias", "mean", "var"}),
std::vector<std::string>(
{"c", "mean", "var", "save_mean", "save_inv_var"}));
return prog;
}
TEST(IsTestPass, basic) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("sync_batch_norm_pass");
graph = pass->Apply(std::move(graph));
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
auto op_name = boost::get<std::string>(op->GetAttr("name"));
if (op_name == "bn") {
ASSERT_EQ(op->Type(), "sync_batch_norm");
}
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(sync_batch_norm_pass);
...@@ -14,8 +14,10 @@ limitations under the License. */ ...@@ -14,8 +14,10 @@ limitations under the License. */
#include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/parallel_executor.h"
#include <algorithm> #include <algorithm>
#include <memory>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
...@@ -251,6 +253,20 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places, ...@@ -251,6 +253,20 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
member_->nccl_ctxs_.reset(new platform::NCCLContextMap( member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
member_->places_, nccl_id, build_strategy.num_trainers_, member_->places_, nccl_id, build_strategy.num_trainers_,
build_strategy.trainer_id_)); build_strategy.trainer_id_));
std::unique_ptr<platform::NCCLContextMap> dev_nccl_ctxs;
dev_nccl_ctxs.reset(new platform::NCCLContextMap(member_->places_));
// Initialize device context's nccl comm
// Note, more than one ParallelExecutor with same place, the nccl comm will
// be rewrite and there will be some problem.
for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) {
auto &nccl_ctx = dev_nccl_ctxs->at(dev_id);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
pool.Get(member_->places_[dev_id]));
dev_ctx->set_nccl_comm(nccl_ctx.comm());
}
#else #else
PADDLE_THROW("Not compiled with CUDA"); PADDLE_THROW("Not compiled with CUDA");
#endif #endif
......
...@@ -44,10 +44,10 @@ if (WITH_DISTRIBUTE) ...@@ -44,10 +44,10 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch) SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif() endif()
register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op sync_batch_norm_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
# warpctc_op needs cudnn 7 above
if (WITH_GPU) if (WITH_GPU)
# warpctc_op needs cudnn 7 above
if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7) if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc)
else() else()
...@@ -58,6 +58,8 @@ if (WITH_GPU) ...@@ -58,6 +58,8 @@ if (WITH_GPU)
op_library(conv_fusion_op) op_library(conv_fusion_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n") file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n")
endif() endif()
op_library(sync_batch_norm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n")
else() else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif() endif()
......
...@@ -33,26 +33,6 @@ using CudnnDataType = platform::CudnnDataType<T>; ...@@ -33,26 +33,6 @@ using CudnnDataType = platform::CudnnDataType<T>;
template <typename T> template <typename T>
using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType; using BatchNormParamType = typename CudnnDataType<T>::BatchNormParamType;
void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
int *N, int *C, int *H, int *W, int *D) {
*N = dims[0];
if (dims.size() == 2) {
*C = dims[1];
*H = 1;
*W = 1;
*D = 1;
} else {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1];
*H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*W = dims.size() > 3
? (data_layout == DataLayout::kNCHW ? dims[3] : dims[2])
: 1;
*D = dims.size() > 4
? (data_layout == DataLayout::kNCHW ? dims[4] : dims[3])
: 1;
}
}
template <typename T> template <typename T>
class BatchNormKernel<platform::CUDADeviceContext, T> class BatchNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
...@@ -196,22 +176,6 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -196,22 +176,6 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
} }
}; };
template <typename T, framework::DataLayout layout>
static __global__ void KeBNBackwardData(const T *dy,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *variance,
const double epsilon, const int C,
const int HxW, const int num, T *dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
BatchNormParamType<T> inv_var = 1.0 / sqrt(variance[c] + epsilon);
dx[i] = static_cast<T>(static_cast<BatchNormParamType<T>>(dy[i]) *
scale[c] * inv_var);
}
}
template <typename T, int BlockDim, framework::DataLayout layout> template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void KeBNBackwardScaleBias( static __global__ void KeBNBackwardScaleBias(
const T *dy, const T *x, const BatchNormParamType<T> *mean, const T *dy, const T *x, const BatchNormParamType<T> *mean,
...@@ -248,6 +212,22 @@ static __global__ void KeBNBackwardScaleBias( ...@@ -248,6 +212,22 @@ static __global__ void KeBNBackwardScaleBias(
} }
} }
template <typename T, framework::DataLayout layout>
static __global__ void KeBNBackwardData(const T *dy,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *variance,
const double epsilon, const int C,
const int HxW, const int num, T *dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
BatchNormParamType<T> inv_var = 1.0 / sqrt(variance[c] + epsilon);
dx[i] = static_cast<T>(static_cast<BatchNormParamType<T>>(dy[i]) *
scale[c] * inv_var);
}
}
template <typename T> template <typename T>
class BatchNormGradKernel<platform::CUDADeviceContext, T> class BatchNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
...@@ -383,7 +363,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -383,7 +363,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
KeBNBackwardScaleBias<T, block, framework::DataLayout::kNCHW><<< KeBNBackwardScaleBias<T, block, framework::DataLayout::kNCHW><<<
grid2, block, 0, dev_ctx.stream()>>>( grid2, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data, d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
epsilon, C, H * W, num, d_scale->data<BatchNormParamType<T>>(), epsilon, N, C, H * W * D, d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>()); d_bias->data<BatchNormParamType<T>>());
} }
} else { } else {
...@@ -394,10 +374,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -394,10 +374,10 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
running_var_data, epsilon, C, H * W, num, d_x->data<T>()); running_var_data, epsilon, C, H * W, num, d_x->data<T>());
} }
if (d_scale && d_bias) { if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, framework::DataLayout::kNCHW><<< KeBNBackwardScaleBias<T, block, framework::DataLayout::kNHWC><<<
grid2, block, 0, dev_ctx.stream()>>>( grid2, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data, d_y->data<T>(), x->data<T>(), running_mean_data, running_var_data,
epsilon, C, H * W, num, d_scale->data<BatchNormParamType<T>>(), epsilon, N, C, H * W * D, d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>()); d_bias->data<BatchNormParamType<T>>());
} }
} }
......
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -35,17 +38,84 @@ template <typename T> ...@@ -35,17 +38,84 @@ template <typename T>
using ConstEigenVectorArrayMap = using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>; Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
class BatchNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
};
class BatchNormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
};
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override;
virtual std::string GradOpType() const {
return this->ForwardOpType() + "_grad";
}
};
class BatchNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class BatchNormKernel : public framework::OpKernel<T> { class BatchNormKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext &ctx) const override;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class BatchNormGradKernel : public framework::OpKernel<T> { class BatchNormGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override; void Compute(const framework::ExecutionContext &ctx) const override;
}; };
inline void ExtractNCWHD(const framework::DDim &dims,
const DataLayout &data_layout, int *N, int *C, int *H,
int *W, int *D) {
*N = dims[0];
if (dims.size() == 2) {
*C = dims[1];
*H = 1;
*W = 1;
*D = 1;
} else {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1];
*H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*W = dims.size() > 3
? (data_layout == DataLayout::kNCHW ? dims[3] : dims[2])
: 1;
*D = dims.size() > 4
? (data_layout == DataLayout::kNCHW ? dims[4] : dims[3])
: 1;
}
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/batch_norm_op.h"
namespace ops = paddle::operators;
REGISTER_OPERATOR(sync_batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
ops::BatchNormOpInferVarType, ops::BatchNormGradMaker);
REGISTER_OPERATOR(sync_batch_norm_grad, ops::BatchNormGradOp);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <cfloat>
#include <string>
#include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void KeLocalStats(const T *x, int N, int M, int C, T *mean_var) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
for (int k = blockIdx.x; k < C; k += gridDim.x) {
T x_sum = 0;
T x2_sum = 0;
for (int i = threadIdx.x; i < N * M; i += BlockDim) {
int id = layout == framework::DataLayout::kNCHW
? (i / M) * C * M + k * M + i % M
: i * C + k;
T x_in = x[id];
x_sum += x_in;
x2_sum += x_in * x_in;
}
__syncthreads();
T out = BlockReduce(temp_storage).Reduce(x_sum, cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
mean_var[k] = out / (N * M);
}
out = BlockReduce(temp_storage).Reduce(x2_sum, cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
mean_var[k + C] = out / (N * M);
}
}
if (blockIdx.x == 0 && threadIdx.x == 0) {
mean_var[2 * C] = static_cast<T>(1.0);
}
}
template <typename T>
__global__ void KeSyncAndMovingStats(T *means, T *variances, T *num_dev,
const int C, const T momentum,
const double epsilon, T *sv_mean_data,
T *sv_inv_var_data, T *moving_means,
T *moving_variances) {
// sync stats across multi-devices
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = gid; i < C; i += stride) {
T mean = means[i] / (*num_dev);
T var = variances[i] / (*num_dev);
var = var - mean * mean;
// sync stats
sv_mean_data[i] = mean;
sv_inv_var_data[i] = 1.0 / sqrt(var + epsilon);
variances[i] = var;
// moving stats
moving_means[i] = moving_means[i] * momentum + mean * (1. - momentum);
moving_variances[i] =
moving_variances[i] * momentum + var * (1. - momentum);
}
}
template <typename T, framework::DataLayout layout>
static __global__ void KeNormAffine(const T *x, const T *scale, const T *bias,
const T *mean, const T *variance,
const double epsilon, const int C,
const int M, const int num, T *y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C;
y[i] = (x[i] - mean[c]) / sqrt(variance[c] + epsilon) * scale[c] + bias[c];
}
}
template <typename DeviceContext, typename T>
class SyncBatchNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
const std::string layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout layout = framework::StringToDataLayout(layout_str);
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
PADDLE_ENFORCE(
!use_global_stats,
"sync_batch_norm doesn't support to set use_global_stats True. ",
"Please use batch_norm in this case.");
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"The Input dim size should be between 2 and 5");
int N, C, H, W, D;
ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D);
int x_numel = x->numel();
const T *x_d = x->data<T>();
const T *s_d = ctx.Input<Tensor>("Scale")->data<T>();
const T *b_d = ctx.Input<Tensor>("Bias")->data<T>();
auto *y = ctx.Output<Tensor>("Y");
T *y_d = y->mutable_data<T>(ctx.GetPlace());
const T *mean_data = nullptr;
const T *var_data = nullptr;
auto &dev_ctx = ctx.cuda_device_context();
auto stream = dev_ctx.stream();
auto *comm = dev_ctx.nccl_comm();
const int block = 512;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
paddle::memory::AllocationPtr alloc_ptr{nullptr};
if (is_test) {
const auto *est_mean = ctx.Input<Tensor>("Mean");
const auto *est_var = ctx.Input<Tensor>("Variance");
mean_data = est_mean->data<T>();
var_data = est_var->data<T>();
} else {
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
// x, x^2, 1, here 1 is used to calc device num
// device num also can be got from platform::DeviceContextPool
const int bytes = (C * 2 + 1) * sizeof(T);
alloc_ptr = allocator.Allocate(bytes);
T *stats = reinterpret_cast<T *>(alloc_ptr->ptr());
const int threads = 256;
int grid = std::min(C, (max_threads + threads - 1) / threads);
if (layout == framework::DataLayout::kNCHW) {
KeLocalStats<
T, threads,
framework::DataLayout::kNCHW><<<grid, threads, 0, stream>>>(
x_d, N, H * W * D, C, stats);
} else {
KeLocalStats<
T, threads,
framework::DataLayout::kNHWC><<<grid, threads, 0, stream>>>(
x_d, N, H * W * D, C, stats);
}
Tensor c_g_st;
T *c_g_st_d = c_g_st.mutable_data<T>({2 * C + 1}, platform::CPUPlace());
auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0);
int dtype = platform::ToNCCLDataType(x->type());
// In-place operation
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
stats, stats, 2 * C + 1, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream));
// moving mean/variance
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
T *est_mean_data = mean_out->mutable_data<T>(ctx.GetPlace());
T *est_var_data = variance_out->mutable_data<T>(ctx.GetPlace());
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_inv_variance = ctx.Output<Tensor>("SavedVariance");
T *sv_mean_data = saved_mean->mutable_data<T>(ctx.GetPlace());
T *sv_inv_var_data = saved_inv_variance->mutable_data<T>(ctx.GetPlace());
// Note, Input('Mean')/Input('Variance') share variable with
// Output('MeanOut')/Output('VarianceOut')
KeSyncAndMovingStats<T><<<(C + block - 1) / block, block, 0, stream>>>(
stats, stats + C, stats + 2 * C, C, momentum, epsilon, sv_mean_data,
sv_inv_var_data, est_mean_data, est_var_data);
mean_data = sv_mean_data;
var_data = stats + C;
}
int grid2 = (std::min(x_numel, max_threads) + block - 1) / block;
if (layout == framework::DataLayout::kNCHW) {
KeNormAffine<T,
framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>(
x_d, s_d, b_d, mean_data, var_data, epsilon, C, H * W * D, x_numel,
y_d);
} else {
KeNormAffine<T,
framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>(
x_d, s_d, b_d, mean_data, var_data, epsilon, C, H * W * D, x_numel,
y_d);
}
}
};
template <typename T, const int BlockDim, framework::DataLayout layout>
__global__ void KeBackwardLocalStats(const T *dy, const T *x, const T *means,
int N, int M, int C, T *sum_dy_prod) {
typedef cub::BlockReduce<double, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
for (int k = blockIdx.x; k < C; k += gridDim.x) {
T sum1 = 0;
T sum2 = 0;
T mean = means[k];
for (int i = threadIdx.x; i < N * M; i += blockDim.x) {
int id = layout == framework::DataLayout::kNCHW
? (i / M) * C * M + k * M + i % M
: i * C + k;
T g = dy[id];
sum1 += g;
sum2 += g * (x[id] - mean);
}
__syncthreads();
T out = BlockReduce(temp_storage).Reduce(sum1, cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
sum_dy_prod[k] = out;
}
out = BlockReduce(temp_storage).Reduce(sum2, cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
sum_dy_prod[k + C] = out;
}
}
if (blockIdx.x == 0 && threadIdx.x == 0) {
sum_dy_prod[2 * C] = static_cast<T>(1.0);
}
}
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void KeBNBackwardScaleBias(const T *dy, const T *x,
const T *mean,
const T *inv_variance,
const double epsilon, const int N,
const int C, const int HxW,
T *dscale, T *dbias) {
const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<double, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
T ds_sum = static_cast<T>(0);
T db_sum = static_cast<T>(0);
T inv_var_i = inv_variance[i];
T mean_i = mean[i];
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int id = layout == framework::DataLayout::kNCHW
? ((j / HxW) * C + i) * HxW + (j % HxW)
: j * outer_size + i;
ds_sum += dy[id] * (x[id] - mean_i);
db_sum += dy[id];
}
__syncthreads();
double os = BlockReduce(temp_storage)
.Reduce(static_cast<double>(ds_sum), cub::Sum());
__syncthreads();
double ob = BlockReduce(temp_storage)
.Reduce(static_cast<double>(db_sum), cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
dscale[i] = static_cast<T>(os * inv_var_i);
dbias[i] = static_cast<T>(ob);
}
__syncthreads();
}
}
template <typename T, framework::DataLayout layout>
static __global__ void KeBNBackwardData(const T *dy, const T *x, const T *beta,
const T *mean, const T *inv_variance,
const T *g_sum_dy,
const T *g_sum_dy_prod,
const T *num_dev, const double epsilon,
const int C, const int HxW,
const int num, T *dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
T scale = static_cast<T>(C) / num;
T dev_num = num_dev[0];
for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
T inv_var = inv_variance[c];
T s_d = beta[c];
T gvar = -1.0 * (g_sum_dy_prod[c] / dev_num) * s_d * inv_var *
(inv_var * inv_var);
T gmean = -1.0 * (g_sum_dy[c] / dev_num) * s_d * inv_var;
dx[i] =
dy[i] * s_d * inv_var + gmean * scale + gvar * scale * (x[i] - mean[c]);
}
}
// Deriving the Gradient for the Backward Pass of Batch Normalization
// https://kevinzakka.github.io/2016/09/14/batch_normalization/
template <typename DeviceContext, typename T>
class SyncBatchNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const std::string layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout layout = framework::StringToDataLayout(layout_str);
const auto *x = ctx.Input<Tensor>("X");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"The Input dim size should be between 2 and 5");
int N, C, H, W, D;
ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D);
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
if (d_scale && d_bias) {
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
}
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
std::vector<int> dims;
std::vector<int> strides;
if (layout == DataLayout::kNCHW) {
dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else {
dims = {N, C, H, W, D};
strides = {H * W * C * D, 1, W * D * C, D * C, C};
}
const T *x_d = x->data<T>();
const T *dy_d = d_y->data<T>();
auto &dev_ctx = ctx.cuda_device_context();
auto stream = dev_ctx.stream();
auto *comm = dev_ctx.nccl_comm();
const T *saved_mean = ctx.Input<Tensor>("SavedMean")->data<T>();
const T *saved_inv_var = ctx.Input<Tensor>("SavedVariance")->data<T>();
auto &allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
const int bytes = (C * 2 + 1) * sizeof(T);
auto alloc_ptr = allocator.Allocate(bytes);
T *stats = reinterpret_cast<T *>(alloc_ptr->ptr());
const int threads = 256;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
int grid = std::min(C, (max_threads + threads - 1) / threads);
int x_numel = x->numel();
int fsize = H * W * D;
if (layout == framework::DataLayout::kNCHW) {
KeBackwardLocalStats<
T, threads,
framework::DataLayout::kNCHW><<<grid, threads, 0, stream>>>(
dy_d, x_d, saved_mean, N, fsize, C, stats);
} else {
KeBackwardLocalStats<
T, threads,
framework::DataLayout::kNHWC><<<grid, threads, 0, stream>>>(
dy_d, x_d, saved_mean, N, fsize, C, stats);
}
int dtype = platform::ToNCCLDataType(x->type());
// In-place operation
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
stats, stats, 2 * C + 1, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream));
const int block = 512;
int grid2 = (std::min(x_numel, max_threads) + block - 1) / block;
if (layout == framework::DataLayout::kNCHW) {
if (d_scale && d_bias) {
KeBNBackwardScaleBias<
T, threads,
framework::DataLayout::kNCHW><<<grid, threads, 0, stream>>>(
dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize,
d_scale->data<T>(), d_bias->data<T>());
}
if (d_x) {
KeBNBackwardData<
T, framework::DataLayout::kNCHW><<<grid2, block, 0, stream>>>(
dy_d, x_d, scale->data<T>(), saved_mean, saved_inv_var, stats,
stats + C, stats + 2 * C, epsilon, C, fsize, x->numel(),
d_x->data<T>());
}
} else {
if (d_scale && d_bias) {
KeBNBackwardScaleBias<
T, threads,
framework::DataLayout::kNHWC><<<grid, threads, 0, stream>>>(
dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize,
d_scale->data<T>(), d_bias->data<T>());
}
if (d_x) {
KeBNBackwardData<
T, framework::DataLayout::kNHWC><<<grid2, block, 0, stream>>>(
dy_d, x_d, scale->data<T>(), saved_mean, saved_inv_var, stats,
stats + C, stats + 2 * C, epsilon, C, fsize, x->numel(),
d_x->data<T>());
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
sync_batch_norm, ops::SyncBatchNormKernel<plat::CUDADeviceContext, float>,
ops::SyncBatchNormKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
sync_batch_norm_grad,
ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, float>,
ops::SyncBatchNormGradKernel<plat::CUDADeviceContext, double>);
...@@ -57,7 +57,6 @@ DeviceContextPool::DeviceContextPool( ...@@ -57,7 +57,6 @@ DeviceContextPool::DeviceContextPool(
for (auto& p : places) { for (auto& p : places) {
set.insert(p); set.insert(p);
} }
for (auto& p : set) { for (auto& p : set) {
if (platform::is_cpu_place(p)) { if (platform::is_cpu_place(p)) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -317,6 +316,7 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -317,6 +316,7 @@ CUDADeviceContext::~CUDADeviceContext() {
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_)); PADDLE_ENFORCE(cudaStreamDestroy(stream_));
PADDLE_ENFORCE(dynload::ncclCommDestroy(nccl_comm_));
} }
Place CUDADeviceContext::GetPlace() const { return place_; } Place CUDADeviceContext::GetPlace() const { return place_; }
......
...@@ -265,6 +265,12 @@ class CUDADeviceContext : public DeviceContext { ...@@ -265,6 +265,12 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cuda stream in the device context. */ /*! \brief Return cuda stream in the device context. */
cudaStream_t stream() const; cudaStream_t stream() const;
/*! \brief Return nccl communicators. */
ncclComm_t nccl_comm() const { return nccl_comm_; }
/*! \brief Set nccl communicators. */
void set_nccl_comm(ncclComm_t comm) { nccl_comm_ = comm; }
template <typename Callback> template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) { void RecordEvent(cudaEvent_t ev, Callback callback) {
callback(); callback();
...@@ -289,6 +295,13 @@ class CUDADeviceContext : public DeviceContext { ...@@ -289,6 +295,13 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<CublasHandleHolder> cublas_handle_; std::unique_ptr<CublasHandleHolder> cublas_handle_;
std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_; std::unique_ptr<CublasHandleHolder> cublas_tensor_core_handle_;
// NCCL communicator (single process version) for NCCL collective operations.
// NCCL collective operations provides fast collectives over multiple GPUs
// both within and across nodes.
// But, this collectives is used for collectives over multiple GPUs within
// nodes.
ncclComm_t nccl_comm_{nullptr};
int compute_capability_; int compute_capability_;
int runtime_version_; int runtime_version_;
int driver_version_; int driver_version_;
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <string.h> // for strdup #include <string.h> // for strdup
#include <algorithm> #include <algorithm>
#include <memory>
#include <set>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
...@@ -140,6 +142,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) { ...@@ -140,6 +142,7 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
places.emplace_back(platform::CPUPlace()); places.emplace_back(platform::CPUPlace());
platform::DeviceContextPool::Init(places); platform::DeviceContextPool::Init(places);
platform::DeviceTemporaryAllocator::Init(); platform::DeviceTemporaryAllocator::Init();
#ifndef PADDLE_WITH_MKLDNN #ifndef PADDLE_WITH_MKLDNN
platform::SetNumThreads(FLAGS_paddle_num_threads); platform::SetNumThreads(FLAGS_paddle_num_threads);
#endif #endif
......
...@@ -16,9 +16,11 @@ ...@@ -16,9 +16,11 @@
#pragma once #pragma once
#include <stdio.h> #include <stdio.h>
#include <memory>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <typeindex> #include <typeindex>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/dynload/nccl.h" #include "paddle/fluid/platform/dynload/nccl.h"
...@@ -78,6 +80,8 @@ struct NCCLContext { ...@@ -78,6 +80,8 @@ struct NCCLContext {
cudaStream_t stream() const { return ctx_->stream(); } cudaStream_t stream() const { return ctx_->stream(); }
ncclComm_t comm() const { return comm_; }
int device_id() const { int device_id() const {
return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device; return boost::get<platform::CUDAPlace>(ctx_->GetPlace()).device;
} }
......
...@@ -1230,6 +1230,21 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1230,6 +1230,21 @@ All parameter, weight, gradient are variables in Paddle.
it will save GPU memory and may make the execution faster. it will save GPU memory and may make the execution faster.
This options is only available in GPU devices. This options is only available in GPU devices.
Default False)DOC") Default False)DOC")
.def_property(
"sync_batch_norm",
[](const BuildStrategy &self) { return self.sync_batch_norm_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.sync_batch_norm_ = b;
},
R"DOC(The type is BOOL, sync_batch_norm indicates whether to use
synchronous batch normalization which synchronizes the mean
and variance through multi-devices in training phase.
Current implementation doesn't support FP16 training and CPU.
And only synchronous on one machine, not all machines.
Default False)DOC")
.def_property( .def_property(
"memory_optimize", "memory_optimize",
[](const BuildStrategy &self) { return self.memory_optimize_; }, [](const BuildStrategy &self) { return self.memory_optimize_; },
......
...@@ -223,6 +223,9 @@ class CompiledProgram(object): ...@@ -223,6 +223,9 @@ class CompiledProgram(object):
tps), "num_trainers == len(end_points)" tps), "num_trainers == len(end_points)"
self._build_strategy.trainers_endpoints = tps self._build_strategy.trainers_endpoints = tps
if self._build_strategy.sync_batch_norm:
self._build_strategy.enable_sequential_execution = True
self._persistable_vars = [] self._persistable_vars = []
for node in self._graph.nodes(): for node in self._graph.nodes():
if node.is_var() and node.var() is not None and node.var().persistable() and \ if node.is_var() and node.var() is not None and node.var().persistable() and \
......
...@@ -2922,11 +2922,17 @@ def batch_norm(input, ...@@ -2922,11 +2922,17 @@ def batch_norm(input,
y_i &\\gets \\gamma \\hat{x_i} + \\beta y_i &\\gets \\gamma \\hat{x_i} + \\beta
Args: Args:
input(variable): The input variable which is a LoDTensor. input(variable): The rank of input variable can be 2, 3, 4, 5.
act(string, Default None): Activation type, linear|relu|prelu|... act(string, Default None): Activation type, linear|relu|prelu|...
is_test(bool, Default False): Used for training or training. is_test (bool, Default False): A flag indicating whether it is in
momentum(float, Default 0.9): test phrase or not.
epsilon(float, Default 1e-05): momentum(float, Default 0.9): The value used for the moving_mean and
moving_var computation. The updated formula is:
:math:`moving\_mean = moving\_mean * momentum + new\_mean * (1. - momentum)`
:math:`moving\_var = moving\_var * momentum + new\_var * (1. - momentum)`
Default is 0.9.
epsilon(float, Default 1e-05): A value added to the denominator for
numerical stability. Default is 1e-5.
param_attr(ParamAttr|None): The parameter attribute for Parameter `scale` param_attr(ParamAttr|None): The parameter attribute for Parameter `scale`
of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as param_attr. If the Initializer of the param_attr will create ParamAttr as param_attr. If the Initializer of the param_attr
...@@ -2984,15 +2990,8 @@ def batch_norm(input, ...@@ -2984,15 +2990,8 @@ def batch_norm(input,
shape=param_shape, shape=param_shape,
dtype=dtype, dtype=dtype,
default_initializer=Constant(1.0)) default_initializer=Constant(1.0))
# setting stop_gradient=True to reduce computation
if use_global_stats and helper.param_attr.learning_rate == 0.:
scale.stop_gradient = True
bias = helper.create_parameter( bias = helper.create_parameter(
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True) attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)
# setting stop_gradient=True to reduce computation
if use_global_stats and helper.bias_attr.learning_rate == 0.:
bias.stop_gradient = True
mean = helper.create_parameter( mean = helper.create_parameter(
attr=ParamAttr( attr=ParamAttr(
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import os
import six
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import compiler
class TestSyncBatchNormOpTraining(unittest.TestCase):
def setUp(self):
#self.dtype = np.float32
self.dtype = np.float64
self.N = 32
self.C = 16
self.H = 64
self.W = 32
self.dshape = [self.N, self.C, self.H, self.W]
def build_program(self,
place,
layout,
seed,
sync_bn=False,
only_forward=False):
main = fluid.Program()
startup = fluid.Program()
main.random_seed = seed
startup.random_seed = seed
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
data = fluid.layers.data(
name='input',
shape=self.dshape,
dtype=self.dtype,
append_batch_size=False)
conv = fluid.layers.conv2d(
input=data,
num_filters=32,
filter_size=1,
param_attr=fluid.ParamAttr(name='conv2d_weight'),
bias_attr=False,
use_cudnn=False)
bn = fluid.layers.batch_norm(
conv,
param_attr=fluid.ParamAttr(name='bn_scale'),
bias_attr=fluid.ParamAttr(name='bn_bias'),
moving_mean_name='bn_moving_mean',
moving_variance_name='bn_moving_variance',
data_layout=layout,
is_test=only_forward)
sigmoid = fluid.layers.sigmoid(bn)
out = fluid.layers.reduce_sum(sigmoid)
if not sync_bn:
out = out / core.get_cuda_device_count()
if not only_forward:
sgd_opt = fluid.optimizer.SGD(learning_rate=0.0)
sgd_opt.backward(out)
return main, startup, [out, conv, bn]
def compare(self, place, layout, only_forward):
seed = 10
os.environ['FLAGS_cudnn_deterministic'] = "1"
data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2
# Single-GPU, N = 32 per GPU
main, startup, outs = self.build_program(place, layout, seed, False,
only_forward)
exe = fluid.Executor(place)
exe.run(startup)
fetch_names = [v.name for v in outs] + [
'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias'
]
if not only_forward:
others = [
'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD',
'bn_bias@GRAD', 'batch_norm_0.tmp_2@GRAD', 'conv2d_0.tmp_0@GRAD'
]
fetch_names += others
bn_fetches = exe.run(program=main,
feed={'input': data},
fetch_list=fetch_names)
#####################################################################
# Multi-GPUs, self.N / core.get_cuda_device_count() per GPU
main, startup, outs = self.build_program(place, layout, seed, True,
only_forward)
exe = fluid.Executor(place)
exe.run(startup)
fetch_names = [v.name for v in outs] + [
'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias'
]
if not only_forward:
others = [
'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD',
'bn_bias@GRAD', 'batch_norm_0.tmp_2@GRAD', 'conv2d_0.tmp_0@GRAD'
]
fetch_names += others
for nm in fetch_names:
fv = fluid.framework._get_var(str(nm), program=main)
fv.persistable = True
build_strategy = fluid.BuildStrategy()
build_strategy.sync_batch_norm = True
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
comp_prog = compiler.CompiledProgram(main).with_data_parallel(
outs[0].name if not only_forward else None,
build_strategy=build_strategy)
sync_bn_fetches = exe.run(program=comp_prog,
feed={'input': data},
fetch_list=fetch_names)
for i in six.moves.xrange(1, len(sync_bn_fetches)):
bn_val = bn_fetches[i]
sync_bn_val = sync_bn_fetches[i]
if sync_bn_val.shape != bn_val.shape:
sync_bn_val = sync_bn_val[:bn_val.shape[0]]
self.assertTrue(
np.allclose(
bn_val, sync_bn_val, atol=1e-3),
"Output (" + fetch_names[i] + ") has diff. \n" + "\nBN " +
str(bn_val) + "\n" + "Sync BN " + str(sync_bn_val))
def test_train(self):
if not core.is_compiled_with_cuda():
return
places = [core.CUDAPlace(0)]
for place in places:
for layout in ["NCHW", "NHWC"]:
self.compare(place, layout, False)
def test_infer(self):
if not core.is_compiled_with_cuda():
return
places = [core.CUDAPlace(0)]
for place in places:
for layout in ["NCHW", "NHWC"]:
self.compare(place, layout, True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册