提交 27073c28 编写于 作者: C chengduoZH

nccl_all_reduce_op_handle => all_reduce_op_handle

上级 2d94697a
......@@ -12,16 +12,16 @@ cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_GPU)
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
nv_library(nccl_all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
else()
cc_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
set(multi_devices_graph_builder_deps)
set(multi_devices_graph_builder_deps all_reduce_op_handle)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
endif()
......
......@@ -13,8 +13,8 @@
// limitations under the License.
#include <algorithm>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
......@@ -23,25 +23,23 @@ namespace framework {
namespace details {
#ifdef PADDLE_WITH_CUDA
NCCLAllReduceOpHandle::NCCLAllReduceOpHandle(
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
if (ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
}
}
}
#else
NCCLAllReduceOpHandle::NCCLAllReduceOpHandle(
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
AllReduceOpHandle::AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
: local_scopes_(local_scopes), places_(places) {}
#endif
void NCCLAllReduceOpHandle::RunImpl() {
void AllReduceOpHandle::RunImpl() {
if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1;
} else {
......@@ -133,7 +131,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
}
}
std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; }
std::string AllReduceOpHandle::Name() const { return "nccl_all_reduce"; }
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -28,14 +28,14 @@ namespace paddle {
namespace framework {
namespace details {
struct NCCLAllReduceOpHandle : public OpHandleBase {
struct AllReduceOpHandle : public OpHandleBase {
#ifdef PADDLE_WITH_CUDA
NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *ctxs);
#else
NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
AllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
#endif
std::string Name() const override;
......
......@@ -17,10 +17,10 @@
#include <utility>
#include <vector>
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
......@@ -283,6 +283,19 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(
return false;
}
void MultiDevSSAGraphBuilder::SetCommunicationContext(
OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
}
#else
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
}
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
const std::string &p_name,
size_t src_dev_id) const {
......@@ -306,19 +319,6 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
}
}
void MultiDevSSAGraphBuilder::SetCommunicationContext(
OpHandleBase *op_handle, const platform::Place &p) const {
#ifdef PADDLE_WITH_CUDA
if (nccl_ctxs_ == nullptr) {
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
}
#else
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
#endif
}
void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
const OpDesc &op,
int dev_id) const {
......@@ -331,9 +331,9 @@ void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back(
new NCCLAllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else
result->ops_.emplace_back(new NCCLAllReduceOpHandle(local_scopes_, places_));
result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_));
#endif
auto *op_handle = result->ops_.back().get();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册