提交 ce2464fd 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: UT for missing bias added. UTs...

MKLDNN conv + elementwise_add fusion: UT for missing bias added. UTs refactored. Some minor changes in the pass
上级 4e72ab41
......@@ -68,8 +68,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output->AsIntermediate();
auto conv_op_has_bias = [](const Node& conv_op,
const Scope& scope) -> std::pair<bool, Node*> {
auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> {
auto bias_input_names = conv_op.Op()->Inputs();
auto bias_it = bias_input_names.find("Bias");
......@@ -116,7 +115,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
bool has_bias;
Node* conv_bias;
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op, *param_scope());
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op);
if (has_bias) {
op_desc.SetInput("Bias", {conv_bias->Name()});
......
......@@ -22,29 +22,22 @@ namespace paddle {
namespace framework {
namespace ir {
namespace {
constexpr int nodes_removed = 3;
constexpr int nodes_added = 1;
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
const std::vector<std::pair<std::string, std::string>>& inputs,
const std::pair<std::string, std::string>& output) {
auto op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", true);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", true);
op->SetInput("Input", {inputs[0]});
op->SetInput("Bias", {inputs[1]});
op->SetInput("Filter", {inputs[2]});
op->SetOutput("Output", outputs);
} else if (type == "elementwise_add") {
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", outputs);
} else if (type == "relu" || type == "sigmoid") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", outputs);
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
op->SetOutput(output.first, {output.second});
}
struct IsReachable {
......@@ -96,30 +89,59 @@ struct IsReachable {
}
};
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v : std::vector<std::string>(
{"a", "b", "bias", "weights", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
}
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
const std::vector<std::string>& persistent_vars) {
ProgramDesc prog;
SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
SetOp(&prog, "relu", {"d"}, {"e"});
auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
auto var = prog.MutableBlock(0)->Var(var_name);
var->SetType(proto::VarType::LOD_TENSOR);
return prog;
return var;
};
auto prog = build_program_desc();
for (const auto& v : transient_vars) {
add_var_to_prog(v);
}
for (const auto& v : persistent_vars) {
auto var = add_var_to_prog(v);
var->SetPersistable(true);
}
return prog;
}
} // namespace
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
auto prog =
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
SetOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
auto pass =
......@@ -132,40 +154,45 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
AssertOpsCount(graph);
}
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "b", "bias", "weights"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights" || v == "bias") {
var->SetPersistable(true);
}
}
TEST(ConvElementwiseAddMKLDNNFusePass,
ConvolutionWithElementwiseAddReluNoBias) {
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
SetOp(&prog, "conv2d", {{"Input", "a"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
return prog;
};
IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
auto prog = BuildProgramDesc({"a", "b", "c", "d"}, {"bias", "weights"});
SetOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
auto prog = build_program_desc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
......@@ -181,43 +208,19 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
AssertOpsCount(graph);
}
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog;
for (auto& v : std::vector<std::string>(
{"a", "b", "bias", "weights", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
if (v.find("weights") || v.find("bias")) {
var->SetPersistable(true);
}
}
SetOp(&prog, "sigmoid", {"a"}, {"b"});
SetOp(&prog, "conv2d", {"b", "bias", "weights"}, {"c"});
SetOp(&prog, "elementwise_add", {"c", "d"}, {"e"});
SetOp(&prog, "relu", {"e"}, {"f"});
return prog;
};
auto prog =
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
SetOp(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "c"});
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "d"}}, {"Out", "e"});
SetOp(&prog, "relu", {{"X", "e"}}, {"Out", "f"});
auto prog = build_program_desc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable;
......@@ -234,20 +237,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
AssertOpsCount(graph);
}
} // namespace ir
......
......@@ -1014,7 +1014,7 @@ PDNode *patterns::Conv::operator()() {
->AsOutput()
->assert_is_op_output("conv2d", "Output");
conv_op->LinksFrom({input_var, /*bias_var,*/ filter_var});
conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var});
return output_var;
......
......@@ -617,7 +617,6 @@ struct Conv : public PatternBase {
PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_input);
PATTERN_DECL_NODE(conv_bias);
PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(conv_residual_data);
PATTERN_DECL_NODE(conv_output);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册