diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 2258f2168cc6c7a40df5dcd8a23626d7f89f32c5..c9661f5ea9d28b1e0f908434d70bba68b2955d3f 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -12,16 +12,14 @@ cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) if(WITH_GPU) - nv_library(nccl_all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory + nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda variable_visitor) - set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) else() cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory variable_visitor) - set(multi_devices_graph_builder_deps all_reduce_op_handle) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) endif() @@ -30,7 +28,7 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) + scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle) cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 35c5d0433fe231630c3f919c2b896f1e27729701..a6e31ea24b5a0cfaaa9420a09f78eaef9f69b684 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -240,7 +240,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( CreateReduceOp(&result, g_name, 0); CreateBroadcastOp(&result, g_name, 0); } else { - InsertNCCLAllReduceOp(&result, g_name); + InsertAllReduceOp(&result, g_name); } break; } @@ -327,8 +327,8 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, CreateOpHandleIOs(result, op, dev_id); } -void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( - SSAGraph *result, const std::string &og) const { +void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, + const std::string &og) const { #ifdef PADDLE_WITH_CUDA result->ops_.emplace_back( new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index bcedc9b8b87578789b86c45b758484a2ab7a9fa8..78581755fe4890800636944d6cd89875a852cc19 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -100,7 +100,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { const std::vector> &var_name_on_devices, const OpDesc &op) const; - void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const; + void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, size_t src_dev_id) const;