未验证 提交 af15f6f0 编写于 作者: Y Yan Chunwei 提交者: GitHub

fea/refine fuse (#13076)

上级 819af27d
......@@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) {
auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto* while_pat_node = gpd.pattern().RetriveNode("while");
auto* while_pat_node = gpd.pattern().RetrieveNode("while");
auto* while_node = subgraph.at(while_pat_node);
marked_nodes.insert(while_node);
};
......
......@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) {
}
void BuildFCPattern(PDPattern* pattern) {
// make sure the selected MUL op has one input argument is a parameter.
auto* mul_parameter_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && node->outputs.size() == 1UL &&
node->outputs.front()->Op()->Type() == "mul" && node->Var() &&
node->Var()->Persistable(); // check is a parameter
},
"mul_weight" /*name*/);
auto* mul_tmp_input_var = pattern->NewNode(
[](Node* node) {
bool result =
node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
!node->Var()->Persistable(); // this input is not an parameter.
if (!result) return false;
// check whether one output is MUL op.
for (auto* op : node->outputs) {
if (op->IsOp() && op->Op()->Type() == "mul") return true;
}
return false;
},
"mul_tmp_var" /*name*/);
// select a MUL op
auto* mul_op = pattern->NewNode(
[](Node* node) {
return node->IsOp() && // start from an Op
node->Op()->Type() == "mul"; // type is mul
// the output should be consumed only by one element_add, that check
// leaves in a Var PDNode.
},
"mul" /*name*/);
// make sure the MUL op's output has only one consumer and links to an
// ELEMENTWISE_ADD op.
auto* mul_out_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && // starts from a Var
node->outputs.size() == 1UL && // only has one consumer
node->outputs.front()->IsOp() && // check basic logic
node->Var() && // not a ControlDepVar
node->outputs.front()->Op()->Type() ==
"elementwise_add"; // a very strong validation
},
"mul_out");
// this check is not essential, just to make the corresponding variable Node
// retrival easier.
auto* elementwise_add_tmp_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
VarOutLinksToOp(node, "elementwise_add");
},
"elementwise_add_tmpvar");
// select an ELEMENTWISE_ADD op
auto* elementwise_add_op = pattern->NewNode(
[](Node* node) {
return node->IsOp() && node->Op()->Type() == "elementwise_add";
},
"elementwise_add" /*name*/);
// get the ELEMENTWISE_ADD op's output
auto* elementwise_add_out_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && node->inputs.size() == 1UL && node->Var() &&
node->inputs.front()->Op()->Type() == "elementwise_add";
},
"elementwise_add_out");
mul_op->LinksFrom({mul_parameter_var, mul_tmp_input_var})
.LinksTo({mul_out_var});
// Create Operators
auto* mul_op = pattern->NewNode("mul")->assert_is_op("mul");
auto* elementwise_add_op =
pattern->NewNode("elementwise_add")->assert_is_op("elementwise_add");
// Create variables
// w
auto* mul_weight_var = pattern->NewNode("mul_weight")
->AsInput()
->assert_is_op_nth_input("mul", "Y", 0);
// x
auto* mul_tmp_var = pattern->NewNode("mul_tmp_var")
->AsInput()
->assert_is_op_nth_input("mul", "X", 0);
// intermediate variable, will be removed in the IR after fuse.
auto* mul_out_var = pattern->NewNode("mul_out")
->AsIntermediate()
->assert_is_only_output_of_op("mul")
->assert_is_op_input("elementwise_add");
// bias
auto* elementwise_add_tmp_var = pattern->NewNode("elementwise_add_tmpvar")
->assert_is_op_input("elementwise_add")
->AsInput();
// output
auto* elementwise_add_out_var = pattern->NewNode("elementwise_add_out")
->AsOutput()
->assert_is_op_output("elementwise_add");
mul_op->LinksFrom({mul_weight_var, mul_tmp_var}).LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var})
.LinksTo({elementwise_add_out_var});
}
......@@ -120,18 +77,20 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("fc", graph.get());
std::unordered_set<Node*> nodes2delete;
GraphPatternDetector gpd;
BuildFCPattern(gpd.mutable_pattern());
#define GET_NODE(id) \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
#define GET_NODE(id) \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(gpd.pattern().RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle FC fuse";
......@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
graph->RemoveNode(mul);
graph->RemoveNode(elementwise_add);
graph->RemoveNode(mul_out); // tmp variable
found_fc_count++;
};
gpd(graph.get(), handler);
AddStatis(found_fc_count);
return graph;
}
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
......@@ -23,7 +24,7 @@ namespace ir {
/*
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
*/
class FCFusePass : public Pass {
class FCFusePass : public FusePassBase {
public:
virtual ~FCFusePass() {}
......
......@@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetInput("Xs", inputs);
op->SetOutput("Ys", outputs);
if (type == "mul") {
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
} else if (type == "elementwise_add") {
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
}
// a->OP0->b
......
......@@ -36,7 +36,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto* id = subgraph.at(gpd.pattern().RetriveNode("any_node"));
auto* id = subgraph.at(gpd.pattern().RetrieveNode("any_node"));
marked_nodes.insert(id);
};
gpd(graph.get(), handler);
......@@ -64,9 +64,9 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
#undef GET_NODE
#undef SET_IN
LOG(INFO) << "hidden_n: " << hidden_n->Name();
LOG(INFO) << "cell: " << cell_n->Name();
LOG(INFO) << "xx: " << xx_n->Name();
VLOG(4) << "hidden_n: " << hidden_n->Name();
VLOG(4) << "cell: " << cell_n->Name();
VLOG(4) << "xx: " << xx_n->Name();
op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {});
......
......@@ -22,21 +22,37 @@ namespace paddle {
namespace framework {
namespace ir {
static const char kParamScopeAttr[] = "param_scope";
static const char kParamScopeAttr[] = "__param_scope__";
static const char kFuseStatisAttr[] = "__fuse_statis__";
class FusePassBase : public Pass {
public:
void Init(Graph* graph) const { graph_ = graph; }
void Init(const std::string& repr, Graph* graph) const {
repr_ = repr;
graph_ = graph;
}
Scope* param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr);
}
void AddStatis(int count_of_fused) const {
PADDLE_ENFORCE(graph_);
PADDLE_ENFORCE(!repr_.empty());
if (!graph_->Has(kFuseStatisAttr)) {
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
}
auto& info =
graph_->Get<std::unordered_map<std::string, int>>(kFuseStatisAttr);
info[repr_] = count_of_fused;
}
virtual ~FusePassBase() {}
protected:
mutable Graph* graph_;
mutable std::string repr_;
};
} // namespace ir
......
......@@ -27,6 +27,19 @@ namespace ir {
size_t PDPattern::id_ = 0UL;
PDNode* PDPattern::NewNode(const std::string& name) {
if (!name.empty()) {
PADDLE_ENFORCE_EQ(node_map_.count(name), 0,
"PDNode's name should be unique, get duplicate [%s]",
name);
}
nodes_.emplace_back(new PDNode(this, name));
auto* cur = nodes_.back().get();
node_map_[name] = cur;
return cur;
}
PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
if (!name.empty()) {
PADDLE_ENFORCE_EQ(node_map_.count(name), 0,
......@@ -40,7 +53,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
return cur;
}
PDNode* PDPattern::RetriveNode(const std::string& id) const {
PDNode* PDPattern::RetrieveNode(const std::string& id) const {
auto it = node_map_.find(id);
if (it == node_map_.end()) {
return nullptr;
......@@ -62,7 +75,9 @@ void GraphPatternDetector::operator()(Graph* graph,
auto subgraphs = DetectPatterns();
UniquePatterns(&subgraphs);
RemoveOverlappedMatch(&subgraphs);
ValidateByNodeRole(&subgraphs);
if (subgraphs.empty()) return;
LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern";
int id = 0;
for (auto& g : subgraphs) {
......@@ -83,10 +98,54 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
}
}
}
// Check to early stop if some PDNode can't find matched Node.
for (auto& pdnode : pattern_.nodes()) {
if (!pdnodes2nodes_.count(pdnode.get())) {
VLOG(4) << pdnode->name() << " can't find matched Node, early stop";
return false;
}
}
VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
return !pdnodes2nodes_.empty();
}
// The intermediate Nodes can only link to the nodes inside the pattern, or this
// subgraph will be droped.
void GraphPatternDetector::ValidateByNodeRole(
std::vector<GraphPatternDetector::subgraph_t>* subgraphs) {
std::vector<GraphPatternDetector::subgraph_t> result;
subgraphs->erase(
std::remove_if(
subgraphs->begin(), subgraphs->end(),
[](const GraphPatternDetector::subgraph_t& subgraph) -> bool {
// Collect the inputs and outputs.
std::unordered_set<Node*> ios;
for (auto& item : subgraph) {
if (!item.first->IsIntermediate()) {
ios.insert(item.second);
}
}
for (auto& item : subgraph) {
if (item.first->IsIntermediate()) {
for (auto* x : item.second->inputs) {
if (!ios.count(x)) {
return true;
}
}
for (auto* x : item.second->outputs) {
if (!ios.count(x)) {
return true;
}
}
}
}
return false;
}),
subgraphs->end());
}
struct HitGroup {
std::unordered_map<PDNode*, Node*> roles;
......@@ -140,6 +199,7 @@ GraphPatternDetector::DetectPatterns() {
// in edges of PDNodes.
for (const auto& edge : pattern_.edges()) {
VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name();
// TODO(Superjomn) Fix bug here, the groups might be duplicate here.
// Each role has two PDNodes, which indicates two roles.
// Detect two Nodes that can match these two roles and they are connected.
auto& pre_groups = bi_records[step % 2];
......@@ -149,6 +209,7 @@ GraphPatternDetector::DetectPatterns() {
// source -> target
for (Node* source : pdnodes2nodes_[edge.first]) {
for (Node* target : pdnodes2nodes_[edge.second]) {
VLOG(8) << "check " << source->id() << " -- " << target->id();
// TODO(Superjomn) add some prune strategies.
for (const auto& group : pre_groups) {
HitGroup new_group = group;
......@@ -165,6 +226,12 @@ GraphPatternDetector::DetectPatterns() {
}
}
VLOG(3) << "step " << step << " get records: " << cur_groups.size();
for (auto& group : cur_groups) {
for (auto& item : group.roles) {
VLOG(4) << "node " << item.second->id() << " as " << item.first->name();
}
VLOG(4) << "=========================================================";
}
}
for (auto& group : bi_records[step % 2]) {
......@@ -260,6 +327,118 @@ PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
return *this;
}
PDNode* PDNode::assert_is_op() {
asserts_.emplace_back([this](Node* x) { return x && x->IsOp(); });
return this;
}
PDNode* PDNode::assert_is_op(const std::string& op_type) {
asserts_.emplace_back([this, op_type](Node* x) {
return x && x->IsOp() && x->Op()->Type() == op_type;
});
return this;
}
PDNode* PDNode::assert_is_var() {
asserts_.emplace_back([this](Node* x) { return x && x->IsVar(); });
return this;
}
PDNode* PDNode::assert_var_not_persistable() {
assert_is_var();
asserts_.emplace_back([this](Node* x) { return !x->Var()->Persistable(); });
return this;
}
PDNode* PDNode::assert_is_persistable_var() {
assert_is_var();
asserts_.emplace_back([=](Node* x) { return x->Var()->Persistable(); });
return this;
}
PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth) {
assert_is_var();
assert_is_op_input(op_type);
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->outputs) {
if (IsNthInput(x, op, argument, nth)) return true;
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type,
const std::string& argument, int nth) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->inputs) {
if (IsNthOutput(x, op, argument, nth)) return true;
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_only_input_of_op(const std::string& op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->outputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type &&
op->inputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_only_output_of_op(const std::string& op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->inputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type &&
op->outputs.size() == 1) {
return true;
}
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_op_output(const std::string& op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->inputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
return true;
}
}
return false;
});
return this;
}
PDNode* PDNode::assert_is_op_input(const std::string& op_type) {
assert_is_var();
asserts_.emplace_back([=](Node* x) {
for (auto* op : x->outputs) {
if (op && op->IsOp() && op->Op() && op->Op()->Type() == op_type) {
return true;
}
}
return false;
});
return this;
}
PDNode* PDNode::assert_op_has_n_inputs(const std::string& op_type, size_t n) {
assert_is_op(op_type);
asserts_.emplace_back([=](Node* x) { return x->inputs.size() == n; });
return this;
}
PDNode* PDNode::assert_op_has_n_outputs(const std::string& op_type, size_t n) {
assert_is_op(op_type);
asserts_.emplace_back([=](Node* x) { return x->outputs.size() == n; });
return this;
}
PDNode* PDNode::assert_more(PDNode::teller_t&& teller) {
asserts_.emplace_back(std::move(teller));
return this;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -39,14 +39,24 @@ struct PDNode {
// tell whether an ir::Node* is a candidation for a PDNode.
using teller_t = std::function<bool(Node*)>;
enum class Type { kOp, kVar };
enum class Role {
kUnknown, // No role,
kInput, // an input and will be retained,
kOutput, // an output and will be retained,
kIntermediate // will be removed after handler.
};
// this link to others
PDNode& LinksTo(const std::vector<PDNode*>& others);
PDNode& LinksFrom(const std::vector<PDNode*>& others);
bool Tell(Node* node) const {
PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode");
return teller_(node);
if (teller_) return teller_(node);
for (auto& asrt : asserts_) {
if (!asrt(node)) return false;
}
return true;
}
bool IsOp() const { return type_ == Type::kOp; }
......@@ -54,10 +64,52 @@ struct PDNode {
const std::string& name() const { return name_; }
PDNode(const PDNode&) = delete;
PDNode& operator=(const PDNode&) = delete;
PDNode(const PDNode&) = delete;
// Mark this node is an Input of a subgraph and will be retained.
PDNode* AsInput() {
role_ = Role::kInput;
return this;
}
// Mark this node is an Output of a subgraph and will be retained.
PDNode* AsOutput() {
role_ = Role::kOutput;
return this;
}
// Mark this node will be removed, so all the links should be inside a matched
// sub-graph.
PDNode* AsIntermediate() {
role_ = Role::kIntermediate;
return this;
}
bool IsIntermediate() const { return role_ == Role::kIntermediate; }
bool IsInput() const { return role_ == Role::kInput; }
bool IsOutput() const { return role_ == Role::kOutput; }
// Assertions, helper functions to simplify the pattern definition.
PDNode* assert_is_op();
PDNode* assert_is_op(const std::string& op_type);
PDNode* assert_is_var();
PDNode* assert_var_not_persistable();
PDNode* assert_is_persistable_var();
PDNode* assert_is_op_output(const std::string& op_type);
PDNode* assert_is_op_input(const std::string& op_type);
PDNode* assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth);
PDNode* assert_is_op_nth_output(const std::string& op_type,
const std::string& argument, int nth);
PDNode* assert_is_only_input_of_op(const std::string& op_type);
PDNode* assert_is_only_output_of_op(const std::string& op_type);
PDNode* assert_op_has_n_inputs(const std::string& op_type, size_t n);
PDNode* assert_op_has_n_outputs(const std::string& op_type, size_t n);
PDNode* assert_more(teller_t&& teller);
private:
PDNode(PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
: pattern_(pattern), name_(name), type_(type) {}
PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
: teller_(std::move(teller)),
......@@ -71,10 +123,13 @@ struct PDNode {
friend class PDPattern;
// Will removed latter.
teller_t teller_;
std::vector<teller_t> asserts_;
PDPattern* pattern_;
std::string name_;
Type type_;
Role role_{Role::kUnknown};
};
/*
......@@ -87,19 +142,18 @@ struct PDNode {
* This pattern can be defined as with the following pseudo codes
*
* // Create two operator PDNodes.
* MUL = PDPattern.NewNode()
* ELE = PDPattern.NewNode()
* MUL = PDPattern.NewNode().assert_is_op("mul");
* ELE = PDPattern.NewNode().assert_is_op("elementwise_add");
* // Create the variable PDNodes.
* MUL_out = PDPattern.NewNode()
* // Add teller to define some rules that help to filter the target Nodes.
* MUL.teller = lambda(node): node->IsOp() && node->Op()->Type == "mul";
* ELE.teller = lambda(node): \
* node->IsOp() && node->Op()->Type == "elementwise_add";
* MUL_out.teller = lambda(node): node->IsVar() && (MUL in node->inputs)
* && (ELE in node->outputs)
* MUL_out = PDPattern.NewNode().assert_is_op_output("mul") \
* .assert_is_op_input("elementwise_add") \
* .AsIntermediate();
* // Add relations.
* MUL->LinksTo({MUL_out});
* MUL_out->LinksTo({ELE});
*
* One can add more specific tellers for PDNodes or edges, both the Operator
* and Variable Nodes can be ruled in PDNode.teller.
* One can add more specific asserts for PDNodes or edges, both the Operator
* and Variable Nodes can be ruled in PDNode.assert_more(...).
*
* PDPattern can record the general patterns, such as the pattern represents
* - Op in CPU -> Op in GPU -> Op in CPU, to findout the IO abnormal place.
......@@ -112,7 +166,8 @@ class PDPattern {
void AddEdge(PDNode* a, PDNode* b);
PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = NewID());
PDNode* RetriveNode(const std::string& id) const;
PDNode* NewNode(const std::string& name = NewID());
PDNode* RetrieveNode(const std::string& id) const;
const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; }
const std::vector<edge_t>& edges() const { return edges_; }
......@@ -185,6 +240,9 @@ class GraphPatternDetector {
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
void RemoveOverlappedMatch(std::vector<subgraph_t>* subgraphs);
// Validate whether the intermediate nodes are linked by external nodes.
void ValidateByNodeRole(std::vector<subgraph_t>* subgraphs);
#ifdef PADDLE_WITH_TESTING
FRIEND_TEST(GraphPatternDetecter, MarkPDNodesInGraph);
FRIEND_TEST(GraphPatternDetecter, DetectPatterns);
......@@ -228,6 +286,14 @@ static bool IsNthInput(Node* var, Node* op, const std::string& argument,
return var->Name() == op->Op()->Input(argument)[nth];
}
static bool IsNthOutput(Node* var, Node* op, const std::string& argument,
size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->inputs.size() <= nth) return false;
return var->Name() == op->Op()->Output(argument)[nth];
}
static void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes) {
for (auto* node : nodes) {
......
......@@ -167,6 +167,39 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
ASSERT_LE(count, 2);
}
TEST(GraphPatternDetector, IntermediateCheck) {
ProgramDesc program;
Graph graph(program);
BuildGraph(&graph);
// o2->v2->o3
// o2->v2->o4
// check o2+o3 fuse, should fail because v2 also link to o4.
GraphPatternDetector detector;
auto* op2 = detector.mutable_pattern()->NewNode(
[](Node* x) { return x && x->IsOp() && x->Name() == "op2"; }, "op2");
auto* op3 = detector.mutable_pattern()->NewNode(
[](Node* x) { return x && x->IsOp() && x->Name() == "op3"; }, "op3");
auto* v2 =
detector.mutable_pattern()
->NewNode(
[](Node* x) { return x && x->IsVar() && x->Name() == "var2"; },
"var2")
->AsIntermediate();
v2->LinksFrom({op2}).LinksTo({op3});
int count = 0;
detector(&graph, [&](const GraphPatternDetector::subgraph_t& g,
Graph* graph) { ++count; });
EXPECT_EQ(count, 0);
count = 0;
v2->AsInput();
detector(&graph, [&](const GraphPatternDetector::subgraph_t& g,
Graph* graph) { ++count; });
ASSERT_EQ(count, 1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -180,16 +180,16 @@ PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(graph.get());
FusePassBase::Init("seq_concat_fc_fuse", graph.get());
GraphPatternDetector detector;
auto* pattern = detector.mutable_pattern();
auto* concat_out = BuildSeqExpandConcatPattern(pattern);
BuildFCPattern(pattern, concat_out);
#define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetriveNode(#id)); \
#define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph,
......
......@@ -93,7 +93,6 @@ class DfgPassManagerImpl final : public DfgPassManager {
void AddGraphvizDebugerPass(Pass* pass) {
auto* debuger_pass = pass->CreateGraphvizDebugerPass();
if (debuger_pass) {
LOG(INFO) << " - register debug pass [" << debuger_pass->repr() << "]";
Register(debuger_pass->repr(), debuger_pass);
}
}
......
......@@ -16,10 +16,13 @@
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN");
......@@ -31,6 +34,8 @@ namespace paddle {
namespace inference {
namespace analysis {
using namespace framework;
TEST(Analyzer, analysis_without_tensorrt) {
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
Argument argument;
......@@ -311,6 +316,20 @@ void TestDituRNNPrediction(const std::string &model_path,
EXPECT_NEAR(data[i], base_data[i], 1e-3);
}
}
if (use_analysis && activate_ir) {
AnalysisPredictor *analysis_predictor =
dynamic_cast<AnalysisPredictor *>(predictor.get());
auto &fuse_statis = analysis_predictor->analysis_argument()
.Get<std::unordered_map<std::string, int>>(
framework::ir::kFuseStatisAttr);
for (auto &item : fuse_statis) {
LOG(INFO) << "fused " << item.first << " " << item.second;
}
ASSERT_TRUE(fuse_statis.count("fc"));
EXPECT_EQ(fuse_statis.at("fc"), 1);
}
}
// Directly infer with the original model.
......
......@@ -64,7 +64,8 @@ struct Argument {
template <typename T>
void Set(const std::string& key, T* data) {
PADDLE_ENFORCE_NOT_NULL(data);
PADDLE_ENFORCE(!attrs_.count(key), "duplicate attr called %s", key);
PADDLE_ENFORCE(!attrs_.count(key), "Duplicate set Argument's attr [%s]",
key);
attrs_[key] = data;
attr_deleters_[key] = [data, key, this]() {
VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
......@@ -34,7 +35,6 @@ std::vector<std::string> ExtractParameters(
bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument)
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc)
PADDLE_ENFORCE(!argument->transformed_program_desc);
// The transformed_program_desc should inherit all the VarDesc and BlockDesc
// from the original program desc. The operators of the main block(the first
// block) should rewritten by data flow graph.
......@@ -66,7 +66,7 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
}
}
if (argument_->Has("param_scope")) {
if (argument_->Has(framework::ir::kParamScopeAttr)) {
LOG(WARNING) << "parameter changes in the scope takes effect";
}
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
......@@ -26,11 +27,11 @@ void FluidToIrPass::EnableParamModify(const std::string &model_dir,
const std::string &prog_file,
const std::string &param_file) {
PADDLE_ENFORCE(argument_);
argument_->Set("param_scope", new framework::Scope);
argument_->Set(framework::ir::kParamScopeAttr, new framework::Scope);
// Load parameters.
VLOG(3) << "Loading parameters from " << model_dir;
LoadParams(&argument_->Get<framework::Scope>("param_scope"), model_dir,
prog_file, param_file);
LoadParams(&argument_->Get<framework::Scope>(framework::ir::kParamScopeAttr),
model_dir, prog_file, param_file);
}
bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir,
......
......@@ -14,12 +14,14 @@
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle {
namespace inference {
namespace analysis {
using namespace framework;
static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__";
......@@ -45,13 +47,12 @@ class FluidToIrPass final : public DataFlowGraphPass {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path);
// Load program.
auto program = LoadProgramDesc(*argument->fluid_model_program_path);
argument->origin_program_desc.reset(
new framework::proto::ProgramDesc(program));
argument->origin_program_desc.reset(new proto::ProgramDesc(program));
// Create main data flow graph.
if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph);
}
argument->Set("ir_program_desc", new framework::ProgramDesc(program));
argument->Set("ir_program_desc", new ProgramDesc(program));
LOG(INFO) << "Loading parameters";
// Load parameters to argument if needed.
......@@ -73,15 +74,15 @@ class FluidToIrPass final : public DataFlowGraphPass {
void Run(DataFlowGraph *graph) override {
// Call all the IR Passes
IRPassManager ir_passes(
argument_->Get<framework::ProgramDesc>("ir_program_desc"), nullptr);
IRPassManager ir_passes(argument_->Get<ProgramDesc>("ir_program_desc"),
nullptr);
// Pass the scope from analysis to IR if needed.
if (argument_->Has("param_scope")) {
if (argument_->Has(ir::kParamScopeAttr)) {
// Here the address is passed, attention that IR doesn't own the scope, so
// the real scope in analysis should live during the IR phase.
ir_passes.graph().Set(
"param_scope", new framework::Scope *(
&argument_->Get<framework::Scope>("param_scope")));
ir::kParamScopeAttr,
new Scope *(&argument_->Get<Scope>(ir::kParamScopeAttr)));
}
const auto &ir_passes_to_apply =
......@@ -90,6 +91,14 @@ class FluidToIrPass final : public DataFlowGraphPass {
PADDLE_ENFORCE(argument_->main_dfg.get());
argument_->main_dfg->Build(ir_passes.graph());
// inherit the arguments from ir.
if (ir_passes.graph().Has(ir::kFuseStatisAttr)) {
argument_->Set(
ir::kFuseStatisAttr,
new std::unordered_map<std::string, int>(
ir_passes.graph().Get<std::unordered_map<std::string, int>>(
ir::kFuseStatisAttr)));
}
}
void EnableParamModify(const std::string &model_dir,
......@@ -100,7 +109,7 @@ class FluidToIrPass final : public DataFlowGraphPass {
private:
// Load parameters from a single file or from a directory.
bool LoadParams(framework::Scope *scope, const std::string &dir,
bool LoadParams(Scope *scope, const std::string &dir,
const std::string &prog_file, const std::string &param_file);
private:
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
......@@ -25,7 +26,8 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
framework::Scope *scope)
: program_(program) {
graph_.reset(new framework::ir::Graph(program));
if (scope) graph_->Set("param_scope", new framework::Scope *(scope));
if (scope)
graph_->Set(framework::ir::kParamScopeAttr, new framework::Scope *(scope));
}
void IRPassManager::Apply(const std::vector<std::string> &passes) {
......
......@@ -12,121 +12,96 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
using inference::analysis::Argument;
using inference::Singleton;
using inference::analysis::Analyzer;
using framework::proto::ProgramDesc;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class AnalysisPredictor : public NativePaddlePredictor {
public:
explicit AnalysisPredictor(const NativeConfig& config)
: NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
VLOG(3) << "Predictor::init()";
if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device);
} else {
place_ = paddle::platform::CPUPlace();
}
PADDLE_ENFORCE(!parent_scope);
if (parent_scope) {
scope_ = parent_scope;
sub_scope_ = &(parent_scope->NewScope());
} else {
paddle::framework::InitDevices(false);
scope_.reset(new paddle::framework::Scope());
}
executor_.reset(new paddle::framework::Executor(place_));
// Initialize the inference program
if (!config_.model_dir.empty()) {
// Parameters are saved in separate files sited in
// the specified `dirname`.
inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.model_dir);
} else if (!config_.prog_file.empty() && !config_.param_file.empty()) {
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.prog_file, config_.param_file);
} else {
LOG(ERROR) << "fail to load inference model.";
return false;
}
OptimizeInferenceProgram();
ctx_ = executor_->Prepare(*inference_program_, 0);
VLOG(5) << "to create variables";
PADDLE_ENFORCE(scope_.get());
executor_->CreateVariables(*inference_program_,
sub_scope_ ? sub_scope_ : scope_.get(), 0);
// Get the feed_target_names and fetch_target_names
PrepareFeedFetch();
return true;
bool AnalysisPredictor::Init(
const std::shared_ptr<framework::Scope>& parent_scope) {
VLOG(3) << "Predictor::init()";
if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device);
} else {
place_ = paddle::platform::CPUPlace();
}
bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data,
int batch_size = -1) override {
return NativePaddlePredictor::Run(inputs, output_data, batch_size);
PADDLE_ENFORCE(!parent_scope);
if (parent_scope) {
scope_ = parent_scope;
sub_scope_ = &(parent_scope->NewScope());
} else {
paddle::framework::InitDevices(false);
scope_.reset(new paddle::framework::Scope());
}
void OptimizeInferenceProgram() {
LOG(INFO) << "optimize begin";
FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = ""; // Don't output the model.
// Analyze inference_program
Argument argument;
if (!config_.model_dir.empty()) {
argument.fluid_model_dir.reset(new std::string(config_.model_dir));
} else {
PADDLE_ENFORCE(
!config_.param_file.empty(),
"Either model_dir or (param_file, prog_file) should be set.");
PADDLE_ENFORCE(!config_.prog_file.empty());
argument.fluid_model_program_path.reset(
new std::string(config_.prog_file));
argument.fluid_model_param_path.reset(
new std::string(config_.param_file));
}
argument.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto()));
Singleton<Analyzer>::Global().Run(&argument);
CHECK(argument.transformed_program_desc);
VLOG(5) << "to prepare executor";
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_.reset(
new framework::ProgramDesc(*argument.transformed_program_desc));
PADDLE_ENFORCE(argument.Has("param_scope"));
// Update scope.
scope_.reset(argument.Release<framework::Scope>("param_scope"));
LOG(INFO) << "optimize end ==";
executor_.reset(new paddle::framework::Executor(place_));
// Initialize the inference program
if (!config_.model_dir.empty()) {
// Parameters are saved in separate files sited in
// the specified `dirname`.
inference_program_ = paddle::inference::Load(executor_.get(), scope_.get(),
config_.model_dir);
} else if (!config_.prog_file.empty() && !config_.param_file.empty()) {
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.prog_file, config_.param_file);
} else {
LOG(ERROR) << "fail to load inference model.";
return false;
}
private:
NativeConfig config_;
};
OptimizeInferenceProgram();
ctx_ = executor_->Prepare(*inference_program_, 0);
VLOG(5) << "to create variables";
PADDLE_ENFORCE(scope_.get());
executor_->CreateVariables(*inference_program_,
sub_scope_ ? sub_scope_ : scope_.get(), 0);
// Get the feed_target_names and fetch_target_names
PrepareFeedFetch();
return true;
}
void AnalysisPredictor::OptimizeInferenceProgram() {
LOG(INFO) << "optimize begin";
FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = ""; // Don't output the model.
// Analyze inference_program
if (!config_.model_dir.empty()) {
argument_.fluid_model_dir.reset(new std::string(config_.model_dir));
} else {
PADDLE_ENFORCE(
!config_.param_file.empty(),
"Either model_dir or (param_file, prog_file) should be set.");
PADDLE_ENFORCE(!config_.prog_file.empty());
argument_.fluid_model_program_path.reset(
new std::string(config_.prog_file));
argument_.fluid_model_param_path.reset(new std::string(config_.param_file));
}
argument_.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto()));
Analyzer().Run(&argument_);
CHECK(argument_.transformed_program_desc);
VLOG(5) << "to prepare executor";
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_.reset(
new framework::ProgramDesc(*argument_.transformed_program_desc));
PADDLE_ENFORCE(argument_.Has(framework::ir::kParamScopeAttr));
// Update scope.
scope_.reset(
argument_.Release<framework::Scope>(framework::ir::kParamScopeAttr));
LOG(INFO) << "optimize end ==";
}
template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace paddle {
using inference::analysis::Argument;
using inference::analysis::Analyzer;
using framework::proto::ProgramDesc;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class AnalysisPredictor : public NativePaddlePredictor {
public:
explicit AnalysisPredictor(const NativeConfig& config)
: NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope);
bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data,
int batch_size = -1) override {
return NativePaddlePredictor::Run(inputs, output_data, batch_size);
}
void OptimizeInferenceProgram();
Argument& analysis_argument() { return argument_; }
private:
NativeConfig config_;
Argument argument_;
};
} // namespace paddle
......@@ -62,7 +62,7 @@ void NativePaddlePredictor::PrepareFeedFetch() {
for (auto *op : inference_program_->Block(0).AllOps()) {
if (op->Type() == "feed") {
int idx = boost::get<int>(op->GetAttr("col"));
if (feeds_.size() <= idx) {
if (feeds_.size() <= static_cast<size_t>(idx)) {
feeds_.resize(idx + 1);
}
feeds_[idx] = op;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册