提交 b3cd2ae8 编写于 作者: L luotao1

Merge branch 'develop' into ner_ut2

...@@ -36,6 +36,7 @@ paddle.fluid.default_startup_program ArgSpec(args=[], varargs=None, keywords=Non ...@@ -36,6 +36,7 @@ paddle.fluid.default_startup_program ArgSpec(args=[], varargs=None, keywords=Non
paddle.fluid.default_main_program ArgSpec(args=[], varargs=None, keywords=None, defaults=None) paddle.fluid.default_main_program ArgSpec(args=[], varargs=None, keywords=None, defaults=None)
paddle.fluid.program_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None) paddle.fluid.program_guard ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.get_var ArgSpec(args=['name', 'program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.get_var ArgSpec(args=['name', 'program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.name_scope ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.Executor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None) paddle.fluid.Executor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.close ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.Executor.close ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False)) paddle.fluid.Executor.run ArgSpec(args=['self', 'program', 'feed', 'fetch_list', 'feed_var_name', 'fetch_var_name', 'scope', 'return_numpy', 'use_program_cache'], varargs=None, keywords=None, defaults=(None, None, None, 'feed', 'fetch', None, True, False))
......
...@@ -625,19 +625,11 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, ...@@ -625,19 +625,11 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
ir::Graph *result, const std::string &loss_grad_name) const { ir::Graph *result, const std::string &loss_grad_name) const {
for (size_t i = 0; i < places_.size(); ++i) { for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
#ifdef PADDLE_WITH_CUDA auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
auto *communication_dev_ctx =
nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i])
: platform::DeviceContextPool::Instance().Get(places_[i]);
#else
auto *communication_dev_ctx =
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
#endif
auto *op_handle = new ScaleLossGradOpHandle( auto *op_handle = new ScaleLossGradOpHandle(
result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation),
local_scopes_.size(), local_scopes_[i], places_[i], local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx);
communication_dev_ctx);
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle); result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
...@@ -744,7 +736,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -744,7 +736,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
.emplace(varname, op_dev_id); .emplace(varname, op_dev_id);
} }
} else { } else {
PADDLE_ENFORCE( PADDLE_THROW(
"the distribute training related op should be in [split_byref, " "the distribute training related op should be in [split_byref, "
"concat]."); "concat].");
} }
......
...@@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) { ...@@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) {
auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
auto* while_pat_node = gpd.pattern().RetriveNode("while"); auto* while_pat_node = gpd.pattern().RetrieveNode("while");
auto* while_node = subgraph.at(while_pat_node); auto* while_node = subgraph.at(while_pat_node);
marked_nodes.insert(while_node); marked_nodes.insert(while_node);
}; };
......
...@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) { ...@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) {
} }
void BuildFCPattern(PDPattern* pattern) { void BuildFCPattern(PDPattern* pattern) {
// make sure the selected MUL op has one input argument is a parameter. // Create Operators
auto* mul_parameter_var = pattern->NewNode( auto* mul_op = pattern->NewNode("mul")->assert_is_op("mul");
[](Node* node) { auto* elementwise_add_op =
return node->IsVar() && node->outputs.size() == 1UL && pattern->NewNode("elementwise_add")->assert_is_op("elementwise_add");
node->outputs.front()->Op()->Type() == "mul" && node->Var() && // Create variables
node->Var()->Persistable(); // check is a parameter // w
}, auto* mul_weight_var = pattern->NewNode("mul_weight")
"mul_weight" /*name*/); ->AsInput()
->assert_is_op_nth_input("mul", "Y", 0);
auto* mul_tmp_input_var = pattern->NewNode( // x
[](Node* node) { auto* mul_tmp_var = pattern->NewNode("mul_tmp_var")
bool result = ->AsInput()
node->IsVar() && node->outputs.size() >= 1UL && node->Var() && ->assert_is_op_nth_input("mul", "X", 0);
!node->Var()->Persistable(); // this input is not an parameter. // intermediate variable, will be removed in the IR after fuse.
if (!result) return false; auto* mul_out_var = pattern->NewNode("mul_out")
// check whether one output is MUL op. ->AsIntermediate()
for (auto* op : node->outputs) { ->assert_is_only_output_of_op("mul")
if (op->IsOp() && op->Op()->Type() == "mul") return true; ->assert_is_op_input("elementwise_add");
} // bias
return false; auto* elementwise_add_tmp_var = pattern->NewNode("elementwise_add_tmpvar")
}, ->assert_is_op_input("elementwise_add")
"mul_tmp_var" /*name*/); ->AsInput();
// output
// select a MUL op auto* elementwise_add_out_var = pattern->NewNode("elementwise_add_out")
auto* mul_op = pattern->NewNode( ->AsOutput()
[](Node* node) { ->assert_is_op_output("elementwise_add");
return node->IsOp() && // start from an Op
node->Op()->Type() == "mul"; // type is mul mul_op->LinksFrom({mul_weight_var, mul_tmp_var}).LinksTo({mul_out_var});
// the output should be consumed only by one element_add, that check
// leaves in a Var PDNode.
},
"mul" /*name*/);
// make sure the MUL op's output has only one consumer and links to an
// ELEMENTWISE_ADD op.
auto* mul_out_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && // starts from a Var
node->outputs.size() == 1UL && // only has one consumer
node->outputs.front()->IsOp() && // check basic logic
node->Var() && // not a ControlDepVar
node->outputs.front()->Op()->Type() ==
"elementwise_add"; // a very strong validation
},
"mul_out");
// this check is not essential, just to make the corresponding variable Node
// retrival easier.
auto* elementwise_add_tmp_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
VarOutLinksToOp(node, "elementwise_add");
},
"elementwise_add_tmpvar");
// select an ELEMENTWISE_ADD op
auto* elementwise_add_op = pattern->NewNode(
[](Node* node) {
return node->IsOp() && node->Op()->Type() == "elementwise_add";
},
"elementwise_add" /*name*/);
// get the ELEMENTWISE_ADD op's output
auto* elementwise_add_out_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && node->inputs.size() == 1UL && node->Var() &&
node->inputs.front()->Op()->Type() == "elementwise_add";
},
"elementwise_add_out");
mul_op->LinksFrom({mul_parameter_var, mul_tmp_input_var})
.LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var}) elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var})
.LinksTo({elementwise_add_out_var}); .LinksTo({elementwise_add_out_var});
} }
...@@ -120,6 +77,7 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) { ...@@ -120,6 +77,7 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph.get());
FusePassBase::Init("fc", graph.get());
std::unordered_set<Node*> nodes2delete; std::unordered_set<Node*> nodes2delete;
...@@ -127,11 +85,12 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -127,11 +85,12 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
BuildFCPattern(gpd.mutable_pattern()); BuildFCPattern(gpd.mutable_pattern());
#define GET_NODE(id) \ #define GET_NODE(id) \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \ PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode(#id)), \
"pattern has no Node called %s", #id); \ "pattern has no Node called %s", #id); \
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \ auto* id = subgraph.at(gpd.pattern().RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle FC fuse"; VLOG(4) << "handle FC fuse";
...@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
graph->RemoveNode(mul); graph->RemoveNode(mul);
graph->RemoveNode(elementwise_add); graph->RemoveNode(elementwise_add);
graph->RemoveNode(mul_out); // tmp variable graph->RemoveNode(mul_out); // tmp variable
found_fc_count++;
}; };
gpd(graph.get(), handler); gpd(graph.get(), handler);
AddStatis(found_fc_count);
return graph; return graph;
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
...@@ -23,7 +24,7 @@ namespace ir { ...@@ -23,7 +24,7 @@ namespace ir {
/* /*
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp. * Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
*/ */
class FCFusePass : public Pass { class FCFusePass : public FusePassBase {
public: public:
virtual ~FCFusePass() {} virtual ~FCFusePass() {}
......
...@@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type, ...@@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& outputs) { const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp(); auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type); op->SetType(type);
op->SetInput("Xs", inputs); if (type == "mul") {
op->SetOutput("Ys", outputs); op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
} else if (type == "elementwise_add") {
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
} }
// a->OP0->b // a->OP0->b
......
...@@ -36,7 +36,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( ...@@ -36,7 +36,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
auto* id = subgraph.at(gpd.pattern().RetriveNode("any_node")); auto* id = subgraph.at(gpd.pattern().RetrieveNode("any_node"));
marked_nodes.insert(id); marked_nodes.insert(id);
}; };
gpd(graph.get(), handler); gpd(graph.get(), handler);
...@@ -64,9 +64,9 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( ...@@ -64,9 +64,9 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
#undef GET_NODE #undef GET_NODE
#undef SET_IN #undef SET_IN
LOG(INFO) << "hidden_n: " << hidden_n->Name(); VLOG(4) << "hidden_n: " << hidden_n->Name();
LOG(INFO) << "cell: " << cell_n->Name(); VLOG(4) << "cell: " << cell_n->Name();
LOG(INFO) << "xx: " << xx_n->Name(); VLOG(4) << "xx: " << xx_n->Name();
op_desc.SetInput("H0", {}); op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {}); op_desc.SetInput("C0", {});
......
...@@ -22,21 +22,37 @@ namespace paddle { ...@@ -22,21 +22,37 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static const char kParamScopeAttr[] = "param_scope"; static const char kParamScopeAttr[] = "__param_scope__";
static const char kFuseStatisAttr[] = "__fuse_statis__";
class FusePassBase : public Pass { class FusePassBase : public Pass {
public: public:
void Init(Graph* graph) const { graph_ = graph; } void Init(const std::string& repr, Graph* graph) const {
repr_ = repr;
graph_ = graph;
}
Scope* param_scope() const { Scope* param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr); return graph_->Get<framework::Scope*>(kParamScopeAttr);
} }
void AddStatis(int count_of_fused) const {
PADDLE_ENFORCE(graph_);
PADDLE_ENFORCE(!repr_.empty());
if (!graph_->Has(kFuseStatisAttr)) {
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
}
auto& info =
graph_->Get<std::unordered_map<std::string, int>>(kFuseStatisAttr);
info[repr_] = count_of_fused;
}
virtual ~FusePassBase() {} virtual ~FusePassBase() {}
protected: protected:
mutable Graph* graph_; mutable Graph* graph_;
mutable std::string repr_;
}; };
} // namespace ir } // namespace ir
......
...@@ -27,6 +27,19 @@ namespace ir { ...@@ -27,6 +27,19 @@ namespace ir {
size_t PDPattern::id_ = 0UL; size_t PDPattern::id_ = 0UL;
PDNode* PDPattern::NewNode(const std::string& name) {
if (!name.empty()) {
PADDLE_ENFORCE_EQ(node_map_.count(name), 0,
"PDNode's name should be unique, get duplicate [%s]",
name);
}
nodes_.emplace_back(new PDNode(this, name));
auto* cur = nodes_.back().get();
node_map_[name] = cur;
return cur;
}
PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
if (!name.empty()) { if (!name.empty()) {
PADDLE_ENFORCE_EQ(node_map_.count(name), 0, PADDLE_ENFORCE_EQ(node_map_.count(name), 0,
...@@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { ...@@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
return cur; return cur;
} }
PDNode* PDPattern::RetriveNode(const std::string& id) const { PDNode* PDPattern::RetrieveNode(const std::string& id) const {
auto it = node_map_.find(id); auto it = node_map_.find(id);
if (it == node_map_.end()) { if (it == node_map_.end()) {
return nullptr; return nullptr;
...@@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph, ...@@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph,
auto subgraphs = DetectPatterns(); auto subgraphs = DetectPatterns();
UniquePatterns(&subgraphs); UniquePatterns(&subgraphs);
RemoveOverlappedMatch(&subgraphs); RemoveOverlappedMatch(&subgraphs);
ValidateByNodeRole(&subgraphs);
if (subgraphs.empty()) return;
LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern"; LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern";
int id = 0; int id = 0;
for (auto& g : subgraphs) { for (auto& g : subgraphs) {
...@@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { ...@@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
} }
} }
} }
// Check to early stop if some PDNode can't find matched Node.
for (auto& pdnode : pattern_.nodes()) {
if (!pdnodes2nodes_.count(pdnode.get())) {
VLOG(4) << pdnode->name() << " can't find matched Node, early stop";
return false;
}
}
VLOG(3) << pdnodes2nodes_.size() << " nodes marked"; VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
return !pdnodes2nodes_.empty(); return !pdnodes2nodes_.empty();
} }
// The intermediate Nodes can only link to the nodes inside the pattern, or this
// subgraph will be droped.
void GraphPatternDetector::ValidateByNodeRole(
std::vector<GraphPatternDetector::subgraph_t>* subgraphs) {
std::vector<GraphPatternDetector::subgraph_t> result;
subgraphs->erase(
std::remove_if(
subgraphs->begin(), subgraphs->end(),
[](const GraphPatternDetector::subgraph_t& subgraph) -> bool {
// Collect the inputs and outputs.
std::unordered_set<Node*> ios;
for (auto& item : subgraph) {
if (!item.first->IsIntermediate()) {
ios.insert(item.second);
}
}
for (auto& item : subgraph) {
if (item.first->IsIntermediate()) {
for (auto* x : item.second->inputs) {
if (!ios.count(x)) {
return true;
}
}
for (auto* x : item.second->outputs) {
if (!ios.count(x)) {
return true;
}
}
}
}
return false;
}),
subgraphs->end());
}
struct HitGroup { struct HitGroup {
std::unordered_map<PDNode*, Node*> roles; std::unordered_map<PDNode*, Node*> roles;
...@@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() { ...@@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() {
// in edges of PDNodes. // in edges of PDNodes.
for (const auto& edge : pattern_.edges()) { for (const auto& edge : pattern_.edges()) {
VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name(); VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name();
// TODO(Superjomn) Fix bug here, the groups might be duplicate here.
// Each role has two PDNodes, which indicates two roles. // Each role has two PDNodes, which indicates two roles.
// Detect two Nodes that can match these two roles and they are connected. // Detect two Nodes that can match these two roles and they are connected.
auto& pre_groups = bi_records[step % 2]; auto& pre_groups = bi_records[step % 2];
...@@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() { ...@@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() {
// source -> target // source -> target
for (Node* source : pdnodes2nodes_[edge.first]) { for (Node* source : pdnodes2nodes_[edge.first]) {
for (Node* target : pdnodes2nodes_[edge.second]) { for (Node* target : pdnodes2nodes_[edge.second]) {
VLOG(8) << "check " << source->id() << " -- " << target->id();
// TODO(Superjomn) add some prune strategies. // TODO(Superjomn) add some prune strategies.
for (const auto& group : pre_groups) { for (const auto& group : pre_groups) {
HitGroup new_group = group; HitGroup new_group = group;
...@@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() { ...@@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() {
} }
} }
VLOG(3) << "step " << step << " get records: " << cur_groups.size(); VLOG(3) << "step " << step << " get records: " << cur_groups.size();
for (auto& group : cur_groups) {
for (auto& item : group.roles) {
VLOG(4) << "node " << item.second->id() << " as " << item.first->name();
}
VLOG(4) << "=========================================================";
}
} }
for (auto& group : bi_records[step % 2]) { for (auto& group : bi_records[step % 2]) {
...@@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) { ...@@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
return *this; return *this;
} }
PDNode* PDNode::assert_is_op() {
asserts_.emplace_back([this](Node* x) { return x && x->IsOp(); });
return this;
}
PDNode* PDNode::assert_is_op(const std::string& op_type) {
asserts_.emplace_back([this, op_type](Node* x) {
return x && x->IsOp() && x->Op()->Type() == op_type;
});
return this;
}
PDNode* PDNode::assert_is_var() {
asserts_.emplace_back([this](Node* x) { return x && x->IsVar(); });
return this;
}
PDNode* PDNode::assert_var_not_persistable() {
assert_is_var();
asserts_.emplace_back([this](Node* x) { return !x->Var()->Persistable(); });
return this;
}
PDNode* PDNode::assert_is_persistable_var() {
assert_is_var();
asserts_.emplace_back([=](Node* x) { return x->Var()->Persistable(); });
return this;
}
PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth) {
assert_is_var();
assert_is_op_input(op_type);
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->outputs) {
if (IsNthInput(x, op, argument, nth)) return true;
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type,
const std::string& argument, int nth) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->inputs) {
if (IsNthOutput(x, op, argument, nth)) return true;
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_only_input_of_op(const std::string& op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->outputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type &&
op->inputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_only_output_of_op(const std::string& op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->inputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type &&
op->outputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_op_output(const std::string& op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->inputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
return true;
}
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_op_input(const std::string& op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->outputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
return true;
}
}
return false;
});
return this;
}
PDNode* PDNode::assert_op_has_n_inputs(const std::string& op_type, size_t n) {
assert_is_op(op_type);
asserts_.emplace_back([=](Node* x) { return x->inputs.size() == n; });
return this;
}
PDNode* PDNode::assert_op_has_n_outputs(const std::string& op_type, size_t n) {
assert_is_op(op_type);
asserts_.emplace_back([=](Node* x) { return x->outputs.size() == n; });
return this;
}
PDNode* PDNode::assert_more(PDNode::teller_t&& teller) {
asserts_.emplace_back(std::move(teller));
return this;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -39,14 +39,24 @@ struct PDNode { ...@@ -39,14 +39,24 @@ struct PDNode {
// tell whether an ir::Node* is a candidation for a PDNode. // tell whether an ir::Node* is a candidation for a PDNode.
using teller_t = std::function<bool(Node*)>; using teller_t = std::function<bool(Node*)>;
enum class Type { kOp, kVar }; enum class Type { kOp, kVar };
enum class Role {
kUnknown, // No role,
kInput, // an input and will be retained,
kOutput, // an output and will be retained,
kIntermediate // will be removed after handler.
};
// this link to others // this link to others
PDNode& LinksTo(const std::vector<PDNode*>& others); PDNode& LinksTo(const std::vector<PDNode*>& others);
PDNode& LinksFrom(const std::vector<PDNode*>& others); PDNode& LinksFrom(const std::vector<PDNode*>& others);
bool Tell(Node* node) const { bool Tell(Node* node) const {
PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode"); if (teller_) return teller_(node);
return teller_(node);
for (auto& asrt : asserts_) {
if (!asrt(node)) return false;
}
return true;
} }
bool IsOp() const { return type_ == Type::kOp; } bool IsOp() const { return type_ == Type::kOp; }
...@@ -54,10 +64,52 @@ struct PDNode { ...@@ -54,10 +64,52 @@ struct PDNode {
const std::string& name() const { return name_; } const std::string& name() const { return name_; }
PDNode(const PDNode&) = delete;
PDNode& operator=(const PDNode&) = delete; PDNode& operator=(const PDNode&) = delete;
PDNode(const PDNode&) = delete;
// Mark this node is an Input of a subgraph and will be retained.
PDNode* AsInput() {
role_ = Role::kInput;
return this;
}
// Mark this node is an Output of a subgraph and will be retained.
PDNode* AsOutput() {
role_ = Role::kOutput;
return this;
}
// Mark this node will be removed, so all the links should be inside a matched
// sub-graph.
PDNode* AsIntermediate() {
role_ = Role::kIntermediate;
return this;
}
bool IsIntermediate() const { return role_ == Role::kIntermediate; }
bool IsInput() const { return role_ == Role::kInput; }
bool IsOutput() const { return role_ == Role::kOutput; }
// Assertions, helper functions to simplify the pattern definition.
PDNode* assert_is_op();
PDNode* assert_is_op(const std::string& op_type);
PDNode* assert_is_var();
PDNode* assert_var_not_persistable();
PDNode* assert_is_persistable_var();
PDNode* assert_is_op_output(const std::string& op_type);
PDNode* assert_is_op_input(const std::string& op_type);
PDNode* assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth);
PDNode* assert_is_op_nth_output(const std::string& op_type,
const std::string& argument, int nth);
PDNode* assert_is_only_input_of_op(const std::string& op_type);
PDNode* assert_is_only_output_of_op(const std::string& op_type);
PDNode* assert_op_has_n_inputs(const std::string& op_type, size_t n);
PDNode* assert_op_has_n_outputs(const std::string& op_type, size_t n);
PDNode* assert_more(teller_t&& teller);
private: private:
PDNode(PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
: pattern_(pattern), name_(name), type_(type) {}
PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "", PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar) Type type = Type::kVar)
: teller_(std::move(teller)), : teller_(std::move(teller)),
...@@ -71,10 +123,13 @@ struct PDNode { ...@@ -71,10 +123,13 @@ struct PDNode {
friend class PDPattern; friend class PDPattern;
// Will removed latter.
teller_t teller_; teller_t teller_;
std::vector<teller_t> asserts_;
PDPattern* pattern_; PDPattern* pattern_;
std::string name_; std::string name_;
Type type_; Type type_;
Role role_{Role::kUnknown};
}; };
/* /*
...@@ -87,19 +142,18 @@ struct PDNode { ...@@ -87,19 +142,18 @@ struct PDNode {
* This pattern can be defined as with the following pseudo codes * This pattern can be defined as with the following pseudo codes
* *
* // Create two operator PDNodes. * // Create two operator PDNodes.
* MUL = PDPattern.NewNode() * MUL = PDPattern.NewNode().assert_is_op("mul");
* ELE = PDPattern.NewNode() * ELE = PDPattern.NewNode().assert_is_op("elementwise_add");
* // Create the variable PDNodes. * // Create the variable PDNodes.
* MUL_out = PDPattern.NewNode() * MUL_out = PDPattern.NewNode().assert_is_op_output("mul") \
* // Add teller to define some rules that help to filter the target Nodes. * .assert_is_op_input("elementwise_add") \
* MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul"; * .AsIntermediate();
* ELE.teller = lambda(node): \ * // Add relations.
* node->IsOp() && node->Op()->Type == "elementwise_add"; * MUL->LinksTo({MUL_out});
* MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs) * MUL_out->LinksTo({ELE});
* && (ELE in node->outputs)
* *
* One can add more specific tellers for PDNodes or edges, both the Operator * One can add more specific asserts for PDNodes or edges, both the Operator
* and Variable Nodes can be ruled in PDNode.teller. * and Variable Nodes can be ruled in PDNode.assert_more(...).
* *
* PDPattern can record the general patterns, such as the pattern represents * PDPattern can record the general patterns, such as the pattern represents
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place. * - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
...@@ -112,7 +166,8 @@ class PDPattern { ...@@ -112,7 +166,8 @@ class PDPattern {
void AddEdge(PDNode* a, PDNode* b); void AddEdge(PDNode* a, PDNode* b);
PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = NewID()); PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = NewID());
PDNode* RetriveNode(const std::string& id) const; PDNode* NewNode(const std::string& name = NewID());
PDNode* RetrieveNode(const std::string& id) const;
const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; } const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; }
const std::vector<edge_t>& edges() const { return edges_; } const std::vector<edge_t>& edges() const { return edges_; }
...@@ -185,6 +240,9 @@ class GraphPatternDetector { ...@@ -185,6 +240,9 @@ class GraphPatternDetector {
// Remove overlapped match subgraphs, when overlapped, keep the previous one. // Remove overlapped match subgraphs, when overlapped, keep the previous one.
void RemoveOverlappedMatch(std::vector<subgraph_t>* subgraphs); void RemoveOverlappedMatch(std::vector<subgraph_t>* subgraphs);
// Validate whether the intermediate nodes are linked by external nodes.
void ValidateByNodeRole(std::vector<subgraph_t>* subgraphs);
#ifdef PADDLE_WITH_TESTING #ifdef PADDLE_WITH_TESTING
FRIEND_TEST(GraphPatternDetecter, MarkPDNodesInGraph); FRIEND_TEST(GraphPatternDetecter, MarkPDNodesInGraph);
FRIEND_TEST(GraphPatternDetecter, DetectPatterns); FRIEND_TEST(GraphPatternDetecter, DetectPatterns);
...@@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument, ...@@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument,
return var->Name() == op->Op()->Input(argument)[nth]; return var->Name() == op->Op()->Input(argument)[nth];
} }
static bool IsNthOutput(Node* var, Node* op, const std::string& argument,
size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->inputs.size() <= nth) return false;
return var->Name() == op->Op()->Output(argument)[nth];
}
static void GraphSafeRemoveNodes(Graph* graph, static void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes) { const std::unordered_set<const Node*>& nodes) {
for (auto* node : nodes) { for (auto* node : nodes) {
......
...@@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) { ...@@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
ASSERT_LE(count, 2); ASSERT_LE(count, 2);
} }
TEST(GraphPatternDetector, IntermediateCheck) {
ProgramDesc program;
Graph graph(program);
BuildGraph(&graph);
// o2->v2->o3
// o2->v2->o4
// check o2+o3 fuse, should fail because v2 also link to o4.
GraphPatternDetector detector;
auto* op2 = detector.mutable_pattern()->NewNode(
[](Node* x) { return x && x->IsOp() && x->Name() == "op2"; }, "op2");
auto* op3 = detector.mutable_pattern()->NewNode(
[](Node* x) { return x && x->IsOp() && x->Name() == "op3"; }, "op3");
auto* v2 =
detector.mutable_pattern()
->NewNode(
[](Node* x) { return x && x->IsVar() && x->Name() == "var2"; },
"var2")
->AsIntermediate();
v2->LinksFrom({op2}).LinksTo({op3});
int count = 0;
detector(&graph, [&](const GraphPatternDetector::subgraph_t& g,
Graph* graph) { ++count; });
EXPECT_EQ(count, 0);
count = 0;
v2->AsInput();
detector(&graph, [&](const GraphPatternDetector::subgraph_t& g,
Graph* graph) { ++count; });
ASSERT_EQ(count, 1);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -16,13 +16,27 @@ limitations under the License. */ ...@@ -16,13 +16,27 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static const char kGraphVizPath[] = "graph_viz_path";
using inference::analysis::Dot; using inference::analysis::Dot;
namespace {
const char kGraphVizPath[] = "graph_viz_path";
std::string FormatName(const Node* node) {
if (!node->IsOp() || !node->Op() ||
!node->Op()->HasAttr(OpProtoAndCheckerMaker::OpNamescopeAttrName())) {
return node->Name();
}
const std::string full_scope = boost::get<std::string>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpNamescopeAttrName()));
return string::Sprintf("%s%s", full_scope.c_str(), node->Name().c_str());
}
} // namespace
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
...@@ -54,7 +68,7 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( ...@@ -54,7 +68,7 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
auto marked_nodes = ConsumeMarkedNodes(graph.get()); auto marked_nodes = ConsumeMarkedNodes(graph.get());
// Create nodes // Create nodes
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
std::string node_id = n->Name() + "(" + std::to_string(n->id()) + ")"; std::string node_id = FormatName(n) + "(" + std::to_string(n->id()) + ")";
if (n->IsOp()) { if (n->IsOp()) {
decltype(op_attrs) attr = decltype(op_attrs) attr =
marked_nodes.count(n) ? marked_op_attrs : op_attrs; marked_nodes.count(n) ? marked_op_attrs : op_attrs;
......
...@@ -55,11 +55,11 @@ class Node { ...@@ -55,11 +55,11 @@ class Node {
std::string Name() const { return name_; } std::string Name() const { return name_; }
VarDesc* Var() { VarDesc* Var() {
PADDLE_ENFORCE(type_ == Type::kVariable); PADDLE_ENFORCE(IsVar());
return var_desc_.get(); return var_desc_.get();
} }
OpDesc* Op() { OpDesc* Op() const {
PADDLE_ENFORCE(IsOp()); PADDLE_ENFORCE(IsOp());
return op_desc_.get(); return op_desc_.get();
} }
......
...@@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) { ...@@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl( std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(graph.get()); FusePassBase::Init("seq_concat_fc_fuse", graph.get());
GraphPatternDetector detector; GraphPatternDetector detector;
auto* pattern = detector.mutable_pattern(); auto* pattern = detector.mutable_pattern();
auto* concat_out = BuildSeqExpandConcatPattern(pattern); auto* concat_out = BuildSeqExpandConcatPattern(pattern);
BuildFCPattern(pattern, concat_out); BuildFCPattern(pattern, concat_out);
#define GET_NODE(id, pattern) \ #define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \ PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \
"pattern has no Node called %s", #id); \ "pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetriveNode(#id)); \ auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph, detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph,
......
...@@ -129,6 +129,9 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -129,6 +129,9 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
"Optimized for variable") "Optimized for variable")
.SetDefault({}); .SetDefault({});
AddAttr<std::string>(OpNamescopeAttrName(), "Operator name with namesope.")
.SetDefault("");
Validate(); Validate();
} }
......
...@@ -39,6 +39,7 @@ class OpProtoAndCheckerMaker { ...@@ -39,6 +39,7 @@ class OpProtoAndCheckerMaker {
public: public:
static const char *OpRoleAttrName() { return "op_role"; } static const char *OpRoleAttrName() { return "op_role"; }
static const char *OpRoleVarAttrName() { return "op_role_var"; } static const char *OpRoleVarAttrName() { return "op_role_var"; }
static const char *OpNamescopeAttrName() { return "op_namescope"; }
void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker); void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
......
...@@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager {
void AddGraphvizDebugerPass(Pass* pass) { void AddGraphvizDebugerPass(Pass* pass) {
auto* debuger_pass = pass->CreateGraphvizDebugerPass(); auto* debuger_pass = pass->CreateGraphvizDebugerPass();
if (debuger_pass) { if (debuger_pass) {
LOG(INFO) << " - register debug pass [" << debuger_pass->repr() << "]";
Register(debuger_pass->repr(), debuger_pass); Register(debuger_pass->repr(), debuger_pass);
} }
} }
......
...@@ -16,10 +16,13 @@ ...@@ -16,10 +16,13 @@
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN"); DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN");
...@@ -31,6 +34,8 @@ namespace paddle { ...@@ -31,6 +34,8 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
using namespace framework;
TEST(Analyzer, analysis_without_tensorrt) { TEST(Analyzer, analysis_without_tensorrt) {
FLAGS_IA_enable_tensorrt_subgraph_engine = false; FLAGS_IA_enable_tensorrt_subgraph_engine = false;
Argument argument; Argument argument;
...@@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path, ...@@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path,
EXPECT_NEAR(data[i], base_data[i], 1e-3); EXPECT_NEAR(data[i], base_data[i], 1e-3);
} }
} }
if (use_analysis && activate_ir) {
AnalysisPredictor *analysis_predictor =
dynamic_cast<AnalysisPredictor *>(predictor.get());
auto &fuse_statis = analysis_predictor->analysis_argument()
.Get<std::unordered_map<std::string, int>>(
framework::ir::kFuseStatisAttr);
for (auto &item : fuse_statis) {
LOG(INFO) << "fused " << item.first << " " << item.second;
}
ASSERT_TRUE(fuse_statis.count("fc"));
EXPECT_EQ(fuse_statis.at("fc"), 1);
}
} }
// Directly infer with the original model. // Directly infer with the original model.
......
...@@ -64,7 +64,8 @@ struct Argument { ...@@ -64,7 +64,8 @@ struct Argument {
template <typename T> template <typename T>
void Set(const std::string& key, T* data) { void Set(const std::string& key, T* data) {
PADDLE_ENFORCE_NOT_NULL(data); PADDLE_ENFORCE_NOT_NULL(data);
PADDLE_ENFORCE(!attrs_.count(key), "duplicate attr called %s", key); PADDLE_ENFORCE(!attrs_.count(key), "Duplicate set Argument's attr [%s]",
key);
attrs_[key] = data; attrs_[key] = data;
attr_deleters_[key] = [data, key, this]() { attr_deleters_[key] = [data, key, this]() {
VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" #include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include <vector> #include <vector>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/proto_desc.h" #include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
...@@ -34,7 +35,6 @@ std::vector<std::string> ExtractParameters( ...@@ -34,7 +35,6 @@ std::vector<std::string> ExtractParameters(
bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument) ANALYSIS_ARGUMENT_CHECK_FIELD(argument)
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc) ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc)
PADDLE_ENFORCE(!argument->transformed_program_desc);
// The transformed_program_desc should inherit all the VarDesc and BlockDesc // The transformed_program_desc should inherit all the VarDesc and BlockDesc
// from the original program desc. The operators of the main block(the first // from the original program desc. The operators of the main block(the first
// block) should rewritten by data flow graph. // block) should rewritten by data flow graph.
...@@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { ...@@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
} }
} }
if (argument_->Has("param_scope")) { if (argument_->Has(framework::ir::kParamScopeAttr)) {
LOG(WARNING) << "parameter changes in the scope takes effect"; LOG(WARNING) << "parameter changes in the scope takes effect";
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir, ...@@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir,
const std::string &prog_file, const std::string &prog_file,
const std::string &param_file) { const std::string &param_file) {
PADDLE_ENFORCE(argument_); PADDLE_ENFORCE(argument_);
argument_->Set("param_scope", new framework::Scope); argument_->Set(framework::ir::kParamScopeAttr, new framework::Scope);
// Load parameters. // Load parameters.
VLOG(3) << "Loading parameters from " << model_dir; VLOG(3) << "Loading parameters from " << model_dir;
LoadParams(&argument_->Get<framework::Scope>("param_scope"), model_dir, LoadParams(&argument_->Get<framework::Scope>(framework::ir::kParamScopeAttr),
prog_file, param_file); model_dir, prog_file, param_file);
} }
bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir, bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir,
......
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
using namespace framework;
static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__"; static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__";
...@@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path); ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path);
// Load program. // Load program.
auto program = LoadProgramDesc(*argument->fluid_model_program_path); auto program = LoadProgramDesc(*argument->fluid_model_program_path);
argument->origin_program_desc.reset( argument->origin_program_desc.reset(new proto::ProgramDesc(program));
new framework::proto::ProgramDesc(program));
// Create main data flow graph. // Create main data flow graph.
if (!argument->main_dfg) { if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph); argument->main_dfg.reset(new DataFlowGraph);
} }
argument->Set("ir_program_desc", new framework::ProgramDesc(program)); argument->Set("ir_program_desc", new ProgramDesc(program));
LOG(INFO) << "Loading parameters"; LOG(INFO) << "Loading parameters";
// Load parameters to argument if needed. // Load parameters to argument if needed.
...@@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass {
void Run(DataFlowGraph *graph) override { void Run(DataFlowGraph *graph) override {
// Call all the IR Passes // Call all the IR Passes
IRPassManager ir_passes( IRPassManager ir_passes(argument_->Get<ProgramDesc>("ir_program_desc"),
argument_->Get<framework::ProgramDesc>("ir_program_desc"), nullptr); nullptr);
// Pass the scope from analysis to IR if needed. // Pass the scope from analysis to IR if needed.
if (argument_->Has("param_scope")) { if (argument_->Has(ir::kParamScopeAttr)) {
// Here the address is passed, attention that IR doesn't own the scope, so // Here the address is passed, attention that IR doesn't own the scope, so
// the real scope in analysis should live during the IR phase. // the real scope in analysis should live during the IR phase.
ir_passes.graph().Set( ir_passes.graph().Set(
"param_scope", new framework::Scope *( ir::kParamScopeAttr,
&argument_->Get<framework::Scope>("param_scope"))); new Scope *(&argument_->Get<Scope>(ir::kParamScopeAttr)));
} }
const auto &ir_passes_to_apply = const auto &ir_passes_to_apply =
...@@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass {
PADDLE_ENFORCE(argument_->main_dfg.get()); PADDLE_ENFORCE(argument_->main_dfg.get());
argument_->main_dfg->Build(ir_passes.graph()); argument_->main_dfg->Build(ir_passes.graph());
// inherit the arguments from ir.
if (ir_passes.graph().Has(ir::kFuseStatisAttr)) {
argument_->Set(
ir::kFuseStatisAttr,
new std::unordered_map<std::string, int>(
ir_passes.graph().Get<std::unordered_map<std::string, int>>(
ir::kFuseStatisAttr)));
}
} }
void EnableParamModify(const std::string &model_dir, void EnableParamModify(const std::string &model_dir,
...@@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass {
private: private:
// Load parameters from a single file or from a directory. // Load parameters from a single file or from a directory.
bool LoadParams(framework::Scope *scope, const std::string &dir, bool LoadParams(Scope *scope, const std::string &dir,
const std::string &prog_file, const std::string &param_file); const std::string &prog_file, const std::string &param_file);
private: private:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program, ...@@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
framework::Scope *scope) framework::Scope *scope)
: program_(program) { : program_(program) {
graph_.reset(new framework::ir::Graph(program)); graph_.reset(new framework::ir::Graph(program));
if (scope) graph_->Set("param_scope", new framework::Scope *(scope)); if (scope)
graph_->Set(framework::ir::kParamScopeAttr, new framework::Scope *(scope));
} }
void IRPassManager::Apply(const std::vector<std::string> &passes) { void IRPassManager::Apply(const std::vector<std::string> &passes) {
......
...@@ -12,31 +12,18 @@ ...@@ -12,31 +12,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <memory> #include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
namespace paddle { namespace paddle {
using inference::analysis::Argument; bool AnalysisPredictor::Init(
using inference::Singleton; const std::shared_ptr<framework::Scope>& parent_scope) {
using inference::analysis::Analyzer;
using framework::proto::ProgramDesc;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class AnalysisPredictor : public NativePaddlePredictor {
public:
explicit AnalysisPredictor(const NativeConfig& config)
: NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
VLOG(3) << "Predictor::init()"; VLOG(3) << "Predictor::init()";
if (config_.use_gpu) { if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device); place_ = paddle::platform::CUDAPlace(config_.device);
...@@ -58,8 +45,8 @@ class AnalysisPredictor : public NativePaddlePredictor { ...@@ -58,8 +45,8 @@ class AnalysisPredictor : public NativePaddlePredictor {
if (!config_.model_dir.empty()) { if (!config_.model_dir.empty()) {
// Parameters are saved in separate files sited in // Parameters are saved in separate files sited in
// the specified `dirname`. // the specified `dirname`.
inference_program_ = paddle::inference::Load( inference_program_ = paddle::inference::Load(executor_.get(), scope_.get(),
executor_.get(), scope_.get(), config_.model_dir); config_.model_dir);
} else if (!config_.prog_file.empty() && !config_.param_file.empty()) { } else if (!config_.prog_file.empty() && !config_.param_file.empty()) {
// All parameters are saved in a single file. // All parameters are saved in a single file.
// The file names should be consistent with that used // The file names should be consistent with that used
...@@ -78,55 +65,43 @@ class AnalysisPredictor : public NativePaddlePredictor { ...@@ -78,55 +65,43 @@ class AnalysisPredictor : public NativePaddlePredictor {
PADDLE_ENFORCE(scope_.get()); PADDLE_ENFORCE(scope_.get());
executor_->CreateVariables(*inference_program_, executor_->CreateVariables(*inference_program_,
sub_scope_ ? sub_scope_ : scope_.get(), 0); sub_scope_ ? sub_scope_ : scope_.get(), 0);
// Get the feed_target_names and fetch_target_names // Get the feed_target_names and fetch_target_names
PrepareFeedFetch(); PrepareFeedFetch();
return true; return true;
} }
bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data,
int batch_size = -1) override {
return NativePaddlePredictor::Run(inputs, output_data, batch_size);
}
void OptimizeInferenceProgram() { void AnalysisPredictor::OptimizeInferenceProgram() {
LOG(INFO) << "optimize begin"; LOG(INFO) << "optimize begin";
FLAGS_IA_enable_ir = true; FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_tensorrt_subgraph_engine = false; FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = ""; // Don't output the model. FLAGS_IA_output_storage_path = ""; // Don't output the model.
// Analyze inference_program // Analyze inference_program
Argument argument;
if (!config_.model_dir.empty()) { if (!config_.model_dir.empty()) {
argument.fluid_model_dir.reset(new std::string(config_.model_dir)); argument_.fluid_model_dir.reset(new std::string(config_.model_dir));
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE(
!config_.param_file.empty(), !config_.param_file.empty(),
"Either model_dir or (param_file, prog_file) should be set."); "Either model_dir or (param_file, prog_file) should be set.");
PADDLE_ENFORCE(!config_.prog_file.empty()); PADDLE_ENFORCE(!config_.prog_file.empty());
argument.fluid_model_program_path.reset( argument_.fluid_model_program_path.reset(
new std::string(config_.prog_file)); new std::string(config_.prog_file));
argument.fluid_model_param_path.reset( argument_.fluid_model_param_path.reset(new std::string(config_.param_file));
new std::string(config_.param_file));
} }
argument.origin_program_desc.reset( argument_.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto())); new ProgramDesc(*inference_program_->Proto()));
Singleton<Analyzer>::Global().Run(&argument); Analyzer().Run(&argument_);
CHECK(argument.transformed_program_desc); CHECK(argument_.transformed_program_desc);
VLOG(5) << "to prepare executor"; VLOG(5) << "to prepare executor";
// LOG(INFO) << "transformed_parogram_desc " << // LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString(); // argument.transformed_program_desc->DebugString();
inference_program_.reset( inference_program_.reset(
new framework::ProgramDesc(*argument.transformed_program_desc)); new framework::ProgramDesc(*argument_.transformed_program_desc));
PADDLE_ENFORCE(argument.Has("param_scope")); PADDLE_ENFORCE(argument_.Has(framework::ir::kParamScopeAttr));
// Update scope. // Update scope.
scope_.reset(argument.Release<framework::Scope>("param_scope")); scope_.reset(
argument_.Release<framework::Scope>(framework::ir::kParamScopeAttr));
LOG(INFO) << "optimize end =="; LOG(INFO) << "optimize end ==";
} }
private:
NativeConfig config_;
};
template <> template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace paddle {
using inference::analysis::Argument;
using inference::analysis::Analyzer;
using framework::proto::ProgramDesc;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class AnalysisPredictor : public NativePaddlePredictor {
public:
explicit AnalysisPredictor(const NativeConfig& config)
: NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope);
bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data,
int batch_size = -1) override {
return NativePaddlePredictor::Run(inputs, output_data, batch_size);
}
void OptimizeInferenceProgram();
Argument& analysis_argument() { return argument_; }
private:
NativeConfig config_;
Argument argument_;
};
} // namespace paddle
...@@ -24,7 +24,7 @@ limitations under the License. */ ...@@ -24,7 +24,7 @@ limitations under the License. */
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include <mkldnn.hpp> #include "mkldnn.hpp"
#endif #endif
#include <map> #include <map>
......
...@@ -43,6 +43,9 @@ void BindConstValue(pybind11::module* m) { ...@@ -43,6 +43,9 @@ void BindConstValue(pybind11::module* m) {
op_proto_and_checker_maker.def( op_proto_and_checker_maker.def(
"kOpRoleVarAttrName", "kOpRoleVarAttrName",
framework::OpProtoAndCheckerMaker::OpRoleVarAttrName); framework::OpProtoAndCheckerMaker::OpRoleVarAttrName);
op_proto_and_checker_maker.def(
"kOpNameScopeAttrName",
framework::OpProtoAndCheckerMaker::OpNamescopeAttrName);
} }
} // namespace pybind } // namespace pybind
......
...@@ -43,6 +43,7 @@ __all__ = [ ...@@ -43,6 +43,7 @@ __all__ = [
'default_main_program', 'default_main_program',
'program_guard', 'program_guard',
'get_var', 'get_var',
'name_scope',
] ]
EMPTY_VAR_NAME = core.kEmptyVarName() EMPTY_VAR_NAME = core.kEmptyVarName()
...@@ -52,6 +53,70 @@ ZERO_VAR_SUFFIX = core.kZeroVarSuffix() ...@@ -52,6 +53,70 @@ ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName() CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
class NameScope(object):
def __init__(self, name="", parent=None):
self._children = dict()
self._name = name
self._parent = parent
def child(self, prefix):
if prefix not in self._children:
new_child = NameScope(prefix, self)
self._children[prefix] = [new_child]
else:
new_child = NameScope(prefix + "_%d" % len(self._children[prefix]),
self)
self._children[prefix].append(new_child)
return new_child
def parent(self):
return self._parent
def name(self):
return self._name
_name_scope = NameScope()
@contextlib.contextmanager
def name_scope(prefix=None):
"""
Generate hierarchical name prefix for the operators.
Note: This should only used for debugging and visualization purpose.
Don't use it for serious analysis such as graph/program transformations.
Args:
prefix(str): prefix.
Examples:
.. code-block:: python
with name_scope("encoder"):
...
with name_scope("decoder"):
...
with name_scope("attention"):
...
"""
# TODO(panyx0718): Only [0-9a-z].
assert prefix, "namescope prefix cannot be empty."
global _name_scope
_name_scope = _name_scope.child(prefix)
yield
_name_scope = _name_scope.parent()
def _full_name_scope():
global _name_scope
scope = _name_scope
name = ""
while scope:
name = scope.name() + "/" + name
scope = scope.parent()
return name
def generate_control_dev_var_name(): def generate_control_dev_var_name():
import random import random
return CONTROL_DEP_VAR_PREFIX + "@" + str(random.random()) return CONTROL_DEP_VAR_PREFIX + "@" + str(random.random())
...@@ -515,6 +580,9 @@ class Operator(object): ...@@ -515,6 +580,9 @@ class Operator(object):
self.desc.set_type(type) self.desc.set_type(type)
proto = OpProtoHolder.instance().get_op_proto(type) proto = OpProtoHolder.instance().get_op_proto(type)
namescope_var_name = op_maker.kOpNameScopeAttrName()
op_attrs[namescope_var_name] = _full_name_scope()
def find_name(var_list, name): def find_name(var_list, name):
for var_name in var_list: for var_name in var_list:
if var_list[var_name] is not None and var_name == name: if var_list[var_name] is not None and var_name == name:
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from __future__ import print_function from __future__ import print_function
import re import re
from collections import defaultdict from collections import defaultdict
from paddle.fluid.framework import Program, Variable from paddle.fluid.framework import Program, Variable, name_scope
from . import framework from . import framework
from . import layers from . import layers
from .backward import append_backward from .backward import append_backward
...@@ -237,7 +237,7 @@ class Optimizer(object): ...@@ -237,7 +237,7 @@ class Optimizer(object):
if param_and_grad[1] is None: if param_and_grad[1] is None:
continue continue
with param_and_grad[0].block.program.optimized_guard( with param_and_grad[0].block.program.optimized_guard(
param_and_grad): param_and_grad), name_scope("optimizer"):
if param_and_grad[0].trainable is True: if param_and_grad[0].trainable is True:
optimize_op = self._append_optimize_op(loss.block, optimize_op = self._append_optimize_op(loss.block,
param_and_grad) param_and_grad)
......
...@@ -82,8 +82,18 @@ class TestDistRunnerBase(object): ...@@ -82,8 +82,18 @@ class TestDistRunnerBase(object):
strategy = fluid.ExecutionStrategy() strategy = fluid.ExecutionStrategy()
strategy.num_threads = 1 strategy.num_threads = 1
strategy.allow_op_delay = False strategy.allow_op_delay = False
build_stra = fluid.BuildStrategy()
if args.use_reduce:
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
else:
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
True, loss_name=avg_cost.name, exec_strategy=strategy) True,
loss_name=avg_cost.name,
exec_strategy=strategy,
build_strategy=build_stra)
feed_var_list = [ feed_var_list = [
var for var in trainer_prog.global_block().vars.values() var for var in trainer_prog.global_block().vars.values()
...@@ -123,6 +133,7 @@ def runtime_main(test_class): ...@@ -123,6 +133,7 @@ def runtime_main(test_class):
'--current_endpoint', type=str, required=False, default="") '--current_endpoint', type=str, required=False, default="")
parser.add_argument('--sync_mode', action='store_true') parser.add_argument('--sync_mode', action='store_true')
parser.add_argument('--mem_opt', action='store_true') parser.add_argument('--mem_opt', action='store_true')
parser.add_argument('--use_reduce', action='store_true')
args = parser.parse_args() args = parser.parse_args()
...@@ -149,20 +160,25 @@ class TestDistBase(unittest.TestCase): ...@@ -149,20 +160,25 @@ class TestDistBase(unittest.TestCase):
self._python_interp = "python" self._python_interp = "python"
self._sync_mode = True self._sync_mode = True
self._mem_opt = False self._mem_opt = False
self._use_reduce = False
self._setup_config() self._setup_config()
def start_pserver(self, model_file, check_error_log): def start_pserver(self, model_file, check_error_log):
ps0_ep, ps1_ep = self._ps_endpoints.split(",") ps0_ep, ps1_ep = self._ps_endpoints.split(",")
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist %s %s" ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"
sync_mode_str = "--sync_mode" if self._sync_mode else ""
mem_opt_str = "--mem_opt" if self._mem_opt else ""
ps0_cmd = ps_cmd % \ ps0_cmd = ps_cmd % \
(self._python_interp, model_file, self._ps_endpoints, ps0_ep, (self._python_interp, model_file, self._ps_endpoints, ps0_ep,
self._trainers, sync_mode_str, mem_opt_str) self._trainers)
ps1_cmd = ps_cmd % \ ps1_cmd = ps_cmd % \
(self._python_interp, model_file, self._ps_endpoints, ps1_ep, (self._python_interp, model_file, self._ps_endpoints, ps1_ep,
self._trainers, sync_mode_str, mem_opt_str) self._trainers)
if self._sync_mode:
ps0_cmd += " --sync_mode"
ps1_cmd += " --sync_mode"
if self._mem_opt:
ps0_cmd += " --mem_opt"
ps1_cmd += " --mem_opt"
ps0_pipe = subprocess.PIPE ps0_pipe = subprocess.PIPE
ps1_pipe = subprocess.PIPE ps1_pipe = subprocess.PIPE
...@@ -242,17 +258,23 @@ class TestDistBase(unittest.TestCase): ...@@ -242,17 +258,23 @@ class TestDistBase(unittest.TestCase):
self._wait_ps_ready(ps1.pid) self._wait_ps_ready(ps1.pid)
ps0_ep, ps1_ep = self._ps_endpoints.split(",") ps0_ep, ps1_ep = self._ps_endpoints.split(",")
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist %s %s" tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
sync_mode_str = "--sync_mode" if self._sync_mode else ""
mem_opt_str = "--mem_opt" if self._mem_opt else ""
tr0_cmd = tr_cmd % \ tr0_cmd = tr_cmd % \
(self._python_interp, model_file, self._ps_endpoints, (self._python_interp, model_file, self._ps_endpoints,
0, ps0_ep, 0, ps0_ep, self._trainers)
self._trainers, sync_mode_str, mem_opt_str)
tr1_cmd = tr_cmd % \ tr1_cmd = tr_cmd % \
(self._python_interp, model_file, self._ps_endpoints, (self._python_interp, model_file, self._ps_endpoints,
1, ps1_ep, 1, ps1_ep, self._trainers)
self._trainers, sync_mode_str, mem_opt_str)
if self._sync_mode:
tr0_cmd += " --sync_mode"
tr1_cmd += " --sync_mode"
if self._mem_opt:
tr0_cmd += " --mem_opt"
tr1_cmd += " --mem_opt"
if self._use_reduce:
tr0_cmd += " --use_reduce"
tr1_cmd += " --use_reduce"
env0 = {"CUDA_VISIBLE_DEVICES": "0"} env0 = {"CUDA_VISIBLE_DEVICES": "0"}
env1 = {"CUDA_VISIBLE_DEVICES": "1"} env1 = {"CUDA_VISIBLE_DEVICES": "1"}
...@@ -303,6 +325,8 @@ class TestDistBase(unittest.TestCase): ...@@ -303,6 +325,8 @@ class TestDistBase(unittest.TestCase):
# FIXME: use terminate() instead of sigkill. # FIXME: use terminate() instead of sigkill.
os.kill(ps0.pid, signal.SIGKILL) os.kill(ps0.pid, signal.SIGKILL)
os.kill(ps1.pid, signal.SIGKILL) os.kill(ps1.pid, signal.SIGKILL)
ps0.terminate()
ps1.terminate()
ps0.wait() ps0.wait()
ps1.wait() ps1.wait()
FNULL.close() FNULL.close()
......
...@@ -20,6 +20,7 @@ from test_dist_base import TestDistBase ...@@ -20,6 +20,7 @@ from test_dist_base import TestDistBase
class TestDistMnist2x2(TestDistBase): class TestDistMnist2x2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
self._use_reduce = False
def test_se_resnext(self): def test_se_resnext(self):
self.check_with_place("dist_mnist.py", delta=1e-7) self.check_with_place("dist_mnist.py", delta=1e-7)
...@@ -37,10 +38,30 @@ class TestDistMnist2x2WithMemopt(TestDistBase): ...@@ -37,10 +38,30 @@ class TestDistMnist2x2WithMemopt(TestDistBase):
class TestDistMnistAsync(TestDistBase): class TestDistMnistAsync(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._sync_mode = False
self._use_reduce = False
def test_se_resnext(self): def test_se_resnext(self):
self.check_with_place("dist_mnist.py", delta=200) self.check_with_place("dist_mnist.py", delta=200)
# FIXME(typhoonzero): enable these tests once we have 4
# 4 GPUs on CI machine, and the base class should be updated.
#
# class TestDistMnist2x2ReduceMode(TestDistBase):
# def _setup_config(self):
# self._sync_mode = True
# self._use_reduce = True
# def test_se_resnext(self):
# self.check_with_place("dist_mnist.py", delta=1e-7)
# class TestDistMnistAsyncReduceMode(TestDistBase):
# def _setup_config(self):
# self._sync_mode = False
# self._use_reduce = True
# def test_se_resnext(self):
# self.check_with_place("dist_mnist.py", delta=200)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
class TestNameScope(unittest.TestCase):
def test_name_scope(self):
with fluid.name_scope("s1"):
a = fluid.layers.data(name='data', shape=[1], dtype='int32')
b = a + 1
with fluid.name_scope("s2"):
c = b * 1
with fluid.name_scope("s3"):
d = c / 1
with fluid.name_scope("s1"):
f = fluid.layers.pow(d, 2.0)
with fluid.name_scope("s4"):
g = f - 1
for op in fluid.default_main_program().block(0).ops:
if op.type == 'elementwise_add':
self.assertEqual(op.desc.attr("op_namescope"), '/s1/')
elif op.type == 'elementwise_mul':
self.assertEqual(op.desc.attr("op_namescope"), '/s1/s2/')
elif op.type == 'elementwise_div':
self.assertEqual(op.desc.attr("op_namescope"), '/s1/s3/')
elif op.type == 'elementwise_sub':
self.assertEqual(op.desc.attr("op_namescope"), '/s4/')
elif op.type == 'pow':
self.assertEqual(op.desc.attr("op_namescope"), '/s1_1/')
...@@ -67,7 +67,10 @@ class TestOperator(unittest.TestCase): ...@@ -67,7 +67,10 @@ class TestOperator(unittest.TestCase):
self.assertEqual(mul_op.output("Out"), ["mul.out"]) self.assertEqual(mul_op.output("Out"), ["mul.out"])
self.assertEqual( self.assertEqual(
set(mul_op.attr_names), set(mul_op.attr_names),
set(["x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var"])) set([
"x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var",
"op_namescope"
]))
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("x_num_col_dims"), 1) self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
......
...@@ -67,6 +67,7 @@ def fc_with_batchnorm(use_feed): ...@@ -67,6 +67,7 @@ def fc_with_batchnorm(use_feed):
hidden = img hidden = img
for _ in range(1): for _ in range(1):
with fluid.name_scope("hidden"):
hidden = fluid.layers.fc( hidden = fluid.layers.fc(
hidden, hidden,
size=200, size=200,
...@@ -75,8 +76,9 @@ def fc_with_batchnorm(use_feed): ...@@ -75,8 +76,9 @@ def fc_with_batchnorm(use_feed):
initializer=fluid.initializer.Constant(value=1.0))) initializer=fluid.initializer.Constant(value=1.0)))
hidden = fluid.layers.batch_norm(input=hidden) hidden = fluid.layers.batch_norm(input=hidden)
with fluid.name_scope("fc_layer"):
prediction = fluid.layers.fc(hidden, size=10, act='softmax') prediction = fluid.layers.fc(hidden, size=10, act='softmax')
with fluid.name_scope("loss"):
loss = fluid.layers.cross_entropy(input=prediction, label=label) loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss) loss = fluid.layers.mean(loss)
return loss return loss
......
...@@ -273,6 +273,10 @@ class DistributeTranspiler(object): ...@@ -273,6 +273,10 @@ class DistributeTranspiler(object):
name=framework.generate_control_dev_var_name()) name=framework.generate_control_dev_var_name())
grad_name_to_send_dummy_out[grad_varname] = dummy_output grad_name_to_send_dummy_out[grad_varname] = dummy_output
# get send op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name (split_by_ref and send
# will be on the same place). ParallelExecutor
# will use op_role_var to get expected device place to run this op.
program.global_block()._insert_op( program.global_block()._insert_op(
index=index + 1, index=index + 1,
type="send", type="send",
...@@ -281,8 +285,10 @@ class DistributeTranspiler(object): ...@@ -281,8 +285,10 @@ class DistributeTranspiler(object):
attrs={ attrs={
"epmap": eplist, "epmap": eplist,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: OP_ROLE_VAR_ATTR_NAME: [
[self.grad_name_to_param_name[grad_varname], grad_varname], self.grad_name_to_param_name[grad_varname],
splited_grad_varname
],
"sync_mode": not self.sync_mode, "sync_mode": not self.sync_mode,
}) })
for _, var in enumerate(splited_vars): for _, var in enumerate(splited_vars):
...@@ -326,6 +332,15 @@ class DistributeTranspiler(object): ...@@ -326,6 +332,15 @@ class DistributeTranspiler(object):
recv_dep_in = grad_name_to_send_dummy_out[ recv_dep_in = grad_name_to_send_dummy_out[
self.param_name_to_grad_name[param_varname]] self.param_name_to_grad_name[param_varname]]
all_recv_outputs.extend(splited_var) all_recv_outputs.extend(splited_var)
# get recv op_role_var, if not splited, the grad should have .trainer suffix
# if splited, grad should be the original grad var name. ParallelExecutor
# will use op_role_var to get expected device place to run this op.
orig_grad_name = self.param_name_to_grad_name[param_varname]
recv_op_role_var_name = orig_grad_name
splited_trainer_grad = self.grad_var_mapping[orig_grad_name]
if len(splited_trainer_grad) == 1:
recv_op_role_var_name = splited_trainer_grad[0].name
program.global_block().append_op( program.global_block().append_op(
type="recv", type="recv",
inputs={"X": [recv_dep_in]}, inputs={"X": [recv_dep_in]},
...@@ -333,10 +348,8 @@ class DistributeTranspiler(object): ...@@ -333,10 +348,8 @@ class DistributeTranspiler(object):
attrs={ attrs={
"epmap": eps, "epmap": eps,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: [ OP_ROLE_VAR_ATTR_NAME:
param_varname, [param_varname, recv_op_role_var_name],
self.param_name_to_grad_name[param_varname]
],
"sync_mode": not self.sync_mode "sync_mode": not self.sync_mode
}) })
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册