提交 ca6fdc6e 编写于 作者: T tensor-tang

refine and fix test

test=develop
上级 a89296ac
...@@ -94,7 +94,6 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, ...@@ -94,7 +94,6 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
return false; return false;
} }
auto* relu_op = x->inputs[0]; auto* relu_op = x->inputs[0];
// std::cout << "xxxx" << std::endl;
bool before_is_fc = relu_op->IsOp() && relu_op->inputs.size() == 1 && bool before_is_fc = relu_op->IsOp() && relu_op->inputs.size() == 1 &&
relu_op->inputs[0]->IsVar() && relu_op->inputs[0]->IsVar() &&
VarLinksFromOp(relu_op->inputs[0], "fc") && VarLinksFromOp(relu_op->inputs[0], "fc") &&
...@@ -105,31 +104,18 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, ...@@ -105,31 +104,18 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
} }
auto* fc_op = relu_op->inputs[0]->inputs[0]; auto* fc_op = relu_op->inputs[0]->inputs[0];
bool is_fc = fc_op->IsOp() && fc_op->inputs.size() == 3; bool is_fc = fc_op->IsOp() && fc_op->inputs.size() == 3;
// std::cout << "*****" << fc_op->inputs.size() << std::endl;
if (!is_fc) { if (!is_fc) {
return false; return false;
} }
for (size_t kkk = 0; kkk < 3; ++kkk) { for (auto* fc_i : fc_op->inputs) {
// std::cout << "++++++" << kkk << std::endl; if (!fc_i->inputs.empty()) {
if (!fc_op->inputs[kkk]->inputs.empty()) {
if (at_top) { if (at_top) {
return true; return true;
} else { } else {
bool res = VarLinksFromOp(fc_op->inputs[kkk], "relu"); return VarLinksFromOp(fc_i, "relu");
// std::cout << fc_op->inputs[kkk]->Name() << "++++++-----" << kkk <<
// ":"
// << res << std::endl;
return res;
} }
} }
} }
// for (auto* fc_i : fc_op->inputs) {
// if (!fc_i->inputs.empty()) {
// std::cout << "++++++" << fc_op->inputs.size()<<std::endl;
// return VarLinksFromOp(fc_i, "relu");
// }
// }
return false; return false;
}; };
...@@ -147,7 +133,6 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, ...@@ -147,7 +133,6 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
Node* x, int repeated_times, Node* x, int repeated_times,
const std::string& act_type = "relu") -> bool { const std::string& act_type = "relu") -> bool {
for (int i = 0; i < repeated_times; ++i) { for (int i = 0; i < repeated_times; ++i) {
// std::cout << "----" << i << std::endl;
if (!var_before_is_fc_act(x, act_type, i == repeated_times - 1)) { if (!var_before_is_fc_act(x, act_type, i == repeated_times - 1)) {
return false; return false;
} }
...@@ -180,17 +165,9 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, ...@@ -180,17 +165,9 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
x, std::max(1, num_fc - i - 1), "relu"); x, std::max(1, num_fc - i - 1), "relu");
} }
} else { } else {
bool part1 = return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && x->inputs.size() > 0 &&
x->inputs.size() > 0; var_before_is_fc_act_repeated_n_times(x, i, "relu");
if (x->Name() == "fc_0.tmp_1" && x->IsVar() && part1) {
// std::cout << "testes" << std::endl;
}
bool part2 = var_before_is_fc_act_repeated_n_times(x, i, "relu");
if (x->Name() == "fc_0.tmp_1") {
// std::cout << "========" << part1 << "," << part2 << std::endl;
}
return part1 && part2;
} }
}, },
name_scope + "/fc_in_" + std::to_string(i)); name_scope + "/fc_in_" + std::to_string(i));
...@@ -394,7 +371,7 @@ std::unique_ptr<ir::Graph> RepeatedFCReluFusePass::ApplyImpl( ...@@ -394,7 +371,7 @@ std::unique_ptr<ir::Graph> RepeatedFCReluFusePass::ApplyImpl(
int fusion_count = 0; int fusion_count = 0;
for (int i = MAX_NUM_FC; i > 1; --i) { for (int i = MAX_NUM_FC; i > 1; --i) {
fusion_count += fusion_count +=
BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(3), 3); BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i);
} }
AddStatis(fusion_count); AddStatis(fusion_count);
......
...@@ -190,8 +190,10 @@ void analysis_fuse_statis(bool use_zerocopy) { ...@@ -190,8 +190,10 @@ void analysis_fuse_statis(bool use_zerocopy) {
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10); ASSERT_EQ(fuse_statis.at("fc_fuse"), 10);
ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse")); ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse"));
EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2); EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2);
ASSERT_TRUE(fuse_statis.count("repeated_fc_relu"));
EXPECT_EQ(fuse_statis.at("repeated_fc_relu"), 2);
LOG(INFO) << "num_ops: " << num_ops; LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 195); EXPECT_EQ(num_ops, 185);
} }
// Check the fuse status // Check the fuse status
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册