From 961fbce8e2d36212f781425ee686c73ac0e19b17 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Mon, 11 Jun 2018 16:45:41 +0800 Subject: [PATCH] follow comments --- paddle/fluid/framework/details/CMakeLists.txt | 6 ++---- .../fluid/framework/details/multi_devices_graph_builder.cc | 6 +++--- .../fluid/framework/details/multi_devices_graph_builder.h | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 2258f2168cc..c9661f5ea9d 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -12,16 +12,14 @@ cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) if(WITH_GPU) - nv_library(nccl_all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory + nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda variable_visitor) - set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) else() cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory variable_visitor) - set(multi_devices_graph_builder_deps all_reduce_op_handle) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) endif() @@ -30,7 +28,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) + scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle) cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 35c5d0433fe..a6e31ea24b5 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -240,7 +240,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( CreateReduceOp(&result, g_name, 0); CreateBroadcastOp(&result, g_name, 0); } else { - InsertNCCLAllReduceOp(&result, g_name); + InsertAllReduceOp(&result, g_name); } break; } @@ -327,8 +327,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, CreateOpHandleIOs(result, op, dev_id); } -void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( - SSAGraph *result, const std::string &og) const { +void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, + const std::string &og) const { #ifdef PADDLE_WITH_CUDA result->ops_.emplace_back( new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index bcedc9b8b87..78581755fe4 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -100,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::vector> &var_name_on_devices, const OpDesc &op) const; - void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const; + void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, size_t src_dev_id) const; -- GitLab