未验证 提交 dd33d28d 编写于 作者: W Wangzheee 提交者: GitHub

[pass_enhance] conv_elementwise_add_mkldnn_fuse_pass (#33931)

上级 ae74c404
...@@ -51,7 +51,7 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -51,7 +51,7 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse"; VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse";
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass op compat failed."; LOG(WARNING) << "conv_activation_mkldnn_fuse_pass op compat failed.";
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
...@@ -114,6 +114,10 @@ ConvActivationFusePass::ConvActivationFusePass() { ...@@ -114,6 +114,10 @@ ConvActivationFusePass::ConvActivationFusePass() {
.IsOptional() .IsOptional()
.IsTensor() .IsTensor()
.End() .End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output") .AddOutput("Output")
.IsTensor() .IsTensor()
.End() .End()
......
...@@ -81,16 +81,72 @@ boost::optional<T> HasAttribute(const Node& op, const std::string& attr) { ...@@ -81,16 +81,72 @@ boost::optional<T> HasAttribute(const Node& op, const std::string& attr) {
return boost::none; return boost::none;
} }
ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 0})
.End();
}
ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle( ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle(
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func, const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func,
const ResidualConnectionMKLDNNFusePass::IdentityConvFunc& const ResidualConnectionMKLDNNFusePass::IdentityConvFunc&
get_node_from_conv_op, get_node_from_conv_op,
const ResidualConnectionMKLDNNFusePass::IdentityElementwiseAddFunc& const ResidualConnectionMKLDNNFusePass::IdentityElementwiseAddFunc&
get_node_from_elementwise_add_op) get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass)
: fusion_stats{std::make_shared<int>(0)}, : fusion_stats{std::make_shared<int>(0)},
can_fuse_func{can_fuse_func}, can_fuse_func{can_fuse_func},
get_node_from_conv_op{get_node_from_conv_op}, get_node_from_conv_op{get_node_from_conv_op},
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {} get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
pass_{pass} {}
void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
...@@ -102,6 +158,11 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( ...@@ -102,6 +158,11 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
Node* elementwise_add_op; Node* elementwise_add_op;
Node* elementwise_add_identity; Node* elementwise_add_identity;
Node* elementwise_add_out; Node* elementwise_add_out;
if (!pass_->IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
std::tie(conv_op, conv_input, conv_filter, conv_output) = std::tie(conv_op, conv_input, conv_filter, conv_output) =
get_node_from_conv_op(subgraph); get_node_from_conv_op(subgraph);
...@@ -133,12 +194,14 @@ ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle( ...@@ -133,12 +194,14 @@ ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle(
const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc& const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc&
get_node_from_conv_y_op, get_node_from_conv_y_op,
const ResidualConnectionMKLDNNFusePass::ProjectionElementwiseAddFunc& const ResidualConnectionMKLDNNFusePass::ProjectionElementwiseAddFunc&
get_node_from_elementwise_add_op) get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass)
: fusion_stats{std::make_shared<int>(0)}, : fusion_stats{std::make_shared<int>(0)},
can_fuse_func{can_fuse_func}, can_fuse_func{can_fuse_func},
get_node_from_conv_x_op{get_node_from_conv_x_op}, get_node_from_conv_x_op{get_node_from_conv_x_op},
get_node_from_conv_y_op{get_node_from_conv_y_op}, get_node_from_conv_y_op{get_node_from_conv_y_op},
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {} get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
pass_{pass} {}
void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()( void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
...@@ -155,6 +218,12 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()( ...@@ -155,6 +218,12 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
Node* elementwise_add_op; Node* elementwise_add_op;
Node* elementwise_add_out; Node* elementwise_add_out;
if (!pass_->IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
std::tie(conv_x_op, conv_x_input, conv_x_filter, conv_x_output) = std::tie(conv_x_op, conv_x_input, conv_x_filter, conv_x_output) =
get_node_from_conv_x_op(subgraph); get_node_from_conv_x_op(subgraph);
std::tie(conv_y_op, conv_y_input, conv_y_filter, conv_y_output) = std::tie(conv_y_op, conv_y_input, conv_y_filter, conv_y_output) =
...@@ -247,7 +316,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -247,7 +316,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { [this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
return GetNodesFromConv(conv_pattern, subgraph); return GetNodesFromConv(conv_pattern, subgraph);
}, },
get_node_from_elementwise_add); get_node_from_elementwise_add, this);
} }
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
...@@ -284,7 +353,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -284,7 +353,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { [this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
return GetNodesFromConv(conv_pattern, subgraph); return GetNodesFromConv(conv_pattern, subgraph);
}, },
get_node_from_elementwise_add); get_node_from_elementwise_add, this);
} }
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
...@@ -325,7 +394,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -325,7 +394,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
&conv_y_pattern](const GraphPatternDetector::subgraph_t& subgraph) { &conv_y_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
return GetNodesFromConv(conv_y_pattern, subgraph); return GetNodesFromConv(conv_y_pattern, subgraph);
}, },
get_node_from_elementwise_add); get_node_from_elementwise_add, this);
} }
void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
......
...@@ -84,7 +84,6 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ...@@ -84,7 +84,6 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
auto can_fuse = [this](Node* op1, Node* op2) -> bool { auto can_fuse = [this](Node* op1, Node* op2) -> bool {
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN; return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
}; };
auto fuse_handle = HandleType{can_fuse, std::forward<OpFuncs>(op_funcs)...}; auto fuse_handle = HandleType{can_fuse, std::forward<OpFuncs>(op_funcs)...};
(*gpd)(graph, fuse_handle); (*gpd)(graph, fuse_handle);
...@@ -96,7 +95,8 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ...@@ -96,7 +95,8 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
IdentityFuseHandle( IdentityFuseHandle(
const CanFuseFunc& can_fuse_func, const CanFuseFunc& can_fuse_func,
const IdentityConvFunc& get_node_from_conv_op, const IdentityConvFunc& get_node_from_conv_op,
const IdentityElementwiseAddFunc& get_node_from_elementwise_add_op); const IdentityElementwiseAddFunc& get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass);
void operator()(const GraphPatternDetector::subgraph_t& subgraph, void operator()(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph); Graph* graph);
...@@ -107,6 +107,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ...@@ -107,6 +107,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
CanFuseFunc can_fuse_func; CanFuseFunc can_fuse_func;
IdentityConvFunc get_node_from_conv_op; IdentityConvFunc get_node_from_conv_op;
IdentityElementwiseAddFunc get_node_from_elementwise_add_op; IdentityElementwiseAddFunc get_node_from_elementwise_add_op;
const ResidualConnectionMKLDNNFusePass* pass_;
}; };
struct ProjectionFuseHandle { struct ProjectionFuseHandle {
...@@ -114,7 +115,8 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ...@@ -114,7 +115,8 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
const CanFuseFunc& can_fuse_func, const CanFuseFunc& can_fuse_func,
const ProjectionConvFunc& get_node_from_conv_x_op, const ProjectionConvFunc& get_node_from_conv_x_op,
const ProjectionConvFunc& get_node_from_conv_y_op, const ProjectionConvFunc& get_node_from_conv_y_op,
const ProjectionElementwiseAddFunc& get_node_from_elementwise_add_op); const ProjectionElementwiseAddFunc& get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass);
void operator()(const GraphPatternDetector::subgraph_t& subgraph, void operator()(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph); Graph* graph);
...@@ -126,9 +128,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ...@@ -126,9 +128,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
ProjectionConvFunc get_node_from_conv_x_op; ProjectionConvFunc get_node_from_conv_x_op;
ProjectionConvFunc get_node_from_conv_y_op; ProjectionConvFunc get_node_from_conv_y_op;
ProjectionElementwiseAddFunc get_node_from_elementwise_add_op; ProjectionElementwiseAddFunc get_node_from_elementwise_add_op;
const ResidualConnectionMKLDNNFusePass* pass_;
}; };
public: public:
ResidualConnectionMKLDNNFusePass();
virtual ~ResidualConnectionMKLDNNFusePass() {} virtual ~ResidualConnectionMKLDNNFusePass() {}
protected: protected:
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_test_util.h" #include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
...@@ -25,16 +26,67 @@ namespace ir { ...@@ -25,16 +26,67 @@ namespace ir {
constexpr int nodes_removed = 3; constexpr int nodes_removed = 3;
constexpr int nodes_added = 1; constexpr int nodes_added = 1;
OpDesc* Create_Op_con2d(ProgramDesc* prog, const std::string& op_type_name,
const std::vector<test::InOutVarNamePair>& inputs,
const std::vector<test::InOutVarNamePair>& outputs,
const bool use_mkldnn = true) {
auto* op = prog->MutableBlock(0)->AppendOp();
const std::vector<int> strides({1, 1});
const std::vector<int> paddings({0, 0});
const std::vector<int> dilations({1, 1});
op->SetType(op_type_name);
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("strides", strides);
op->SetAttr("groups", 1);
op->SetAttr("paddings", paddings);
op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
op->SetAttr("dilations", dilations);
op->SetAttr("data_format", std::string("NCHW"));
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
for (const auto& output : outputs) {
op->SetOutput(output.first, {output.second});
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return op;
}
OpDesc* Create_Op_elemntwise_add(
ProgramDesc* prog, const std::string& op_type_name,
const std::vector<test::InOutVarNamePair>& inputs,
const std::vector<test::InOutVarNamePair>& outputs,
bool use_mkldnn = true) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(op_type_name);
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("axis", -1);
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
for (const auto& output : outputs) {
op->SetOutput(output.first, {output.second});
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return op;
}
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) { TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) {
auto prog = auto prog =
test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"}); test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d", Create_Op_con2d(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{{"Output", "c"}}); {{"Output", "c"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}},
{{"Out", "d"}}); {{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog); Graph graph(prog);
...@@ -53,17 +105,17 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ...@@ -53,17 +105,17 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
// right branch // right branch
test::CreateOp(&prog, "conv2d", Create_Op_con2d(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{{"Output", "c"}}); {{"Output", "c"}});
// left branch // left branch
test::CreateOp(&prog, "conv2d", Create_Op_con2d(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias2"}, {"Filter", "weights2"}}, {{"Input", "a"}, {"Bias", "bias2"}, {"Filter", "weights2"}},
{{"Output", "f"}}); {{"Output", "f"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "f"}, {"Y", "c"}}, Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "f"}, {"Y", "c"}},
{{"Out", "d"}}); {{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog); Graph graph(prog);
...@@ -80,10 +132,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ...@@ -80,10 +132,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, Create_Op_con2d(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{{"Output", "c"}}); {{"Output", "c"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}},
{{"Out", "d"}}); {{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog); Graph graph(prog);
...@@ -100,12 +152,12 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) { ...@@ -100,12 +152,12 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) {
test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"}); test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d", Create_Op_con2d(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{{"Output", "c"}}); {{"Output", "c"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}},
{{"Out", "d"}}); {{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog); Graph graph(prog);
...@@ -122,10 +174,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ...@@ -122,10 +174,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, Create_Op_con2d(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{{"Output", "c"}}); {{"Output", "c"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}},
{{"Out", "d"}}); {{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog); Graph graph(prog);
...@@ -142,14 +194,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) { ...@@ -142,14 +194,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
test::BuildProgramDesc({"a", "b", "c", "d", "e", "f", "g"}, {"weights"}); test::BuildProgramDesc({"a", "b", "c", "d", "e", "f", "g"}, {"weights"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, Create_Op_con2d(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{{"Output", "c"}}); {{"Output", "c"}});
test::CreateOp(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}}, Create_Op_con2d(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}},
{{"Output", "e"}}); {{"Output", "e"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}}, Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}},
{{"Out", "f"}}); {{"Out", "f"}});
test::CreateOp(&prog, "relu", {{"X", "f"}}, {{"Out", "g"}}); test::CreateOp(&prog, "relu", {{"X", "f"}}, {{"Out", "g"}});
Graph graph(prog); Graph graph(prog);
......
...@@ -67,6 +67,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() { ...@@ -67,6 +67,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() {
.AddAttr("paddings") .AddAttr("paddings")
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("groups") .AddAttr("groups")
...@@ -75,6 +76,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() { ...@@ -75,6 +76,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() {
.AddAttr("dilations") .AddAttr("dilations")
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsOptional()
.IsStringIn({"NCHW", "NHWC"}) .IsStringIn({"NCHW", "NHWC"})
.End(); .End();
} }
......
...@@ -15,6 +15,10 @@ def { ...@@ -15,6 +15,10 @@ def {
} }
} }
extra { extra {
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs { attrs {
name: "out_threshold" name: "out_threshold"
type: FLOAT type: FLOAT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册