未验证 提交 d4372a1e 编写于 作者: Z Zuza Gawrysiak 提交者: GitHub

Quantize shape operator (#44124)

* Quantize shape operator

* Add shape op to propagate scales pass
上级 ab57cbf6
......@@ -1056,7 +1056,7 @@ struct ResidualElementwise : public PatternBase {
};
// General struct for immutable ops:
// reshape, transpose, slice, nearest-interp
// reshape, transpose, slice, shape, nearest-interp
// Forward pass for no weights-op.
// immutable_out is a result of the operator.
struct Immutable : public PatternBase {
......
......@@ -372,7 +372,7 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
const auto op_name = op_node->Name();
if (scale_immutable_ops.count(op_name)) {
std::string input_name;
if (op_name == "slice") {
if (op_name == "slice" || op_name == "shape") {
input_name = op_node->Op()->Input("Input")[0];
} else {
input_name = op_node->Op()->Input("X")[0];
......@@ -445,6 +445,7 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const {
"reshape2",
"pool2d",
"slice",
"shape",
"nearest_interp",
"nearest_interp_v2"};
......
......@@ -1136,6 +1136,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeImmutable(graph, "reshape2", "X");
QuantizeImmutable(graph, "transpose2", "X");
QuantizeImmutable(graph, "slice", "Input");
QuantizeImmutable(graph, "shape", "Input");
QuantizeImmutable(graph, "nearest_interp", "X");
QuantizeImmutable(graph, "nearest_interp_v2", "X");
QuantizeElementwise(graph, "elementwise_add");
......
......@@ -66,7 +66,7 @@ void SetOp(ProgramDesc* prog,
type == "nearest_interp" || type == "nearest_interp_v2") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
} else if (type == "slice") {
} else if (type == "slice" || type == "shape") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
} else if (type == "dropout") {
......@@ -550,8 +550,12 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) {
SCALE * S8_MAX);
}
const std::vector<std::string> immutables = {
"reshape2", "transpose2", "slice", "nearest_interp", "nearest_interp_v2"};
const std::vector<std::string> immutables = {"reshape2",
"transpose2",
"slice",
"shape",
"nearest_interp",
"nearest_interp_v2"};
class TestImmutables : public testing::TestWithParam<std::string> {};
......
......@@ -142,7 +142,7 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs(
scales_[var_name] = scales_[input_var_name];
}
compute_scale = false;
} else if (op->Type() == "slice") {
} else if (op->Type() == "slice" || op->Type() == "shape") {
auto input_var_name = op->Input("Input")[0];
PADDLE_ENFORCE_NE(scales_.find(input_var_name),
scales_.end(),
......
......@@ -45,6 +45,9 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["slice"]["Input"] = ScaleAlgo::KL;
rules_["slice"]["Out"] = ScaleAlgo::NONE;
rules_["shape"]["Input"] = ScaleAlgo::KL;
rules_["shape"]["Out"] = ScaleAlgo::NONE;
rules_["fc"]["Input"] = ScaleAlgo::KL;
rules_["fc"]["W"] = ScaleAlgo::MAX_CH_T;
rules_["fc"]["Bias"] = ScaleAlgo::NONE;
......@@ -62,6 +65,10 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["elementwise_mul"]["Y"] = ScaleAlgo::KL;
rules_["elementwise_mul"]["Out"] = ScaleAlgo::KL;
rules_["elementwise_sub"]["X"] = ScaleAlgo::KL;
rules_["elementwise_sub"]["Y"] = ScaleAlgo::KL;
rules_["elementwise_sub"]["Out"] = ScaleAlgo::KL;
// Reshape2 does not perform calculation on the data and shapes are not
// changed. Scale is calculated on input data and assign to Quantize and
// Dequantize scale.
......
......@@ -63,8 +63,8 @@ class Quant2Int8MkldnnPass(object):
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
[-1])
self._scale_immutable_ops = [
'transpose2', 'reshape2', 'pool2d', 'slice', 'nearest_interp',
'nearest_interp_v2'
'transpose2', 'reshape2', 'pool2d', 'slice', 'shape',
'nearest_interp', 'nearest_interp_v2'
]
self._scale_ops = ['scale']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
......@@ -247,7 +247,7 @@ class Quant2Int8MkldnnPass(object):
waiting_for_scale = set()
for op in graph.all_op_nodes():
if op.name() in self._scale_immutable_ops:
if op.name() == 'slice':
if op.name() == 'slice' or op.name() == 'shape':
input_name = op.input("Input")[0]
else:
input_name = op.input("X")[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册