未验证 提交 33edb62a 编写于 作者: W Wangzheee 提交者: GitHub

pass_enhance_conv_concat_relu_mkldnn (#33867)

上级 7c4e5150
...@@ -23,7 +23,67 @@ namespace paddle { ...@@ -23,7 +23,67 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class Graph; ConvConcatReLUFusePass::ConvConcatReLUFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X") // Input("X"): vector<tensors>
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumGE(0)
.End();
AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
void ConvConcatReLUFusePass::FindConcatWithConvs( void ConvConcatReLUFusePass::FindConcatWithConvs(
ir::Graph* graph, ir::Graph* graph,
......
...@@ -18,9 +18,6 @@ ...@@ -18,9 +18,6 @@
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -31,10 +28,10 @@ namespace ir { ...@@ -31,10 +28,10 @@ namespace ir {
* to a: * to a:
* (multi ConvReLU) -> Concat -> next_op. * (multi ConvReLU) -> Concat -> next_op.
*/ */
class Graph;
class ConvConcatReLUFusePass : public FusePassBase { class ConvConcatReLUFusePass : public FusePassBase {
public: public:
ConvConcatReLUFusePass();
virtual ~ConvConcatReLUFusePass() {} virtual ~ConvConcatReLUFusePass() {}
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册