未验证 提交 930ca3f4 编写于 作者: W Wangzheee 提交者: GitHub

pass enhance (#33661)

上级 39556a44
...@@ -140,6 +140,91 @@ void recompute_bias_and_weights(const Scope* scope, ...@@ -140,6 +140,91 @@ void recompute_bias_and_weights(const Scope* scope,
} }
} }
ConvBNFusePass::ConvBNFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.End()
.AddInput("ResidualData")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("batch_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddInput("Mean")
.IsTensor()
.End()
.AddInput("Variance")
.IsTensor()
.End()
.AddOutput("MeanOut")
.IsTensor()
.End()
.AddOutput("VarianceOut")
.IsTensor()
.End()
.AddOutput("SavedMean")
.IsTensor()
.End()
.AddOutput("SavedVariance")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumLE(0.001f)
.IsNumGE(0.0f)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
}
void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
...@@ -161,8 +246,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -161,8 +246,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
int found_conv_bn_count = 0; int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle " + conv_type() + "BN fuse"; VLOG(4) << "handle " + conv_type() + "BN fuse";
// conv, batch_norm, // conv, batch_norm,
// conv_weight, conv_out, // conv_weight, conv_out,
// bn_scale, bn_bias, bn_mean, bn_variance, // bn_scale, bn_bias, bn_mean, bn_variance,
...@@ -236,6 +324,10 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -236,6 +324,10 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
} }
conv->Op()->SetOutput("Output", conv->Op()->SetOutput("Output",
std::vector<std::string>({bn_out->Name()})); std::vector<std::string>({bn_out->Name()}));
if (!IsCompat(*conv->Op())) {
LOG(WARNING) << "conv_bn fuse pass in out conv op compat failed.";
return;
}
GraphSafeRemoveNodes( GraphSafeRemoveNodes(
graph, graph,
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, {conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
...@@ -251,6 +343,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -251,6 +343,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()})); desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
desc.SetType("elementwise_add"); desc.SetType("elementwise_add");
desc.SetAttr("axis", 1); desc.SetAttr("axis", 1);
if (!IsCompat(desc)) {
LOG(WARNING)
<< "conv_bn fuse pass in out elementwise_add op compat failed.";
return;
}
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance, GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance,
...@@ -269,6 +366,91 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -269,6 +366,91 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_conv_bn_count); AddStatis(found_conv_bn_count);
} }
ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.End()
.AddInput("ResidualData")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("batch_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddInput("Mean")
.IsTensor()
.End()
.AddInput("Variance")
.IsTensor()
.End()
.AddOutput("MeanOut")
.IsTensor()
.End()
.AddOutput("VarianceOut")
.IsTensor()
.End()
.AddOutput("SavedMean")
.IsTensor()
.End()
.AddOutput("SavedVariance")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumLE(0.001f)
.IsNumGE(0.0f)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
}
void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
...@@ -290,8 +472,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -290,8 +472,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
int found_conv_bn_count = 0; int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle " + conv_type() + "BN fuse"; VLOG(4) << "handle " + conv_type() + "BN fuse";
// conv, batch_norm, // conv, batch_norm,
// conv_weight, conv_out, // conv_weight, conv_out,
// bn_scale, bn_bias, bn_mean, bn_variance, // bn_scale, bn_bias, bn_mean, bn_variance,
...@@ -361,7 +546,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -361,7 +546,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
// Update the elementwise_add node // Update the elementwise_add node
eltwise->Op()->SetAttr("axis", 1); eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()})); eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
if (!IsCompat(*eltwise->Op())) {
LOG(WARNING)
<< "conv_eltwise_bn fuse pass in out eltwise op compat failed.";
return;
}
GraphSafeRemoveNodes( GraphSafeRemoveNodes(
graph, graph,
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
...@@ -377,6 +566,70 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -377,6 +566,70 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_conv_bn_count); AddStatis(found_conv_bn_count);
} }
ConvTransposeBNFusePass::ConvTransposeBNFusePass() {
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() {
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -31,6 +31,7 @@ class Graph; ...@@ -31,6 +31,7 @@ class Graph;
class ConvBNFusePass : public FusePassBase { class ConvBNFusePass : public FusePassBase {
public: public:
ConvBNFusePass();
virtual ~ConvBNFusePass() {} virtual ~ConvBNFusePass() {}
virtual std::string conv_type() const { return "conv2d"; } virtual std::string conv_type() const { return "conv2d"; }
...@@ -41,6 +42,7 @@ class ConvBNFusePass : public FusePassBase { ...@@ -41,6 +42,7 @@ class ConvBNFusePass : public FusePassBase {
class ConvEltwiseAddBNFusePass : public FusePassBase { class ConvEltwiseAddBNFusePass : public FusePassBase {
public: public:
ConvEltwiseAddBNFusePass();
virtual ~ConvEltwiseAddBNFusePass() {} virtual ~ConvEltwiseAddBNFusePass() {}
virtual std::string conv_type() const { return "conv2d"; } virtual std::string conv_type() const { return "conv2d"; }
...@@ -51,11 +53,15 @@ class ConvEltwiseAddBNFusePass : public FusePassBase { ...@@ -51,11 +53,15 @@ class ConvEltwiseAddBNFusePass : public FusePassBase {
class ConvTransposeBNFusePass : public ConvBNFusePass { class ConvTransposeBNFusePass : public ConvBNFusePass {
public: public:
ConvTransposeBNFusePass();
virtual ~ConvTransposeBNFusePass() {}
std::string conv_type() const { return "conv2d_transpose"; } std::string conv_type() const { return "conv2d_transpose"; }
}; };
class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass { class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
public: public:
ConvTransposeEltwiseAddBNFusePass();
virtual ~ConvTransposeEltwiseAddBNFusePass() {}
std::string conv_type() const { return "conv2d_transpose"; } std::string conv_type() const { return "conv2d_transpose"; }
}; };
......
...@@ -39,28 +39,49 @@ struct Layers { ...@@ -39,28 +39,49 @@ struct Layers {
} }
VarDesc* conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias, VarDesc* conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias,
bool use_cudnn = false) { int groups = 1, std::vector<int> strides = {1, 1},
std::vector<int> paddings = {0, 0},
std::string padding_algorithm = "EXPLICIT",
std::vector<int> dilations = {1, 1},
std::string data_format = "NCHW", bool use_cudnn = false) {
VarDesc* out = lod_tensor(unique_name()); VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp(); OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("conv2d"); op->SetType("conv2d");
op->SetInput("Input", {input->Name()}); op->SetInput("Input", {input->Name()});
op->SetInput("Filter", {filter->Name()}); op->SetInput("Filter", {filter->Name()});
op->SetInput("Bias", {bias->Name()}); op->SetInput("Bias", {bias->Name()});
op->SetOutput("Out", {out->Name()}); op->SetOutput("Output", {out->Name()});
op->SetAttr("use_cudnn", use_cudnn); op->SetAttr("use_cudnn", use_cudnn);
op->SetAttr("groups", groups);
op->SetAttr("strides", strides);
op->SetAttr("paddings", paddings);
op->SetAttr("padding_algorithm", padding_algorithm);
op->SetAttr("dilations", dilations);
op->SetAttr("data_format", data_format);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward)); static_cast<int>(OpRole::kForward));
return out; return out;
} }
VarDesc* conv2d_transpose(VarDesc* input, VarDesc* filter, VarDesc* bias) { VarDesc* conv2d_transpose(VarDesc* input, VarDesc* filter, VarDesc* bias,
int groups = 1, std::vector<int> strides = {1, 1},
std::vector<int> paddings = {0, 0},
std::string padding_algorithm = "EXPLICIT",
std::vector<int> dilations = {1, 1},
std::string data_format = "NCHW") {
VarDesc* out = lod_tensor(unique_name()); VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp(); OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("conv2d_transpose"); op->SetType("conv2d_transpose");
op->SetInput("Input", {input->Name()}); op->SetInput("Input", {input->Name()});
op->SetInput("Filter", {filter->Name()}); op->SetInput("Filter", {filter->Name()});
op->SetInput("Bias", {bias->Name()}); op->SetInput("Bias", {bias->Name()});
op->SetOutput("Out", {out->Name()}); op->SetOutput("Output", {out->Name()});
op->SetAttr("groups", groups);
op->SetAttr("strides", strides);
op->SetAttr("paddings", paddings);
op->SetAttr("padding_algorithm", padding_algorithm);
op->SetAttr("dilations", dilations);
op->SetAttr("data_format", data_format);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward)); static_cast<int>(OpRole::kForward));
return out; return out;
......
...@@ -42,6 +42,10 @@ extra { ...@@ -42,6 +42,10 @@ extra {
inputs { inputs {
name: "MomentumTensor" name: "MomentumTensor"
} }
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs { attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN type: BOOLEAN
......
...@@ -41,6 +41,14 @@ def { ...@@ -41,6 +41,14 @@ def {
} }
} }
extra { extra {
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "skip_quant"
type: BOOLEAN
}
attrs { attrs {
name: "is_test" name: "is_test"
type: BOOLEAN type: BOOLEAN
......
...@@ -12,6 +12,14 @@ extra { ...@@ -12,6 +12,14 @@ extra {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN type: BOOLEAN
} }
attrs {
name: "out_threshold"
type: FLOAT
}
attrs {
name: "Out0_threshold"
type: FLOAT
}
attrs { attrs {
name: "use_mkldnn" name: "use_mkldnn"
type: BOOLEAN type: BOOLEAN
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册