From af15f6f0380a3975239d7d4a6c7c682f4058d309 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Fri, 31 Aug 2018 13:22:42 +0800 Subject: [PATCH] fea/refine fuse (#13076) --- .../framework/ir/attention_lstm_fuse_pass.cc | 2 +- paddle/fluid/framework/ir/fc_fuse_pass.cc | 112 ++++------- paddle/fluid/framework/ir/fc_fuse_pass.h | 3 +- .../fluid/framework/ir/fc_fuse_pass_tester.cc | 9 +- .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 8 +- paddle/fluid/framework/ir/fuse_pass_base.h | 20 +- .../framework/ir/graph_pattern_detector.cc | 181 +++++++++++++++++- .../framework/ir/graph_pattern_detector.h | 96 ++++++++-- .../ir/graph_pattern_detector_tester.cc | 33 ++++ .../framework/ir/seq_concat_fc_fuse_pass.cc | 10 +- paddle/fluid/inference/analysis/analyzer.cc | 1 - .../inference/analysis/analyzer_tester.cc | 19 ++ paddle/fluid/inference/analysis/argument.h | 3 +- .../analysis/data_flow_graph_to_fluid_pass.cc | 4 +- .../inference/analysis/fluid_to_ir_pass.cc | 7 +- .../inference/analysis/fluid_to_ir_pass.h | 27 ++- .../inference/analysis/ir_pass_manager.cc | 4 +- .../fluid/inference/api/analysis_predictor.cc | 179 ++++++++--------- .../fluid/inference/api/analysis_predictor.h | 51 +++++ paddle/fluid/inference/api/api_impl.cc | 2 +- 20 files changed, 545 insertions(+), 226 deletions(-) create mode 100644 paddle/fluid/inference/api/analysis_predictor.h diff --git a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc index 2876de88f17..d2d051a69a3 100644 --- a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc @@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) { auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - auto* while_pat_node = gpd.pattern().RetriveNode("while"); + auto* while_pat_node = gpd.pattern().RetrieveNode("while"); auto* while_node = subgraph.at(while_pat_node); marked_nodes.insert(while_node); }; diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 201160f29df..513742bab69 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) { } void BuildFCPattern(PDPattern* pattern) { - // make sure the selected MUL op has one input argument is a parameter. - auto* mul_parameter_var = pattern->NewNode( - [](Node* node) { - return node->IsVar() && node->outputs.size() == 1UL && - node->outputs.front()->Op()->Type() == "mul" && node->Var() && - node->Var()->Persistable(); // check is a parameter - }, - "mul_weight" /*name*/); - - auto* mul_tmp_input_var = pattern->NewNode( - [](Node* node) { - bool result = - node->IsVar() && node->outputs.size() >= 1UL && node->Var() && - !node->Var()->Persistable(); // this input is not an parameter. - if (!result) return false; - // check whether one output is MUL op. - for (auto* op : node->outputs) { - if (op->IsOp() && op->Op()->Type() == "mul") return true; - } - return false; - }, - "mul_tmp_var" /*name*/); - - // select a MUL op - auto* mul_op = pattern->NewNode( - [](Node* node) { - return node->IsOp() && // start from an Op - node->Op()->Type() == "mul"; // type is mul - // the output should be consumed only by one element_add, that check - // leaves in a Var PDNode. - }, - "mul" /*name*/); - - // make sure the MUL op's output has only one consumer and links to an - // ELEMENTWISE_ADD op. - auto* mul_out_var = pattern->NewNode( - [](Node* node) { - return node->IsVar() && // starts from a Var - node->outputs.size() == 1UL && // only has one consumer - node->outputs.front()->IsOp() && // check basic logic - node->Var() && // not a ControlDepVar - node->outputs.front()->Op()->Type() == - "elementwise_add"; // a very strong validation - }, - "mul_out"); - // this check is not essential, just to make the corresponding variable Node - // retrival easier. - auto* elementwise_add_tmp_var = pattern->NewNode( - [](Node* node) { - return node->IsVar() && node->outputs.size() >= 1UL && node->Var() && - VarOutLinksToOp(node, "elementwise_add"); - }, - "elementwise_add_tmpvar"); - - // select an ELEMENTWISE_ADD op - auto* elementwise_add_op = pattern->NewNode( - [](Node* node) { - return node->IsOp() && node->Op()->Type() == "elementwise_add"; - }, - "elementwise_add" /*name*/); - - // get the ELEMENTWISE_ADD op's output - auto* elementwise_add_out_var = pattern->NewNode( - [](Node* node) { - return node->IsVar() && node->inputs.size() == 1UL && node->Var() && - node->inputs.front()->Op()->Type() == "elementwise_add"; - }, - "elementwise_add_out"); - - mul_op->LinksFrom({mul_parameter_var, mul_tmp_input_var}) - .LinksTo({mul_out_var}); + // Create Operators + auto* mul_op = pattern->NewNode("mul")->assert_is_op("mul"); + auto* elementwise_add_op = + pattern->NewNode("elementwise_add")->assert_is_op("elementwise_add"); + // Create variables + // w + auto* mul_weight_var = pattern->NewNode("mul_weight") + ->AsInput() + ->assert_is_op_nth_input("mul", "Y", 0); + // x + auto* mul_tmp_var = pattern->NewNode("mul_tmp_var") + ->AsInput() + ->assert_is_op_nth_input("mul", "X", 0); + // intermediate variable, will be removed in the IR after fuse. + auto* mul_out_var = pattern->NewNode("mul_out") + ->AsIntermediate() + ->assert_is_only_output_of_op("mul") + ->assert_is_op_input("elementwise_add"); + // bias + auto* elementwise_add_tmp_var = pattern->NewNode("elementwise_add_tmpvar") + ->assert_is_op_input("elementwise_add") + ->AsInput(); + // output + auto* elementwise_add_out_var = pattern->NewNode("elementwise_add_out") + ->AsOutput() + ->assert_is_op_output("elementwise_add"); + + mul_op->LinksFrom({mul_weight_var, mul_tmp_var}).LinksTo({mul_out_var}); elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var}) .LinksTo({elementwise_add_out_var}); } @@ -120,18 +77,20 @@ bool LinksReplace(std::vector* links, Node* from, Node* to) { std::unique_ptr FCFusePass::ApplyImpl( std::unique_ptr graph) const { PADDLE_ENFORCE(graph.get()); + FusePassBase::Init("fc", graph.get()); std::unordered_set nodes2delete; GraphPatternDetector gpd; BuildFCPattern(gpd.mutable_pattern()); -#define GET_NODE(id) \ - PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \ - "pattern has no Node called %s", #id); \ - auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \ +#define GET_NODE(id) \ + PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode(#id)), \ + "pattern has no Node called %s", #id); \ + auto* id = subgraph.at(gpd.pattern().RetrieveNode(#id)); \ PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); + int found_fc_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { VLOG(4) << "handle FC fuse"; @@ -176,10 +135,13 @@ std::unique_ptr FCFusePass::ApplyImpl( graph->RemoveNode(mul); graph->RemoveNode(elementwise_add); graph->RemoveNode(mul_out); // tmp variable + + found_fc_count++; }; gpd(graph.get(), handler); + AddStatis(found_fc_count); return graph; } diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.h b/paddle/fluid/framework/ir/fc_fuse_pass.h index 31ed0e362f7..6c69539d1e4 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_fuse_pass.h @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" @@ -23,7 +24,7 @@ namespace ir { /* * Fuse the MUL and ELEMENTWISE_ADD to a FCOp. */ -class FCFusePass : public Pass { +class FCFusePass : public FusePassBase { public: virtual ~FCFusePass() {} diff --git a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc index 87ba417b1a4..06286a109d0 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass_tester.cc @@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::vector& outputs) { auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); - op->SetInput("Xs", inputs); - op->SetOutput("Ys", outputs); + if (type == "mul") { + op->SetInput("X", {inputs[0]}); + op->SetInput("Y", {inputs[1]}); + } else if (type == "elementwise_add") { + op->SetInput("X", inputs); + } + op->SetOutput("Out", outputs); } // a->OP0->b diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index daecf3b407c..5852705b6b8 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -36,7 +36,7 @@ std::unique_ptr FCLstmFusePass::ApplyImpl( auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - auto* id = subgraph.at(gpd.pattern().RetriveNode("any_node")); + auto* id = subgraph.at(gpd.pattern().RetrieveNode("any_node")); marked_nodes.insert(id); }; gpd(graph.get(), handler); @@ -64,9 +64,9 @@ std::unique_ptr FCLstmFusePass::ApplyImpl( #undef GET_NODE #undef SET_IN - LOG(INFO) << "hidden_n: " << hidden_n->Name(); - LOG(INFO) << "cell: " << cell_n->Name(); - LOG(INFO) << "xx: " << xx_n->Name(); + VLOG(4) << "hidden_n: " << hidden_n->Name(); + VLOG(4) << "cell: " << cell_n->Name(); + VLOG(4) << "xx: " << xx_n->Name(); op_desc.SetInput("H0", {}); op_desc.SetInput("C0", {}); diff --git a/paddle/fluid/framework/ir/fuse_pass_base.h b/paddle/fluid/framework/ir/fuse_pass_base.h index bf6a0ae8274..877bbeb5022 100644 --- a/paddle/fluid/framework/ir/fuse_pass_base.h +++ b/paddle/fluid/framework/ir/fuse_pass_base.h @@ -22,21 +22,37 @@ namespace paddle { namespace framework { namespace ir { -static const char kParamScopeAttr[] = "param_scope"; +static const char kParamScopeAttr[] = "__param_scope__"; +static const char kFuseStatisAttr[] = "__fuse_statis__"; class FusePassBase : public Pass { public: - void Init(Graph* graph) const { graph_ = graph; } + void Init(const std::string& repr, Graph* graph) const { + repr_ = repr; + graph_ = graph; + } Scope* param_scope() const { PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); return graph_->Get(kParamScopeAttr); } + void AddStatis(int count_of_fused) const { + PADDLE_ENFORCE(graph_); + PADDLE_ENFORCE(!repr_.empty()); + if (!graph_->Has(kFuseStatisAttr)) { + graph_->Set(kFuseStatisAttr, new std::unordered_map); + } + auto& info = + graph_->Get>(kFuseStatisAttr); + info[repr_] = count_of_fused; + } + virtual ~FusePassBase() {} protected: mutable Graph* graph_; + mutable std::string repr_; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index dce4be8ff04..945ab110b14 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -27,6 +27,19 @@ namespace ir { size_t PDPattern::id_ = 0UL; +PDNode* PDPattern::NewNode(const std::string& name) { + if (!name.empty()) { + PADDLE_ENFORCE_EQ(node_map_.count(name), 0, + "PDNode's name should be unique, get duplicate [%s]", + name); + } + + nodes_.emplace_back(new PDNode(this, name)); + auto* cur = nodes_.back().get(); + node_map_[name] = cur; + return cur; +} + PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { if (!name.empty()) { PADDLE_ENFORCE_EQ(node_map_.count(name), 0, @@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { return cur; } -PDNode* PDPattern::RetriveNode(const std::string& id) const { +PDNode* PDPattern::RetrieveNode(const std::string& id) const { auto it = node_map_.find(id); if (it == node_map_.end()) { return nullptr; @@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph, auto subgraphs = DetectPatterns(); UniquePatterns(&subgraphs); RemoveOverlappedMatch(&subgraphs); + ValidateByNodeRole(&subgraphs); + if (subgraphs.empty()) return; LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern"; int id = 0; for (auto& g : subgraphs) { @@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { } } } + // Check to early stop if some PDNode can't find matched Node. + for (auto& pdnode : pattern_.nodes()) { + if (!pdnodes2nodes_.count(pdnode.get())) { + VLOG(4) << pdnode->name() << " can't find matched Node, early stop"; + + return false; + } + } VLOG(3) << pdnodes2nodes_.size() << " nodes marked"; return !pdnodes2nodes_.empty(); } +// The intermediate Nodes can only link to the nodes inside the pattern, or this +// subgraph will be droped. +void GraphPatternDetector::ValidateByNodeRole( + std::vector* subgraphs) { + std::vector result; + + subgraphs->erase( + std::remove_if( + subgraphs->begin(), subgraphs->end(), + [](const GraphPatternDetector::subgraph_t& subgraph) -> bool { + // Collect the inputs and outputs. + std::unordered_set ios; + for (auto& item : subgraph) { + if (!item.first->IsIntermediate()) { + ios.insert(item.second); + } + } + for (auto& item : subgraph) { + if (item.first->IsIntermediate()) { + for (auto* x : item.second->inputs) { + if (!ios.count(x)) { + return true; + } + } + for (auto* x : item.second->outputs) { + if (!ios.count(x)) { + return true; + } + } + } + } + return false; + }), + subgraphs->end()); +} + struct HitGroup { std::unordered_map roles; @@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() { // in edges of PDNodes. for (const auto& edge : pattern_.edges()) { VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name(); + // TODO(Superjomn) Fix bug here, the groups might be duplicate here. // Each role has two PDNodes, which indicates two roles. // Detect two Nodes that can match these two roles and they are connected. auto& pre_groups = bi_records[step % 2]; @@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() { // source -> target for (Node* source : pdnodes2nodes_[edge.first]) { for (Node* target : pdnodes2nodes_[edge.second]) { + VLOG(8) << "check " << source->id() << " -- " << target->id(); // TODO(Superjomn) add some prune strategies. for (const auto& group : pre_groups) { HitGroup new_group = group; @@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() { } } VLOG(3) << "step " << step << " get records: " << cur_groups.size(); + for (auto& group : cur_groups) { + for (auto& item : group.roles) { + VLOG(4) << "node " << item.second->id() << " as " << item.first->name(); + } + VLOG(4) << "========================================================="; + } } for (auto& group : bi_records[step % 2]) { @@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector& others) { return *this; } +PDNode* PDNode::assert_is_op() { + asserts_.emplace_back([this](Node* x) { return x && x->IsOp(); }); + return this; +} +PDNode* PDNode::assert_is_op(const std::string& op_type) { + asserts_.emplace_back([this, op_type](Node* x) { + return x && x->IsOp() && x->Op()->Type() == op_type; + }); + return this; +} +PDNode* PDNode::assert_is_var() { + asserts_.emplace_back([this](Node* x) { return x && x->IsVar(); }); + return this; +} +PDNode* PDNode::assert_var_not_persistable() { + assert_is_var(); + asserts_.emplace_back([this](Node* x) { return !x->Var()->Persistable(); }); + return this; +} +PDNode* PDNode::assert_is_persistable_var() { + assert_is_var(); + asserts_.emplace_back([=](Node* x) { return x->Var()->Persistable(); }); + return this; +} +PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type, + const std::string& argument, int nth) { + assert_is_var(); + assert_is_op_input(op_type); + asserts_.emplace_back([=](Node* x) { + for (auto* op : x->outputs) { + if (IsNthInput(x, op, argument, nth)) return true; + } + return false; + }); + return this; +} +PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type, + const std::string& argument, int nth) { + assert_is_var(); + asserts_.emplace_back([=](Node* x) { + for (auto* op : x->inputs) { + if (IsNthOutput(x, op, argument, nth)) return true; + } + return false; + }); + return this; +} +PDNode* PDNode::assert_is_only_input_of_op(const std::string& op_type) { + assert_is_var(); + asserts_.emplace_back([=](Node* x) { + for (auto* op : x->outputs) { + if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type && + op->inputs.size() == 1) { + return true; + } + } + return false; + }); + return this; +} +PDNode* PDNode::assert_is_only_output_of_op(const std::string& op_type) { + assert_is_var(); + asserts_.emplace_back([=](Node* x) { + for (auto* op : x->inputs) { + if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type && + op->outputs.size() == 1) { + return true; + } + } + return false; + }); + return this; +} +PDNode* PDNode::assert_is_op_output(const std::string& op_type) { + assert_is_var(); + asserts_.emplace_back([=](Node* x) { + for (auto* op : x->inputs) { + if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) { + return true; + } + } + return false; + }); + return this; +} +PDNode* PDNode::assert_is_op_input(const std::string& op_type) { + assert_is_var(); + asserts_.emplace_back([=](Node* x) { + for (auto* op : x->outputs) { + if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) { + return true; + } + } + return false; + }); + return this; +} +PDNode* PDNode::assert_op_has_n_inputs(const std::string& op_type, size_t n) { + assert_is_op(op_type); + asserts_.emplace_back([=](Node* x) { return x->inputs.size() == n; }); + return this; +} +PDNode* PDNode::assert_op_has_n_outputs(const std::string& op_type, size_t n) { + assert_is_op(op_type); + asserts_.emplace_back([=](Node* x) { return x->outputs.size() == n; }); + return this; +} +PDNode* PDNode::assert_more(PDNode::teller_t&& teller) { + asserts_.emplace_back(std::move(teller)); + return this; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 0ac34a57aac..f8488c84962 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -39,14 +39,24 @@ struct PDNode { // tell whether an ir::Node* is a candidation for a PDNode. using teller_t = std::function; enum class Type { kOp, kVar }; + enum class Role { + kUnknown, // No role, + kInput, // an input and will be retained, + kOutput, // an output and will be retained, + kIntermediate // will be removed after handler. + }; // this link to others PDNode& LinksTo(const std::vector& others); PDNode& LinksFrom(const std::vector& others); bool Tell(Node* node) const { - PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode"); - return teller_(node); + if (teller_) return teller_(node); + + for (auto& asrt : asserts_) { + if (!asrt(node)) return false; + } + return true; } bool IsOp() const { return type_ == Type::kOp; } @@ -54,10 +64,52 @@ struct PDNode { const std::string& name() const { return name_; } - PDNode(const PDNode&) = delete; PDNode& operator=(const PDNode&) = delete; + PDNode(const PDNode&) = delete; + + // Mark this node is an Input of a subgraph and will be retained. + PDNode* AsInput() { + role_ = Role::kInput; + return this; + } + // Mark this node is an Output of a subgraph and will be retained. + PDNode* AsOutput() { + role_ = Role::kOutput; + return this; + } + // Mark this node will be removed, so all the links should be inside a matched + // sub-graph. + PDNode* AsIntermediate() { + role_ = Role::kIntermediate; + return this; + } + + bool IsIntermediate() const { return role_ == Role::kIntermediate; } + bool IsInput() const { return role_ == Role::kInput; } + bool IsOutput() const { return role_ == Role::kOutput; } + + // Assertions, helper functions to simplify the pattern definition. + PDNode* assert_is_op(); + PDNode* assert_is_op(const std::string& op_type); + PDNode* assert_is_var(); + PDNode* assert_var_not_persistable(); + PDNode* assert_is_persistable_var(); + PDNode* assert_is_op_output(const std::string& op_type); + PDNode* assert_is_op_input(const std::string& op_type); + PDNode* assert_is_op_nth_input(const std::string& op_type, + const std::string& argument, int nth); + PDNode* assert_is_op_nth_output(const std::string& op_type, + const std::string& argument, int nth); + PDNode* assert_is_only_input_of_op(const std::string& op_type); + PDNode* assert_is_only_output_of_op(const std::string& op_type); + PDNode* assert_op_has_n_inputs(const std::string& op_type, size_t n); + PDNode* assert_op_has_n_outputs(const std::string& op_type, size_t n); + PDNode* assert_more(teller_t&& teller); private: + PDNode(PDPattern* pattern, const std::string& name = "", + Type type = Type::kVar) + : pattern_(pattern), name_(name), type_(type) {} PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "", Type type = Type::kVar) : teller_(std::move(teller)), @@ -71,10 +123,13 @@ struct PDNode { friend class PDPattern; + // Will removed latter. teller_t teller_; + std::vector asserts_; PDPattern* pattern_; std::string name_; Type type_; + Role role_{Role::kUnknown}; }; /* @@ -87,19 +142,18 @@ struct PDNode { * This pattern can be defined as with the following pseudo codes * * // Create two operator PDNodes. - * MUL = PDPattern.NewNode() - * ELE = PDPattern.NewNode() + * MUL = PDPattern.NewNode().assert_is_op("mul"); + * ELE = PDPattern.NewNode().assert_is_op("elementwise_add"); * // Create the variable PDNodes. - * MUL_out = PDPattern.NewNode() - * // Add teller to define some rules that help to filter the target Nodes. - * MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul"; - * ELE.teller = lambda(node): \ - * node->IsOp() && node->Op()->Type == "elementwise_add"; - * MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs) - * && (ELE in node->outputs) + * MUL_out = PDPattern.NewNode().assert_is_op_output("mul") \ + * .assert_is_op_input("elementwise_add") \ + * .AsIntermediate(); + * // Add relations. + * MUL->LinksTo({MUL_out}); + * MUL_out->LinksTo({ELE}); * - * One can add more specific tellers for PDNodes or edges, both the Operator - * and Variable Nodes can be ruled in PDNode.teller. + * One can add more specific asserts for PDNodes or edges, both the Operator + * and Variable Nodes can be ruled in PDNode.assert_more(...). * * PDPattern can record the general patterns, such as the pattern represents * - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place. @@ -112,7 +166,8 @@ class PDPattern { void AddEdge(PDNode* a, PDNode* b); PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = NewID()); - PDNode* RetriveNode(const std::string& id) const; + PDNode* NewNode(const std::string& name = NewID()); + PDNode* RetrieveNode(const std::string& id) const; const std::vector>& nodes() const { return nodes_; } const std::vector& edges() const { return edges_; } @@ -185,6 +240,9 @@ class GraphPatternDetector { // Remove overlapped match subgraphs, when overlapped, keep the previous one. void RemoveOverlappedMatch(std::vector* subgraphs); + // Validate whether the intermediate nodes are linked by external nodes. + void ValidateByNodeRole(std::vector* subgraphs); + #ifdef PADDLE_WITH_TESTING FRIEND_TEST(GraphPatternDetecter, MarkPDNodesInGraph); FRIEND_TEST(GraphPatternDetecter, DetectPatterns); @@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument, return var->Name() == op->Op()->Input(argument)[nth]; } +static bool IsNthOutput(Node* var, Node* op, const std::string& argument, + size_t nth) { + PADDLE_ENFORCE(var->IsVar()); + PADDLE_ENFORCE(op->IsOp()); + if (op->inputs.size() <= nth) return false; + return var->Name() == op->Op()->Output(argument)[nth]; +} + static void GraphSafeRemoveNodes(Graph* graph, const std::unordered_set& nodes) { for (auto* node : nodes) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc b/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc index a4d0646230c..7e5c86b033a 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc @@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) { ASSERT_LE(count, 2); } +TEST(GraphPatternDetector, IntermediateCheck) { + ProgramDesc program; + Graph graph(program); + BuildGraph(&graph); + + // o2->v2->o3 + // o2->v2->o4 + // check o2+o3 fuse, should fail because v2 also link to o4. + GraphPatternDetector detector; + auto* op2 = detector.mutable_pattern()->NewNode( + [](Node* x) { return x && x->IsOp() && x->Name() == "op2"; }, "op2"); + auto* op3 = detector.mutable_pattern()->NewNode( + [](Node* x) { return x && x->IsOp() && x->Name() == "op3"; }, "op3"); + auto* v2 = + detector.mutable_pattern() + ->NewNode( + [](Node* x) { return x && x->IsVar() && x->Name() == "var2"; }, + "var2") + ->AsIntermediate(); + v2->LinksFrom({op2}).LinksTo({op3}); + + int count = 0; + detector(&graph, [&](const GraphPatternDetector::subgraph_t& g, + Graph* graph) { ++count; }); + EXPECT_EQ(count, 0); + + count = 0; + v2->AsInput(); + detector(&graph, [&](const GraphPatternDetector::subgraph_t& g, + Graph* graph) { ++count; }); + ASSERT_EQ(count, 1); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc index 9bb5c232e5c..a776a898a5e 100644 --- a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc @@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) { std::unique_ptr SeqConcatFcFusePass::ApplyImpl( std::unique_ptr graph) const { - FusePassBase::Init(graph.get()); + FusePassBase::Init("seq_concat_fc_fuse", graph.get()); GraphPatternDetector detector; auto* pattern = detector.mutable_pattern(); auto* concat_out = BuildSeqExpandConcatPattern(pattern); BuildFCPattern(pattern, concat_out); -#define GET_NODE(id, pattern) \ - PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \ - "pattern has no Node called %s", #id); \ - auto* id = subgraph.at(pattern.RetriveNode(#id)); \ +#define GET_NODE(id, pattern) \ + PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \ + "pattern has no Node called %s", #id); \ + auto* id = subgraph.at(pattern.RetrieveNode(#id)); \ PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph, diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 08a55a73e3a..e6e63544ffa 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager { void AddGraphvizDebugerPass(Pass* pass) { auto* debuger_pass = pass->CreateGraphvizDebugerPass(); if (debuger_pass) { - LOG(INFO) << " - register debug pass [" << debuger_pass->repr() << "]"; Register(debuger_pass->repr(), debuger_pass); } } diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index 263fbb04490..2cc83c777ce 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -16,10 +16,13 @@ #include #include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/inference/analysis/ut_helper.h" +#include "paddle/fluid/inference/api/analysis_predictor.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/platform/profiler.h" DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN"); @@ -31,6 +34,8 @@ namespace paddle { namespace inference { namespace analysis { +using namespace framework; + TEST(Analyzer, analysis_without_tensorrt) { FLAGS_IA_enable_tensorrt_subgraph_engine = false; Argument argument; @@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path, EXPECT_NEAR(data[i], base_data[i], 1e-3); } } + + if (use_analysis && activate_ir) { + AnalysisPredictor *analysis_predictor = + dynamic_cast(predictor.get()); + auto &fuse_statis = analysis_predictor->analysis_argument() + .Get>( + framework::ir::kFuseStatisAttr); + for (auto &item : fuse_statis) { + LOG(INFO) << "fused " << item.first << " " << item.second; + } + + ASSERT_TRUE(fuse_statis.count("fc")); + EXPECT_EQ(fuse_statis.at("fc"), 1); + } } // Directly infer with the original model. diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 4401d5c5a3c..3a4ffe967e6 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -64,7 +64,8 @@ struct Argument { template void Set(const std::string& key, T* data) { PADDLE_ENFORCE_NOT_NULL(data); - PADDLE_ENFORCE(!attrs_.count(key), "duplicate attr called %s", key); + PADDLE_ENFORCE(!attrs_.count(key), "Duplicate set Argument's attr [%s]", + key); attrs_[key] = data; attr_deleters_[key] = [data, key, this]() { VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index 8ca402da31f..80c85555e72 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" #include #include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/proto_desc.h" #include "paddle/fluid/inference/analysis/analyzer.h" @@ -34,7 +35,6 @@ std::vector ExtractParameters( bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { ANALYSIS_ARGUMENT_CHECK_FIELD(argument) ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc) - PADDLE_ENFORCE(!argument->transformed_program_desc); // The transformed_program_desc should inherit all the VarDesc and BlockDesc // from the original program desc. The operators of the main block(the first // block) should rewritten by data flow graph. @@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { } } - if (argument_->Has("param_scope")) { + if (argument_->Has(framework::ir::kParamScopeAttr)) { LOG(WARNING) << "parameter changes in the scope takes effect"; } diff --git a/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc b/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc index 5e53fff3921..fc60ca3bd0b 100644 --- a/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc +++ b/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h" #include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" @@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir, const std::string &prog_file, const std::string ¶m_file) { PADDLE_ENFORCE(argument_); - argument_->Set("param_scope", new framework::Scope); + argument_->Set(framework::ir::kParamScopeAttr, new framework::Scope); // Load parameters. VLOG(3) << "Loading parameters from " << model_dir; - LoadParams(&argument_->Get("param_scope"), model_dir, - prog_file, param_file); + LoadParams(&argument_->Get(framework::ir::kParamScopeAttr), + model_dir, prog_file, param_file); } bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir, diff --git a/paddle/fluid/inference/analysis/fluid_to_ir_pass.h b/paddle/fluid/inference/analysis/fluid_to_ir_pass.h index 29008105f82..6731b1f7593 100644 --- a/paddle/fluid/inference/analysis/fluid_to_ir_pass.h +++ b/paddle/fluid/inference/analysis/fluid_to_ir_pass.h @@ -14,12 +14,14 @@ #pragma once +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/pass.h" namespace paddle { namespace inference { namespace analysis { +using namespace framework; static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__"; @@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass { ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path); // Load program. auto program = LoadProgramDesc(*argument->fluid_model_program_path); - argument->origin_program_desc.reset( - new framework::proto::ProgramDesc(program)); + argument->origin_program_desc.reset(new proto::ProgramDesc(program)); // Create main data flow graph. if (!argument->main_dfg) { argument->main_dfg.reset(new DataFlowGraph); } - argument->Set("ir_program_desc", new framework::ProgramDesc(program)); + argument->Set("ir_program_desc", new ProgramDesc(program)); LOG(INFO) << "Loading parameters"; // Load parameters to argument if needed. @@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass { void Run(DataFlowGraph *graph) override { // Call all the IR Passes - IRPassManager ir_passes( - argument_->Get("ir_program_desc"), nullptr); + IRPassManager ir_passes(argument_->Get("ir_program_desc"), + nullptr); // Pass the scope from analysis to IR if needed. - if (argument_->Has("param_scope")) { + if (argument_->Has(ir::kParamScopeAttr)) { // Here the address is passed, attention that IR doesn't own the scope, so // the real scope in analysis should live during the IR phase. ir_passes.graph().Set( - "param_scope", new framework::Scope *( - &argument_->Get("param_scope"))); + ir::kParamScopeAttr, + new Scope *(&argument_->Get(ir::kParamScopeAttr))); } const auto &ir_passes_to_apply = @@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass { PADDLE_ENFORCE(argument_->main_dfg.get()); argument_->main_dfg->Build(ir_passes.graph()); + // inherit the arguments from ir. + if (ir_passes.graph().Has(ir::kFuseStatisAttr)) { + argument_->Set( + ir::kFuseStatisAttr, + new std::unordered_map( + ir_passes.graph().Get>( + ir::kFuseStatisAttr))); + } } void EnableParamModify(const std::string &model_dir, @@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass { private: // Load parameters from a single file or from a directory. - bool LoadParams(framework::Scope *scope, const std::string &dir, + bool LoadParams(Scope *scope, const std::string &dir, const std::string &prog_file, const std::string ¶m_file); private: diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 5da5241e49a..ea0f2241d7d 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/scope.h" @@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program, framework::Scope *scope) : program_(program) { graph_.reset(new framework::ir::Graph(program)); - if (scope) graph_->Set("param_scope", new framework::Scope *(scope)); + if (scope) + graph_->Set(framework::ir::kParamScopeAttr, new framework::Scope *(scope)); } void IRPassManager::Apply(const std::vector &passes) { diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 7cdaf4c6a23..33862232bda 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -12,121 +12,96 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/inference/api/analysis_predictor.h" #include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/inference/analysis/analyzer.h" -#include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/utils/singleton.h" namespace paddle { -using inference::analysis::Argument; -using inference::Singleton; -using inference::analysis::Analyzer; -using framework::proto::ProgramDesc; - -/* This predictor is based on the original native predictor with IR and Analysis - * support. It will optimize IR and Parameters in the runtime. - * TODO(Superjomn) Replace the Navive predictor? - */ -class AnalysisPredictor : public NativePaddlePredictor { - public: - explicit AnalysisPredictor(const NativeConfig& config) - : NativePaddlePredictor(config), config_(config) {} - - bool Init(const std::shared_ptr& parent_scope) { - VLOG(3) << "Predictor::init()"; - if (config_.use_gpu) { - place_ = paddle::platform::CUDAPlace(config_.device); - } else { - place_ = paddle::platform::CPUPlace(); - } - PADDLE_ENFORCE(!parent_scope); - if (parent_scope) { - scope_ = parent_scope; - sub_scope_ = &(parent_scope->NewScope()); - } else { - paddle::framework::InitDevices(false); - scope_.reset(new paddle::framework::Scope()); - } - - executor_.reset(new paddle::framework::Executor(place_)); - - // Initialize the inference program - if (!config_.model_dir.empty()) { - // Parameters are saved in separate files sited in - // the specified `dirname`. - inference_program_ = paddle::inference::Load( - executor_.get(), scope_.get(), config_.model_dir); - } else if (!config_.prog_file.empty() && !config_.param_file.empty()) { - // All parameters are saved in a single file. - // The file names should be consistent with that used - // in Python API `fluid.io.save_inference_model`. - inference_program_ = paddle::inference::Load( - executor_.get(), scope_.get(), config_.prog_file, config_.param_file); - } else { - LOG(ERROR) << "fail to load inference model."; - return false; - } - - OptimizeInferenceProgram(); - ctx_ = executor_->Prepare(*inference_program_, 0); - - VLOG(5) << "to create variables"; - PADDLE_ENFORCE(scope_.get()); - executor_->CreateVariables(*inference_program_, - sub_scope_ ? sub_scope_ : scope_.get(), 0); - - // Get the feed_target_names and fetch_target_names - PrepareFeedFetch(); - return true; +bool AnalysisPredictor::Init( + const std::shared_ptr& parent_scope) { + VLOG(3) << "Predictor::init()"; + if (config_.use_gpu) { + place_ = paddle::platform::CUDAPlace(config_.device); + } else { + place_ = paddle::platform::CPUPlace(); } - - bool Run(const std::vector& inputs, - std::vector* output_data, - int batch_size = -1) override { - return NativePaddlePredictor::Run(inputs, output_data, batch_size); + PADDLE_ENFORCE(!parent_scope); + if (parent_scope) { + scope_ = parent_scope; + sub_scope_ = &(parent_scope->NewScope()); + } else { + paddle::framework::InitDevices(false); + scope_.reset(new paddle::framework::Scope()); } - void OptimizeInferenceProgram() { - LOG(INFO) << "optimize begin"; - FLAGS_IA_enable_ir = true; - FLAGS_IA_enable_tensorrt_subgraph_engine = false; - FLAGS_IA_output_storage_path = ""; // Don't output the model. - // Analyze inference_program - Argument argument; - if (!config_.model_dir.empty()) { - argument.fluid_model_dir.reset(new std::string(config_.model_dir)); - } else { - PADDLE_ENFORCE( - !config_.param_file.empty(), - "Either model_dir or (param_file, prog_file) should be set."); - PADDLE_ENFORCE(!config_.prog_file.empty()); - argument.fluid_model_program_path.reset( - new std::string(config_.prog_file)); - argument.fluid_model_param_path.reset( - new std::string(config_.param_file)); - } - argument.origin_program_desc.reset( - new ProgramDesc(*inference_program_->Proto())); - Singleton::Global().Run(&argument); - CHECK(argument.transformed_program_desc); - VLOG(5) << "to prepare executor"; - // LOG(INFO) << "transformed_parogram_desc " << - // argument.transformed_program_desc->DebugString(); - inference_program_.reset( - new framework::ProgramDesc(*argument.transformed_program_desc)); - PADDLE_ENFORCE(argument.Has("param_scope")); - // Update scope. - scope_.reset(argument.Release("param_scope")); - LOG(INFO) << "optimize end =="; + executor_.reset(new paddle::framework::Executor(place_)); + + // Initialize the inference program + if (!config_.model_dir.empty()) { + // Parameters are saved in separate files sited in + // the specified `dirname`. + inference_program_ = paddle::inference::Load(executor_.get(), scope_.get(), + config_.model_dir); + } else if (!config_.prog_file.empty() && !config_.param_file.empty()) { + // All parameters are saved in a single file. + // The file names should be consistent with that used + // in Python API `fluid.io.save_inference_model`. + inference_program_ = paddle::inference::Load( + executor_.get(), scope_.get(), config_.prog_file, config_.param_file); + } else { + LOG(ERROR) << "fail to load inference model."; + return false; } - private: - NativeConfig config_; -}; + OptimizeInferenceProgram(); + ctx_ = executor_->Prepare(*inference_program_, 0); + + VLOG(5) << "to create variables"; + PADDLE_ENFORCE(scope_.get()); + executor_->CreateVariables(*inference_program_, + sub_scope_ ? sub_scope_ : scope_.get(), 0); + // Get the feed_target_names and fetch_target_names + PrepareFeedFetch(); + return true; +} + +void AnalysisPredictor::OptimizeInferenceProgram() { + LOG(INFO) << "optimize begin"; + FLAGS_IA_enable_ir = true; + FLAGS_IA_enable_tensorrt_subgraph_engine = false; + FLAGS_IA_output_storage_path = ""; // Don't output the model. + // Analyze inference_program + if (!config_.model_dir.empty()) { + argument_.fluid_model_dir.reset(new std::string(config_.model_dir)); + } else { + PADDLE_ENFORCE( + !config_.param_file.empty(), + "Either model_dir or (param_file, prog_file) should be set."); + PADDLE_ENFORCE(!config_.prog_file.empty()); + argument_.fluid_model_program_path.reset( + new std::string(config_.prog_file)); + argument_.fluid_model_param_path.reset(new std::string(config_.param_file)); + } + argument_.origin_program_desc.reset( + new ProgramDesc(*inference_program_->Proto())); + Analyzer().Run(&argument_); + CHECK(argument_.transformed_program_desc); + VLOG(5) << "to prepare executor"; + // LOG(INFO) << "transformed_parogram_desc " << + // argument.transformed_program_desc->DebugString(); + inference_program_.reset( + new framework::ProgramDesc(*argument_.transformed_program_desc)); + PADDLE_ENFORCE(argument_.Has(framework::ir::kParamScopeAttr)); + // Update scope. + scope_.reset( + argument_.Release(framework::ir::kParamScopeAttr)); + LOG(INFO) << "optimize end =="; +} template <> std::unique_ptr CreatePaddlePredictor< diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h new file mode 100644 index 00000000000..e32b6185f60 --- /dev/null +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -0,0 +1,51 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/analysis/analyzer.h" +#include "paddle/fluid/inference/api/api_impl.h" +#include "paddle/fluid/inference/api/paddle_inference_api.h" + +namespace paddle { + +using inference::analysis::Argument; +using inference::analysis::Analyzer; +using framework::proto::ProgramDesc; + +/* This predictor is based on the original native predictor with IR and Analysis + * support. It will optimize IR and Parameters in the runtime. + * TODO(Superjomn) Replace the Navive predictor? + */ +class AnalysisPredictor : public NativePaddlePredictor { + public: + explicit AnalysisPredictor(const NativeConfig& config) + : NativePaddlePredictor(config), config_(config) {} + + bool Init(const std::shared_ptr& parent_scope); + + bool Run(const std::vector& inputs, + std::vector* output_data, + int batch_size = -1) override { + return NativePaddlePredictor::Run(inputs, output_data, batch_size); + } + + void OptimizeInferenceProgram(); + + Argument& analysis_argument() { return argument_; } + + private: + NativeConfig config_; + Argument argument_; +}; + +} // namespace paddle diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 7d7a14ed08a..da1c0b1fbc9 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -62,7 +62,7 @@ void NativePaddlePredictor::PrepareFeedFetch() { for (auto *op : inference_program_->Block(0).AllOps()) { if (op->Type() == "feed") { int idx = boost::get(op->GetAttr("col")); - if (feeds_.size() <= idx) { + if (feeds_.size() <= static_cast(idx)) { feeds_.resize(idx + 1); } feeds_[idx] = op; -- GitLab