提交 baea2cf1 编写于 作者: T typhoonzero

wip

上级 01c6618d
...@@ -5,6 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod ...@@ -5,6 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda) dynload_cuda)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -34,26 +35,46 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -34,26 +35,46 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs) platform::NCCLContextMap *nccl_ctxs, bool distributed)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes), local_scopes_(local_scopes),
distributed_(distributed),
nccl_ctxs_(nccl_ctxs) { nccl_ctxs_(nccl_ctxs) {
#else #else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes) const std::vector<Scope *> &local_scopes, bool distributed)
: loss_var_name_(loss_var_name), : loss_var_name_(loss_var_name),
places_(places), places_(places),
local_scopes_(local_scopes) { local_scopes_(local_scopes),
distributed_(distributed) {
#endif #endif
for (auto &p : params) { for (auto &p : params) {
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
} }
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
const platform::Place &p,
const size_t &i) const {
auto *op_handle = result->ops_.back().get();
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i);
}
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const { const ProgramDesc &program) const {
auto graph = new SSAGraph(); auto graph = new SSAGraph();
...@@ -72,6 +93,17 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -72,6 +93,17 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
} }
// append send op if program is distributed trainer main program.
// always use the first device
if (is_forwarding && distributed_ && op->Type() == "send") {
auto &p = places_[0];
auto *s = local_scopes_[0];
size_t i = 0;
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
CreateOpHandleIOs(&result, op, p, i);
continue;
}
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i]; auto &p = places_[i];
auto *s = local_scopes_[i]; auto *s = local_scopes_[i];
...@@ -81,18 +113,19 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -81,18 +113,19 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>( op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames(); CreateOpHandleIOs(&result, op, p, i);
// auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) { // for (auto &each_var_name : var_names) {
VarHandle *var = // VarHandle *var =
CreateOrGetLatestVarHandle(&result, each_var_name, p, i); // CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
op_handle->AddInput(var); // op_handle->AddInput(var);
} // }
var_names = op->OutputArgumentNames(); auto var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) { // for (auto &each_var_name : var_names) {
CreateOpOutput(&result, op_handle, each_var_name, p, i); // CreateOpOutput(&result, op_handle, each_var_name, p, i);
} // }
if (is_forwarding) { if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name_) { if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#pragma once #pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph_builder.h" #include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle { namespace paddle {
...@@ -31,21 +34,28 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -31,21 +34,28 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs); platform::NCCLContextMap *nccl_ctxs,
bool distributed = false);
#else #else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places, MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name,
const std::unordered_set<std::string> &params, const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes); const std::vector<Scope *> &local_scopes,
bool distributed = false);
#endif #endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override; std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
private:
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
const size_t &i) const;
private: private:
std::string loss_var_name_; std::string loss_var_name_;
const std::vector<platform::Place> &places_; const std::vector<platform::Place> &places_;
const std::vector<Scope *> &local_scopes_; const std::vector<Scope *> &local_scopes_;
std::unordered_set<std::string> grad_names_; std::unordered_set<std::string> grad_names_;
bool distributed_;
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_; platform::NCCLContextMap *nccl_ctxs_;
......
...@@ -18,61 +18,24 @@ namespace paddle { ...@@ -18,61 +18,24 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
SendOpHandle::SendOpHandle(const std::vector<Scope *> &local_scopes, SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
const std::vector<platform::Place> &places, const Scope *local_scope,
const platform::NCCLContextMap &ctxs) const platform::Place &place)
: local_scopes_(local_scopes), places_(places) {} : op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
place_(place) {}
void SendOpHandle::RunImpl() { void SendOpHandle::RunImpl() {
if (inputs_.size() == 1) { // Wait input done
return; // No need to all reduce when GPU count = 1; for (auto *in : inputs_) {
} else { auto &p = static_cast<VarHandle *>(in)->place_;
// Wait input done in->generated_op_->Wait(dev_ctxes_[p]);
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
in->generated_op_->Wait(dev_ctxes_[p]);
}
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
int dtype = -1;
size_t numel = 0;
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = places_[i];
auto *s = local_scopes_[i];
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &lod_tensor = s->FindVar(var_name)->Get<LoDTensor>();
void *buffer = const_cast<void *>(lod_tensor.data<void>());
if (dtype == -1) {
dtype = platform::ToNCCLDataType(lod_tensor.type());
}
if (numel == 0) {
numel = static_cast<size_t>(lod_tensor.numel());
}
auto &nccl_ctx = nccl_ctxs_.at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream));
});
}
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
} }
op_->Run(*local_scope_, place_);
} }
std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; } std::string SendOpHandle::Name() const { return "send"; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
...@@ -27,19 +29,18 @@ namespace framework { ...@@ -27,19 +29,18 @@ namespace framework {
namespace details { namespace details {
struct SendOpHandle : public OpHandleBase { struct SendOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_; std::unique_ptr<OperatorBase> op_;
const std::vector<platform::Place> &places_; const Scope* local_scope_;
const platform::NCCLContextMap &nccl_ctxs_; const platform::Place& place_;
SendOpHandle(const std::vector<Scope *> &local_scopes, SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const std::vector<platform::Place> &places, const platform::Place& place);
const platform::NCCLContextMap &ctxs);
std::string Name() const override; std::string Name() const override;
// Delay and buffer nccl_all_reduce together can significantly increase // Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false. // performance. Disable this feature by returning false.
bool IsMultiDeviceTransfer() override { return true; }; bool IsMultiDeviceTransfer() override { return false; };
protected: protected:
void RunImpl() override; void RunImpl() override;
......
...@@ -951,6 +951,13 @@ class Block(object): ...@@ -951,6 +951,13 @@ class Block(object):
if var.type == core.VarDesc.VarType.STEP_SCOPES: if var.type == core.VarDesc.VarType.STEP_SCOPES:
ret_var = self.create_var( ret_var = self.create_var(
name=var.name, persistable=var.persistable, type=var.type) name=var.name, persistable=var.persistable, type=var.type)
elif var.type == core.VarDesc.VarType.SELECTED_ROWS:
ret_var = self.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=True)
else: else:
ret_var = self.create_var( ret_var = self.create_var(
name=var.name, name=var.name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册