未验证 提交 2ffb3371 编写于 作者: Z Zuza Gawrysiak 提交者: GitHub

Quantize elementwise sub (#42854)

* Add elementwise_sub quantization

* Remove unnecessary comments

* Specify names for tests

* Remove comments

* Remove comments leftovers
上级 7b6bf281
...@@ -1057,7 +1057,7 @@ struct Pool : public PatternBase { ...@@ -1057,7 +1057,7 @@ struct Pool : public PatternBase {
}; };
// Elementwise ops // Elementwise ops
// Forward pass for element-wise operators (add, mul) // Forward pass for element-wise operators
// elementwise_out is the result of the operator // elementwise_out is the result of the operator
struct Elementwise : public PatternBase { struct Elementwise : public PatternBase {
Elementwise(PDPattern* pattern, const std::string& name_scope) Elementwise(PDPattern* pattern, const std::string& name_scope)
......
...@@ -1188,6 +1188,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -1188,6 +1188,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeMatmul(graph); QuantizeMatmul(graph);
QuantizeElementwise(graph, "elementwise_add"); QuantizeElementwise(graph, "elementwise_add");
QuantizeElementwise(graph, "elementwise_mul"); QuantizeElementwise(graph, "elementwise_mul");
QuantizeElementwise(graph, "elementwise_sub");
QuantizeFusionGru(graph); QuantizeFusionGru(graph);
QuantizeMultiGru(graph); QuantizeMultiGru(graph);
QuantizeFusionLSTM(graph); QuantizeFusionLSTM(graph);
......
...@@ -90,7 +90,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -90,7 +90,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("Scale_x", 1.0f); op->SetAttr("Scale_x", 1.0f);
op->SetAttr("Scale_y", 1.0f); op->SetAttr("Scale_y", 1.0f);
op->SetAttr("Scale_out", 1.0f); op->SetAttr("Scale_out", 1.0f);
} else if (type == "elementwise_add" || type == "elementwise_mul") { } else if (type == "elementwise_add" || type == "elementwise_mul" ||
type == "elementwise_sub") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]}); if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
...@@ -168,7 +169,7 @@ void CheckScales(const OpDesc* op, float scale, float shift) { ...@@ -168,7 +169,7 @@ void CheckScales(const OpDesc* op, float scale, float shift) {
scale_names.push_back("Scale_in"); scale_names.push_back("Scale_in");
scale_names.push_back("Scale_out"); scale_names.push_back("Scale_out");
} else if (type == "matmul" || type == "elementwise_add" || } else if (type == "matmul" || type == "elementwise_add" ||
type == "elementwise_mul") { type == "elementwise_mul" || type == "elementwise_sub") {
scale_names.push_back("Scale_x"); scale_names.push_back("Scale_x");
scale_names.push_back("Scale_y"); scale_names.push_back("Scale_y");
scale_names.push_back("Scale_out"); scale_names.push_back("Scale_out");
...@@ -565,60 +566,59 @@ ProgramDesc BuildProgramDescElementwise(const std::string elementwise_type, ...@@ -565,60 +566,59 @@ ProgramDesc BuildProgramDescElementwise(const std::string elementwise_type,
return prog; return prog;
} }
void TestElementwise(const std::string elementwise_type, void TestElementwise(std::vector<std::string> elementwise) {
const std::string elementwise_name) {
// 2 Quant + 2 IN + 1 DeQuant + 1 OUT // 2 Quant + 2 IN + 1 DeQuant + 1 OUT
int added_nodes = 6; int added_nodes = 6;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{elementwise_type, 1}, {"quantize", 2}, {"dequantize", 3}}; {elementwise[0], 1}, {"quantize", 2}, {"dequantize", 3}};
MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name), MainTest(BuildProgramDescElementwise(elementwise[0], elementwise[1]),
variable_names_elementwise, expected_operators, added_nodes, variable_names_elementwise, expected_operators, added_nodes,
SCALE * S8_MAX); SCALE * S8_MAX);
} }
void TestElementwiseOutputScaleMissing(const std::string elementwise_type, void TestElementwiseOutputScaleMissing(std::vector<std::string> elementwise) {
const std::string elementwise_name) {
int added_nodes = 0; int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{elementwise_type, 1}, {"quantize", 0}, {"dequantize", 2}}; {elementwise[0], 1}, {"quantize", 0}, {"dequantize", 2}};
MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name), MainTest(BuildProgramDescElementwise(elementwise[0], elementwise[1]),
variable_names_elementwise, expected_operators, added_nodes, 1.f, variable_names_elementwise, expected_operators, added_nodes, 1.f,
1.f, "e"); 1.f, "e");
} }
void TestElementwiseUnsignedAndSignedInput(const std::string elementwise_type, void TestElementwiseUnsignedAndSignedInput(
const std::string elementwise_name) { std::vector<std::string> elementwise) {
int added_nodes = 0; int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{elementwise_type, 1}, {"quantize", 0}, {"dequantize", 2}}; {elementwise[0], 1}, {"quantize", 0}, {"dequantize", 2}};
MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name), MainTest(BuildProgramDescElementwise(elementwise[0], elementwise[1]),
variable_names_elementwise, expected_operators, added_nodes, 1.f, variable_names_elementwise, expected_operators, added_nodes, 1.f,
1.f, "", "b"); 1.f, "", "b");
} }
TEST(CpuQuantizePass, elementwise_add) { const std::vector<std::vector<std::string>> elementwises = {
TestElementwise("elementwise_add", "ElementwiseAdd"); {"elementwise_add", "ElementwiseAdd"},
} {"elementwise_mul", "ElementwiseMul"},
{"elementwise_sub", "ElementwiseSub"}};
TEST(CpuQuantizePass, elementwise_add_output_scale_missing) { class TestElementwises
TestElementwiseOutputScaleMissing("elementwise_add", "ElementwiseAdd"); : public testing::TestWithParam<std::vector<std::string>> {};
}
TEST(CpuQuantizePass, elementwise_add_unsigned_and_signed_input) { TEST_P(TestElementwises, elementwise_basic) { TestElementwise(GetParam()); }
TestElementwiseUnsignedAndSignedInput("elementwise_add", "ElementwiseAdd");
}
TEST(CpuQuantizePass, elementwise_mul) { TEST_P(TestElementwises, elementwise_output_scale_missing) {
TestElementwise("elementwise_mul", "ElementwiseMul"); TestElementwiseOutputScaleMissing(GetParam());
} }
TEST(CpuQuantizePass, elementwise_mul_output_scale_missing) { TEST_P(TestElementwises, elementwise_unsigned_and_signed_input) {
TestElementwiseOutputScaleMissing("elementwise_mul", "ElementwiseMul"); TestElementwiseUnsignedAndSignedInput(GetParam());
} }
TEST(CpuQuantizePass, elementwise_mul_unsigned_and_signed_input) { INSTANTIATE_TEST_CASE_P(
TestElementwiseUnsignedAndSignedInput("elementwise_mul", "ElementwiseMul"); Elementwises, TestElementwises, testing::ValuesIn(elementwises),
} [](const ::testing::TestParamInfo<TestElementwises::ParamType>& info) {
std::string name = info.param[0];
return name;
});
const std::vector<std::string> churn_out_vars(ProgramDesc* prog, const std::vector<std::string> churn_out_vars(ProgramDesc* prog,
const std::string& prefix, const std::string& prefix,
......
...@@ -27,9 +27,10 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { ...@@ -27,9 +27,10 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> supported_op_types = std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>( std::unordered_set<std::string>(
{"concat", "conv2d", "depthwise_conv2d", "elementwise_add", {"concat", "conv2d", "depthwise_conv2d", "elementwise_add",
"elementwise_mul", "fc", "matmul", "nearest_interp", "elementwise_mul", "elementwise_sub", "fc", "matmul",
"nearest_interp_v2", "pool2d", "prior_box", "reshape2", "transpose2", "nearest_interp", "nearest_interp_v2", "pool2d", "prior_box",
"fusion_gru", "fusion_lstm", "multi_gru", "slice"}); "reshape2", "transpose2", "fusion_gru", "fusion_lstm", "multi_gru",
"slice"});
const auto& excluded_ids_list = const auto& excluded_ids_list =
Get<std::unordered_set<int>>("quantize_excluded_op_ids"); Get<std::unordered_set<int>>("quantize_excluded_op_ids");
const auto& op_types_list = const auto& op_types_list =
......
...@@ -14,10 +14,6 @@ ...@@ -14,10 +14,6 @@
#pragma once #pragma once
// #include <memory>
// #include <string>
// #include <unordered_map>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle { namespace paddle {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册