提交 21a45420 编写于 作者: X Xin Pan

polish and test

上级 2782e71a
...@@ -170,6 +170,7 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( ...@@ -170,6 +170,7 @@ size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const { const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0; int64_t numel_sum = 0;
for (auto var_name : var_names) { for (auto var_name : var_names) {
if (all_vars_.find(var_name) == all_vars_.end()) continue;
auto var_desc = all_vars_.at(var_name); auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc); PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape()); auto dim = framework::make_ddim(var_desc->GetShape());
...@@ -271,6 +272,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -271,6 +272,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ != if (strategy_.gradient_scale_ !=
BuildStrategy::GradientScaleStrategy::kCustomized) { BuildStrategy::GradientScaleStrategy::kCustomized) {
// TODO(paddle-dev): Why is there no input for this op_handle?
CreateScaleLossGradOp(&result); CreateScaleLossGradOp(&result);
} }
// This assumes the backward generating code will ensure IsScaleLossOp // This assumes the backward generating code will ensure IsScaleLossOp
...@@ -288,6 +290,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -288,6 +290,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
} else { } else {
// This op runs on all devices, and its output may have parameter's // This op runs on all devices, and its output may have parameter's
// gradients. // gradients.
// TODO(paddle-dev): Why is so special about "read" op?
if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) { if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) {
node->Op()->SetAttr("throw_eof_exp", false); node->Op()->SetAttr("throw_eof_exp", false);
CreateComputationalOps(&result, node, places_.size()); CreateComputationalOps(&result, node, places_.size());
...@@ -363,6 +366,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply( ...@@ -363,6 +366,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::Apply(
* Only variables should be the leaves of graph. * Only variables should be the leaves of graph.
*/ */
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
PADDLE_ENFORCE(!ir::HasCircle(result));
return graph; return graph;
} }
...@@ -620,6 +624,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -620,6 +624,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
if (node->Op()->Type() == "split_byref" || if (node->Op()->Type() == "split_byref" ||
node->Op()->Type() == "split_selected_rows") { node->Op()->Type() == "split_selected_rows") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(input_var_names[0]); op_dev_id = GetVarDeviceID(input_var_names[0]);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
...@@ -657,7 +662,10 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -657,7 +662,10 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
ir::Node *node) const { ir::Node *node) const {
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(node->inputs[0]->Name()); op_dev_id = GetVarDeviceID(node->inputs[0]->Name());
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
// split_byref op // split_byref op
// so that we can balance the variable blocks to all the pserver // so that we can balance the variable blocks to all the pserver
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -33,8 +34,7 @@ void RPCOpHandle::RunImpl() { ...@@ -33,8 +34,7 @@ void RPCOpHandle::RunImpl() {
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString() // FIXME(Yancey1989): need a better solution instead of use DebugString()
if (in->Node()->Name().find(ir::Node::kControlDepVarName) != if (ir::IsControlDepVar(*in->Node())) { // HACK
std::string::npos) { // HACK
continue; continue;
} }
if (in->GeneratedOp()) { if (in->GeneratedOp()) {
......
...@@ -17,36 +17,6 @@ ...@@ -17,36 +17,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
continue;
}
auto it_new = name_pair.second.rbegin();
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
OpHandleBase *write_op = (*it_new)->GeneratedOp();
const auto &read_ops = (*it_old)->PendingOps();
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
if (read_op == write_op) {
// Read Write is the same op.
continue;
}
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
}
}
}
}
}
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
ir::Graph *graph, ir::Node *node, const platform::Place &place, ir::Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
......
...@@ -57,15 +57,6 @@ class SSAGraphBuilder : public ir::Pass { ...@@ -57,15 +57,6 @@ class SSAGraphBuilder : public ir::Pass {
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
protected: protected:
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static void PolishGraphToSupportDataHazards(ir::Graph *graph);
static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node,
const platform::Place &place, const platform::Place &place,
size_t place_offset); size_t place_offset);
......
...@@ -2,4 +2,5 @@ cc_library(node SRCS node.cc DEPS proto_desc) ...@@ -2,4 +2,5 @@ cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node) cc_library(graph SRCS graph.cc DEPS node)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node) cc_library(pass SRCS pass.cc DEPS graph node)
cc_test(graph_test SRCS graph_test.cc DEPS graph proto_desc op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph_helper op_registry)
...@@ -107,6 +107,10 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -107,6 +107,10 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
} }
} }
} }
bool IsControlDepVar(const ir::Node &var) {
return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -57,25 +57,34 @@ class Graph { ...@@ -57,25 +57,34 @@ class Graph {
const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; } const std::unordered_set<ir::Node *> &Nodes() const { return node_set_; }
// Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) { ir::Node *CreateVarNode(VarDesc *var_desc) {
return AddNode(new ir::Node(var_desc)); return AddNode(new ir::Node(var_desc));
} }
// Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) { ir::Node *CreateOpNode(OpDesc *op_desc) {
return AddNode(new ir::Node(op_desc)); return AddNode(new ir::Node(op_desc));
} }
// Create a control dependency var that connects 2 operations. The
// var doesn't hold any data. Other than that, it's no different from
// other var, considering dependency analysis.
ir::Node *CreateControlDepVar() { ir::Node *CreateControlDepVar() {
// TODO(panyx0718): control var name should be unique. // TODO(panyx0718): control var name should be really unique.
const std::string name = string::Sprintf( const std::string name = string::Sprintf(
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size()); "%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
return AddNode(new ir::Node(name, ir::Node::Type::kVariable)); return AddNode(new ir::Node(name, ir::Node::Type::kVariable));
} }
// A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible.
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
return AddNode(new ir::Node(name, type)); return AddNode(new ir::Node(name, type));
} }
// Clear all node information of the graph and return the ownership of the
// nodes.
std::vector<std::unique_ptr<ir::Node>> ReleaseNodes() { std::vector<std::unique_ptr<ir::Node>> ReleaseNodes() {
std::vector<std::unique_ptr<ir::Node>> ret; std::vector<std::unique_ptr<ir::Node>> ret;
for (auto &n : nodes_) { for (auto &n : nodes_) {
...@@ -108,6 +117,8 @@ class Graph { ...@@ -108,6 +117,8 @@ class Graph {
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_; std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
std::unordered_set<ir::Node *> node_set_; std::unordered_set<ir::Node *> node_set_;
}; };
bool IsControlDepVar(const ir::Node &var);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -59,12 +59,9 @@ bool HasCircleHelper( ...@@ -59,12 +59,9 @@ bool HasCircleHelper(
in_trace->erase(node); in_trace->erase(node);
return false; return false;
} }
} // namespace
bool HasCircle(const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildOperationAdjList(graph);
bool HasCircleInternal(
const std::map<ir::Node *, std::unordered_set<ir::Node *>> &adj_list) {
std::unordered_set<ir::Node *> visited; std::unordered_set<ir::Node *> visited;
std::unordered_set<ir::Node *> in_trace; std::unordered_set<ir::Node *> in_trace;
for (auto &adj : adj_list) { for (auto &adj : adj_list) {
...@@ -74,10 +71,16 @@ bool HasCircle(const Graph &graph) { ...@@ -74,10 +71,16 @@ bool HasCircle(const Graph &graph) {
} }
return false; return false;
} }
} // namespace
bool HasCircle(const Graph &graph) {
return HasCircleInternal(BuildOperationAdjList(graph));
}
std::vector<ir::Node *> TopologySortOperations(const Graph &graph) { std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list = std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list =
BuildOperationAdjList(graph); BuildOperationAdjList(graph);
PADDLE_ENFORCE(!HasCircleInternal(adj_list));
std::unordered_set<ir::Node *> visited; std::unordered_set<ir::Node *> visited;
std::vector<ir::Node *> ret; std::vector<ir::Node *> ret;
for (auto adj : adj_list) { for (auto adj : adj_list) {
......
...@@ -24,10 +24,14 @@ limitations under the License. */ ...@@ -24,10 +24,14 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// Test if the graph contains circle.
bool HasCircle(const Graph &graph); bool HasCircle(const Graph &graph);
// Topology Sort the operations in the graph from inputs to outputs.
// `graph` cannot contain circle.
std::vector<ir::Node *> TopologySortOperations(const Graph &graph); std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
// Build an adjacency list of operations for the `graph`.
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList( std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph); const Graph &graph);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace ir {
void BuildCircleGraph(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
o1->outputs.push_back(v1);
o1->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o1);
}
void BuildCircleGraph2(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
o2->outputs.push_back(v2);
o1->inputs.push_back(v2);
v2->inputs.push_back(o2);
v2->outputs.push_back(o1);
}
void BuildNoCircleGraph(Graph* g) {
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation);
ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation);
ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation);
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable);
ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable);
// o1->v1->o2
o1->outputs.push_back(v1);
o2->inputs.push_back(v1);
v1->inputs.push_back(o1);
v1->outputs.push_back(o2);
// o2->v2->o3
// o2->v2->o4
o2->outputs.push_back(v2);
o3->inputs.push_back(v2);
o4->inputs.push_back(v2);
v2->inputs.push_back(o2);
v2->outputs.push_back(o3);
v2->outputs.push_back(o4);
// o2->v3->o5
o2->outputs.push_back(v3);
o5->inputs.push_back(v3);
v3->inputs.push_back(o2);
v3->outputs.push_back(o5);
// o3-v4->o5
o3->outputs.push_back(v4);
o5->inputs.push_back(v4);
v4->inputs.push_back(o3);
v4->outputs.push_back(o5);
}
TEST(GraphHelperTest, Basic) {
ProgramDesc prog;
Graph g(prog);
BuildCircleGraph(&g);
ASSERT_TRUE(HasCircle(g));
Graph g2(prog);
BuildCircleGraph2(&g2);
ASSERT_TRUE(HasCircle(g2));
auto adj_list = BuildOperationAdjList(g2);
for (auto& adj : adj_list) {
auto& adj_set = adj.second;
if (adj.first->Name() == "op1") {
ASSERT_EQ((*adj_set.begin())->Name(), "op2");
} else if (adj.first->Name() == "op2") {
ASSERT_EQ((*adj_set.begin())->Name(), "op1");
} else {
ASSERT_TRUE(false);
}
}
Graph g3(prog);
BuildNoCircleGraph(&g3);
ASSERT_FALSE(HasCircle(g3));
auto sorted = TopologySortOperations(g3);
std::map<std::string, size_t> node_map;
for (size_t i = 0; i < sorted.size(); ++i) {
node_map[sorted[i]->Name()] = i;
}
ASSERT_EQ(node_map.at("op1"), 0);
ASSERT_EQ(node_map.at("op2"), 1);
ASSERT_TRUE(node_map.at("op3") < node_map.at("op5"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -23,6 +23,7 @@ inline bool NeedSend(const framework::Scope& scope, ...@@ -23,6 +23,7 @@ inline bool NeedSend(const framework::Scope& scope,
const std::string& varname) { const std::string& varname) {
// dummy variable is only used in parallel executor to represent // dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it. // some dependency relationship, we don't need to send/recv it.
// TODO(paddle-dev): Why would parallel executor logic leaked into here?
if (varname.find(framework::ir::Node::kControlDepVarName) != if (varname.find(framework::ir::Node::kControlDepVarName) !=
std::string::npos) std::string::npos)
return false; return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册