提交 93355cc0 编写于 作者: X Xin Pan

fix control deps

上级 f6d99d1f
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto)
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
......
......@@ -94,12 +94,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
}
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
const std::vector<ir::Node *> &nodes) const {
std::vector<std::string> send_vars;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue;
OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
......@@ -114,10 +113,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
}
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const {
const std::vector<ir::Node *> &nodes) const {
std::vector<std::string> recv_vars;
for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue;
OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
......@@ -214,6 +212,19 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
}
}
// Verify that no operations before optimize ops depends on optimize ops.
std::unordered_set<ir::Node *> optimize_set(optimize_ops.begin(),
optimize_ops.end());
for (size_t i = 0; i < last_backward; ++i) {
for (ir::Node *in : sorted_ret[i]->inputs) {
for (ir::Node *pre_n : in->inputs) {
PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(),
"optimize operations cannot be depended by forward "
"or backward node %s -> %s",
pre_n->Name(), sorted_ret[i]->Name());
}
}
}
sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(),
optimize_ops.end());
return sorted_ret;
......@@ -221,18 +232,16 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std::unique_ptr<ir::Graph> graph) const {
// Rebuild the graph structure.
// Give the topology sort order and rebuild the graph structure.
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
auto nodes = std::move(graph->nodes);
graph->nodes.clear();
auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph;
for (auto &node : nodes) {
if (node->NodeType() == ir::Node::Type::kVariable) {
all_vars_.emplace(node->Name(), node->Var());
}
}
ir::Graph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast;
// We cannot invoke resize. It is a bug of GCC 4.8
......@@ -242,8 +251,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
// find send/recv vars so that we can place the distributed training
// realted op in the place 0
auto send_vars = FindDistTrainSendVars(nodes);
auto recv_vars = FindDistTrainRecvVars(nodes);
auto send_vars = FindDistTrainSendVars(sorted_ops);
auto recv_vars = FindDistTrainRecvVars(sorted_ops);
std::vector<std::unordered_set<std::string>> bcast_var_name_set;
bcast_var_name_set.resize(places_.size());
......@@ -589,8 +598,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const {
for (auto &prev_op : result->Get<GraphOps>("ops")) {
if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle(
result->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
auto *dep_var = new DummyVarHandle(result->CreateControlDepVar());
prev_op->AddOutput(dep_var);
result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
op->AddInput(dep_var);
......
......@@ -76,10 +76,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
const std::vector<ir::Node *> &nodes) const;
std::vector<std::string> FindDistTrainRecvVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const;
const std::vector<ir::Node *> &nodes) const;
void ConnectOp(ir::Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const;
......
......@@ -33,7 +33,8 @@ void RPCOpHandle::RunImpl() {
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString()
if (in->DebugString() == "dummy") { // HACK
if (in->Node()->Name().find(ir::Node::kControlDepVarName) !=
std::string::npos) { // HACK
continue;
}
if (in->GeneratedOp()) {
......
......@@ -17,6 +17,36 @@
namespace paddle {
namespace framework {
namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
continue;
}
auto it_new = name_pair.second.rbegin();
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
OpHandleBase *write_op = (*it_new)->GeneratedOp();
const auto &read_ops = (*it_old)->PendingOps();
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
if (read_op == write_op) {
// Read Write is the same op.
continue;
}
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
}
}
}
}
}
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
ir::Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) {
......@@ -56,8 +86,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
if (!op->Outputs().empty()) {
continue;
}
auto *dummy_leaf = new DummyVarHandle(
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
op->AddOutput(dummy_leaf);
}
......
......@@ -57,6 +57,15 @@ class SSAGraphBuilder : public ir::Pass {
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
protected:
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static void PolishGraphToSupportDataHazards(ir::Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
const platform::Place &place,
size_t place_offset);
......
......@@ -26,7 +26,7 @@ std::string VarHandle::DebugString() const {
return ss.str();
}
std::string DummyVarHandle::DebugString() const { return "dummy"; }
std::string DummyVarHandle::DebugString() const { return node_->Name(); }
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -34,7 +34,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
std::map<std::string, std::vector<ir::Node *>> var_nodes;
for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = CreateOpNode(op);
// For input args, reuse the same var name if it was created before.
// Otherwise, create a new one.
for (auto &each_var_name : op->InputArgumentNames()) {
ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) {
......@@ -43,16 +44,16 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name].push_back(var);
} else {
// TODO(paddle-dev): Seems some assumption doesn't hold?
VLOG(3) << op->Type()
<< " input var not in all_var list: " << each_var_name;
// Operation input var can be optional (dispensable). Which means
// the operation doesn't really need the var at runtime. In this
// case, the no-existed var is ready at the beginning.
var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable);
var_nodes[each_var_name].push_back(var);
}
node->inputs.push_back(var);
var->outputs.push_back(node);
}
// For output args, always create a new var.
for (auto &each_var_name : op->OutputArgumentNames()) {
ir::Node *var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name].push_back(var);
......@@ -67,6 +68,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
for (auto &var : var_nodes) {
auto &versions = var.second;
if (versions.size() <= 1) continue;
......@@ -85,8 +87,18 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
// Read Write is the same op.
continue;
}
ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName,
ir::Node::Type::kVariable);
// 2 ops might have been connected via other vars.
bool has_dep = false;
for (ir::Node *r_out : read_op->outputs) {
for (ir::Node *w_in : write_op->inputs) {
if (r_out == w_in) {
has_dep = true;
break;
}
}
}
if (has_dep) continue;
ir::Node *dep_var = CreateControlDepVar();
read_op->outputs.push_back(dep_var);
dep_var->inputs.push_back(read_op);
write_op->inputs.push_back(dep_var);
......
......@@ -27,6 +27,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace ir {
class Graph {
public:
explicit Graph(const ProgramDesc &program);
......@@ -54,28 +55,58 @@ class Graph {
};
}
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
ir::Node *CreateVarNode(VarDesc *var_desc) {
nodes.emplace_back(new ir::Node(var_desc));
return nodes.back().get();
return AddNode(new ir::Node(var_desc));
}
ir::Node *CreateOpNode(OpDesc *op_desc) {
nodes.emplace_back(new ir::Node(op_desc));
return nodes.back().get();
return AddNode(new ir::Node(op_desc));
}
ir::Node *CreateControlDepVar() {
// TODO(panyx0718): control var name should be unique.
const std::string name = string::Sprintf(
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
return AddNode(new ir::Node(name, ir::Node::Type::kVariable));
}
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
nodes.emplace_back(new ir::Node(name, type));
return nodes.back().get();
return AddNode(new ir::Node(name, type));
}
std::vector<std::unique_ptr<ir::Node>> nodes;
std::vector<std::unique_ptr<ir::Node>> ReleaseNodes() {
std::vector<std::unique_ptr<ir::Node>> ret;
for (auto &n : nodes_) {
ret.emplace_back(n.second.release());
}
nodes_.clear();
node_set_.clear();
return ret;
}
private:
// This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
nodes_[node].reset(node);
node_set_.insert(node);
return node;
}
void RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
nodes_.erase(node);
}
// NOTE: program_ shouldn't be exposed to user.
const ProgramDesc &program_;
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
std::unordered_set<ir::Node *> node_set_;
};
} // namespace ir
} // namespace framework
......
......@@ -33,9 +33,8 @@ void SortHelper(
}
}
LOG(ERROR) << "topology sort insert: " << node->Name()
<< reinterpret_cast<void *>(node) << " input "
<< node->inputs.size();
VLOG(3) << "topology sort insert: " << node->Name()
<< reinterpret_cast<void *>(node) << " input " << node->inputs.size();
ret->push_back(node);
}
......@@ -93,17 +92,17 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
for (auto &n : graph.nodes) {
for (auto &n : graph.Nodes()) {
if (n->NodeType() != ir::Node::Type::kOperation) continue;
if (adj_list.find(n.get()) == adj_list.end()) {
adj_list[n.get()] = std::unordered_set<ir::Node *>();
if (adj_list.find(n) == adj_list.end()) {
adj_list[n] = std::unordered_set<ir::Node *>();
}
for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
adj_list[n.get()].insert(adj_n);
LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n.get())
adj_list[n].insert(adj_n);
VLOG(3) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var);
}
}
......
......@@ -30,6 +30,7 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph);
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -94,20 +94,21 @@ TEST(GraphTest, Basic) {
prog.MutableBlock(0)->Var("test_out")->GetType());
std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ASSERT_EQ(g->nodes[0]->Name(), "sum");
ASSERT_EQ(g->nodes[0]->inputs[0]->Name(), "test_a");
ASSERT_EQ(g->nodes[0]->inputs[1]->Name(), "test_b");
ASSERT_EQ(g->nodes[0]->inputs[2]->Name(), "test_c");
ASSERT_EQ(g->nodes[0]->outputs[0]->Name(), "test_out");
ASSERT_EQ(g->nodes[1]->Name(), "test_a");
ASSERT_EQ(g->nodes[1]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[2]->Name(), "test_b");
ASSERT_EQ(g->nodes[2]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[3]->Name(), "test_c");
ASSERT_EQ(g->nodes[3]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[4]->Name(), "test_out");
ASSERT_EQ(g->nodes[4]->inputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes.size(), 5);
std::vector<ir::Node *> nodes(g->Nodes().begin(), g->Nodes().end());
ASSERT_EQ(nodes[0]->Name(), "sum");
ASSERT_EQ(nodes[0]->inputs[0]->Name(), "test_a");
ASSERT_EQ(nodes[0]->inputs[1]->Name(), "test_b");
ASSERT_EQ(nodes[0]->inputs[2]->Name(), "test_c");
ASSERT_EQ(nodes[0]->outputs[0]->Name(), "test_out");
ASSERT_EQ(nodes[1]->Name(), "test_a");
ASSERT_EQ(nodes[1]->outputs[0]->Name(), "sum");
ASSERT_EQ(nodes[2]->Name(), "test_b");
ASSERT_EQ(nodes[2]->outputs[0]->Name(), "sum");
ASSERT_EQ(nodes[3]->Name(), "test_c");
ASSERT_EQ(nodes[3]->outputs[0]->Name(), "sum");
ASSERT_EQ(nodes[4]->Name(), "test_out");
ASSERT_EQ(nodes[4]->inputs[0]->Name(), "sum");
ASSERT_EQ(nodes.size(), 5);
}
} // namespace framework
} // namespace paddle
......@@ -192,9 +192,9 @@ if(WITH_DISTRIBUTE)
set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else()
set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib)
set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/node.h"
namespace paddle {
namespace operators {
......@@ -22,7 +23,9 @@ inline bool NeedSend(const framework::Scope& scope,
const std::string& varname) {
// dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it.
if (varname == "dummy") return false;
if (varname.find(framework::ir::Node::kControlDepVarName) !=
std::string::npos)
return false;
auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
varname);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册