未验证 提交 78ab656c 编写于 作者: F feng_shuai 提交者: GitHub

depthwise_conv_mkl_pass (#33936)

上级 033d736d
...@@ -31,6 +31,47 @@ class Graph; ...@@ -31,6 +31,47 @@ class Graph;
PADDLE_ENFORCE_NOT_NULL( \ PADDLE_ENFORCE_NOT_NULL( \
id, platform::errors::InvalidArgument("Subgraph has no node %s.", #id)); id, platform::errors::InvalidArgument("Subgraph has no node %s.", #id));
DepthwiseConvMKLDNNPass::DepthwiseConvMKLDNNPass() {
AddOpCompat(OpCompat("depthwise_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.IsTensor()
.End()
.AddInput("ResidualData")
.IsOptional()
.IsTensor()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
// mobilenet-ssd has no "padding_algorithm"
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End();
}
void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const { void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
...@@ -45,6 +86,10 @@ void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const { ...@@ -45,6 +86,10 @@ void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
int found_depthwise_conv_mkldnn_count = 0; int found_depthwise_conv_mkldnn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass op compat failed.";
return;
}
VLOG(3) << "handle DepthwiseConvMKLDNN fuse"; VLOG(3) << "handle DepthwiseConvMKLDNN fuse";
GET_NODE(depthwise_conv, (*pattern)); GET_NODE(depthwise_conv, (*pattern));
depthwise_conv->Op()->SetType("conv2d"); depthwise_conv->Op()->SetType("conv2d");
......
...@@ -24,6 +24,7 @@ class Graph; ...@@ -24,6 +24,7 @@ class Graph;
class DepthwiseConvMKLDNNPass : public FusePassBase { class DepthwiseConvMKLDNNPass : public FusePassBase {
public: public:
DepthwiseConvMKLDNNPass();
virtual ~DepthwiseConvMKLDNNPass() {} virtual ~DepthwiseConvMKLDNNPass() {}
protected: protected:
......
...@@ -29,10 +29,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -29,10 +29,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetType(type); op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name); op->SetAttr("name", name);
op->SetAttr("groups", 1);
op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
op->SetAttr("data_format", std::string("NCHW"));
op->SetAttr("strides", std::vector<int>({1, 1}));
op->SetAttr("dilations", std::vector<int>({1, 1}));
op->SetAttr("paddings", std::vector<int>({0, 0}));
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]}); op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]}); op->SetInput("Bias", {inputs[2]});
op->SetOutput("Out", outputs); op->SetOutput("Output", outputs);
} }
// (a, weights, bias)->depthwise conv mkldnn->b // (a, weights, bias)->depthwise conv mkldnn->b
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册