提交 27073c28 编写于 作者: C chengduoZH

nccl_all_reduce_op_handle => all_reduce_op_handle

上级 2d94697a
...@@ -12,16 +12,16 @@ cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder) ...@@ -12,16 +12,16 @@ cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_GPU) if(WITH_GPU)
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(nccl_all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor) dynload_cuda variable_visitor)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle) set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
else() else()
cc_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor) variable_visitor)
set(multi_devices_graph_builder_deps) set(multi_devices_graph_builder_deps all_reduce_op_handle)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
endif() endif()
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/details/variable_visitor.h"
...@@ -23,25 +23,23 @@ namespace framework { ...@@ -23,25 +23,23 @@ namespace framework {
namespace details { namespace details {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
NCCLAllReduceOpHandle::NCCLAllReduceOpHandle( AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs) const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) { : local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
if (ctxs) { if (nccl_ctxs_) {
for (auto &p : places_) { for (auto &p : places_) {
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p); this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
} }
} }
} }
#else #else
NCCLAllReduceOpHandle::NCCLAllReduceOpHandle( AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {} : local_scopes_(local_scopes), places_(places) {}
#endif #endif
void NCCLAllReduceOpHandle::RunImpl() { void AllReduceOpHandle::RunImpl() {
if (NoDummyInputSize() == 1) { if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1; return; // No need to all reduce when GPU count = 1;
} else { } else {
...@@ -133,7 +131,7 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -133,7 +131,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
} }
} }
std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; } std::string AllReduceOpHandle::Name() const { return "nccl_all_reduce"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -28,13 +28,13 @@ namespace paddle { ...@@ -28,13 +28,13 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct NCCLAllReduceOpHandle : public OpHandleBase { struct AllReduceOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes, AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs); const platform::NCCLContextMap *ctxs);
#else #else
NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes, AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
#endif #endif
std::string Name() const override; std::string Name() const override;
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
...@@ -283,6 +283,19 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient( ...@@ -283,6 +283,19 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(
return false; return false;
} }
void MultiDevSSAGraphBuilder::SetCommunicationContext(
OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
}
#else
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
}
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
const std::string &p_name, const std::string &p_name,
size_t src_dev_id) const { size_t src_dev_id) const {
...@@ -306,19 +319,6 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result, ...@@ -306,19 +319,6 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
} }
} }
void MultiDevSSAGraphBuilder::SetCommunicationContext(
OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
}
#else
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
}
void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
const OpDesc &op, const OpDesc &op,
int dev_id) const { int dev_id) const {
...@@ -331,9 +331,9 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( ...@@ -331,9 +331,9 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const { SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back( result->ops_.emplace_back(
new NCCLAllReduceOpHandle(local_scopes_, places_, nccl_ctxs_)); new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else #else
result->ops_.emplace_back(new NCCLAllReduceOpHandle(local_scopes_, places_)); result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_));
#endif #endif
auto *op_handle = result->ops_.back().get(); auto *op_handle = result->ops_.back().get();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册