提交 baea2cf1 编写于 作者: T typhoonzero

wip

上级 01c6618d
......@@ -5,6 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA
......@@ -34,26 +35,46 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs)
platform::NCCLContextMap *nccl_ctxs, bool distributed)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes),
distributed_(distributed),
nccl_ctxs_(nccl_ctxs) {
#else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes)
const std::vector<Scope *> &local_scopes, bool distributed)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes) {
local_scopes_(local_scopes),
distributed_(distributed) {
#endif
for (auto &p : params) {
grad_names_.insert(GradVarName(p));
}
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
const platform::Place &p,
const size_t &i) const {
auto *op_handle = result->ops_.back().get();
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i);
}
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
auto graph = new SSAGraph();
......@@ -72,6 +93,17 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
}
// append send op if program is distributed trainer main program.
// always use the first device
if (is_forwarding && distributed_ && op->Type() == "send") {
auto &p = places_[0];
auto *s = local_scopes_[0];
size_t i = 0;
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
CreateOpHandleIOs(&result, op, p, i);
continue;
}
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
auto *s = local_scopes_[i];
......@@ -81,18 +113,19 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames();
CreateOpHandleIOs(&result, op, p, i);
// auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var =
CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
// for (auto &each_var_name : var_names) {
// VarHandle *var =
// CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
// op_handle->AddInput(var);
// }
auto var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(&result, op_handle, each_var_name, p, i);
}
// for (auto &each_var_name : var_names) {
// CreateOpOutput(&result, op_handle, each_var_name, p, i);
// }
if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
......
......@@ -14,6 +14,9 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
......@@ -31,21 +34,28 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs);
platform::NCCLContextMap *nccl_ctxs,
bool distributed = false);
#else
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes);
const std::vector<Scope *> &local_scopes,
bool distributed = false);
#endif
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
private:
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
const size_t &i) const;
private:
std::string loss_var_name_;
const std::vector<platform::Place> &places_;
const std::vector<Scope *> &local_scopes_;
std::unordered_set<std::string> grad_names_;
bool distributed_;
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_;
......
......@@ -18,61 +18,24 @@ namespace paddle {
namespace framework {
namespace details {
SendOpHandle::SendOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap &ctxs)
: local_scopes_(local_scopes), places_(places) {}
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope,
const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
place_(place) {}
void SendOpHandle::RunImpl() {
if (inputs_.size() == 1) {
return; // No need to all reduce when GPU count = 1;
} else {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
in->generated_op_->Wait(dev_ctxes_[p]);
}
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
int dtype = -1;
size_t numel = 0;
std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &p = places_[i];
auto *s = local_scopes_[i];
int dev_id = boost::get<platform::CUDAPlace>(p).device;
auto &lod_tensor = s->FindVar(var_name)->Get<LoDTensor>();
void *buffer = const_cast<void *>(lod_tensor.data<void>());
if (dtype == -1) {
dtype = platform::ToNCCLDataType(lod_tensor.type());
}
if (numel == 0) {
numel = static_cast<size_t>(lod_tensor.numel());
}
auto &nccl_ctx = nccl_ctxs_.at(dev_id);
auto stream = nccl_ctx.stream();
auto comm = nccl_ctx.comm_;
all_reduce_calls.emplace_back([=] {
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
comm, stream));
});
}
platform::NCCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
}
op_->Run(*local_scope_, place_);
}
std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; }
std::string SendOpHandle::Name() const { return "send"; }
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -19,6 +19,8 @@
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/nccl_helper.h"
......@@ -27,19 +29,18 @@ namespace framework {
namespace details {
struct SendOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
const platform::NCCLContextMap &nccl_ctxs_;
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
SendOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap &ctxs);
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place);
std::string Name() const override;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool IsMultiDeviceTransfer() override { return true; };
bool IsMultiDeviceTransfer() override { return false; };
protected:
void RunImpl() override;
......
......@@ -951,6 +951,13 @@ class Block(object):
if var.type == core.VarDesc.VarType.STEP_SCOPES:
ret_var = self.create_var(
name=var.name, persistable=var.persistable, type=var.type)
elif var.type == core.VarDesc.VarType.SELECTED_ROWS:
ret_var = self.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
persistable=True)
else:
ret_var = self.create_var(
name=var.name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册