未验证 提交 af15f6f0 编写于 作者: Y Yan Chunwei 提交者: GitHub

fea/refine fuse (#13076)

上级 819af27d
...@@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) { ...@@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) {
auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
auto* while_pat_node = gpd.pattern().RetriveNode("while"); auto* while_pat_node = gpd.pattern().RetrieveNode("while");
auto* while_node = subgraph.at(while_pat_node); auto* while_node = subgraph.at(while_pat_node);
marked_nodes.insert(while_node); marked_nodes.insert(while_node);
}; };
......
...@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) { ...@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) {
} }
void BuildFCPattern(PDPattern* pattern) { void BuildFCPattern(PDPattern* pattern) {
// make sure the selected MUL op has one input argument is a parameter. // Create Operators
auto* mul_parameter_var = pattern->NewNode( auto* mul_op = pattern->NewNode("mul")->assert_is_op("mul");
[](Node* node) { auto* elementwise_add_op =
return node->IsVar() && node->outputs.size() == 1UL && pattern->NewNode("elementwise_add")->assert_is_op("elementwise_add");
node->outputs.front()->Op()->Type() == "mul" && node->Var() && // Create variables
node->Var()->Persistable(); // check is a parameter // w
}, auto* mul_weight_var = pattern->NewNode("mul_weight")
"mul_weight" /*name*/); ->AsInput()
->assert_is_op_nth_input("mul", "Y", 0);
auto* mul_tmp_input_var = pattern->NewNode( // x
[](Node* node) { auto* mul_tmp_var = pattern->NewNode("mul_tmp_var")
bool result = ->AsInput()
node->IsVar() && node->outputs.size() >= 1UL && node->Var() && ->assert_is_op_nth_input("mul", "X", 0);
!node->Var()->Persistable(); // this input is not an parameter. // intermediate variable, will be removed in the IR after fuse.
if (!result) return false; auto* mul_out_var = pattern->NewNode("mul_out")
// check whether one output is MUL op. ->AsIntermediate()
for (auto* op : node->outputs) { ->assert_is_only_output_of_op("mul")
if (op->IsOp() && op->Op()->Type() == "mul") return true; ->assert_is_op_input("elementwise_add");
} // bias
return false; auto* elementwise_add_tmp_var = pattern->NewNode("elementwise_add_tmpvar")
}, ->assert_is_op_input("elementwise_add")
"mul_tmp_var" /*name*/); ->AsInput();
// output
// select a MUL op auto* elementwise_add_out_var = pattern->NewNode("elementwise_add_out")
auto* mul_op = pattern->NewNode( ->AsOutput()
[](Node* node) { ->assert_is_op_output("elementwise_add");
return node->IsOp() && // start from an Op
node->Op()->Type() == "mul"; // type is mul mul_op->LinksFrom({mul_weight_var, mul_tmp_var}).LinksTo({mul_out_var});
// the output should be consumed only by one element_add, that check
// leaves in a Var PDNode.
},
"mul" /*name*/);
// make sure the MUL op's output has only one consumer and links to an
// ELEMENTWISE_ADD op.
auto* mul_out_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && // starts from a Var
node->outputs.size() == 1UL && // only has one consumer
node->outputs.front()->IsOp() && // check basic logic
node->Var() && // not a ControlDepVar
node->outputs.front()->Op()->Type() ==
"elementwise_add"; // a very strong validation
},
"mul_out");
// this check is not essential, just to make the corresponding variable Node
// retrival easier.
auto* elementwise_add_tmp_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
VarOutLinksToOp(node, "elementwise_add");
},
"elementwise_add_tmpvar");
// select an ELEMENTWISE_ADD op
auto* elementwise_add_op = pattern->NewNode(
[](Node* node) {
return node->IsOp() && node->Op()->Type() == "elementwise_add";
},
"elementwise_add" /*name*/);
// get the ELEMENTWISE_ADD op's output
auto* elementwise_add_out_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && node->inputs.size() == 1UL && node->Var() &&
node->inputs.front()->Op()->Type() == "elementwise_add";
},
"elementwise_add_out");
mul_op->LinksFrom({mul_parameter_var, mul_tmp_input_var})
.LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var}) elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var})
.LinksTo({elementwise_add_out_var}); .LinksTo({elementwise_add_out_var});
} }
...@@ -120,6 +77,7 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) { ...@@ -120,6 +77,7 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph.get());
FusePassBase::Init("fc", graph.get());
std::unordered_set<Node*> nodes2delete; std::unordered_set<Node*> nodes2delete;
...@@ -127,11 +85,12 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -127,11 +85,12 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
BuildFCPattern(gpd.mutable_pattern()); BuildFCPattern(gpd.mutable_pattern());
#define GET_NODE(id) \ #define GET_NODE(id) \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \ PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode(#id)), \
"pattern has no Node called %s", #id); \ "pattern has no Node called %s", #id); \
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \ auto* id = subgraph.at(gpd.pattern().RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle FC fuse"; VLOG(4) << "handle FC fuse";
...@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
graph->RemoveNode(mul); graph->RemoveNode(mul);
graph->RemoveNode(elementwise_add); graph->RemoveNode(elementwise_add);
graph->RemoveNode(mul_out); // tmp variable graph->RemoveNode(mul_out); // tmp variable
found_fc_count++;
}; };
gpd(graph.get(), handler); gpd(graph.get(), handler);
AddStatis(found_fc_count);
return graph; return graph;
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
...@@ -23,7 +24,7 @@ namespace ir { ...@@ -23,7 +24,7 @@ namespace ir {
/* /*
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp. * Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
*/ */
class FCFusePass : public Pass { class FCFusePass : public FusePassBase {
public: public:
virtual ~FCFusePass() {} virtual ~FCFusePass() {}
......
...@@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type, ...@@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& outputs) { const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp(); auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type); op->SetType(type);
op->SetInput("Xs", inputs); if (type == "mul") {
op->SetOutput("Ys", outputs); op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
} else if (type == "elementwise_add") {
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
} }
// a->OP0->b // a->OP0->b
......
...@@ -36,7 +36,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( ...@@ -36,7 +36,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
auto* id = subgraph.at(gpd.pattern().RetriveNode("any_node")); auto* id = subgraph.at(gpd.pattern().RetrieveNode("any_node"));
marked_nodes.insert(id); marked_nodes.insert(id);
}; };
gpd(graph.get(), handler); gpd(graph.get(), handler);
...@@ -64,9 +64,9 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( ...@@ -64,9 +64,9 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
#undef GET_NODE #undef GET_NODE
#undef SET_IN #undef SET_IN
LOG(INFO) << "hidden_n: " << hidden_n->Name(); VLOG(4) << "hidden_n: " << hidden_n->Name();
LOG(INFO) << "cell: " << cell_n->Name(); VLOG(4) << "cell: " << cell_n->Name();
LOG(INFO) << "xx: " << xx_n->Name(); VLOG(4) << "xx: " << xx_n->Name();
op_desc.SetInput("H0", {}); op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {}); op_desc.SetInput("C0", {});
......
...@@ -22,21 +22,37 @@ namespace paddle { ...@@ -22,21 +22,37 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static const char kParamScopeAttr[] = "param_scope"; static const char kParamScopeAttr[] = "__param_scope__";
static const char kFuseStatisAttr[] = "__fuse_statis__";
class FusePassBase : public Pass { class FusePassBase : public Pass {
public: public:
void Init(Graph* graph) const { graph_ = graph; } void Init(const std::string& repr, Graph* graph) const {
repr_ = repr;
graph_ = graph;
}
Scope* param_scope() const { Scope* param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr); return graph_->Get<framework::Scope*>(kParamScopeAttr);
} }
void AddStatis(int count_of_fused) const {
PADDLE_ENFORCE(graph_);
PADDLE_ENFORCE(!repr_.empty());
if (!graph_->Has(kFuseStatisAttr)) {
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
}
auto& info =
graph_->Get<std::unordered_map<std::string, int>>(kFuseStatisAttr);
info[repr_] = count_of_fused;
}
virtual ~FusePassBase() {} virtual ~FusePassBase() {}
protected: protected:
mutable Graph* graph_; mutable Graph* graph_;
mutable std::string repr_;
}; };
} // namespace ir } // namespace ir
......
...@@ -27,6 +27,19 @@ namespace ir { ...@@ -27,6 +27,19 @@ namespace ir {
size_t PDPattern::id_ = 0UL; size_t PDPattern::id_ = 0UL;
PDNode* PDPattern::NewNode(const std::string& name) {
if (!name.empty()) {
PADDLE_ENFORCE_EQ(node_map_.count(name), 0,
"PDNode's name should be unique, get duplicate [%s]",
name);
}
nodes_.emplace_back(new PDNode(this, name));
auto* cur = nodes_.back().get();
node_map_[name] = cur;
return cur;
}
PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
if (!name.empty()) { if (!name.empty()) {
PADDLE_ENFORCE_EQ(node_map_.count(name), 0, PADDLE_ENFORCE_EQ(node_map_.count(name), 0,
...@@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { ...@@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
return cur; return cur;
} }
PDNode* PDPattern::RetriveNode(const std::string& id) const { PDNode* PDPattern::RetrieveNode(const std::string& id) const {
auto it = node_map_.find(id); auto it = node_map_.find(id);
if (it == node_map_.end()) { if (it == node_map_.end()) {
return nullptr; return nullptr;
...@@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph, ...@@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph,
auto subgraphs = DetectPatterns(); auto subgraphs = DetectPatterns();
UniquePatterns(&subgraphs); UniquePatterns(&subgraphs);
RemoveOverlappedMatch(&subgraphs); RemoveOverlappedMatch(&subgraphs);
ValidateByNodeRole(&subgraphs);
if (subgraphs.empty()) return;
LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern"; LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern";
int id = 0; int id = 0;
for (auto& g : subgraphs) { for (auto& g : subgraphs) {
...@@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { ...@@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
} }
} }
} }
// Check to early stop if some PDNode can't find matched Node.
for (auto& pdnode : pattern_.nodes()) {
if (!pdnodes2nodes_.count(pdnode.get())) {
VLOG(4) << pdnode->name() << " can't find matched Node, early stop";
return false;
}
}
VLOG(3) << pdnodes2nodes_.size() << " nodes marked"; VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
return !pdnodes2nodes_.empty(); return !pdnodes2nodes_.empty();
} }
// The intermediate Nodes can only link to the nodes inside the pattern, or this
// subgraph will be droped.
void GraphPatternDetector::ValidateByNodeRole(
std::vector<GraphPatternDetector::subgraph_t>* subgraphs) {
std::vector<GraphPatternDetector::subgraph_t> result;
subgraphs->erase(
std::remove_if(
subgraphs->begin(), subgraphs->end(),
[](const GraphPatternDetector::subgraph_t& subgraph) -> bool {
// Collect the inputs and outputs.
std::unordered_set<Node*> ios;
for (auto& item : subgraph) {
if (!item.first->IsIntermediate()) {
ios.insert(item.second);
}
}
for (auto& item : subgraph) {
if (item.first->IsIntermediate()) {
for (auto* x : item.second->inputs) {
if (!ios.count(x)) {
return true;
}
}
for (auto* x : item.second->outputs) {
if (!ios.count(x)) {
return true;
}
}
}
}
return false;
}),
subgraphs->end());
}
struct HitGroup { struct HitGroup {
std::unordered_map<PDNode*, Node*> roles; std::unordered_map<PDNode*, Node*> roles;
...@@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() { ...@@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() {
// in edges of PDNodes. // in edges of PDNodes.
for (const auto& edge : pattern_.edges()) { for (const auto& edge : pattern_.edges()) {
VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name(); VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name();
// TODO(Superjomn) Fix bug here, the groups might be duplicate here.
// Each role has two PDNodes, which indicates two roles. // Each role has two PDNodes, which indicates two roles.
// Detect two Nodes that can match these two roles and they are connected. // Detect two Nodes that can match these two roles and they are connected.
auto& pre_groups = bi_records[step % 2]; auto& pre_groups = bi_records[step % 2];
...@@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() { ...@@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() {
// source -> target // source -> target
for (Node* source : pdnodes2nodes_[edge.first]) { for (Node* source : pdnodes2nodes_[edge.first]) {
for (Node* target : pdnodes2nodes_[edge.second]) { for (Node* target : pdnodes2nodes_[edge.second]) {
VLOG(8) << "check " << source->id() << " -- " << target->id();
// TODO(Superjomn) add some prune strategies. // TODO(Superjomn) add some prune strategies.
for (const auto& group : pre_groups) { for (const auto& group : pre_groups) {
HitGroup new_group = group; HitGroup new_group = group;
...@@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() { ...@@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() {
} }
} }
VLOG(3) << "step " << step << " get records: " << cur_groups.size(); VLOG(3) << "step " << step << " get records: " << cur_groups.size();
for (auto& group : cur_groups) {
for (auto& item : group.roles) {
VLOG(4) << "node " << item.second->id() << " as " << item.first->name();
}
VLOG(4) << "=========================================================";
}
} }
for (auto& group : bi_records[step % 2]) { for (auto& group : bi_records[step % 2]) {
...@@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) { ...@@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
return *this; return *this;
} }
PDNode* PDNode::assert_is_op() {
asserts_.emplace_back([this](Node* x) { return x && x->IsOp(); });
return this;
}
PDNode* PDNode::assert_is_op(const std::string& op_type) {
asserts_.emplace_back([this, op_type](Node* x) {
return x && x->IsOp() && x->Op()->Type() == op_type;
});
return this;
}
PDNode* PDNode::assert_is_var() {
asserts_.emplace_back([this](Node* x) { return x && x->IsVar(); });
return this;
}
PDNode* PDNode::assert_var_not_persistable() {
assert_is_var();
asserts_.emplace_back([this](Node* x) { return !x->Var()->Persistable(); });
return this;
}
PDNode* PDNode::assert_is_persistable_var() {
assert_is_var();
asserts_.emplace_back([=](Node* x) { return x->Var()->Persistable(); });
return this;
}
PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth) {
assert_is_var();
assert_is_op_input(op_type);
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->outputs) {
if (IsNthInput(x, op, argument, nth)) return true;
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type,
const std::string& argument, int nth) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->inputs) {
if (IsNthOutput(x, op, argument, nth)) return true;
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_only_input_of_op(const std::string& op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->outputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type &&
op->inputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_only_output_of_op(const std::string& op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->inputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type &&
op->outputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_op_output(const std::string& op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->inputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
return true;
}
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_op_input(const std::string& op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->outputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
return true;
}
}
return false;
});
return this;
}
PDNode* PDNode::assert_op_has_n_inputs(const std::string& op_type, size_t n) {
assert_is_op(op_type);
asserts_.emplace_back([=](Node* x) { return x->inputs.size() == n; });
return this;
}
PDNode* PDNode::assert_op_has_n_outputs(const std::string& op_type, size_t n) {
assert_is_op(op_type);
asserts_.emplace_back([=](Node* x) { return x->outputs.size() == n; });
return this;
}
PDNode* PDNode::assert_more(PDNode::teller_t&& teller) {
asserts_.emplace_back(std::move(teller));
return this;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -39,14 +39,24 @@ struct PDNode { ...@@ -39,14 +39,24 @@ struct PDNode {
// tell whether an ir::Node* is a candidation for a PDNode. // tell whether an ir::Node* is a candidation for a PDNode.
using teller_t = std::function<bool(Node*)>; using teller_t = std::function<bool(Node*)>;
enum class Type { kOp, kVar }; enum class Type { kOp, kVar };
enum class Role {
kUnknown, // No role,
kInput, // an input and will be retained,
kOutput, // an output and will be retained,
kIntermediate // will be removed after handler.
};
// this link to others // this link to others
PDNode& LinksTo(const std::vector<PDNode*>& others); PDNode& LinksTo(const std::vector<PDNode*>& others);
PDNode& LinksFrom(const std::vector<PDNode*>& others); PDNode& LinksFrom(const std::vector<PDNode*>& others);
bool Tell(Node* node) const { bool Tell(Node* node) const {
PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode"); if (teller_) return teller_(node);
return teller_(node);
for (auto& asrt : asserts_) {
if (!asrt(node)) return false;
}
return true;
} }
bool IsOp() const { return type_ == Type::kOp; } bool IsOp() const { return type_ == Type::kOp; }
...@@ -54,10 +64,52 @@ struct PDNode { ...@@ -54,10 +64,52 @@ struct PDNode {
const std::string& name() const { return name_; } const std::string& name() const { return name_; }
PDNode(const PDNode&) = delete;
PDNode& operator=(const PDNode&) = delete; PDNode& operator=(const PDNode&) = delete;
PDNode(const PDNode&) = delete;
// Mark this node is an Input of a subgraph and will be retained.
PDNode* AsInput() {
role_ = Role::kInput;
return this;
}
// Mark this node is an Output of a subgraph and will be retained.
PDNode* AsOutput() {
role_ = Role::kOutput;
return this;
}
// Mark this node will be removed, so all the links should be inside a matched
// sub-graph.
PDNode* AsIntermediate() {
role_ = Role::kIntermediate;
return this;
}
bool IsIntermediate() const { return role_ == Role::kIntermediate; }
bool IsInput() const { return role_ == Role::kInput; }
bool IsOutput() const { return role_ == Role::kOutput; }
// Assertions, helper functions to simplify the pattern definition.
PDNode* assert_is_op();
PDNode* assert_is_op(const std::string& op_type);
PDNode* assert_is_var();
PDNode* assert_var_not_persistable();
PDNode* assert_is_persistable_var();
PDNode* assert_is_op_output(const std::string& op_type);
PDNode* assert_is_op_input(const std::string& op_type);
PDNode* assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth);
PDNode* assert_is_op_nth_output(const std::string& op_type,
const std::string& argument, int nth);
PDNode* assert_is_only_input_of_op(const std::string& op_type);
PDNode* assert_is_only_output_of_op(const std::string& op_type);
PDNode* assert_op_has_n_inputs(const std::string& op_type, size_t n);
PDNode* assert_op_has_n_outputs(const std::string& op_type, size_t n);
PDNode* assert_more(teller_t&& teller);
private: private:
PDNode(PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
: pattern_(pattern), name_(name), type_(type) {}
PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "", PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar) Type type = Type::kVar)
: teller_(std::move(teller)), : teller_(std::move(teller)),
...@@ -71,10 +123,13 @@ struct PDNode { ...@@ -71,10 +123,13 @@ struct PDNode {
friend class PDPattern; friend class PDPattern;
// Will removed latter.
teller_t teller_; teller_t teller_;
std::vector<teller_t> asserts_;
PDPattern* pattern_; PDPattern* pattern_;
std::string name_; std::string name_;
Type type_; Type type_;
Role role_{Role::kUnknown};
}; };
/* /*
...@@ -87,19 +142,18 @@ struct PDNode { ...@@ -87,19 +142,18 @@ struct PDNode {
* This pattern can be defined as with the following pseudo codes * This pattern can be defined as with the following pseudo codes
* *
* // Create two operator PDNodes. * // Create two operator PDNodes.
* MUL = PDPattern.NewNode() * MUL = PDPattern.NewNode().assert_is_op("mul");
* ELE = PDPattern.NewNode() * ELE = PDPattern.NewNode().assert_is_op("elementwise_add");
* // Create the variable PDNodes. * // Create the variable PDNodes.
* MUL_out = PDPattern.NewNode() * MUL_out = PDPattern.NewNode().assert_is_op_output("mul") \
* // Add teller to define some rules that help to filter the target Nodes. * .assert_is_op_input("elementwise_add") \
* MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul"; * .AsIntermediate();
* ELE.teller = lambda(node): \ * // Add relations.
* node->IsOp() && node->Op()->Type == "elementwise_add"; * MUL->LinksTo({MUL_out});
* MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs) * MUL_out->LinksTo({ELE});
* && (ELE in node->outputs)
* *
* One can add more specific tellers for PDNodes or edges, both the Operator * One can add more specific asserts for PDNodes or edges, both the Operator
* and Variable Nodes can be ruled in PDNode.teller. * and Variable Nodes can be ruled in PDNode.assert_more(...).
* *
* PDPattern can record the general patterns, such as the pattern represents * PDPattern can record the general patterns, such as the pattern represents
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place. * - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
...@@ -112,7 +166,8 @@ class PDPattern { ...@@ -112,7 +166,8 @@ class PDPattern {
void AddEdge(PDNode* a, PDNode* b); void AddEdge(PDNode* a, PDNode* b);
PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = NewID()); PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = NewID());
PDNode* RetriveNode(const std::string& id) const; PDNode* NewNode(const std::string& name = NewID());
PDNode* RetrieveNode(const std::string& id) const;
const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; } const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; }
const std::vector<edge_t>& edges() const { return edges_; } const std::vector<edge_t>& edges() const { return edges_; }
...@@ -185,6 +240,9 @@ class GraphPatternDetector { ...@@ -185,6 +240,9 @@ class GraphPatternDetector {
// Remove overlapped match subgraphs, when overlapped, keep the previous one. // Remove overlapped match subgraphs, when overlapped, keep the previous one.
void RemoveOverlappedMatch(std::vector<subgraph_t>* subgraphs); void RemoveOverlappedMatch(std::vector<subgraph_t>* subgraphs);
// Validate whether the intermediate nodes are linked by external nodes.
void ValidateByNodeRole(std::vector<subgraph_t>* subgraphs);
#ifdef PADDLE_WITH_TESTING #ifdef PADDLE_WITH_TESTING
FRIEND_TEST(GraphPatternDetecter, MarkPDNodesInGraph); FRIEND_TEST(GraphPatternDetecter, MarkPDNodesInGraph);
FRIEND_TEST(GraphPatternDetecter, DetectPatterns); FRIEND_TEST(GraphPatternDetecter, DetectPatterns);
...@@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument, ...@@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument,
return var->Name() == op->Op()->Input(argument)[nth]; return var->Name() == op->Op()->Input(argument)[nth];
} }
static bool IsNthOutput(Node* var, Node* op, const std::string& argument,
size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->inputs.size() <= nth) return false;
return var->Name() == op->Op()->Output(argument)[nth];
}
static void GraphSafeRemoveNodes(Graph* graph, static void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes) { const std::unordered_set<const Node*>& nodes) {
for (auto* node : nodes) { for (auto* node : nodes) {
......
...@@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) { ...@@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
ASSERT_LE(count, 2); ASSERT_LE(count, 2);
} }
TEST(GraphPatternDetector, IntermediateCheck) {
ProgramDesc program;
Graph graph(program);
BuildGraph(&graph);
// o2->v2->o3
// o2->v2->o4
// check o2+o3 fuse, should fail because v2 also link to o4.
GraphPatternDetector detector;
auto* op2 = detector.mutable_pattern()->NewNode(
[](Node* x) { return x && x->IsOp() && x->Name() == "op2"; }, "op2");
auto* op3 = detector.mutable_pattern()->NewNode(
[](Node* x) { return x && x->IsOp() && x->Name() == "op3"; }, "op3");
auto* v2 =
detector.mutable_pattern()
->NewNode(
[](Node* x) { return x && x->IsVar() && x->Name() == "var2"; },
"var2")
->AsIntermediate();
v2->LinksFrom({op2}).LinksTo({op3});
int count = 0;
detector(&graph, [&](const GraphPatternDetector::subgraph_t& g,
Graph* graph) { ++count; });
EXPECT_EQ(count, 0);
count = 0;
v2->AsInput();
detector(&graph, [&](const GraphPatternDetector::subgraph_t& g,
Graph* graph) { ++count; });
ASSERT_EQ(count, 1);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) { ...@@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl( std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(graph.get()); FusePassBase::Init("seq_concat_fc_fuse", graph.get());
GraphPatternDetector detector; GraphPatternDetector detector;
auto* pattern = detector.mutable_pattern(); auto* pattern = detector.mutable_pattern();
auto* concat_out = BuildSeqExpandConcatPattern(pattern); auto* concat_out = BuildSeqExpandConcatPattern(pattern);
BuildFCPattern(pattern, concat_out); BuildFCPattern(pattern, concat_out);
#define GET_NODE(id, pattern) \ #define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \ PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \
"pattern has no Node called %s", #id); \ "pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetriveNode(#id)); \ auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph, detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph,
......
...@@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager {
void AddGraphvizDebugerPass(Pass* pass) { void AddGraphvizDebugerPass(Pass* pass) {
auto* debuger_pass = pass->CreateGraphvizDebugerPass(); auto* debuger_pass = pass->CreateGraphvizDebugerPass();
if (debuger_pass) { if (debuger_pass) {
LOG(INFO) << " - register debug pass [" << debuger_pass->repr() << "]";
Register(debuger_pass->repr(), debuger_pass); Register(debuger_pass->repr(), debuger_pass);
} }
} }
......
...@@ -16,10 +16,13 @@ ...@@ -16,10 +16,13 @@
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN"); DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN");
...@@ -31,6 +34,8 @@ namespace paddle { ...@@ -31,6 +34,8 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
using namespace framework;
TEST(Analyzer, analysis_without_tensorrt) { TEST(Analyzer, analysis_without_tensorrt) {
FLAGS_IA_enable_tensorrt_subgraph_engine = false; FLAGS_IA_enable_tensorrt_subgraph_engine = false;
Argument argument; Argument argument;
...@@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path, ...@@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path,
EXPECT_NEAR(data[i], base_data[i], 1e-3); EXPECT_NEAR(data[i], base_data[i], 1e-3);
} }
} }
if (use_analysis && activate_ir) {
AnalysisPredictor *analysis_predictor =
dynamic_cast<AnalysisPredictor *>(predictor.get());
auto &fuse_statis = analysis_predictor->analysis_argument()
.Get<std::unordered_map<std::string, int>>(
framework::ir::kFuseStatisAttr);
for (auto &item : fuse_statis) {
LOG(INFO) << "fused " << item.first << " " << item.second;
}
ASSERT_TRUE(fuse_statis.count("fc"));
EXPECT_EQ(fuse_statis.at("fc"), 1);
}
} }
// Directly infer with the original model. // Directly infer with the original model.
......
...@@ -64,7 +64,8 @@ struct Argument { ...@@ -64,7 +64,8 @@ struct Argument {
template <typename T> template <typename T>
void Set(const std::string& key, T* data) { void Set(const std::string& key, T* data) {
PADDLE_ENFORCE_NOT_NULL(data); PADDLE_ENFORCE_NOT_NULL(data);
PADDLE_ENFORCE(!attrs_.count(key), "duplicate attr called %s", key); PADDLE_ENFORCE(!attrs_.count(key), "Duplicate set Argument's attr [%s]",
key);
attrs_[key] = data; attrs_[key] = data;
attr_deleters_[key] = [data, key, this]() { attr_deleters_[key] = [data, key, this]() {
VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" #include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include <vector> #include <vector>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/proto_desc.h" #include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
...@@ -34,7 +35,6 @@ std::vector<std::string> ExtractParameters( ...@@ -34,7 +35,6 @@ std::vector<std::string> ExtractParameters(
bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument) ANALYSIS_ARGUMENT_CHECK_FIELD(argument)
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc) ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc)
PADDLE_ENFORCE(!argument->transformed_program_desc);
// The transformed_program_desc should inherit all the VarDesc and BlockDesc // The transformed_program_desc should inherit all the VarDesc and BlockDesc
// from the original program desc. The operators of the main block(the first // from the original program desc. The operators of the main block(the first
// block) should rewritten by data flow graph. // block) should rewritten by data flow graph.
...@@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { ...@@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
} }
} }
if (argument_->Has("param_scope")) { if (argument_->Has(framework::ir::kParamScopeAttr)) {
LOG(WARNING) << "parameter changes in the scope takes effect"; LOG(WARNING) << "parameter changes in the scope takes effect";
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir, ...@@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir,
const std::string &prog_file, const std::string &prog_file,
const std::string &param_file) { const std::string &param_file) {
PADDLE_ENFORCE(argument_); PADDLE_ENFORCE(argument_);
argument_->Set("param_scope", new framework::Scope); argument_->Set(framework::ir::kParamScopeAttr, new framework::Scope);
// Load parameters. // Load parameters.
VLOG(3) << "Loading parameters from " << model_dir; VLOG(3) << "Loading parameters from " << model_dir;
LoadParams(&argument_->Get<framework::Scope>("param_scope"), model_dir, LoadParams(&argument_->Get<framework::Scope>(framework::ir::kParamScopeAttr),
prog_file, param_file); model_dir, prog_file, param_file);
} }
bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir, bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir,
......
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
#pragma once #pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
using namespace framework;
static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__"; static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__";
...@@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path); ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path);
// Load program. // Load program.
auto program = LoadProgramDesc(*argument->fluid_model_program_path); auto program = LoadProgramDesc(*argument->fluid_model_program_path);
argument->origin_program_desc.reset( argument->origin_program_desc.reset(new proto::ProgramDesc(program));
new framework::proto::ProgramDesc(program));
// Create main data flow graph. // Create main data flow graph.
if (!argument->main_dfg) { if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph); argument->main_dfg.reset(new DataFlowGraph);
} }
argument->Set("ir_program_desc", new framework::ProgramDesc(program)); argument->Set("ir_program_desc", new ProgramDesc(program));
LOG(INFO) << "Loading parameters"; LOG(INFO) << "Loading parameters";
// Load parameters to argument if needed. // Load parameters to argument if needed.
...@@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass {
void Run(DataFlowGraph *graph) override { void Run(DataFlowGraph *graph) override {
// Call all the IR Passes // Call all the IR Passes
IRPassManager ir_passes( IRPassManager ir_passes(argument_->Get<ProgramDesc>("ir_program_desc"),
argument_->Get<framework::ProgramDesc>("ir_program_desc"), nullptr); nullptr);
// Pass the scope from analysis to IR if needed. // Pass the scope from analysis to IR if needed.
if (argument_->Has("param_scope")) { if (argument_->Has(ir::kParamScopeAttr)) {
// Here the address is passed, attention that IR doesn't own the scope, so // Here the address is passed, attention that IR doesn't own the scope, so
// the real scope in analysis should live during the IR phase. // the real scope in analysis should live during the IR phase.
ir_passes.graph().Set( ir_passes.graph().Set(
"param_scope", new framework::Scope *( ir::kParamScopeAttr,
&argument_->Get<framework::Scope>("param_scope"))); new Scope *(&argument_->Get<Scope>(ir::kParamScopeAttr)));
} }
const auto &ir_passes_to_apply = const auto &ir_passes_to_apply =
...@@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass {
PADDLE_ENFORCE(argument_->main_dfg.get()); PADDLE_ENFORCE(argument_->main_dfg.get());
argument_->main_dfg->Build(ir_passes.graph()); argument_->main_dfg->Build(ir_passes.graph());
// inherit the arguments from ir.
if (ir_passes.graph().Has(ir::kFuseStatisAttr)) {
argument_->Set(
ir::kFuseStatisAttr,
new std::unordered_map<std::string, int>(
ir_passes.graph().Get<std::unordered_map<std::string, int>>(
ir::kFuseStatisAttr)));
}
} }
void EnableParamModify(const std::string &model_dir, void EnableParamModify(const std::string &model_dir,
...@@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass {
private: private:
// Load parameters from a single file or from a directory. // Load parameters from a single file or from a directory.
bool LoadParams(framework::Scope *scope, const std::string &dir, bool LoadParams(Scope *scope, const std::string &dir,
const std::string &prog_file, const std::string &param_file); const std::string &prog_file, const std::string &param_file);
private: private:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program, ...@@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
framework::Scope *scope) framework::Scope *scope)
: program_(program) { : program_(program) {
graph_.reset(new framework::ir::Graph(program)); graph_.reset(new framework::ir::Graph(program));
if (scope) graph_->Set("param_scope", new framework::Scope *(scope)); if (scope)
graph_->Set(framework::ir::kParamScopeAttr, new framework::Scope *(scope));
} }
void IRPassManager::Apply(const std::vector<std::string> &passes) { void IRPassManager::Apply(const std::vector<std::string> &passes) {
......
...@@ -12,31 +12,18 @@ ...@@ -12,31 +12,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <memory> #include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
namespace paddle { namespace paddle {
using inference::analysis::Argument; bool AnalysisPredictor::Init(
using inference::Singleton; const std::shared_ptr<framework::Scope>& parent_scope) {
using inference::analysis::Analyzer;
using framework::proto::ProgramDesc;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class AnalysisPredictor : public NativePaddlePredictor {
public:
explicit AnalysisPredictor(const NativeConfig& config)
: NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
VLOG(3) << "Predictor::init()"; VLOG(3) << "Predictor::init()";
if (config_.use_gpu) { if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device); place_ = paddle::platform::CUDAPlace(config_.device);
...@@ -58,8 +45,8 @@ class AnalysisPredictor : public NativePaddlePredictor { ...@@ -58,8 +45,8 @@ class AnalysisPredictor : public NativePaddlePredictor {
if (!config_.model_dir.empty()) { if (!config_.model_dir.empty()) {
// Parameters are saved in separate files sited in // Parameters are saved in separate files sited in
// the specified `dirname`. // the specified `dirname`.
inference_program_ = paddle::inference::Load( inference_program_ = paddle::inference::Load(executor_.get(), scope_.get(),
executor_.get(), scope_.get(), config_.model_dir); config_.model_dir);
} else if (!config_.prog_file.empty() && !config_.param_file.empty()) { } else if (!config_.prog_file.empty() && !config_.param_file.empty()) {
// All parameters are saved in a single file. // All parameters are saved in a single file.
// The file names should be consistent with that used // The file names should be consistent with that used
...@@ -78,55 +65,43 @@ class AnalysisPredictor : public NativePaddlePredictor { ...@@ -78,55 +65,43 @@ class AnalysisPredictor : public NativePaddlePredictor {
PADDLE_ENFORCE(scope_.get()); PADDLE_ENFORCE(scope_.get());
executor_->CreateVariables(*inference_program_, executor_->CreateVariables(*inference_program_,
sub_scope_ ? sub_scope_ : scope_.get(), 0); sub_scope_ ? sub_scope_ : scope_.get(), 0);
// Get the feed_target_names and fetch_target_names // Get the feed_target_names and fetch_target_names
PrepareFeedFetch(); PrepareFeedFetch();
return true; return true;
} }
bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data,
int batch_size = -1) override {
return NativePaddlePredictor::Run(inputs, output_data, batch_size);
}
void OptimizeInferenceProgram() { void AnalysisPredictor::OptimizeInferenceProgram() {
LOG(INFO) << "optimize begin"; LOG(INFO) << "optimize begin";
FLAGS_IA_enable_ir = true; FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_tensorrt_subgraph_engine = false; FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = ""; // Don't output the model. FLAGS_IA_output_storage_path = ""; // Don't output the model.
// Analyze inference_program // Analyze inference_program
Argument argument;
if (!config_.model_dir.empty()) { if (!config_.model_dir.empty()) {
argument.fluid_model_dir.reset(new std::string(config_.model_dir)); argument_.fluid_model_dir.reset(new std::string(config_.model_dir));
} else { } else {
PADDLE_ENFORCE( PADDLE_ENFORCE(
!config_.param_file.empty(), !config_.param_file.empty(),
"Either model_dir or (param_file, prog_file) should be set."); "Either model_dir or (param_file, prog_file) should be set.");
PADDLE_ENFORCE(!config_.prog_file.empty()); PADDLE_ENFORCE(!config_.prog_file.empty());
argument.fluid_model_program_path.reset( argument_.fluid_model_program_path.reset(
new std::string(config_.prog_file)); new std::string(config_.prog_file));
argument.fluid_model_param_path.reset( argument_.fluid_model_param_path.reset(new std::string(config_.param_file));
new std::string(config_.param_file));
} }
argument.origin_program_desc.reset( argument_.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto())); new ProgramDesc(*inference_program_->Proto()));
Singleton<Analyzer>::Global().Run(&argument); Analyzer().Run(&argument_);
CHECK(argument.transformed_program_desc); CHECK(argument_.transformed_program_desc);
VLOG(5) << "to prepare executor"; VLOG(5) << "to prepare executor";
// LOG(INFO) << "transformed_parogram_desc " << // LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString(); // argument.transformed_program_desc->DebugString();
inference_program_.reset( inference_program_.reset(
new framework::ProgramDesc(*argument.transformed_program_desc)); new framework::ProgramDesc(*argument_.transformed_program_desc));
PADDLE_ENFORCE(argument.Has("param_scope")); PADDLE_ENFORCE(argument_.Has(framework::ir::kParamScopeAttr));
// Update scope. // Update scope.
scope_.reset(argument.Release<framework::Scope>("param_scope")); scope_.reset(
argument_.Release<framework::Scope>(framework::ir::kParamScopeAttr));
LOG(INFO) << "optimize end =="; LOG(INFO) << "optimize end ==";
} }
private:
NativeConfig config_;
};
template <> template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace paddle {
using inference::analysis::Argument;
using inference::analysis::Analyzer;
using framework::proto::ProgramDesc;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class AnalysisPredictor : public NativePaddlePredictor {
public:
explicit AnalysisPredictor(const NativeConfig& config)
: NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope);
bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data,
int batch_size = -1) override {
return NativePaddlePredictor::Run(inputs, output_data, batch_size);
}
void OptimizeInferenceProgram();
Argument& analysis_argument() { return argument_; }
private:
NativeConfig config_;
Argument argument_;
};
} // namespace paddle
...@@ -62,7 +62,7 @@ void NativePaddlePredictor::PrepareFeedFetch() { ...@@ -62,7 +62,7 @@ void NativePaddlePredictor::PrepareFeedFetch() {
for (auto *op : inference_program_->Block(0).AllOps()) { for (auto *op : inference_program_->Block(0).AllOps()) {
if (op->Type() == "feed") { if (op->Type() == "feed") {
int idx = boost::get<int>(op->GetAttr("col")); int idx = boost::get<int>(op->GetAttr("col"));
if (feeds_.size() <= idx) { if (feeds_.size() <= static_cast<size_t>(idx)) {
feeds_.resize(idx + 1); feeds_.resize(idx + 1);
} }
feeds_[idx] = op; feeds_[idx] = op;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册