未验证 提交 defae0ef 编写于 作者: W wenbin 提交者: GitHub

ConvAffineChannelFusePass and ConvEltwiseAddAffineChannelFusePass, test=allcase (#33840)

上级 560617ca
...@@ -94,6 +94,77 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight, ...@@ -94,6 +94,77 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
} }
} }
ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"})
.End();
AddOpCompat(OpCompat("affine_channel"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("data_layout")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
}
void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
...@@ -116,6 +187,11 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -116,6 +187,11 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
int found_conv_ac_count = 0; int found_conv_ac_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "ConvAffineChannelFusePass in op compat failed.";
return;
}
VLOG(4) << "handle ConvAffineChannel fuse"; VLOG(4) << "handle ConvAffineChannel fuse";
GET_CONV_BN_NODES(conv_ac_pattern); GET_CONV_BN_NODES(conv_ac_pattern);
...@@ -149,6 +225,12 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -149,6 +225,12 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetType("elementwise_add"); desc.SetType("elementwise_add");
desc.SetAttr("axis", 1); desc.SetAttr("axis", 1);
desc.SetAttr("use_mkldnn", conv->Op()->GetAttrIfExists<bool>("use_mkldnn")); desc.SetAttr("use_mkldnn", conv->Op()->GetAttrIfExists<bool>("use_mkldnn"));
if (!IsCompat(desc)) {
LOG(WARNING) << "ConvAffineChannelFusePass in out fc op compat failed.";
return;
}
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel}); GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel});
...@@ -164,6 +246,75 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -164,6 +246,75 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_conv_ac_count); AddStatis(found_conv_ac_count);
} }
ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"})
.End();
AddOpCompat(OpCompat("affine_channel"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("data_layout")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
}
void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
...@@ -186,6 +337,12 @@ void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -186,6 +337,12 @@ void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
int found_conv_ac_count = 0; int found_conv_ac_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "ConvEltwiseAddAffineChannelFusePass in op compat failed.";
return;
}
VLOG(4) << "handle ConvBN fuse"; VLOG(4) << "handle ConvBN fuse";
GET_CONV_BN_NODES(conv_ac_pattern); GET_CONV_BN_NODES(conv_ac_pattern);
......
...@@ -31,6 +31,7 @@ class Graph; ...@@ -31,6 +31,7 @@ class Graph;
class ConvAffineChannelFusePass : public FusePassBase { class ConvAffineChannelFusePass : public FusePassBase {
public: public:
ConvAffineChannelFusePass();
virtual ~ConvAffineChannelFusePass() {} virtual ~ConvAffineChannelFusePass() {}
protected: protected:
...@@ -40,6 +41,7 @@ class ConvAffineChannelFusePass : public FusePassBase { ...@@ -40,6 +41,7 @@ class ConvAffineChannelFusePass : public FusePassBase {
class ConvEltwiseAddAffineChannelFusePass : public FusePassBase { class ConvEltwiseAddAffineChannelFusePass : public FusePassBase {
public: public:
ConvEltwiseAddAffineChannelFusePass();
virtual ~ConvEltwiseAddAffineChannelFusePass() {} virtual ~ConvEltwiseAddAffineChannelFusePass() {}
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册