提交 93355cc0 编写于 作者: X Xin Pan

fix control deps

上级 f6d99d1f
cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto) cc_library(var_handle SRCS var_handle.cc DEPS place framework_proto node)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
......
...@@ -94,12 +94,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, ...@@ -94,12 +94,11 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
} }
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const { const std::vector<ir::Node *> &nodes) const {
std::vector<std::string> send_vars; std::vector<std::string> send_vars;
// since parameters are all in block 0, // since parameters are all in block 0,
// it's enough to only scan send ops in block 0 // it's enough to only scan send ops in block 0
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue;
OpDesc *op = node->Op(); OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find send op, // TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string // instead of the the hard code string
...@@ -114,10 +113,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( ...@@ -114,10 +113,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
} }
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars( std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const { const std::vector<ir::Node *> &nodes) const {
std::vector<std::string> recv_vars; std::vector<std::string> recv_vars;
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->NodeType() != ir::Node::Type::kOperation) continue;
OpDesc *op = node->Op(); OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find recv op, // TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string // instead of the hard code string
...@@ -214,6 +212,19 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { ...@@ -214,6 +212,19 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
} }
} }
// Verify that no operations before optimize ops depends on optimize ops.
std::unordered_set<ir::Node *> optimize_set(optimize_ops.begin(),
optimize_ops.end());
for (size_t i = 0; i < last_backward; ++i) {
for (ir::Node *in : sorted_ret[i]->inputs) {
for (ir::Node *pre_n : in->inputs) {
PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(),
"optimize operations cannot be depended by forward "
"or backward node %s -> %s",
pre_n->Name(), sorted_ret[i]->Name());
}
}
}
sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(), sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(),
optimize_ops.end()); optimize_ops.end());
return sorted_ret; return sorted_ret;
...@@ -221,18 +232,16 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { ...@@ -221,18 +232,16 @@ std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
// Rebuild the graph structure. // Give the topology sort order and rebuild the graph structure.
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph); std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph);
auto nodes = std::move(graph->nodes); auto nodes = graph->ReleaseNodes();
graph->nodes.clear(); ir::Graph &result = *graph;
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->NodeType() == ir::Node::Type::kVariable) { if (node->NodeType() == ir::Node::Type::kVariable) {
all_vars_.emplace(node->Name(), node->Var()); all_vars_.emplace(node->Name(), node->Var());
} }
} }
ir::Graph &result = *graph;
std::unordered_set<std::string> og_has_been_broadcast; std::unordered_set<std::string> og_has_been_broadcast;
// We cannot invoke resize. It is a bug of GCC 4.8 // We cannot invoke resize. It is a bug of GCC 4.8
...@@ -242,8 +251,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -242,8 +251,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
// find send/recv vars so that we can place the distributed training // find send/recv vars so that we can place the distributed training
// realted op in the place 0 // realted op in the place 0
auto send_vars = FindDistTrainSendVars(nodes); auto send_vars = FindDistTrainSendVars(sorted_ops);
auto recv_vars = FindDistTrainRecvVars(nodes); auto recv_vars = FindDistTrainRecvVars(sorted_ops);
std::vector<std::unordered_set<std::string>> bcast_var_name_set; std::vector<std::unordered_set<std::string>> bcast_var_name_set;
bcast_var_name_set.resize(places_.size()); bcast_var_name_set.resize(places_.size());
...@@ -589,8 +598,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op, ...@@ -589,8 +598,7 @@ void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->Get<GraphOps>("ops")) { for (auto &prev_op : result->Get<GraphOps>("ops")) {
if (prev_op->Name() == prev_op_name) { if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle( auto *dep_var = new DummyVarHandle(result->CreateControlDepVar());
result->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
prev_op->AddOutput(dep_var); prev_op->AddOutput(dep_var);
result->Get<GraphDepVars>("dep_vars").emplace(dep_var); result->Get<GraphDepVars>("dep_vars").emplace(dep_var);
op->AddInput(dep_var); op->AddInput(dep_var);
......
...@@ -76,10 +76,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -76,10 +76,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::vector<std::string> &recv_vars) const; const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars( std::vector<std::string> FindDistTrainSendVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const; const std::vector<ir::Node *> &nodes) const;
std::vector<std::string> FindDistTrainRecvVars( std::vector<std::string> FindDistTrainRecvVars(
const std::vector<std::unique_ptr<ir::Node>> &nodes) const; const std::vector<ir::Node *> &nodes) const;
void ConnectOp(ir::Graph *result, OpHandleBase *op, void ConnectOp(ir::Graph *result, OpHandleBase *op,
const std::string &prev_op_name) const; const std::string &prev_op_name) const;
......
...@@ -33,7 +33,8 @@ void RPCOpHandle::RunImpl() { ...@@ -33,7 +33,8 @@ void RPCOpHandle::RunImpl() {
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString() // FIXME(Yancey1989): need a better solution instead of use DebugString()
if (in->DebugString() == "dummy") { // HACK if (in->Node()->Name().find(ir::Node::kControlDepVarName) !=
std::string::npos) { // HACK
continue; continue;
} }
if (in->GeneratedOp()) { if (in->GeneratedOp()) {
......
...@@ -17,6 +17,36 @@ ...@@ -17,6 +17,36 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
continue;
}
auto it_new = name_pair.second.rbegin();
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
OpHandleBase *write_op = (*it_new)->GeneratedOp();
const auto &read_ops = (*it_old)->PendingOps();
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
if (read_op == write_op) {
// Read Write is the same op.
continue;
}
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
}
}
}
}
}
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
ir::Graph *graph, ir::Node *node, const platform::Place &place, ir::Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
...@@ -56,8 +86,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) { ...@@ -56,8 +86,7 @@ void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
auto *dummy_leaf = new DummyVarHandle( auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable));
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf); graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
op->AddOutput(dummy_leaf); op->AddOutput(dummy_leaf);
} }
......
...@@ -57,6 +57,15 @@ class SSAGraphBuilder : public ir::Pass { ...@@ -57,6 +57,15 @@ class SSAGraphBuilder : public ir::Pass {
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
protected: protected:
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static void PolishGraphToSupportDataHazards(ir::Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
const platform::Place &place, const platform::Place &place,
size_t place_offset); size_t place_offset);
......
...@@ -26,7 +26,7 @@ std::string VarHandle::DebugString() const { ...@@ -26,7 +26,7 @@ std::string VarHandle::DebugString() const {
return ss.str(); return ss.str();
} }
std::string DummyVarHandle::DebugString() const { return "dummy"; } std::string DummyVarHandle::DebugString() const { return node_->Name(); }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -34,7 +34,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -34,7 +34,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
std::map<std::string, std::vector<ir::Node *>> var_nodes; std::map<std::string, std::vector<ir::Node *>> var_nodes;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = CreateOpNode(op); ir::Node *node = CreateOpNode(op);
// For input args, reuse the same var name if it was created before.
// Otherwise, create a new one.
for (auto &each_var_name : op->InputArgumentNames()) { for (auto &each_var_name : op->InputArgumentNames()) {
ir::Node *var = nullptr; ir::Node *var = nullptr;
if (var_nodes.find(each_var_name) != var_nodes.end()) { if (var_nodes.find(each_var_name) != var_nodes.end()) {
...@@ -43,16 +44,16 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -43,16 +44,16 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
var = CreateVarNode(all_vars.at(each_var_name)); var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name].push_back(var); var_nodes[each_var_name].push_back(var);
} else { } else {
// TODO(paddle-dev): Seems some assumption doesn't hold? // Operation input var can be optional (dispensable). Which means
VLOG(3) << op->Type() // the operation doesn't really need the var at runtime. In this
<< " input var not in all_var list: " << each_var_name; // case, the no-existed var is ready at the beginning.
var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable); var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable);
var_nodes[each_var_name].push_back(var); var_nodes[each_var_name].push_back(var);
} }
node->inputs.push_back(var); node->inputs.push_back(var);
var->outputs.push_back(node); var->outputs.push_back(node);
} }
// For output args, always create a new var.
for (auto &each_var_name : op->OutputArgumentNames()) { for (auto &each_var_name : op->OutputArgumentNames()) {
ir::Node *var = CreateVarNode(all_vars.at(each_var_name)); ir::Node *var = CreateVarNode(all_vars.at(each_var_name));
var_nodes[each_var_name].push_back(var); var_nodes[each_var_name].push_back(var);
...@@ -67,6 +68,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -67,6 +68,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
* *
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/ */
for (auto &var : var_nodes) { for (auto &var : var_nodes) {
auto &versions = var.second; auto &versions = var.second;
if (versions.size() <= 1) continue; if (versions.size() <= 1) continue;
...@@ -85,8 +87,18 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -85,8 +87,18 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
// Read Write is the same op. // Read Write is the same op.
continue; continue;
} }
ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName, // 2 ops might have been connected via other vars.
ir::Node::Type::kVariable); bool has_dep = false;
for (ir::Node *r_out : read_op->outputs) {
for (ir::Node *w_in : write_op->inputs) {
if (r_out == w_in) {
has_dep = true;
break;
}
}
}
if (has_dep) continue;
ir::Node *dep_var = CreateControlDepVar();
read_op->outputs.push_back(dep_var); read_op->outputs.push_back(dep_var);
dep_var->inputs.push_back(read_op); dep_var->inputs.push_back(read_op);
write_op->inputs.push_back(dep_var); write_op->inputs.push_back(dep_var);
......
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class Graph { class Graph {
public: public:
explicit Graph(const ProgramDesc &program); explicit Graph(const ProgramDesc &program);
...@@ -54,28 +55,58 @@ class Graph { ...@@ -54,28 +55,58 @@ class Graph {
}; };
} }
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
ir::Node *CreateVarNode(VarDesc *var_desc) { ir::Node *CreateVarNode(VarDesc *var_desc) {
nodes.emplace_back(new ir::Node(var_desc)); return AddNode(new ir::Node(var_desc));
return nodes.back().get();
} }
ir::Node *CreateOpNode(OpDesc *op_desc) { ir::Node *CreateOpNode(OpDesc *op_desc) {
nodes.emplace_back(new ir::Node(op_desc)); return AddNode(new ir::Node(op_desc));
return nodes.back().get(); }
ir::Node *CreateControlDepVar() {
// TODO(panyx0718): control var name should be unique.
const std::string name = string::Sprintf(
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
return AddNode(new ir::Node(name, ir::Node::Type::kVariable));
} }
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
nodes.emplace_back(new ir::Node(name, type)); return AddNode(new ir::Node(name, type));
return nodes.back().get();
} }
std::vector<std::unique_ptr<ir::Node>> nodes; std::vector<std::unique_ptr<ir::Node>> ReleaseNodes() {
std::vector<std::unique_ptr<ir::Node>> ret;
for (auto &n : nodes_) {
ret.emplace_back(n.second.release());
}
nodes_.clear();
node_set_.clear();
return ret;
}
private: private:
// This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
nodes_[node].reset(node);
node_set_.insert(node);
return node;
}
void RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
nodes_.erase(node);
}
// NOTE: program_ shouldn't be exposed to user. // NOTE: program_ shouldn't be exposed to user.
const ProgramDesc &program_; const ProgramDesc &program_;
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
std::unordered_set<ir::Node *> node_set_;
}; };
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -33,9 +33,8 @@ void SortHelper( ...@@ -33,9 +33,8 @@ void SortHelper(
} }
} }
LOG(ERROR) << "topology sort insert: " << node->Name() VLOG(3) << "topology sort insert: " << node->Name()
<< reinterpret_cast<void *>(node) << " input " << reinterpret_cast<void *>(node) << " input " << node->inputs.size();
<< node->inputs.size();
ret->push_back(node); ret->push_back(node);
} }
...@@ -93,17 +92,17 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList( ...@@ -93,17 +92,17 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph) { const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list; std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
for (auto &n : graph.nodes) { for (auto &n : graph.Nodes()) {
if (n->NodeType() != ir::Node::Type::kOperation) continue; if (n->NodeType() != ir::Node::Type::kOperation) continue;
if (adj_list.find(n.get()) == adj_list.end()) { if (adj_list.find(n) == adj_list.end()) {
adj_list[n.get()] = std::unordered_set<ir::Node *>(); adj_list[n] = std::unordered_set<ir::Node *>();
} }
for (auto &var : n->inputs) { for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) { for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
adj_list[n.get()].insert(adj_n); adj_list[n].insert(adj_n);
LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n) VLOG(3) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n.get()) << " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var); << " via " << var->Name() << reinterpret_cast<void *>(var);
} }
} }
......
...@@ -30,6 +30,7 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph); ...@@ -30,6 +30,7 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList( std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph); const Graph &graph);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -94,20 +94,21 @@ TEST(GraphTest, Basic) { ...@@ -94,20 +94,21 @@ TEST(GraphTest, Basic) {
prog.MutableBlock(0)->Var("test_out")->GetType()); prog.MutableBlock(0)->Var("test_out")->GetType());
std::unique_ptr<ir::Graph> g(new ir::Graph(prog)); std::unique_ptr<ir::Graph> g(new ir::Graph(prog));
ASSERT_EQ(g->nodes[0]->Name(), "sum"); std::vector<ir::Node *> nodes(g->Nodes().begin(), g->Nodes().end());
ASSERT_EQ(g->nodes[0]->inputs[0]->Name(), "test_a"); ASSERT_EQ(nodes[0]->Name(), "sum");
ASSERT_EQ(g->nodes[0]->inputs[1]->Name(), "test_b"); ASSERT_EQ(nodes[0]->inputs[0]->Name(), "test_a");
ASSERT_EQ(g->nodes[0]->inputs[2]->Name(), "test_c"); ASSERT_EQ(nodes[0]->inputs[1]->Name(), "test_b");
ASSERT_EQ(g->nodes[0]->outputs[0]->Name(), "test_out"); ASSERT_EQ(nodes[0]->inputs[2]->Name(), "test_c");
ASSERT_EQ(g->nodes[1]->Name(), "test_a"); ASSERT_EQ(nodes[0]->outputs[0]->Name(), "test_out");
ASSERT_EQ(g->nodes[1]->outputs[0]->Name(), "sum"); ASSERT_EQ(nodes[1]->Name(), "test_a");
ASSERT_EQ(g->nodes[2]->Name(), "test_b"); ASSERT_EQ(nodes[1]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[2]->outputs[0]->Name(), "sum"); ASSERT_EQ(nodes[2]->Name(), "test_b");
ASSERT_EQ(g->nodes[3]->Name(), "test_c"); ASSERT_EQ(nodes[2]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[3]->outputs[0]->Name(), "sum"); ASSERT_EQ(nodes[3]->Name(), "test_c");
ASSERT_EQ(g->nodes[4]->Name(), "test_out"); ASSERT_EQ(nodes[3]->outputs[0]->Name(), "sum");
ASSERT_EQ(g->nodes[4]->inputs[0]->Name(), "sum"); ASSERT_EQ(nodes[4]->Name(), "test_out");
ASSERT_EQ(g->nodes.size(), 5); ASSERT_EQ(nodes[4]->inputs[0]->Name(), "sum");
ASSERT_EQ(nodes.size(), 5);
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -192,9 +192,9 @@ if(WITH_DISTRIBUTE) ...@@ -192,9 +192,9 @@ if(WITH_DISTRIBUTE)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
if(WITH_GRPC) if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf node)
else() else()
set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib) set(DISTRIBUTE_DEPS sendrecvop_brpc brpc leveldb snappystream snappy protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA) if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs) find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/ir/node.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -22,7 +23,9 @@ inline bool NeedSend(const framework::Scope& scope, ...@@ -22,7 +23,9 @@ inline bool NeedSend(const framework::Scope& scope,
const std::string& varname) { const std::string& varname) {
// dummy variable is only used in parallel executor to represent // dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it. // some dependency relationship, we don't need to send/recv it.
if (varname == "dummy") return false; if (varname.find(framework::ir::Node::kControlDepVarName) !=
std::string::npos)
return false;
auto* var = scope.FindVar(varname); auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
varname); varname);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册