未验证 提交 930ca3f4 编写于 作者: W Wangzheee 提交者: GitHub

pass enhance (#33661)

上级 39556a44
......@@ -140,6 +140,91 @@ void recompute_bias_and_weights(const Scope* scope,
}
}
ConvBNFusePass::ConvBNFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.End()
.AddInput("ResidualData")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("batch_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddInput("Mean")
.IsTensor()
.End()
.AddInput("Variance")
.IsTensor()
.End()
.AddOutput("MeanOut")
.IsTensor()
.End()
.AddOutput("VarianceOut")
.IsTensor()
.End()
.AddOutput("SavedMean")
.IsTensor()
.End()
.AddOutput("SavedVariance")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumLE(0.001f)
.IsNumGE(0.0f)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
}
void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
......@@ -161,8 +246,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle " + conv_type() + "BN fuse";
// conv, batch_norm,
// conv_weight, conv_out,
// bn_scale, bn_bias, bn_mean, bn_variance,
......@@ -236,6 +324,10 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
}
conv->Op()->SetOutput("Output",
std::vector<std::string>({bn_out->Name()}));
if (!IsCompat(*conv->Op())) {
LOG(WARNING) << "conv_bn fuse pass in out conv op compat failed.";
return;
}
GraphSafeRemoveNodes(
graph,
{conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
......@@ -251,6 +343,11 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
desc.SetType("elementwise_add");
desc.SetAttr("axis", 1);
if (!IsCompat(desc)) {
LOG(WARNING)
<< "conv_bn fuse pass in out elementwise_add op compat failed.";
return;
}
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance,
......@@ -269,6 +366,91 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_conv_bn_count);
}
ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.End()
.AddInput("ResidualData")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("batch_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddInput("Mean")
.IsTensor()
.End()
.AddInput("Variance")
.IsTensor()
.End()
.AddOutput("MeanOut")
.IsTensor()
.End()
.AddOutput("VarianceOut")
.IsTensor()
.End()
.AddOutput("SavedMean")
.IsTensor()
.End()
.AddOutput("SavedVariance")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumLE(0.001f)
.IsNumGE(0.0f)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
}
void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
......@@ -290,8 +472,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
int found_conv_bn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle " + conv_type() + "BN fuse";
// conv, batch_norm,
// conv_weight, conv_out,
// bn_scale, bn_bias, bn_mean, bn_variance,
......@@ -361,7 +546,11 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
// Update the elementwise_add node
eltwise->Op()->SetAttr("axis", 1);
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
if (!IsCompat(*eltwise->Op())) {
LOG(WARNING)
<< "conv_eltwise_bn fuse pass in out eltwise op compat failed.";
return;
}
GraphSafeRemoveNodes(
graph,
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
......@@ -377,6 +566,70 @@ void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_conv_bn_count);
}
ConvTransposeBNFusePass::ConvTransposeBNFusePass() {
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() {
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -31,6 +31,7 @@ class Graph;
class ConvBNFusePass : public FusePassBase {
public:
ConvBNFusePass();
virtual ~ConvBNFusePass() {}
virtual std::string conv_type() const { return "conv2d"; }
......@@ -41,6 +42,7 @@ class ConvBNFusePass : public FusePassBase {
class ConvEltwiseAddBNFusePass : public FusePassBase {
public:
ConvEltwiseAddBNFusePass();
virtual ~ConvEltwiseAddBNFusePass() {}
virtual std::string conv_type() const { return "conv2d"; }
......@@ -51,11 +53,15 @@ class ConvEltwiseAddBNFusePass : public FusePassBase {
class ConvTransposeBNFusePass : public ConvBNFusePass {
public:
ConvTransposeBNFusePass();
virtual ~ConvTransposeBNFusePass() {}
std::string conv_type() const { return "conv2d_transpose"; }
};
class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
public:
ConvTransposeEltwiseAddBNFusePass();
virtual ~ConvTransposeEltwiseAddBNFusePass() {}
std::string conv_type() const { return "conv2d_transpose"; }
};
......
......@@ -39,28 +39,49 @@ struct Layers {
}
VarDesc* conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias,
bool use_cudnn = false) {
int groups = 1, std::vector<int> strides = {1, 1},
std::vector<int> paddings = {0, 0},
std::string padding_algorithm = "EXPLICIT",
std::vector<int> dilations = {1, 1},
std::string data_format = "NCHW", bool use_cudnn = false) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("conv2d");
op->SetInput("Input", {input->Name()});
op->SetInput("Filter", {filter->Name()});
op->SetInput("Bias", {bias->Name()});
op->SetOutput("Out", {out->Name()});
op->SetOutput("Output", {out->Name()});
op->SetAttr("use_cudnn", use_cudnn);
op->SetAttr("groups", groups);
op->SetAttr("strides", strides);
op->SetAttr("paddings", paddings);
op->SetAttr("padding_algorithm", padding_algorithm);
op->SetAttr("dilations", dilations);
op->SetAttr("data_format", data_format);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
}
VarDesc* conv2d_transpose(VarDesc* input, VarDesc* filter, VarDesc* bias) {
VarDesc* conv2d_transpose(VarDesc* input, VarDesc* filter, VarDesc* bias,
int groups = 1, std::vector<int> strides = {1, 1},
std::vector<int> paddings = {0, 0},
std::string padding_algorithm = "EXPLICIT",
std::vector<int> dilations = {1, 1},
std::string data_format = "NCHW") {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("conv2d_transpose");
op->SetInput("Input", {input->Name()});
op->SetInput("Filter", {filter->Name()});
op->SetInput("Bias", {bias->Name()});
op->SetOutput("Out", {out->Name()});
op->SetOutput("Output", {out->Name()});
op->SetAttr("groups", groups);
op->SetAttr("strides", strides);
op->SetAttr("paddings", paddings);
op->SetAttr("padding_algorithm", padding_algorithm);
op->SetAttr("dilations", dilations);
op->SetAttr("data_format", data_format);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return out;
......
......@@ -46,6 +46,10 @@ extra {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "is_test"
type: BOOLEAN
......
......@@ -41,6 +41,14 @@ def {
}
}
extra {
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "skip_quant"
type: BOOLEAN
}
attrs {
name: "is_test"
type: BOOLEAN
......
......@@ -12,6 +12,14 @@ extra {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "out_threshold"
type: FLOAT
}
attrs {
name: "Out0_threshold"
type: FLOAT
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部