提交 ca6fdc6e 编写于 作者: T tensor-tang

refine and fix test

test=develop
上级 a89296ac
......@@ -94,7 +94,6 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
return false;
}
auto* relu_op = x->inputs[0];
// std::cout << "xxxx" << std::endl;
bool before_is_fc = relu_op->IsOp() && relu_op->inputs.size() == 1 &&
relu_op->inputs[0]->IsVar() &&
VarLinksFromOp(relu_op->inputs[0], "fc") &&
......@@ -105,31 +104,18 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
}
auto* fc_op = relu_op->inputs[0]->inputs[0];
bool is_fc = fc_op->IsOp() && fc_op->inputs.size() == 3;
// std::cout << "*****" << fc_op->inputs.size() << std::endl;
if (!is_fc) {
return false;
}
for (size_t kkk = 0; kkk < 3; ++kkk) {
// std::cout << "++++++" << kkk << std::endl;
if (!fc_op->inputs[kkk]->inputs.empty()) {
for (auto* fc_i : fc_op->inputs) {
if (!fc_i->inputs.empty()) {
if (at_top) {
return true;
} else {
bool res = VarLinksFromOp(fc_op->inputs[kkk], "relu");
// std::cout << fc_op->inputs[kkk]->Name() << "++++++-----" << kkk <<
// ":"
// << res << std::endl;
return res;
return VarLinksFromOp(fc_i, "relu");
}
}
}
// for (auto* fc_i : fc_op->inputs) {
// if (!fc_i->inputs.empty()) {
// std::cout << "++++++" << fc_op->inputs.size()<<std::endl;
// return VarLinksFromOp(fc_i, "relu");
// }
// }
return false;
};
......@@ -147,7 +133,6 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
Node* x, int repeated_times,
const std::string& act_type = "relu") -> bool {
for (int i = 0; i < repeated_times; ++i) {
// std::cout << "----" << i << std::endl;
if (!var_before_is_fc_act(x, act_type, i == repeated_times - 1)) {
return false;
}
......@@ -180,17 +165,9 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern,
x, std::max(1, num_fc - i - 1), "relu");
}
} else {
bool part1 =
var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
x->inputs.size() > 0;
if (x->Name() == "fc_0.tmp_1" && x->IsVar() && part1) {
// std::cout << "testes" << std::endl;
}
bool part2 = var_before_is_fc_act_repeated_n_times(x, i, "relu");
if (x->Name() == "fc_0.tmp_1") {
// std::cout << "========" << part1 << "," << part2 << std::endl;
}
return part1 && part2;
return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") &&
x->inputs.size() > 0 &&
var_before_is_fc_act_repeated_n_times(x, i, "relu");
}
},
name_scope + "/fc_in_" + std::to_string(i));
......@@ -394,7 +371,7 @@ std::unique_ptr<ir::Graph> RepeatedFCReluFusePass::ApplyImpl(
int fusion_count = 0;
for (int i = MAX_NUM_FC; i > 1; --i) {
fusion_count +=
BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(3), 3);
BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i);
}
AddStatis(fusion_count);
......
......@@ -190,8 +190,10 @@ void analysis_fuse_statis(bool use_zerocopy) {
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10);
ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse"));
EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2);
ASSERT_TRUE(fuse_statis.count("repeated_fc_relu"));
EXPECT_EQ(fuse_statis.at("repeated_fc_relu"), 2);
LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 195);
EXPECT_EQ(num_ops, 185);
}
// Check the fuse status
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册