提交 2eeaa8d5 编写于 作者: X Xin Pan

Graph in ParallelExecutor Builder

上级 7781297c
......@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"
......@@ -66,11 +67,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
}
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
const OpDesc &op,
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(Graph *result, const OpDesc &op,
size_t place_id) const {
auto p = places_[place_id];
auto *op_handle = result->ops_.back().get();
auto *op_handle =
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
......@@ -169,18 +170,21 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
const ProgramDesc &program) const {
std::unique_ptr<Graph> graph(new Graph);
for (auto *var : program.Block(0).AllVars()) {
all_vars_.emplace(var->Name(), var);
}
auto graph = new SSAGraph();
SSAGraph &result = *graph;
Graph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast;
// We cannot invoke resize. It is a bug of GCC 4.8
result.vars_ = std::vector<
result.attrs["vars"] = new std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size());
result.attrs["dep_vars"] =
new std::unordered_set<std::unique_ptr<VarHandleBase>>();
result.attrs["ops"] = new std::vector<std::unique_ptr<OpHandleBase>>();
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
......@@ -303,7 +307,15 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
*/
AddOutputToLeafOps(&result);
return std::unique_ptr<SSAGraph>(graph);
std::unique_ptr<SSAGraph> ssa_graph(new SSAGraph);
ssa_graph->vars_ =
std::move(*boost::any_cast<GraphVars *>(graph->attrs["vars"]));
ssa_graph->ops_ =
std::move(*boost::any_cast<GraphOps *>(graph->attrs["ops"]));
ssa_graph->dep_vars_ =
std::move(*boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"]));
return std::move(ssa_graph);
}
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
......@@ -327,7 +339,7 @@ void MultiDevSSAGraphBuilder::SetCommunicationContext(
#endif
}
void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
void MultiDevSSAGraphBuilder::CreateBroadcastOp(Graph *result,
const std::string &p_name,
size_t src_dev_id) const {
#ifdef PADDLE_WITH_CUDA
......@@ -336,42 +348,50 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(SSAGraph *result,
auto *op_handle = new BroadcastOpHandle(local_scopes_, places_);
#endif
result->ops_.emplace_back(op_handle);
auto *in = result->vars_.at(src_dev_id).at(p_name).back().get();
boost::any_cast<GraphOps *>(result->attrs["ops"])->emplace_back(op_handle);
auto *in = boost::any_cast<GraphVars *>(result->attrs["vars"])
->at(src_dev_id)
.at(p_name)
.back()
.get();
op_handle->AddInput(in);
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_.at(i).at(p_name);
auto &vars =
boost::any_cast<GraphVars *>(result->attrs["vars"])->at(i).at(p_name);
auto *out_var = new VarHandle(vars.size(), i, p_name, p);
vars.emplace_back(out_var);
op_handle->AddOutput(out_var);
}
}
void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
void MultiDevSSAGraphBuilder::CreateComputationalOp(Graph *result,
const OpDesc &op,
int dev_id) const {
result->ops_.emplace_back(
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id]));
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(
new ComputationOpHandle(op, local_scopes_[dev_id], places_[dev_id]));
CreateOpHandleIOs(result, op, dev_id);
}
void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
void MultiDevSSAGraphBuilder::InsertAllReduceOp(Graph *result,
const std::string &og) const {
#ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back(
new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new AllReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else
result->ops_.emplace_back(new AllReduceOpHandle(local_scopes_, places_));
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new AllReduceOpHandle(local_scopes_, places_));
#endif
auto *op_handle = result->ops_.back().get();
auto *op_handle =
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og];
auto &vars = (*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
......@@ -383,19 +403,23 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
}
void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
SSAGraph *result, const std::vector<std::string> &datas) const {
Graph *result, const std::vector<std::string> &datas) const {
#ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back(
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_));
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(
new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else
result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_));
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new DataBalanceOpHandle(local_scopes_, places_));
#endif
auto *op_handle = result->ops_.back().get();
auto *op_handle =
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
for (const std::string &d_name : datas) {
auto &vars = result->vars_[i][d_name];
auto &vars =
(*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][d_name];
PADDLE_ENFORCE(!vars.empty());
op_handle->AddInput(vars.back().get());
auto var = new VarHandle(vars.size(), i, d_name, p);
......@@ -441,7 +465,7 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const std::string &varname) const {
return got == var_name_on_devices_.end() ? -1 : got->second;
}
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(Graph *result) const {
for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA
......@@ -456,7 +480,7 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
auto *op_handle =
new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i],
places_[i], communication_dev_ctx);
result->ops_.emplace_back(op_handle);
boost::any_cast<GraphOps *>(result->attrs["ops"])->emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
......@@ -469,37 +493,41 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const {
}
}
void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
void MultiDevSSAGraphBuilder::CreateComputationalOps(Graph *result,
const OpDesc &op,
size_t num_places) const {
for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) {
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->ops_.emplace_back(new ComputationOpHandle(op, s, p));
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new ComputationOpHandle(op, s, p));
CreateOpHandleIOs(result, op, scope_idx);
}
}
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(Graph *result,
const std::string &og,
int dst_dev_id) const {
#ifdef PADDLE_WITH_CUDA
result->ops_.emplace_back(
new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new ReduceOpHandle(local_scopes_, places_, nccl_ctxs_));
#else
result->ops_.emplace_back(new ReduceOpHandle(local_scopes_, places_));
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new ReduceOpHandle(local_scopes_, places_));
#endif
auto *op_handle = result->ops_.back().get();
auto *op_handle =
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
auto &vars = result->vars_[i][og];
auto &vars = (*boost::any_cast<GraphVars *>(result->attrs["vars"]))[i][og];
PADDLE_ENFORCE(!vars.empty());
auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get());
}
auto &vars = result->vars_[dst_dev_id][og];
auto &vars =
(*boost::any_cast<GraphVars *>(result->attrs["vars"]))[dst_dev_id][og];
auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
vars.emplace_back(var);
op_handle->AddOutput(var);
......@@ -508,19 +536,20 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
// Find the first occurence of `prev_op_name` and make current `op` depend
// on it.
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
void MultiDevSSAGraphBuilder::ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const {
for (auto &prev_op : result->ops_) {
for (auto &prev_op : (*boost::any_cast<GraphOps *>(result->attrs["ops"]))) {
if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle();
prev_op->AddOutput(dep_var);
result->dep_vars_.emplace(dep_var);
boost::any_cast<GraphDepVars *>(result->attrs["dep_vars"])
->emplace(dep_var);
op->AddInput(dep_var);
}
}
}
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
void MultiDevSSAGraphBuilder::CreateDistTrainOp(Graph *result,
const OpDesc &op) const {
int op_dev_id = -1;
if (op.Type() == "split_byref" || op.Type() == "split_selected_rows") {
......@@ -550,12 +579,14 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
CreateComputationalOp(result, op, op_dev_id);
if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
ConnectOp(result,
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
"fetch_barrier");
}
}
// Create RPC related op handles that connects its in ops and out ops.
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
void MultiDevSSAGraphBuilder::CreateRPCOp(Graph *result,
const OpDesc &op) const {
int op_dev_id = -1;
if (op.Type() == "send") {
......@@ -584,15 +615,22 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
op.Type());
result->ops_.emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id],
op.Type(), places_[op_dev_id]));
boost::any_cast<GraphOps *>(result->attrs["ops"])
->emplace_back(new RPCOpHandle(op, local_scopes_[op_dev_id], op.Type(),
places_[op_dev_id]));
if (op.Type() == "send_barrier") {
ConnectOp(result, result->ops_.back().get(), "send");
ConnectOp(result,
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
"send");
} else if (op.Type() == "recv") {
ConnectOp(result, result->ops_.back().get(), "send_barrier");
ConnectOp(result,
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
"send_barrier");
} else if (op.Type() == "fetch_barrier") {
ConnectOp(result, result->ops_.back().get(), "recv");
ConnectOp(result,
boost::any_cast<GraphOps *>(result->attrs["ops"])->back().get(),
"recv");
} else if (op.Type() == "send") {
// do nothing
} else {
......
......@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace platform {
......@@ -50,7 +51,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int GetVarDeviceID(const std::string &varname) const override;
private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
void CreateOpHandleIOs(Graph *result, const OpDesc &op,
size_t device_id) const;
private:
......@@ -65,8 +66,8 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const;
void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
void CreateRPCOp(Graph *result, const OpDesc &op) const;
void CreateDistTrainOp(Graph *result, const OpDesc &op) const;
/**
* Is this operator as the end-point operator before/after send operator.
......@@ -81,17 +82,16 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const;
void ConnectOp(SSAGraph *result, OpHandleBase *op,
void ConnectOp(Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
void CreateComputationalOps(Graph *result, const OpDesc &op,
size_t num_places) const;
void CreateScaleLossGradOp(SSAGraph *result) const;
VarHandle *CreateReduceOp(SSAGraph *result, const std::string &og,
void CreateScaleLossGradOp(Graph *result) const;
VarHandle *CreateReduceOp(Graph *result, const std::string &og,
int dst_dev_id) const;
void CreateComputationalOp(SSAGraph *result, const OpDesc &op,
int dev_id) const;
void CreateComputationalOp(Graph *result, const OpDesc &op, int dev_id) const;
bool IsParameterGradientOnce(
const std::string &og,
......@@ -99,12 +99,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
int GetOpDeviceID(const OpDesc &op) const;
void InsertAllReduceOp(SSAGraph *result, const std::string &og) const;
void InsertAllReduceOp(Graph *result, const std::string &og) const;
void InsertDataBalanceOp(SSAGraph *result,
void InsertDataBalanceOp(Graph *result,
const std::vector<std::string> &datas) const;
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
void CreateBroadcastOp(Graph *result, const std::string &p_name,
size_t src_dev_id) const;
bool IsSparseGradient(const std::string &og) const;
......
......@@ -17,8 +17,8 @@
namespace paddle {
namespace framework {
namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
for (auto &var_map : graph->vars_) {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) {
for (auto &var_map : *boost::any_cast<GraphVars *>(graph->attrs["vars"])) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
continue;
......@@ -40,7 +40,8 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
auto *dep_var = new DummyVarHandle();
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->dep_vars_.emplace(dep_var);
boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"])
->emplace(dep_var);
}
}
}
......@@ -48,9 +49,10 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
}
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
SSAGraph *graph, const std::string &each_var_name,
Graph *graph, const std::string &each_var_name,
const platform::Place &place, size_t place_offset) {
auto &var_holders = graph->vars_[place_offset];
auto &var_holders =
(*boost::any_cast<GraphVars *>(graph->attrs["vars"]))[place_offset];
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
......@@ -62,24 +64,29 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
return var;
}
void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
void SSAGraphBuilder::CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place,
size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name];
auto &vars =
(*boost::any_cast<GraphVars *>(graph->attrs["vars"]))[place_offset]
[each_var_name];
size_t version = vars.size();
auto var = new VarHandle(version, place_offset, each_var_name, place);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
for (auto &op : graph->ops_) {
void SSAGraphBuilder::AddOutputToLeafOps(Graph *graph) {
GraphOps &all_ops = *boost::any_cast<GraphOps *>(graph->attrs["ops"]);
for (auto &op : all_ops) {
if (!op->Outputs().empty()) {
continue;
}
auto *dummy_leaf = new DummyVarHandle();
graph->dep_vars_.emplace(dummy_leaf);
boost::any_cast<GraphDepVars *>(graph->attrs["dep_vars"])
->emplace(dummy_leaf);
op->AddOutput(dummy_leaf);
}
}
......
......@@ -16,15 +16,24 @@
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/ssa_graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace details {
typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars;
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
class SSAGraphBuilder {
public:
SSAGraphBuilder() {}
......@@ -42,20 +51,20 @@ class SSAGraphBuilder {
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static void PolishGraphToSupportDataHazards(SSAGraph *graph);
static void PolishGraphToSupportDataHazards(Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph,
static VarHandle *CreateOrGetLatestVarHandle(Graph *graph,
const std::string &each_var_name,
const platform::Place &place,
size_t place_offset);
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
static void CreateOpOutput(Graph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place, size_t place_offset);
static void AddOutputToLeafOps(SSAGraph *graph);
static void AddOutputToLeafOps(Graph *graph);
};
} // namespace details
} // namespace framework
......
......@@ -27,7 +27,7 @@ namespace framework {
class Graph {
public:
std::map<std::string, std::vector<boost::any>> attrs;
std::map<std::string, boost::any> attrs;
std::vector<Node *> inputs;
std::vector<Node *> outputs;
......
......@@ -14,6 +14,27 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {} // namespace framework
namespace framework {
class Pass {
public:
Pass() = default;
virtual ~Pass() {}
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) {
return std::move(graph);
}
};
std::unique_ptr<Graph> ProgramToGraph(const ProgramDesc& program) {
std::unique_ptr<Graph> g(new Graph);
return std::move(g);
}
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册