From f26ba5bddd08aa93723d28b04a9c51398d25cc80 Mon Sep 17 00:00:00 2001 From: chengduo Date: Tue, 19 Mar 2019 23:38:06 -0500 Subject: [PATCH] Fuse AllReduce (#15921) * fuse all_reduce test=develop * add fuse_parameter_groups_size test=develop * Polish code test=develop * Fix travis-ci test=develop * Add SetGroupAccordingToLayers and SetGroupAccordingToGroupSize test=develop * Add SetGroupAccordingToMemorySize test=develop * fix multi_devices_graph test=develop * reset params_grads test=develop * Polish code test=develop --- paddle/fluid/framework/details/CMakeLists.txt | 11 +- .../alloc_continuous_space_for_grad_pass.cc | 393 ++++++++++++++++++ .../fluid/framework/details/build_strategy.cc | 86 +++- .../fluid/framework/details/build_strategy.h | 3 + .../details/fuse_all_reduce_op_pass.cc | 195 +++++++++ .../details/fused_all_reduce_op_handle.cc | 248 +++++++++++ .../details/fused_all_reduce_op_handle.h | 76 ++++ .../details/multi_devices_graph_pass.cc | 30 +- .../details/multi_devices_graph_pass.h | 14 +- .../framework/details/multi_devices_helper.h | 23 + .../framework/details/reduce_and_gather.h | 25 ++ paddle/fluid/pybind/pybind.cc | 4 + python/paddle/fluid/__init__.py | 3 +- .../unittests/parallel_executor_test_base.py | 2 + .../unittests/test_fuse_all_reduce_pass.py | 121 ++++++ 15 files changed, 1185 insertions(+), 49 deletions(-) create mode 100644 paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.cc create mode 100644 paddle/fluid/framework/details/fuse_all_reduce_op_pass.cc create mode 100644 paddle/fluid/framework/details/fused_all_reduce_op_handle.cc create mode 100644 paddle/fluid/framework/details/fused_all_reduce_op_handle.h create mode 100644 python/paddle/fluid/tests/unittests/test_fuse_all_reduce_pass.py diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 9f06455ea54..5f29ebc70f0 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -9,6 +9,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper) cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper) cc_library(multi_devices_graph_check_pass SRCS multi_devices_graph_check_pass.cc DEPS multi_devices_helper) +cc_library(alloc_continuous_space_for_grad_pass SRCS alloc_continuous_space_for_grad_pass.cc DEPS graph graph_helper) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) @@ -22,6 +23,8 @@ endif() if(WITH_GPU) nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda variable_visitor) + nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory + dynload_cuda variable_visitor) if(WITH_DISTRIBUTE) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda selected_rows_functor sendrecvop_rpc) @@ -35,6 +38,8 @@ if(WITH_GPU) else() cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory variable_visitor) + cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory + variable_visitor) if(WITH_DISTRIBUTE) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim selected_rows_functor sendrecvop_rpc) @@ -71,6 +76,8 @@ cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_he cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle) +cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle) + set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass) if (WITH_GPU) list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass) @@ -98,5 +105,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS graph_viz_pass multi_devices_graph_pass multi_devices_graph_print_pass multi_devices_graph_check_pass fuse_elewise_add_act_pass multi_batch_merge_pass - fuse_relu_depthwise_conv_pass - memory_optimize_pass lock_free_optimize_pass) + fuse_relu_depthwise_conv_pass + memory_optimize_pass lock_free_optimize_pass alloc_continuous_space_for_grad_pass fuse_all_reduce_op_pass) diff --git a/paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.cc b/paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.cc new file mode 100644 index 00000000000..fbc8bbf56b0 --- /dev/null +++ b/paddle/fluid/framework/details/alloc_continuous_space_for_grad_pass.cc @@ -0,0 +1,393 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "paddle/fluid/framework/details/build_strategy.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/op_registry.h" +DEFINE_uint32(fuse_parameter_memory_size, 0, // 0 KB + "fuse_parameter_memory_size is up limited memory size " + "of one group parameters' gradient which is the input " + "of communication calling(e.g NCCLAllReduce). " + "The default value is 0, it means that " + "not set group according to memory_size."); +DEFINE_int32( + fuse_parameter_groups_size, 3, + "fuse_parameter_groups_size is the size of one group parameters' gradient. " + "The default value is a experimental result. If the " + "fuse_parameter_groups_size is 1, it means that the groups size is " + "the number of parameters' gradient. If the fuse_parameter_groups_size is " + "-1, it means that there are only one group. The default value is 3, it is " + "an experimental value."); + +namespace paddle { +namespace framework { +namespace details { + +static const char kUnKnow[] = "@UNKNOW@"; +static framework::proto::VarType::Type kDefaultDtype = + framework::proto::VarType::Type::VarType_Type_BOOL; + +class AllocContinuousSpaceForGradPass : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override { + ir::Graph &result = *graph; + + auto &places = Get>(kPlaces); + auto &local_scopes = Get>(kLocalScopes); + + ResetAttribute(kParamsAndGrads, &result); + ResetAttribute(kGroupGradsAndParams, &result); + + // NOTE: The operator nodes should be in topology order. + std::vector topo_nodes = ir::TopologySortOperations(result); + auto ¶ms_grads = result.Get(kParamsAndGrads); + for (auto &node : topo_nodes) { + RecordParamsAndGrads(node, ¶ms_grads); + } + + if (params_grads.size() == 0) { + VLOG(10) << "Doesn't find gradients"; + return std::move(graph); + } + + std::unordered_map vars; + for (ir::Node *node : result.Nodes()) { + if (node->IsVar() && node->Var()) { + // Note: The graph may have the same name node. For example, parameter + // is the input of operator and it also is the output of optimizer; + vars.emplace(node->Var()->Name(), node); + } + } + + auto &group_grads_params = + result.Get(kGroupGradsAndParams); + + // Note: the order of params_grads may be changed by SetGroupGradsAndParams. + SetGroupGradsAndParams(vars, params_grads, &group_grads_params); + + params_grads.clear(); + for (auto &group_p_g : group_grads_params) { + params_grads.insert(params_grads.begin(), group_p_g.begin(), + group_p_g.end()); + } + for (auto &p_g : params_grads) { + std::swap(p_g.first, p_g.second); + } + + // Set Gradients as Persistable to prevent this var becoming reusable. + auto dtype = kDefaultDtype; + for (auto &p_g : params_grads) { + // Get gradient var + auto iter = vars.find(p_g.second); + PADDLE_ENFORCE(iter != vars.end(), "%s is not found.", p_g.second); + iter->second->Var()->SetPersistable(true); + + PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType())); + + // Get Dtype + auto ele_dtype = iter->second->Var()->GetDataType(); + if (dtype == kDefaultDtype) { + dtype = ele_dtype; + PADDLE_ENFORCE_NE(ele_dtype, kDefaultDtype); + } + PADDLE_ENFORCE_EQ(ele_dtype, dtype); + } + + // Create the fused variable name. + if (!result.Has(kFusedVars)) { + result.Set(kFusedVars, new FusedVars); + } + const std::string prefix(kFusedVarNamePrefix); + // The fused_var_name should be unique. + auto fused_var_name = prefix + "GRAD@" + params_grads[0].second; + auto &fused_var_set = result.Get(kFusedVars); + PADDLE_ENFORCE_EQ(fused_var_set.count(fused_var_name), 0); + fused_var_set.insert(fused_var_name); + + InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars, + fused_var_name, params_grads); + + return std::move(graph); + } + + template + void ResetAttribute(const std::string &attr_name, ir::Graph *graph) const { + if (graph->Has(attr_name)) { + VLOG(10) << attr_name << " is reset."; + graph->Erase(attr_name); + } + graph->Set(attr_name, new AttrType); + } + + void SetGroupGradsAndParams( + const std::unordered_map &var_nodes, + const ParamsAndGrads ¶ms_grads, + GroupGradsAndParams *group_grads_params) const { + SetGroupAccordingToLayers(var_nodes, params_grads, group_grads_params); + SetGroupAccordingToMemorySize(var_nodes, group_grads_params); + SetGroupAccordingToGroupSize(var_nodes, group_grads_params); + } + + void SetGroupAccordingToLayers( + const std::unordered_map &var_nodes, + const ParamsAndGrads ¶ms_grads, + GroupGradsAndParams *group_grads_params) const { + std::unordered_map> layer_params; + + for (size_t i = 0; i < params_grads.size(); ++i) { + auto pos = params_grads[i].first.find_first_of("."); + if (pos == std::string::npos) { + layer_params[std::string(kUnKnow)].emplace_back(i); + } else { + layer_params[params_grads[i].first.substr(0, pos)].emplace_back(i); + } + } + + group_grads_params->reserve(layer_params.size()); + for (size_t i = 0; i < params_grads.size(); ++i) { + auto pos = params_grads[i].first.find_first_of("."); + std::string key = kUnKnow; + if (pos != std::string::npos) { + key = params_grads[i].first.substr(0, pos); + } + auto iter = layer_params.find(key); + if (iter == layer_params.end()) continue; + + group_grads_params->emplace_back(); + auto &local_group_grads_params = group_grads_params->back(); + for (auto &idx : iter->second) { + local_group_grads_params.emplace_back( + std::make_pair(params_grads[idx].second, params_grads[idx].first)); + } + layer_params.erase(iter); + } + + VLOG(10) << "SetGroupAccordingToLayers: "; + for (size_t i = 0; i < group_grads_params->size(); ++i) { + VLOG(10) << "group " << i; + std::stringstream out; + for (auto &p_g : group_grads_params->at(i)) { + out << "(" << p_g.second << ", " << p_g.first << "), "; + } + VLOG(10) << out.str(); + } + } + + void SetGroupAccordingToMemorySize( + const std::unordered_map &var_nodes, + GroupGradsAndParams *group_grads_params) const { + if (FLAGS_fuse_parameter_memory_size == 0) { + return; + } + size_t group_memory_size = + static_cast(FLAGS_fuse_parameter_memory_size); + GroupGradsAndParams local_group_grads_params; + + size_t j = 0; + while (j < group_grads_params->size()) { + local_group_grads_params.emplace_back(); + auto &group_p_g = local_group_grads_params.back(); + size_t local_group_memory_size = 0; + while (j < group_grads_params->size()) { + std::for_each( + group_grads_params->at(j).begin(), group_grads_params->at(j).end(), + [&local_group_memory_size, + &var_nodes](const std::pair &g_p) { + auto iter = var_nodes.find(g_p.second); + PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", + g_p.second); + auto shape = iter->second->Var()->GetShape(); + size_t size = + framework::SizeOfType(iter->second->Var()->GetDataType()); + std::for_each(shape.begin(), shape.end(), + [&size](const int64_t &n) { size *= n; }); + local_group_memory_size += size; + }); + group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(), + group_grads_params->at(j).end()); + ++j; + if (local_group_memory_size >= group_memory_size) { + break; + } + } + } + + std::swap(*group_grads_params, local_group_grads_params); + + VLOG(10) << string::Sprintf( + "SetGroupAccordingToMemorySize(memory_size: %d):", + FLAGS_fuse_parameter_memory_size); + for (size_t i = 0; i < group_grads_params->size(); ++i) { + VLOG(10) << "group " << i; + std::stringstream out; + for (auto &g_p : group_grads_params->at(i)) { + auto iter = var_nodes.find(g_p.second); + PADDLE_ENFORCE(iter != var_nodes.end(), "%s is not found.", g_p.second); + auto shape = iter->second->Var()->GetShape(); + size_t size = framework::SizeOfType(iter->second->Var()->GetDataType()); + std::for_each(shape.begin(), shape.end(), + [&size](const int64_t &n) { size *= n; }); + out << string::Sprintf("(%s(%d), %s)", g_p.second, size, g_p.first); + } + VLOG(10) << out.str(); + } + } + + void SetGroupAccordingToGroupSize( + const std::unordered_map &var_nodes, + GroupGradsAndParams *group_grads_params) const { + if (FLAGS_fuse_parameter_groups_size == 1) { + return; + } + size_t group_size = static_cast(FLAGS_fuse_parameter_groups_size); + if (FLAGS_fuse_parameter_groups_size == -1) { + group_size = group_grads_params->size(); + } + PADDLE_ENFORCE_GT(group_size, 1); + size_t groups = (group_grads_params->size() + group_size - 1) / group_size; + GroupGradsAndParams local_group_grads_params; + local_group_grads_params.reserve(groups); + + size_t j = 0; + for (size_t i = 0; i < groups; ++i) { + local_group_grads_params.emplace_back(); + auto &group_p_g = local_group_grads_params.back(); + group_p_g.reserve(group_size); + while (j < group_grads_params->size()) { + group_p_g.insert(group_p_g.end(), group_grads_params->at(j).begin(), + group_grads_params->at(j).end()); + ++j; + if (j % group_size == 0) break; + } + } + std::swap(*group_grads_params, local_group_grads_params); + + VLOG(10) << "SetGroupAccordingToGroupSize(group_size: " << group_size + << "): "; + for (size_t i = 0; i < group_grads_params->size(); ++i) { + VLOG(10) << "group " << i; + std::stringstream out; + for (auto &p_g : group_grads_params->at(i)) { + out << "(" << p_g.second << ", " << p_g.first << "), "; + } + VLOG(10) << out.str(); + } + } + + private: + bool IsSupportedVarType(const proto::VarType::Type &type) const { + // Current only support LOD_TENSOR. + return type == proto::VarType::LOD_TENSOR; + } + + void AppendAllocSpaceForVarsOp(const std::vector ¶ms_name, + const std::vector &grads_name, + const std::string &fused_var_name, + BlockDesc *global_block) const { + auto op_desc = global_block->AppendOp(); + op_desc->SetType("alloc_continuous_space"); + op_desc->SetInput("Input", params_name); + op_desc->SetOutput("Output", grads_name); + op_desc->SetOutput("FusedOutput", {fused_var_name}); + } + + void RecordParamsAndGrads(ir::Node *node, + ParamsAndGrads *params_grads) const { + try { + bool is_bk_op = + static_cast(boost::get(node->Op()->GetAttr( + OpProtoAndCheckerMaker::OpRoleAttrName())) & + static_cast(OpRole::kBackward)); + if (!is_bk_op) return; + + // Currently, we assume that once gradient is generated, it can be + // broadcast, and each gradient is only broadcast once. + auto backward_vars = + boost::get>(node->Op()->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName())); + PADDLE_ENFORCE_EQ(backward_vars.size() % 2, static_cast(0)); + + for (size_t i = 0; i < backward_vars.size(); i += 2) { + VLOG(10) << "Trainable parameter: " << backward_vars[i] + << ", gradient: " << backward_vars[i + 1]; + + params_grads->emplace_back(std::make_pair( + backward_vars[i] /*param*/, backward_vars[i + 1] /*grad*/)); + } + } catch (boost::bad_get e) { + } + } + + void InitFusedVarsAndAllocSpaceForVars( + const std::vector &places, + const std::vector &local_scopes, + const std::unordered_map &vars, + const std::string &fused_var_name, + const ParamsAndGrads ¶ms_grads) const { + // Init Gradients and FusedVars + VLOG(10) << "Init FusedVars and Gradients."; + for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) { + auto &scope = *it; + + PADDLE_ENFORCE(scope->FindVar(fused_var_name) == nullptr, + "%s has existed in scope.", fused_var_name); + scope->Var(fused_var_name)->GetMutable(); + + for (auto &p_g : params_grads) { + auto iter = vars.find(p_g.second); + PADDLE_ENFORCE(iter != vars.end()); + PADDLE_ENFORCE_NOT_NULL(iter->second->Var()); + PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(), + proto::VarType::LOD_TENSOR); + scope->Var(p_g.second)->GetMutable(); + } + } + + std::vector grads_name; + std::vector params_name; + grads_name.reserve(params_grads.size()); + params_name.reserve(params_grads.size()); + for (auto &p_g : params_grads) { + params_name.emplace_back(p_g.first); + grads_name.emplace_back(p_g.second); + } + framework::ProgramDesc program_desc; + AppendAllocSpaceForVarsOp(params_name, grads_name, fused_var_name, + program_desc.MutableBlock(0)); + + // Run Only Once Programs + for (size_t i = 0; i < local_scopes.size(); ++i) { + for (auto &op_desc : program_desc.Block(0).AllOps()) { + auto op = OpRegistry::CreateOp(*op_desc); + op->Run(*local_scopes[i], places[i]); + } + } + } +}; + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(alloc_continuous_space_for_grad_pass, + paddle::framework::details::AllocContinuousSpaceForGradPass) + .RequirePassAttr(paddle::framework::details::kPlaces) + .RequirePassAttr(paddle::framework::details::kLocalScopes); diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 932d0b4538e..4184353bcbd 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -46,7 +46,16 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { public: explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy) : ir::PassBuilder(), strategy_(strategy) { + // Add a graph viz pass to record a graph. + if (!strategy_.debug_graphviz_path_.empty()) { + auto viz_pass = AppendPass("graph_viz_pass"); + const std::string graph_path = string::Sprintf( + "%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph"); + viz_pass->Set("graph_viz_path", new std::string(graph_path)); + } + if (strategy_.enable_sequential_execution_) { + VLOG(10) << "Add sequential_execution_pass"; AppendPass("sequential_execution_pass"); } @@ -57,6 +66,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // Add op fusion. if (strategy.fuse_relu_depthwise_conv_) { + VLOG(10) << "Add fuse_relu_depthwise_conv_pass"; AppendPass("fuse_relu_depthwise_conv_pass"); } @@ -68,29 +78,30 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // Add automatically inplace. if (strategy_.enable_inplace_) { + VLOG(10) << "Add inplace_pass"; AppendPass("inplace_pass"); } + if (strategy.fuse_elewise_add_act_ops_) { + VLOG(10) << "Add fuse_elewise_add_act_pass"; + AppendPass("fuse_elewise_add_act_pass"); + } + + // for single card training, fuse_all_reduce_ops is unnecessary. + // alloc_continuous_space_for_grad_pass should be before of MultiDevPass. + if (strategy.fuse_all_reduce_ops_) { + VLOG(10) << "Add alloc_continuous_space_for_grad_pass"; + AppendPass("alloc_continuous_space_for_grad_pass"); + } + // Add a graph viz pass to record a graph. - if (!strategy_.debug_graphviz_path_.empty()) { + if (!strategy.debug_graphviz_path_.empty()) { auto viz_pass = AppendPass("graph_viz_pass"); const std::string graph_path = string::Sprintf( - "%s%s", strategy_.debug_graphviz_path_.c_str(), "_original_graph"); + "%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph"); viz_pass->Set("graph_viz_path", new std::string(graph_path)); } - if (strategy.fuse_elewise_add_act_ops_) { - auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass"); - // Add a graph viz pass to record a graph. - if (!strategy.debug_graphviz_path_.empty()) { - auto viz_pass = AppendPass("graph_viz_pass"); - const std::string graph_path = string::Sprintf( - "%s%s", strategy.debug_graphviz_path_.c_str(), "_fused_graph"); - viz_pass->Set("graph_viz_path", - new std::string(graph_path)); - } - } - CollectiveContext *context = CollectiveContext::GetInstance(); context->endpoints_ = strategy_.trainers_endpoints_; context->trainer_id_ = strategy_.trainer_id_; @@ -108,11 +119,19 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // A side-effect of that, memory optimize cannot forsee the fetched vars // , so fetchlist should be set persistable before call the Run interface. if (strategy.memory_optimize_) { - auto memory_optimize_pass = AppendPass("memory_optimize_pass"); + VLOG(10) << "Add memory_optimize_pass"; + AppendPass("memory_optimize_pass"); } AppendMultiDevPass(strategy); + if (strategy.fuse_all_reduce_ops_) { + // NOTE: fuse_all_reduce_ops will count the number of all_reduce operator + // first, if the number is zero, fuse_all_reduce_ops will do nothing. + VLOG(10) << "Add fuse_all_reduce_op_pass"; + AppendPass("fuse_all_reduce_op_pass"); + } + // Add a graph print pass to record a graph with device info. if (!strategy_.debug_graphviz_path_.empty()) { auto multi_devices_print_pass = AppendPass("multi_devices_print_pass"); @@ -129,27 +148,29 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPass("multi_devices_check_pass"); if (SeqOnlyAllReduceOps(strategy)) { + VLOG(10) << "Add all_reduce_deps_pass"; AppendPass("all_reduce_deps_pass"); } if (strategy_.remove_unnecessary_lock_) { + VLOG(10) << "Add modify_op_lock_and_record_event_pass"; AppendPass("modify_op_lock_and_record_event_pass"); } } // Convert graph to run on multi-devices. void AppendMultiDevPass(const BuildStrategy &strategy) { - ir::Pass *multi_devices_pass; + ir::Pass *multi_devices_pass = nullptr; if (strategy_.is_distribution_) { - VLOG(3) << "multi device parameter server mode"; + VLOG(10) << "Add dist_multi_devices_pass"; multi_devices_pass = AppendPass("dist_multi_devices_pass").get(); } else { if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { - VLOG(3) << "multi devices collective mode with allreduce"; + VLOG(10) << "Add all_reduce_mode_multi_devices_pass"; multi_devices_pass = - AppendPass("allreduce_mode_multi_devices_pass").get(); + AppendPass("all_reduce_mode_multi_devices_pass").get(); } else if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { - VLOG(3) << "multi deivces collective mode with reduce"; + VLOG(10) << "Add reduce_mode_multi_devices_pass"; multi_devices_pass = AppendPass("reduce_mode_multi_devices_pass").get(); } else { PADDLE_THROW("Unknown reduce strategy."); @@ -206,9 +227,26 @@ std::unique_ptr BuildStrategy::Apply( #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; - pass->Erase("nccl_ctxs"); - pass->SetNotOwned("nccl_ctxs", nctx); + pass->Erase(kNCCLCtxs); + pass->SetNotOwned(kNCCLCtxs, nctx); #endif + } else if (pass->Type() == "fuse_all_reduce_op_pass") { + pass->Erase(kPlaces); + pass->SetNotOwned>(kPlaces, &places); + pass->Erase(kLocalScopes); + pass->SetNotOwned>(kLocalScopes, + &local_scopes); +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + platform::NCCLContextMap *nctx = use_cuda ? nccl_ctxs : nullptr; + pass->Erase(kNCCLCtxs); + pass->SetNotOwned(kNCCLCtxs, nctx); +#endif + } else if (pass->Type() == "alloc_continuous_space_for_grad_pass") { + pass->Erase(kPlaces); + pass->SetNotOwned>(kPlaces, &places); + pass->Erase(kLocalScopes); + pass->SetNotOwned>(kLocalScopes, + &local_scopes); } else if (pass->Type() == "sequential_execution_pass") { LOG(INFO) << "set enable_sequential_execution:" << enable_sequential_execution_; @@ -239,7 +277,7 @@ USE_PASS(fuse_elewise_add_act_pass); USE_PASS(graph_viz_pass); USE_PASS(multi_batch_merge_pass); USE_PASS(reduce_mode_multi_devices_pass); -USE_PASS(allreduce_mode_multi_devices_pass); +USE_PASS(all_reduce_mode_multi_devices_pass); USE_PASS(dist_multi_devices_pass); USE_PASS(multi_devices_check_pass); USE_PASS(multi_devices_print_pass); @@ -249,4 +287,6 @@ USE_PASS(all_reduce_deps_pass); USE_PASS(modify_op_lock_and_record_event_pass); USE_PASS(inplace_pass); USE_PASS(lock_free_optimize_pass); +USE_PASS(alloc_continuous_space_for_grad_pass); USE_PASS(graph_to_program_pass); +USE_PASS(fuse_all_reduce_op_pass); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 122411641da..4b599fb914d 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -16,6 +16,7 @@ #include #include +#include #include #include "paddle/fluid/framework/ir/pass_builder.h" @@ -75,6 +76,8 @@ struct BuildStrategy { bool fuse_elewise_add_act_ops_{false}; + bool fuse_all_reduce_ops_{false}; + bool fuse_relu_depthwise_conv_{false}; bool sync_batch_norm_{false}; diff --git a/paddle/fluid/framework/details/fuse_all_reduce_op_pass.cc b/paddle/fluid/framework/details/fuse_all_reduce_op_pass.cc new file mode 100644 index 00000000000..f226491c9f5 --- /dev/null +++ b/paddle/fluid/framework/details/fuse_all_reduce_op_pass.cc @@ -0,0 +1,195 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/fluid/framework/details/all_reduce_op_handle.h" +#include "paddle/fluid/framework/details/container_cast.h" +#include "paddle/fluid/framework/details/fused_all_reduce_op_handle.h" +#include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/ir/graph_helper.h" + +namespace paddle { +namespace framework { +namespace details { + +class FuseAllReduceOpPass : public ir::Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override { + ir::Graph &result = *graph; + + auto &places = Get>(kPlaces); + auto &local_scopes = Get>(kLocalScopes); +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + auto *nccl_ctxs = &Get(kNCCLCtxs); +#endif + + std::unordered_set grads; + auto ¶ms_grads = result.Get(kParamsAndGrads); + size_t num_of_all_reduce = params_grads.size(); + grads.reserve(num_of_all_reduce); + for (auto p_g : params_grads) { + grads.insert(p_g.second); + } + + size_t num_place = places.size(); + std::unordered_map all_reduce_ops; + all_reduce_ops.reserve(grads.size()); + for (auto &node : result.Nodes()) { + if (node->IsOp()) { + PADDLE_ENFORCE(node->IsWrappedBy()); + auto *all_reduce_op_handle = + dynamic_cast(&node->Wrapper()); + if (all_reduce_op_handle) { + auto inputs = DynamicCast(all_reduce_op_handle->Inputs()); + PADDLE_ENFORCE_EQ(inputs.size(), num_place); + // The inputs' name should be the same. + auto &grad_name = inputs[0]->name(); + for (size_t i = 1; i < inputs.size(); ++i) { + PADDLE_ENFORCE_EQ(inputs[i]->name(), grad_name, + "The input name should be the same."); + } + PADDLE_ENFORCE_NE(grads.count(grad_name), static_cast(0)); + all_reduce_ops.emplace(grad_name, node); + } + } + } + + VLOG(10) << "Find all_reduce_ops: " << all_reduce_ops.size(); + if (all_reduce_ops.size() == 0) { + return std::move(graph); + } + + PADDLE_ENFORCE_EQ(all_reduce_ops.size(), grads.size(), + "The number of all_reduce OpHandle is not equal to the " + "number of grads. Maybe some gradients are sparse type, " + "it is not supported currently."); + VLOG(10) << "Insert fused_all_reduce"; + + auto &group_grads_params = + graph->Get(kGroupGradsAndParams); + + for (auto &group_g_p : group_grads_params) { + size_t group_size = group_g_p.size(); + PADDLE_ENFORCE_GT(group_size, static_cast(0)); + std::vector group_all_reduce_ops; + group_all_reduce_ops.reserve(group_size); + for (auto &g_p : group_g_p) { + group_all_reduce_ops.emplace_back(all_reduce_ops.at(g_p.first)); + } +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + InsertFusedAllReduce(places, local_scopes, group_size, + group_all_reduce_ops, nccl_ctxs, &result); +#else + InsertFusedAllReduce(places, local_scopes, group_size, + group_all_reduce_ops, &result); +#endif + } + return std::move(graph); + } + + void InsertFusedAllReduce(const std::vector &places, + const std::vector &local_scopes, + const size_t num_of_all_reduce, + const std::vector &all_reduce_ops, +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + const platform::NCCLContextMap *nccl_ctxs, +#endif + ir::Graph *result) const { + std::vector inputs; + std::vector outputs; + for (auto &op : all_reduce_ops) { + auto &op_handle = op->Wrapper(); + inputs.insert(inputs.end(), op_handle.Inputs().begin(), + op_handle.Inputs().end()); + // Remove output + for_each(op_handle.Inputs().begin(), op_handle.Inputs().end(), + [&op_handle](VarHandleBase *var_handle) { + var_handle->RemoveOutput(&op_handle, op_handle.Node()); + }); + + outputs.insert(outputs.end(), op_handle.Outputs().begin(), + op_handle.Outputs().end()); + // Remove Input + for_each( + op_handle.Outputs().begin(), op_handle.Outputs().end(), + [](VarHandleBase *var_handle) { var_handle->ClearGeneratedOp(); }); + + result->RemoveNode(op_handle.Node()); + } + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + CreateFusedAllReduceOp(inputs, outputs, num_of_all_reduce, places, + local_scopes, nccl_ctxs, result); +#else + CreateFusedAllReduceOp(inputs, outputs, num_of_all_reduce, places, + local_scopes, result); +#endif + } + + private: + void CreateFusedAllReduceOp(const std::vector &inputs, + const std::vector &outputs, + const size_t num_of_all_reduce, + const std::vector &places, + const std::vector &local_scopes, +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + const platform::NCCLContextMap *nccl_ctxs, +#endif + ir::Graph *result) const { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + auto *op_handle = new FusedAllReduceOpHandle( + result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation), + local_scopes, places, num_of_all_reduce, nccl_ctxs); +#else + auto *op_handle = new FusedAllReduceOpHandle( + result->CreateEmptyNode("fused_all_reduce", ir::Node::Type::kOperation), + local_scopes, places, num_of_all_reduce); +#endif + + for (auto in : inputs) { + op_handle->AddInput(in); + } + + for (auto out : outputs) { + op_handle->AddOutput(out); + } + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + if (!nccl_ctxs) { + SetCommunicationContext(places, op_handle); + } +#else + SetCommunicationContext(places, op_handle); +#endif + } + + void SetCommunicationContext(const std::vector &places, + FusedAllReduceOpHandle *op_handle) const { + for (size_t i = 0; i < places.size(); ++i) { + op_handle->SetDeviceContext( + places[i], platform::DeviceContextPool::Instance().Get(places[i])); + } + } +}; + +} // namespace details +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fuse_all_reduce_op_pass, + paddle::framework::details::FuseAllReduceOpPass); diff --git a/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc new file mode 100644 index 00000000000..b90655184bc --- /dev/null +++ b/paddle/fluid/framework/details/fused_all_reduce_op_handle.cc @@ -0,0 +1,248 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/framework/details/fused_all_reduce_op_handle.h" +#include +#include +#include "paddle/fluid/framework/details/container_cast.h" +#include "paddle/fluid/framework/details/reduce_and_gather.h" +#include "paddle/fluid/framework/details/variable_visitor.h" +#include "paddle/fluid/platform/profiler.h" + +DEFINE_bool(skip_fused_all_reduce_check, false, ""); +namespace paddle { +namespace framework { +namespace details { + +typedef std::vector>> + GradientAndLoDTensor; + +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +FusedAllReduceOpHandle::FusedAllReduceOpHandle( + ir::Node *node, const std::vector &local_scopes, + const std::vector &places, const size_t num_of_all_reduce, + const platform::NCCLContextMap *ctxs) + : OpHandleBase(node), + local_scopes_(local_scopes), + places_(places), + num_of_all_reduce_(num_of_all_reduce), + nccl_ctxs_(ctxs) { + if (nccl_ctxs_) { + for (auto &p : places_) { + this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p)); + } + } + PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); +} +#else + +FusedAllReduceOpHandle::FusedAllReduceOpHandle( + ir::Node *node, const std::vector &local_scopes, + const std::vector &places, const size_t num_of_all_reduce) + : OpHandleBase(node), + local_scopes_(local_scopes), + places_(places), + num_of_all_reduce_(num_of_all_reduce) { + PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); +} + +#endif + +void FusedAllReduceOpHandle::RunImpl() { + platform::RecordEvent record_event(Name()); + + VLOG(4) << this->DebugString(); + + WaitInputVarGenerated(); + // The input: grad0(dev0), grad0(dev1), grad1(dev0), grad1(dev1)... + // The output: grad0(dev0), grad0(dev1), grad1(dev0), grad1(dev1)... + auto in_var_handles = DynamicCast(this->Inputs()); + auto out_var_handles = DynamicCast(this->Outputs()); + + size_t place_num = places_.size(); + PADDLE_ENFORCE_EQ( + in_var_handles.size(), place_num * num_of_all_reduce_, + "The NoDummyInputSize should be equal to the number of places."); + PADDLE_ENFORCE_EQ( + in_var_handles.size(), out_var_handles.size(), + "The NoDummyInputSize and NoDummyOutputSize should be equal."); + + GradientAndLoDTensor grads_tensor; + grads_tensor.resize(place_num); + + int64_t numel = -1; + auto dtype = static_cast(0); + for (size_t scope_idx = 0; scope_idx < local_scopes_.size(); ++scope_idx) { + auto &g_tensor = grads_tensor.at(scope_idx); + g_tensor.reserve(num_of_all_reduce_); + + GetGradLoDTensor(scope_idx, in_var_handles, out_var_handles, &g_tensor); + + int64_t element_num = 0; + framework::proto::VarType::Type ele_dtype = + static_cast(0); + GetDTypeAndNumel(g_tensor, &ele_dtype, &element_num); + + if (numel == -1) { + numel = element_num; + } + if (dtype == static_cast(0)) { + dtype = ele_dtype; + PADDLE_ENFORCE_NE(ele_dtype, + static_cast(0)); + } + PADDLE_ENFORCE_EQ(ele_dtype, dtype); + + // Check whether the address space is contiguous. + std::sort( + g_tensor.begin(), g_tensor.end(), + [](const std::pair &grad1, + const std::pair &grad2) -> bool { + return grad1.second->data() < grad2.second->data(); + }); + + for (size_t k = 1; k < g_tensor.size(); ++k) { + const void *pre_address = g_tensor.at(k - 1).second->data(); + int64_t len = g_tensor.at(k - 1).second->numel(); + auto offset = len * framework::SizeOfType(dtype); + void *next_address = reinterpret_cast( + reinterpret_cast(pre_address) + offset); + const void *cur_address = g_tensor.at(k).second->data(); + VLOG(10) << k << ", " + << " pre_address(" << g_tensor.at(k - 1).first + << "): " << pre_address << ", cur_address(" + << g_tensor.at(k).first << "): " << cur_address + << ", offset:" << offset << ", " << next_address << ", " + << cur_address; + PADDLE_ENFORCE_EQ(next_address, cur_address); + } + } + + if (!FLAGS_skip_fused_all_reduce_check) { + for (size_t scope_idx = 0; scope_idx < place_num; ++scope_idx) { + for (size_t j = 1; j < num_of_all_reduce_; ++j) { + PADDLE_ENFORCE_EQ(grads_tensor.at(0).at(j).first, + grads_tensor.at(scope_idx).at(j).first); + } + } + } + + std::vector lod_tensor_data; + for (size_t scope_idx = 0; scope_idx < place_num; ++scope_idx) { + auto data = grads_tensor.at(scope_idx).at(0).second->data(); + lod_tensor_data.emplace_back(data); + } + + if (platform::is_gpu_place(places_[0])) { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr."); + int nccl_dtype = platform::ToNCCLDataType(dtype); + std::vector> all_reduce_calls; + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &p = places_[i]; + void *buffer = const_cast(lod_tensor_data.at(i)); + + int dev_id = boost::get(p).device; + auto &nccl_ctx = nccl_ctxs_->at(dev_id); + auto stream = nccl_ctx.stream(); + auto comm = nccl_ctx.comm_; + all_reduce_calls.emplace_back([=] { + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + buffer, buffer, numel, static_cast(nccl_dtype), + ncclSum, comm, stream)); + }); + } + + this->RunAndRecordEvent([&] { + if (all_reduce_calls.size() == 1UL) { + // Do not use NCCLGroup when manage NCCL by per thread per device + all_reduce_calls[0](); + } else { + platform::NCCLGroupGuard guard; + for (auto &call : all_reduce_calls) { + call(); + } + } + }); +#else + PADDLE_THROW("Not compiled with CUDA"); +#endif + } else { + // Special handle CPU only Operator's gradient. Like CRF + auto grad_name = grads_tensor.at(0).at(0).first; + auto &trg = *this->local_scopes_[0] + ->FindVar(kLocalExecScopeName) + ->Get() + ->FindVar(grad_name) + ->GetMutable(); + + // Reduce All data to trg in CPU + ReduceBufferData func(lod_tensor_data, trg.data(), numel); + VisitDataType(trg.type(), func); + + for (size_t i = 1; i < local_scopes_.size(); ++i) { + auto &scope = + *local_scopes_[i]->FindVar(kLocalExecScopeName)->Get(); + auto &p = places_[i]; + auto *var = scope.FindVar(grad_name); + auto *dev_ctx = dev_ctxes_.at(p); + size_t size = numel * SizeOfType(trg.type()); + RunAndRecordEvent(p, [&trg, var, dev_ctx, p, size] { + auto dst_ptr = var->GetMutable()->data(); + platform::CPUPlace cpu_place; + memory::Copy(cpu_place, dst_ptr, cpu_place, trg.data(), size); + }); + } + } +} + +void FusedAllReduceOpHandle::GetGradLoDTensor( + const size_t &scope_idx, const std::vector &in_var_handles, + const std::vector &out_var_handles, + std::vector> *grad_tensor) const { + auto *local_scope = + local_scopes_.at(scope_idx)->FindVar(kLocalExecScopeName)->Get(); + size_t place_num = places_.size(); + + for (size_t j = 0; j < in_var_handles.size(); j += place_num) { + auto var_name = in_var_handles[j]->name(); + PADDLE_ENFORCE_EQ(var_name, out_var_handles[j]->name()); + auto &lod_tensor = local_scope->FindVar(var_name)->Get(); + PADDLE_ENFORCE_EQ(lod_tensor.place(), places_.at(scope_idx)); + grad_tensor->emplace_back(std::make_pair(var_name, &lod_tensor)); + } +} + +void FusedAllReduceOpHandle::GetDTypeAndNumel( + const std::vector> &grad_tensor, + proto::VarType::Type *dtype, int64_t *numel) const { + *numel = 0; + for (size_t i = 0; i < grad_tensor.size(); ++i) { + // Get element number + int64_t len = grad_tensor.at(i).second->numel(); + PADDLE_ENFORCE_GT(len, 0); + *numel += len; + + // Get dtype + auto ele_type = grad_tensor.at(i).second->type(); + if (i == 0) { + *dtype = ele_type; + } + PADDLE_ENFORCE_EQ(ele_type, *dtype); + } +} + +std::string FusedAllReduceOpHandle::Name() const { return "fused_all_reduce"; } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fused_all_reduce_op_handle.h b/paddle/fluid/framework/details/fused_all_reduce_op_handle.h new file mode 100644 index 00000000000..79772c61f8c --- /dev/null +++ b/paddle/fluid/framework/details/fused_all_reduce_op_handle.h @@ -0,0 +1,76 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace framework { +namespace details { + +struct FusedAllReduceOpHandle : public OpHandleBase { +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + FusedAllReduceOpHandle(ir::Node *node, + const std::vector &local_scopes, + const std::vector &places, + const size_t num_of_all_reduce, + const platform::NCCLContextMap *ctxs); +#else + FusedAllReduceOpHandle(ir::Node *node, + const std::vector &local_scopes, + const std::vector &places, + const size_t num_of_all_reduce); +#endif + std::string Name() const override; + + // Delay and buffer nccl_all_reduce together can significantly increase + // performance. Disable this feature by returning false. + bool IsMultiDeviceTransfer() override { return true; }; + + protected: + void RunImpl() override; + + private: + std::vector local_scopes_; + std::vector places_; + size_t num_of_all_reduce_; +#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) + const platform::NCCLContextMap *nccl_ctxs_; +#endif + + // Check the dtype of the input + void GetDTypeAndNumel( + const std::vector> &g_tensor, + proto::VarType::Type *dtype, int64_t *total_num) const; + + // Get gradient's name and LoDTensor + void GetGradLoDTensor(const size_t &scope_idx, + const std::vector &in_var_handles, + const std::vector &out_var_handles, + std::vector> + *grad_tensor) const; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index ed2e0bc65ed..e3cd2340c97 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -11,18 +11,17 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include #include #include #include #include - #include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/data_balance_op_handle.h" #include "paddle/fluid/framework/details/fused_broadcast_op_handle.h" -#include "paddle/fluid/framework/details/multi_devices_graph_pass.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" @@ -134,21 +133,26 @@ void AddOutputToLeafOps(ir::Graph *graph) { } } // namespace +void MultiDevSSAGraphBuilderBase::CheckGraph(const ir::Graph &graph) const {} + void MultiDevSSAGraphBuilderBase::Init() const { all_vars_.clear(); loss_var_name_ = Get(kLossVarName); + VLOG(10) << "Init MultiDevSSAGraphBuilder, loss name: " << loss_var_name_; places_ = Get>(kPlaces); local_scopes_ = Get>(kLocalScopes); strategy_ = Get(kStrategy); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) - nccl_ctxs_ = &Get("nccl_ctxs"); + nccl_ctxs_ = &Get(kNCCLCtxs); #endif + PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); } std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( std::unique_ptr graph) const { Init(); + CheckGraph(*graph); std::vector sorted_ops = SortOperations(*graph); auto nodes = graph->ReleaseNodes(); @@ -199,7 +203,6 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( boost::get>(node->Op()->GetNullableAttr( OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); - for (size_t i = 0; i < backward_vars.size(); i += 2) { auto &p_name = backward_vars[i]; auto &g_name = backward_vars[i + 1]; @@ -226,6 +229,7 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); + result.Erase(kGraphOps); return graph; } @@ -258,6 +262,11 @@ void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp( } } +bool MultiDevSSAGraphBuilderBase::DealWithSpecialOp(ir::Graph *result, + ir::Node *node) const { + return false; +} + std::vector MultiDevSSAGraphBuilderBase::SortOperations( const ir::Graph &graph) const { return ir::TopologySortOperations(graph); @@ -508,20 +517,17 @@ VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result, } bool MultiDevSSAGraphBuilderBase::IsScaleLossOp(ir::Node *node) const { - return boost::get( + return !loss_var_name_.empty() && node->Op() && + boost::get( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == (static_cast(OpRole::kBackward) | - static_cast(OpRole::kLoss)) && - !loss_var_name_.empty(); // If loss_var is empty. This is test mode + static_cast(OpRole::kLoss)); } bool MultiDevSSAGraphBuilderBase::IsSparseGradient( const std::string &og) const { PADDLE_ENFORCE(all_vars_.count(og) != 0); - if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { - return true; - } - return false; + return all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS; } void AllReduceSSAGraphBuilder::InsertCollectiveOp( @@ -1007,7 +1013,7 @@ static int MultiDevSSAGraphBuilderRegister(const std::string &builder_mode) { REGISTER_MULTI_DEVICES_PASS(reduce_mode_multi_devices_pass, paddle::framework::details::ReduceSSAGraphBuilder); REGISTER_MULTI_DEVICES_PASS( - allreduce_mode_multi_devices_pass, + all_reduce_mode_multi_devices_pass, paddle::framework::details::AllReduceSSAGraphBuilder); REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass, paddle::framework::details::DistSSAGraphBuilder); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 9018243dd7b..0ee3a060629 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -34,12 +34,6 @@ namespace framework { class Scope; namespace details { -constexpr char kLossVarName[] = "loss_var_name"; -constexpr char kPlaces[] = "places"; -constexpr char kLocalScopes[] = "local_scopes"; -constexpr char kStrategy[] = "strategy"; -constexpr char kNRanks[] = "nranks"; - class MultiDevSSAGraphBuilderBase : public ir::Pass { protected: std::unique_ptr ApplyImpl( @@ -47,12 +41,14 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { virtual void Init() const; + virtual void CheckGraph(const ir::Graph &graph) const; + virtual std::vector SortOperations(const ir::Graph &graph) const; virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const = 0; - virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const = 0; + virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const; virtual void InsertPostprocessOps(ir::Graph *result) const = 0; @@ -113,10 +109,6 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const; - virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const { - return false; - } - virtual void InsertPostprocessOps(ir::Graph *result) const {} }; diff --git a/paddle/fluid/framework/details/multi_devices_helper.h b/paddle/fluid/framework/details/multi_devices_helper.h index 9afbb91005c..ab5e0990233 100644 --- a/paddle/fluid/framework/details/multi_devices_helper.h +++ b/paddle/fluid/framework/details/multi_devices_helper.h @@ -16,6 +16,9 @@ #include #include +#include +#include +#include #include #include "paddle/fluid/framework/details/op_handle_base.h" @@ -44,6 +47,26 @@ const char kGraphVars[] = "vars"; typedef std::unordered_set GraphDepVars; const char kGraphDepVars[] = "dep_vars"; +constexpr char kNCCLCtxs[] = "nccl_ctxs"; + +constexpr char kLossVarName[] = "loss_var_name"; +constexpr char kPlaces[] = "places"; +constexpr char kLocalScopes[] = "local_scopes"; +constexpr char kStrategy[] = "strategy"; +constexpr char kNRanks[] = "nranks"; + +typedef std::unordered_set FusedVars; +constexpr char kFusedVars[] = "fused_vars"; + +typedef std::vector> ParamsAndGrads; +constexpr char kParamsAndGrads[] = "params_grads"; + +typedef std::vector>> + GroupGradsAndParams; +constexpr char kGroupGradsAndParams[] = "group_grads_params"; + +constexpr char kFusedVarNamePrefix[] = "@FUSEDVAR@"; + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/reduce_and_gather.h b/paddle/fluid/framework/details/reduce_and_gather.h index 2e5256fbd49..0de8e436518 100644 --- a/paddle/fluid/framework/details/reduce_and_gather.h +++ b/paddle/fluid/framework/details/reduce_and_gather.h @@ -53,6 +53,31 @@ struct ReduceLoDTensor { } }; +struct ReduceBufferData { + const std::vector &src_data_; + void *dst_data_; + int64_t numel_; + + ReduceBufferData(const std::vector &src, void *dst, + int64_t numel) + : src_data_(src), dst_data_(dst), numel_(numel) {} + + template + void apply() const { + T *dst_data = reinterpret_cast(dst_data_); + for (size_t i = 0; i < src_data_.size(); ++i) { + auto srd_data = reinterpret_cast(src_data_[i]); + VLOG(10) << "dst: " << dst_data_ << ", " << srd_data; + if (srd_data == dst_data_) { + continue; + } + + std::transform(srd_data, srd_data + numel_, dst_data, dst_data, + [](T a, T b) -> T { return a + b; }); + } + } +}; + inline void GatherLocalSelectedRows( const std::vector &src_selecte_rows_, const std::vector &in_places, diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 691b437ab0c..a57083a1444 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1263,6 +1263,10 @@ All parameter, weight, gradient are variables in Paddle. "enable_inplace", [](const BuildStrategy &self) { return self.enable_inplace_; }, [](BuildStrategy &self, bool b) { self.enable_inplace_ = b; }) + .def_property( + "fuse_all_reduce_ops", + [](const BuildStrategy &self) { return self.fuse_all_reduce_ops_; }, + [](BuildStrategy &self, bool b) { self.fuse_all_reduce_ops_ = b; }) .def("_finalize_strategy_and_create_passes", [](BuildStrategy &self) -> std::shared_ptr { return self.CreatePassesFromStrategy(true); diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index cb9c75a14f5..3bc5cd44481 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -132,7 +132,8 @@ def __bootstrap__(): 'allocator_strategy', 'reader_queue_speed_test_mode', 'print_sub_graph_dir', 'pe_profile_fname', 'warpctc_dir', 'inner_op_parallelism', 'enable_parallel_graph', - 'multiple_of_cupti_buffer_size', 'enable_subgraph_optimize', + 'fuse_parameter_groups_size', 'multiple_of_cupti_buffer_size', + 'enable_subgraph_optimize', 'fuse_parameter_memory_size', 'tracer_profile_fname' ] if 'Darwin' not in sysstr: diff --git a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py index a94487e67dc..61fd9af1275 100644 --- a/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py +++ b/python/paddle/fluid/tests/unittests/parallel_executor_test_base.py @@ -43,6 +43,7 @@ class TestParallelExecutorBase(unittest.TestCase): use_ir_memory_optimize=True, enable_inplace=True, fuse_elewise_add_act_ops=False, + fuse_all_reduce_ops=False, fuse_relu_depthwise_conv=False, optimizer=fluid.optimizer.Adam, use_fast_executor=False, @@ -80,6 +81,7 @@ class TestParallelExecutorBase(unittest.TestCase): build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv build_strategy.memory_optimize = False if memory_opt else use_ir_memory_optimize + build_strategy.fuse_all_reduce_ops = fuse_all_reduce_ops # python memory optimization is conflict with inplace pass. # Use ir graph memory optimization after inplace pass is the correct way. build_strategy.enable_inplace = False if memory_opt else enable_inplace diff --git a/python/paddle/fluid/tests/unittests/test_fuse_all_reduce_pass.py b/python/paddle/fluid/tests/unittests/test_fuse_all_reduce_pass.py new file mode 100644 index 00000000000..ca8669bbc6f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fuse_all_reduce_pass.py @@ -0,0 +1,121 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parallel_executor_test_base import TestParallelExecutorBase +import paddle.fluid as fluid +import paddle.fluid.core as core +import numpy as np +import paddle +import paddle.dataset.mnist as mnist +import unittest +import os + + +def simple_fc_net(use_feed): + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + hidden = img + for _ in range(4): + hidden = fluid.layers.fc( + hidden, + size=200, + act='relu', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + return loss + + +def fc_with_batchnorm(use_feed): + img = fluid.layers.data(name='image', shape=[784], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + hidden = img + for _ in range(2): + hidden = fluid.layers.fc( + hidden, + size=200, + act='relu', + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=1.0))) + + hidden = fluid.layers.batch_norm(input=hidden) + + prediction = fluid.layers.fc(hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + loss = fluid.layers.mean(loss) + return loss + + +class TestMNIST(TestParallelExecutorBase): + @classmethod + def setUpClass(cls): + os.environ['CPU_NUM'] = str(4) + + def _init_data(self, random=True): + np.random.seed(5) + if random: + img = np.random.random(size=[32, 784]).astype(np.float32) + else: + img = np.ones(shape=[32, 784], dtype='float32') + label = np.ones(shape=[32, 1], dtype='int64') + return img, label + + def _compare_fuse_all_reduce_ops(self, model, use_cuda, random_data=True): + if use_cuda and not core.is_compiled_with_cuda(): + return + img, label = self._init_data(random_data) + + def _optimizer(learning_rate=1e-6): + optimizer = fluid.optimizer.SGD( + learning_rate=learning_rate, + regularization=fluid.regularizer.L2Decay(1e-6)) + return optimizer + + not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence( + model, + feed_dict={"image": img, + "label": label}, + use_cuda=use_cuda, + fuse_all_reduce_ops=False, + memory_opt=False, + optimizer=_optimizer) + fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence( + model, + feed_dict={"image": img, + "label": label}, + use_cuda=use_cuda, + fuse_all_reduce_ops=True, + memory_opt=False, + optimizer=_optimizer) + + for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss): + self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) + for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss): + self.assertAlmostEquals(loss[0], loss[1], delta=1e-6) + + def test_simple_fc_with_fuse_op(self): + self._compare_fuse_all_reduce_ops(simple_fc_net, True) + self._compare_fuse_all_reduce_ops(simple_fc_net, False) + + def test_batchnorm_fc_with_fuse_op(self): + self._compare_fuse_all_reduce_ops(fc_with_batchnorm, True) + self._compare_fuse_all_reduce_ops(fc_with_batchnorm, False) + + +if __name__ == '__main__': + unittest.main() -- GitLab