提交 40f8456a 编写于 作者: T tensor-tang

refine fuse pattern and attr

test=develop
上级 cbbacb25
...@@ -349,11 +349,6 @@ PDNode *PDNode::assert_is_op() { ...@@ -349,11 +349,6 @@ PDNode *PDNode::assert_is_op() {
return this; return this;
} }
// PDNode *PDNode::assert_op_attr() {
// asserts_.emplace_back([](Node *x) { return x && x->IsOp(); });
// return this;
// }
PDNode *PDNode::assert_is_op(const std::string &op_type) { PDNode *PDNode::assert_is_op(const std::string &op_type) {
asserts_.emplace_back([op_type](Node *x) { asserts_.emplace_back([op_type](Node *x) {
return x && x->IsOp() && x->Op()->Type() == op_type; return x && x->IsOp() && x->Op()->Type() == op_type;
...@@ -770,10 +765,10 @@ PDNode *patterns::SeqConvEltAddRelu::operator()( ...@@ -770,10 +765,10 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) { paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators // Create Operators
seqconv_input->assert_is_op_input("sequence_conv", "X"); seqconv_input->assert_is_op_input("sequence_conv", "X");
auto *seqconv_op = auto *seqconv_op = pattern->NewNode(seqconv_repr())
pattern->NewNode(seqconv_repr())->assert_is_op("sequence_conv"); ->assert_is_op("sequence_conv")
// ->assert_op_attr("paddingTrainable", false) ->assert_op_attr<bool>("paddingTrainable", false)
// ->assert_op_attr("contextStride", 1) ->assert_op_attr<int>("contextStride", 1);
auto *eltadd_op = auto *eltadd_op =
pattern->NewNode(eltadd_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltadd_repr())->assert_is_op("elementwise_add");
......
...@@ -128,6 +128,15 @@ struct PDNode { ...@@ -128,6 +128,15 @@ struct PDNode {
const std::unordered_set<std::string>& op_types, const std::unordered_set<std::string>& op_types,
const std::string& argument, int nth); const std::string& argument, int nth);
template <typename T>
PDNode* assert_op_attr(const std::string& attr_name, const T& attr) {
asserts_.emplace_back([=](Node* x) {
return x && x->IsOp() && x->Op()->HasAttr(attr_name) &&
boost::get<T>(x->Op()->GetAttr(attr_name)) == attr;
});
return this;
}
private: private:
PDNode(PDPattern* pattern, const std::string& name = "", PDNode(PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar) Type type = Type::kVar)
......
...@@ -183,7 +183,13 @@ TEST(Analyzer_seq_conv1, fuse_statis) { ...@@ -183,7 +183,13 @@ TEST(Analyzer_seq_conv1, fuse_statis) {
SetConfig(&cfg); SetConfig(&cfg);
int num_ops; int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg); auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
GetFuseStatis(predictor.get(), &num_ops);
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_TRUE(fuse_statis.count("seqconv_eltadd_relu_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 2);
EXPECT_EQ(fuse_statis.at("seqconv_eltadd_relu_fuse"), 6);
EXPECT_EQ(num_ops, 32);
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册