未验证 提交 d4372a1e 编写于 作者: Z Zuza Gawrysiak 提交者: GitHub

Quantize shape operator (#44124)

* Quantize shape operator

* Add shape op to propagate scales pass
上级 ab57cbf6
...@@ -1056,7 +1056,7 @@ struct ResidualElementwise : public PatternBase { ...@@ -1056,7 +1056,7 @@ struct ResidualElementwise : public PatternBase {
}; };
// General struct for immutable ops: // General struct for immutable ops:
// reshape, transpose, slice, nearest-interp // reshape, transpose, slice, shape, nearest-interp
// Forward pass for no weights-op. // Forward pass for no weights-op.
// immutable_out is a result of the operator. // immutable_out is a result of the operator.
struct Immutable : public PatternBase { struct Immutable : public PatternBase {
......
...@@ -372,7 +372,7 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales( ...@@ -372,7 +372,7 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
const auto op_name = op_node->Name(); const auto op_name = op_node->Name();
if (scale_immutable_ops.count(op_name)) { if (scale_immutable_ops.count(op_name)) {
std::string input_name; std::string input_name;
if (op_name == "slice") { if (op_name == "slice" || op_name == "shape") {
input_name = op_node->Op()->Input("Input")[0]; input_name = op_node->Op()->Input("Input")[0];
} else { } else {
input_name = op_node->Op()->Input("X")[0]; input_name = op_node->Op()->Input("X")[0];
...@@ -445,6 +445,7 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const { ...@@ -445,6 +445,7 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const {
"reshape2", "reshape2",
"pool2d", "pool2d",
"slice", "slice",
"shape",
"nearest_interp", "nearest_interp",
"nearest_interp_v2"}; "nearest_interp_v2"};
......
...@@ -1136,6 +1136,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -1136,6 +1136,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeImmutable(graph, "reshape2", "X"); QuantizeImmutable(graph, "reshape2", "X");
QuantizeImmutable(graph, "transpose2", "X"); QuantizeImmutable(graph, "transpose2", "X");
QuantizeImmutable(graph, "slice", "Input"); QuantizeImmutable(graph, "slice", "Input");
QuantizeImmutable(graph, "shape", "Input");
QuantizeImmutable(graph, "nearest_interp", "X"); QuantizeImmutable(graph, "nearest_interp", "X");
QuantizeImmutable(graph, "nearest_interp_v2", "X"); QuantizeImmutable(graph, "nearest_interp_v2", "X");
QuantizeElementwise(graph, "elementwise_add"); QuantizeElementwise(graph, "elementwise_add");
......
...@@ -66,7 +66,7 @@ void SetOp(ProgramDesc* prog, ...@@ -66,7 +66,7 @@ void SetOp(ProgramDesc* prog,
type == "nearest_interp" || type == "nearest_interp_v2") { type == "nearest_interp" || type == "nearest_interp_v2") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
} else if (type == "slice") { } else if (type == "slice" || type == "shape") {
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
} else if (type == "dropout") { } else if (type == "dropout") {
...@@ -550,8 +550,12 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) { ...@@ -550,8 +550,12 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) {
SCALE * S8_MAX); SCALE * S8_MAX);
} }
const std::vector<std::string> immutables = { const std::vector<std::string> immutables = {"reshape2",
"reshape2", "transpose2", "slice", "nearest_interp", "nearest_interp_v2"}; "transpose2",
"slice",
"shape",
"nearest_interp",
"nearest_interp_v2"};
class TestImmutables : public testing::TestWithParam<std::string> {}; class TestImmutables : public testing::TestWithParam<std::string> {};
......
...@@ -142,7 +142,7 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs( ...@@ -142,7 +142,7 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs(
scales_[var_name] = scales_[input_var_name]; scales_[var_name] = scales_[input_var_name];
} }
compute_scale = false; compute_scale = false;
} else if (op->Type() == "slice") { } else if (op->Type() == "slice" || op->Type() == "shape") {
auto input_var_name = op->Input("Input")[0]; auto input_var_name = op->Input("Input")[0];
PADDLE_ENFORCE_NE(scales_.find(input_var_name), PADDLE_ENFORCE_NE(scales_.find(input_var_name),
scales_.end(), scales_.end(),
......
...@@ -45,6 +45,9 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { ...@@ -45,6 +45,9 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["slice"]["Input"] = ScaleAlgo::KL; rules_["slice"]["Input"] = ScaleAlgo::KL;
rules_["slice"]["Out"] = ScaleAlgo::NONE; rules_["slice"]["Out"] = ScaleAlgo::NONE;
rules_["shape"]["Input"] = ScaleAlgo::KL;
rules_["shape"]["Out"] = ScaleAlgo::NONE;
rules_["fc"]["Input"] = ScaleAlgo::KL; rules_["fc"]["Input"] = ScaleAlgo::KL;
rules_["fc"]["W"] = ScaleAlgo::MAX_CH_T; rules_["fc"]["W"] = ScaleAlgo::MAX_CH_T;
rules_["fc"]["Bias"] = ScaleAlgo::NONE; rules_["fc"]["Bias"] = ScaleAlgo::NONE;
...@@ -62,6 +65,10 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { ...@@ -62,6 +65,10 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
rules_["elementwise_mul"]["Y"] = ScaleAlgo::KL; rules_["elementwise_mul"]["Y"] = ScaleAlgo::KL;
rules_["elementwise_mul"]["Out"] = ScaleAlgo::KL; rules_["elementwise_mul"]["Out"] = ScaleAlgo::KL;
rules_["elementwise_sub"]["X"] = ScaleAlgo::KL;
rules_["elementwise_sub"]["Y"] = ScaleAlgo::KL;
rules_["elementwise_sub"]["Out"] = ScaleAlgo::KL;
// Reshape2 does not perform calculation on the data and shapes are not // Reshape2 does not perform calculation on the data and shapes are not
// changed. Scale is calculated on input data and assign to Quantize and // changed. Scale is calculated on input data and assign to Quantize and
// Dequantize scale. // Dequantize scale.
......
...@@ -63,8 +63,8 @@ class Quant2Int8MkldnnPass(object): ...@@ -63,8 +63,8 @@ class Quant2Int8MkldnnPass(object):
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set( self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
[-1]) [-1])
self._scale_immutable_ops = [ self._scale_immutable_ops = [
'transpose2', 'reshape2', 'pool2d', 'slice', 'nearest_interp', 'transpose2', 'reshape2', 'pool2d', 'slice', 'shape',
'nearest_interp_v2' 'nearest_interp', 'nearest_interp_v2'
] ]
self._scale_ops = ['scale'] self._scale_ops = ['scale']
self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._conv_ops = ['conv2d', 'depthwise_conv2d']
...@@ -247,7 +247,7 @@ class Quant2Int8MkldnnPass(object): ...@@ -247,7 +247,7 @@ class Quant2Int8MkldnnPass(object):
waiting_for_scale = set() waiting_for_scale = set()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._scale_immutable_ops: if op.name() in self._scale_immutable_ops:
if op.name() == 'slice': if op.name() == 'slice' or op.name() == 'shape':
input_name = op.input("Input")[0] input_name = op.input("Input")[0]
else: else:
input_name = op.input("X")[0] input_name = op.input("X")[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册