提交 347bf904 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: bias is also handled

上级 bf95ac36
...@@ -49,6 +49,7 @@ struct Pattern : public PatternBase { ...@@ -49,6 +49,7 @@ struct Pattern : public PatternBase {
struct Conv { struct Conv {
std::string op_name() const { return "conv2d"; } std::string op_name() const { return "conv2d"; }
std::string input_name() const { return "Input"; } std::string input_name() const { return "Input"; }
std::string bias_name() const { return "Bias"; }
std::string filter_name() const { return "Filter"; } std::string filter_name() const { return "Filter"; }
std::string residual_data_name() const { return "ResidualData"; } std::string residual_data_name() const { return "ResidualData"; }
std::string output_name() const { return "Output"; } std::string output_name() const { return "Output"; }
...@@ -60,13 +61,16 @@ struct Conv { ...@@ -60,13 +61,16 @@ struct Conv {
auto input_var = pattern->new_node(input_name()) auto input_var = pattern->new_node(input_name())
->assert_is_op_input(op_name(), input_name()); ->assert_is_op_input(op_name(), input_name());
auto bias_var = pattern->new_node(bias_name())
->assert_is_op_input(op_name(), bias_name());
auto filter_var = pattern->new_node(filter_name()) auto filter_var = pattern->new_node(filter_name())
->assert_is_op_input(op_name(), filter_name()); ->assert_is_op_input(op_name(), filter_name());
auto output_var = pattern->new_node(output_name()) auto output_var = pattern->new_node(output_name())
->assert_is_op_output(op_name(), output_name()); ->assert_is_op_output(op_name(), output_name());
conv_op->LinksFrom({input_var, filter_var}); conv_op->LinksFrom({input_var, bias_var, filter_var});
conv_op->LinksTo({output_var}); conv_op->LinksTo({output_var});
return output_var; return output_var;
...@@ -178,13 +182,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -178,13 +182,14 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_output->AsIntermediate(); conv_output->AsIntermediate();
auto fuse_conv = [&conv_pattern](Graph* g, Node* conv_input, auto fuse_conv = [&conv_pattern](Graph* g, Node* conv_input, Node* conv_bias,
Node* conv_filter, Node* conv_output, Node* conv_filter, Node* conv_output,
Node* elementwise_add_x) { Node* elementwise_add_x) {
OpDesc op_desc; OpDesc op_desc;
op_desc.SetType(conv_pattern.op_name()); op_desc.SetType(conv_pattern.op_name());
op_desc.SetInput(conv_pattern.input_name(), {conv_input->Name()}); op_desc.SetInput(conv_pattern.input_name(), {conv_input->Name()});
op_desc.SetInput(conv_pattern.bias_name(), {conv_bias->Name()});
op_desc.SetInput(conv_pattern.filter_name(), {conv_filter->Name()}); op_desc.SetInput(conv_pattern.filter_name(), {conv_filter->Name()});
op_desc.SetInput(conv_pattern.residual_data_name(), op_desc.SetInput(conv_pattern.residual_data_name(),
{elementwise_add_x->Name()}); {elementwise_add_x->Name()});
...@@ -196,6 +201,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -196,6 +201,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
auto fused_conv_op = g->CreateOpNode(&op_desc); auto fused_conv_op = g->CreateOpNode(&op_desc);
patterns::LinkNodes(conv_input, fused_conv_op); patterns::LinkNodes(conv_input, fused_conv_op);
patterns::LinkNodes(conv_bias, fused_conv_op);
patterns::LinkNodes(conv_filter, fused_conv_op); patterns::LinkNodes(conv_filter, fused_conv_op);
patterns::LinkNodes(elementwise_add_x, fused_conv_op); patterns::LinkNodes(elementwise_add_x, fused_conv_op);
patterns::LinkNodes(fused_conv_op, conv_output); patterns::LinkNodes(fused_conv_op, conv_output);
...@@ -208,6 +214,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -208,6 +214,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
conv_pattern.op_name()); conv_pattern.op_name());
auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr, auto conv_input = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.input_name()); conv_pattern.input_name());
auto conv_bias = patterns::GetNodeFromSubgraph(subgraph, pattern_ptr,
conv_pattern.bias_name());
auto conv_filter = patterns::GetNodeFromSubgraph( auto conv_filter = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, conv_pattern.filter_name()); subgraph, pattern_ptr, conv_pattern.filter_name());
auto conv_output = patterns::GetNodeFromSubgraph( auto conv_output = patterns::GetNodeFromSubgraph(
...@@ -220,7 +228,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -220,7 +228,8 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
auto elementwise_add_out = patterns::GetNodeFromSubgraph( auto elementwise_add_out = patterns::GetNodeFromSubgraph(
subgraph, pattern_ptr, elementwise_add_pattern.out_name()); subgraph, pattern_ptr, elementwise_add_pattern.out_name());
fuse_conv(g, conv_input, conv_filter, conv_output, elementwise_add_x); fuse_conv(g, conv_input, conv_bias, conv_filter, conv_output,
elementwise_add_x);
patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output); patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op}); GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
}; };
......
...@@ -34,7 +34,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, ...@@ -34,7 +34,8 @@ void SetOp(ProgramDesc* prog, const std::string& type,
if (type == "conv2d") { if (type == "conv2d") {
op->SetAttr("use_mkldnn", true); op->SetAttr("use_mkldnn", true);
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]}); op->SetInput("Bias", {inputs[1]});
op->SetInput("Filter", {inputs[2]});
op->SetOutput("Output", outputs); op->SetOutput("Output", outputs);
} else if (type == "elementwise_add") { } else if (type == "elementwise_add") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
...@@ -98,8 +99,8 @@ struct IsReachable { ...@@ -98,8 +99,8 @@ struct IsReachable {
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
auto build_program_desc = [&]() -> ProgramDesc { auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : for (auto& v : std::vector<std::string>(
std::vector<std::string>({"a", "b", "weights", "c", "d", "e"})) { {"a", "b", "bias", "weights", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v); auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR); var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights") { if (v == "weights") {
...@@ -107,7 +108,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { ...@@ -107,7 +108,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
} }
} }
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"}); SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"}); SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
SetOp(&prog, "relu", {"d"}, {"e"}); SetOp(&prog, "relu", {"d"}, {"e"});
...@@ -150,7 +151,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) { ...@@ -150,7 +151,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
auto build_program_desc = [&]() -> ProgramDesc { auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "b", "weights"})) { for (auto& v : std::vector<std::string>({"a", "b", "bias", "weights"})) {
auto* var = prog.MutableBlock(0)->Var(v); auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR); var->SetType(proto::VarType::LOD_TENSOR);
if (v == "weights" || v == "bias") { if (v == "weights" || v == "bias") {
...@@ -158,7 +159,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { ...@@ -158,7 +159,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
} }
} }
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"}); SetOp(&prog, "conv2d", {"a", "bias", "weights"}, {"b"});
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"}); SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
return prog; return prog;
...@@ -199,8 +200,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) { ...@@ -199,8 +200,8 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
auto build_program_desc = [&]() -> ProgramDesc { auto build_program_desc = [&]() -> ProgramDesc {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : for (auto& v : std::vector<std::string>(
std::vector<std::string>({"a", "b", "weights", "c", "d", "e", "f"})) { {"a", "b", "bias", "weights", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v); auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR); var->SetType(proto::VarType::LOD_TENSOR);
if (v.find("weights")) { if (v.find("weights")) {
...@@ -209,7 +210,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) { ...@@ -209,7 +210,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
} }
SetOp(&prog, "sigmoid", {"a"}, {"b"}); SetOp(&prog, "sigmoid", {"a"}, {"b"});
SetOp(&prog, "conv2d", {"b", "weights"}, {"c"}); SetOp(&prog, "conv2d", {"b", "bias", "weights"}, {"c"});
SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"}); SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"});
SetOp(&prog, "relu", {"e"}, {"f"}); SetOp(&prog, "relu", {"e"}, {"f"});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册