提交 ce2464fd 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: UT for missing bias added. UTs...

MKLDNN conv + elementwise_add fusion: UT for missing bias added. UTs refactored. Some minor changes in the pass
上级 4e72ab41
...@@ -68,8 +68,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -68,8 +68,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto conv_op_has_bias = [](const Node& conv_op, auto conv_op_has_bias = [](const Node& conv_op) -> std::pair<bool, Node*> {
const Scope& scope) -> std::pair<bool, Node*> {
auto bias_input_names = conv_op.Op()->Inputs(); auto bias_input_names = conv_op.Op()->Inputs();
auto bias_it = bias_input_names.find("Bias"); auto bias_it = bias_input_names.find("Bias");
...@@ -116,7 +115,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -116,7 +115,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
bool has_bias; bool has_bias;
Node* conv_bias; Node* conv_bias;
std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op, *param_scope()); std::tie(has_bias, conv_bias) = conv_op_has_bias(*conv_op);
if (has_bias) { if (has_bias) {
op_desc.SetInput("Bias", {conv_bias->Name()}); op_desc.SetInput("Bias", {conv_bias->Name()});
......
...@@ -22,29 +22,22 @@ namespace paddle { ...@@ -22,29 +22,22 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
namespace {
constexpr int nodes_removed = 3; constexpr int nodes_removed = 3;
constexpr int nodes_added = 1; constexpr int nodes_added = 1;
void SetOp(ProgramDesc* prog, const std::string& type, void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs, const std::vector<std::pair<std::string, std::string>>& inputs,
const std::vector<std::string>& outputs) { const std::pair<std::string, std::string>& output) {
auto op = prog->MutableBlock(0)->AppendOp(); auto op = prog->MutableBlock(0)->AppendOp();
op->SetType(type); op->SetType(type);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", true); op->SetAttr("use_mkldnn", true);
op->SetInput("Input", {inputs[0]});
op->SetInput("Bias", {inputs[1]}); for (const auto& input : inputs) {
op->SetInput("Filter", {inputs[2]}); op->SetInput(input.first, {input.second});
op->SetOutput("Output", outputs);
} else if (type == "elementwise_add") {
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", outputs);
} else if (type == "relu" || type == "sigmoid") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", outputs);
} }
op->SetOutput(output.first, {output.second});
} }
struct IsReachable { struct IsReachable {
...@@ -96,30 +89,59 @@ struct IsReachable { ...@@ -96,30 +89,59 @@ struct IsReachable {
} }
}; };
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
auto build_program_desc = [&]() -> ProgramDesc { int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
}
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
const std::vector<std::string>& persistent_vars) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : std::vector<std::string>(
{"a", "b", "bias", "weights", "c", "d", "e", "f"})) { auto add_var_to_prog = [&prog](const std::string& var_name) -> VarDesc* {
auto* var = prog.MutableBlock(0)->Var(v); auto var = prog.MutableBlock(0)->Var(var_name);
var->SetType(proto::VarType::LOD_TENSOR); var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights" || v == "bias") {
var->SetPersistable(true); return var;
} };
for (const auto& v : transient_vars) {
add_var_to_prog(v);
} }
SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"}); for (const auto& v : persistent_vars) {
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"}); auto var = add_var_to_prog(v);
SetOp(&prog, "relu", {"d"}, {"e"}); var->SetPersistable(true);
}
return prog; return prog;
}; }
} // namespace
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
auto prog =
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
SetOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
auto prog = build_program_desc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable; IsReachable is_reachable;
EXPECT_TRUE(is_reachable(graph)("a", "relu")); EXPECT_TRUE(is_reachable(graph)("a", "relu"));
auto pass = auto pass =
...@@ -132,40 +154,45 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { ...@@ -132,40 +154,45 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num); current_nodes_num);
// Assert conv_relu op in newly generated graph
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) { AssertOpsCount(graph);
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
} }
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { TEST(ConvElementwiseAddMKLDNNFusePass,
auto build_program_desc = [&]() -> ProgramDesc { ConvolutionWithElementwiseAddReluNoBias) {
ProgramDesc prog; auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
for (auto& v : std::vector<std::string>({"a", "b", "bias", "weights"})) { SetOp(&prog, "conv2d", {{"Input", "a"}, {"Filter", "weights"}},
auto* var = prog.MutableBlock(0)->Var(v); {"Output", "b"});
var->SetType(proto::VarType::LOD_TENSOR); SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
if (v == "weights" || v == "bias") { SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
var->SetPersistable(true);
}
}
SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"}); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
return prog; IsReachable is_reachable;
};
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
auto pass =
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num);
AssertOpsCount(graph);
}
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
auto prog = BuildProgramDesc({"a", "b", "c", "d"}, {"bias", "weights"});
SetOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "b"});
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
auto prog = build_program_desc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable; IsReachable is_reachable;
...@@ -181,43 +208,19 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { ...@@ -181,43 +208,19 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num); current_nodes_num);
// Assert conv_relu op in newly generated graph AssertOpsCount(graph);
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
} }
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
auto build_program_desc = [&]() -> ProgramDesc { auto prog =
ProgramDesc prog; BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
for (auto& v : std::vector<std::string>( SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
{"a", "b", "bias", "weights", "c", "d", "e", "f"})) { SetOp(&prog, "conv2d",
auto* var = prog.MutableBlock(0)->Var(v); {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
var->SetType(proto::VarType::LOD_TENSOR); {"Output", "c"});
if (v.find("weights") || v.find("bias")) { SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "d"}}, {"Out", "e"});
var->SetPersistable(true); SetOp(&prog, "relu", {{"X", "e"}}, {"Out", "f"});
}
}
SetOp(&prog, "sigmoid", {"a"}, {"b"});
SetOp(&prog, "conv2d", {"b", "bias", "weights"}, {"c"});
SetOp(&prog, "elementwise_add", {"c", "d"}, {"e"});
SetOp(&prog, "relu", {"e"}, {"f"});
return prog;
};
auto prog = build_program_desc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
IsReachable is_reachable; IsReachable is_reachable;
...@@ -234,20 +237,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { ...@@ -234,20 +237,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
current_nodes_num); current_nodes_num);
// Assert conv_relu op in newly generated graph AssertOpsCount(graph);
int conv_count = 0;
int elementwise_add_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "conv2d") {
++conv_count;
}
if (node->IsOp() && node->Op()->Type() == "elementwise_add") {
++elementwise_add_count;
}
}
EXPECT_EQ(conv_count, 1);
EXPECT_EQ(elementwise_add_count, 0);
} }
} // namespace ir } // namespace ir
......
...@@ -1014,7 +1014,7 @@ PDNode *patterns::Conv::operator()() { ...@@ -1014,7 +1014,7 @@ PDNode *patterns::Conv::operator()() {
->AsOutput() ->AsOutput()
->assert_is_op_output("conv2d", "Output"); ->assert_is_op_output("conv2d", "Output");
conv_op->LinksFrom({input_var, /*bias_var,*/ filter_var}); conv_op->LinksFrom({input_var, filter_var});
conv_op->LinksTo({output_var}); conv_op->LinksTo({output_var});
return output_var; return output_var;
......
...@@ -617,7 +617,6 @@ struct Conv : public PatternBase { ...@@ -617,7 +617,6 @@ struct Conv : public PatternBase {
PATTERN_DECL_NODE(conv_op); PATTERN_DECL_NODE(conv_op);
PATTERN_DECL_NODE(conv_input); PATTERN_DECL_NODE(conv_input);
PATTERN_DECL_NODE(conv_bias);
PATTERN_DECL_NODE(conv_filter); PATTERN_DECL_NODE(conv_filter);
PATTERN_DECL_NODE(conv_residual_data); PATTERN_DECL_NODE(conv_residual_data);
PATTERN_DECL_NODE(conv_output); PATTERN_DECL_NODE(conv_output);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册