未验证 提交 2800897a 编写于 作者: 王明冬 提交者: GitHub

add compat precondition for cpu_quantize_squash_pass, test=develop (#33611)

上级 6cacd63e
......@@ -25,10 +25,60 @@ namespace paddle {
namespace framework {
namespace ir {
class Graph;
using string::PrettyLogDetail;
CPUQuantizeSquashPass::CPUQuantizeSquashPass() {
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("bias")
.IsNumEQ(0.0f)
.End()
.AddAttr("scale")
.IsNumGT(0.0f)
.End()
.AddAttr("bias_after_scale") // bias equal to 0.0, so this attribute is
// unconstrained.
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"})
.End();
}
void CPUQuantizeSquashPass::FindNodesToKeep(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
......@@ -354,6 +404,10 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
int found_dequant_scale_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "squash dequant-scale ops pair";
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, dequant_scale_pattern);
......@@ -362,9 +416,10 @@ void CPUQuantizeSquashPass::DequantScaleSquash(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, dequant_scale_pattern);
if (dequant_out->outputs.size() == 1 &&
scale_op->Op()->GetAttrIfExists<float>("bias") == 0.0) {
BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias")) == 0.0f) {
auto dequant_scale = dequant_op->Op()->GetAttrIfExists<float>("Scale");
auto scale_scale = scale_op->Op()->GetAttrIfExists<float>("scale");
float scale_scale =
BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale"));
PADDLE_ENFORCE_GT(dequant_scale, 0.0f,
platform::errors::InvalidArgument(
......@@ -399,6 +454,10 @@ void CPUQuantizeSquashPass::ScaleQuantSquash(Graph* graph) const {
int found_scale_quant_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "squash scale-quant ops pair";
GET_IR_NODE_FROM_SUBGRAPH(scale_in, scale_in, scale_quant_pattern);
......@@ -407,9 +466,10 @@ void CPUQuantizeSquashPass::ScaleQuantSquash(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, scale_quant_pattern);
if (quant_in->outputs.size() == 1 &&
scale_op->Op()->GetAttrIfExists<float>("bias") == 0.0) {
BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias")) == 0.0f) {
auto quant_scale = quant_op->Op()->GetAttrIfExists<float>("Scale");
auto scale_scale = scale_op->Op()->GetAttrIfExists<float>("scale");
float scale_scale =
BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale"));
PADDLE_ENFORCE_GT(
quant_scale, 0.0f,
......@@ -443,6 +503,11 @@ void CPUQuantizeSquashPass::QuantizeBf16Conv(Graph* graph) const {
int found_quant_conv_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "squash quant-conv2d ops pair";
GET_IR_NODE_FROM_SUBGRAPH(quant_in, quant_in, pattern);
......
......@@ -19,9 +19,6 @@
#include <unordered_map>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
......@@ -30,10 +27,10 @@ namespace ir {
/*
* Squash dequantize->quantize pair pattern into requantize op
*/
class Graph;
class CPUQuantizeSquashPass : public FusePassBase {
public:
CPUQuantizeSquashPass();
virtual ~CPUQuantizeSquashPass() {}
protected:
......
......@@ -25,7 +25,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn,
const std::vector<float> scale = {}, float bias = 0.0,
const std::string& mkldnn_data_type = "float32") {
const std::string& mkldnn_data_type = "float32",
bool bias_after_scale = false, int groups = 1) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn);
......@@ -37,6 +38,15 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
op->SetOutput("Output", {outputs[0]});
const std::vector<int> strides({1, 1});
const std::vector<int> paddings({1, 1});
const std::vector<int> dilations({1, 1});
op->SetAttr("strides", strides);
op->SetAttr("paddings", paddings);
op->SetAttr("dilations", dilations);
op->SetAttr("groups", groups);
op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
op->SetAttr("data_format", std::string("NCHW"));
op->SetAttr("force_fp32_output", false);
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
} else if (type == "quantize") {
......@@ -74,6 +84,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetOutput("Out", {outputs[0]});
op->SetAttr("scale", scale[0]);
op->SetAttr("bias", bias);
op->SetAttr("bias_after_scale", bias_after_scale);
} else if (type == "matmul") {
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
......@@ -373,8 +384,8 @@ ProgramDesc BuildQuantConv2dProgramDesc(const bool& use_mkldnn,
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "quantize", "Quant", {"a"}, {"b"}, use_mkldnn, {quant_scale});
SetOp(&prog, "conv2d", "Conv2d", {"b"}, {"c"}, use_mkldnn, {}, 0.0f,
mkldnn_data_type);
SetOp(&prog, "conv2d", "Conv2d", {"b", "filter", "bias"}, {"c"}, use_mkldnn,
{}, 0.0f, mkldnn_data_type);
return prog;
}
......
......@@ -9,6 +9,9 @@ def {
inputs {
name: "Bias"
}
inputs {
name: "ResidualData"
}
outputs {
name: "Output"
}
......@@ -38,13 +41,14 @@ def {
}
}
extra {
inputs {
name: "ResidualData"
}
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "name"
type: STRING
}
attrs {
name: "use_cudnn"
type: BOOLEAN
......
......@@ -20,6 +20,14 @@ def {
}
}
extra {
attrs {
name: "name"
type: STRING
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册