未验证 提交 2ffb3371 编写于 作者: Z Zuza Gawrysiak 提交者: GitHub

Quantize elementwise sub (#42854)

* Add elementwise_sub quantization

* Remove unnecessary comments

* Specify names for tests

* Remove comments

* Remove comments leftovers
上级 7b6bf281
......@@ -1057,7 +1057,7 @@ struct Pool : public PatternBase {
};
// Elementwise ops
// Forward pass for element-wise operators (add, mul)
// Forward pass for element-wise operators
// elementwise_out is the result of the operator
struct Elementwise : public PatternBase {
Elementwise(PDPattern* pattern, const std::string& name_scope)
......
......@@ -1188,6 +1188,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeMatmul(graph);
QuantizeElementwise(graph, "elementwise_add");
QuantizeElementwise(graph, "elementwise_mul");
QuantizeElementwise(graph, "elementwise_sub");
QuantizeFusionGru(graph);
QuantizeMultiGru(graph);
QuantizeFusionLSTM(graph);
......
......@@ -90,7 +90,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("Scale_x", 1.0f);
op->SetAttr("Scale_y", 1.0f);
op->SetAttr("Scale_out", 1.0f);
} else if (type == "elementwise_add" || type == "elementwise_mul") {
} else if (type == "elementwise_add" || type == "elementwise_mul" ||
type == "elementwise_sub") {
op->SetInput("X", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]});
......@@ -168,7 +169,7 @@ void CheckScales(const OpDesc* op, float scale, float shift) {
scale_names.push_back("Scale_in");
scale_names.push_back("Scale_out");
} else if (type == "matmul" || type == "elementwise_add" ||
type == "elementwise_mul") {
type == "elementwise_mul" || type == "elementwise_sub") {
scale_names.push_back("Scale_x");
scale_names.push_back("Scale_y");
scale_names.push_back("Scale_out");
......@@ -565,60 +566,59 @@ ProgramDesc BuildProgramDescElementwise(const std::string elementwise_type,
return prog;
}
void TestElementwise(const std::string elementwise_type,
const std::string elementwise_name) {
void TestElementwise(std::vector<std::string> elementwise) {
// 2 Quant + 2 IN + 1 DeQuant + 1 OUT
int added_nodes = 6;
std::unordered_map<std::string, int> expected_operators = {
{elementwise_type, 1}, {"quantize", 2}, {"dequantize", 3}};
MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name),
{elementwise[0], 1}, {"quantize", 2}, {"dequantize", 3}};
MainTest(BuildProgramDescElementwise(elementwise[0], elementwise[1]),
variable_names_elementwise, expected_operators, added_nodes,
SCALE * S8_MAX);
}
void TestElementwiseOutputScaleMissing(const std::string elementwise_type,
const std::string elementwise_name) {
void TestElementwiseOutputScaleMissing(std::vector<std::string> elementwise) {
int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = {
{elementwise_type, 1}, {"quantize", 0}, {"dequantize", 2}};
MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name),
{elementwise[0], 1}, {"quantize", 0}, {"dequantize", 2}};
MainTest(BuildProgramDescElementwise(elementwise[0], elementwise[1]),
variable_names_elementwise, expected_operators, added_nodes, 1.f,
1.f, "e");
}
void TestElementwiseUnsignedAndSignedInput(const std::string elementwise_type,
const std::string elementwise_name) {
void TestElementwiseUnsignedAndSignedInput(
std::vector<std::string> elementwise) {
int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = {
{elementwise_type, 1}, {"quantize", 0}, {"dequantize", 2}};
MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name),
{elementwise[0], 1}, {"quantize", 0}, {"dequantize", 2}};
MainTest(BuildProgramDescElementwise(elementwise[0], elementwise[1]),
variable_names_elementwise, expected_operators, added_nodes, 1.f,
1.f, "", "b");
}
TEST(CpuQuantizePass, elementwise_add) {
TestElementwise("elementwise_add", "ElementwiseAdd");
}
const std::vector<std::vector<std::string>> elementwises = {
{"elementwise_add", "ElementwiseAdd"},
{"elementwise_mul", "ElementwiseMul"},
{"elementwise_sub", "ElementwiseSub"}};
TEST(CpuQuantizePass, elementwise_add_output_scale_missing) {
TestElementwiseOutputScaleMissing("elementwise_add", "ElementwiseAdd");
}
class TestElementwises
: public testing::TestWithParam<std::vector<std::string>> {};
TEST(CpuQuantizePass, elementwise_add_unsigned_and_signed_input) {
TestElementwiseUnsignedAndSignedInput("elementwise_add", "ElementwiseAdd");
}
TEST_P(TestElementwises, elementwise_basic) { TestElementwise(GetParam()); }
TEST(CpuQuantizePass, elementwise_mul) {
TestElementwise("elementwise_mul", "ElementwiseMul");
TEST_P(TestElementwises, elementwise_output_scale_missing) {
TestElementwiseOutputScaleMissing(GetParam());
}
TEST(CpuQuantizePass, elementwise_mul_output_scale_missing) {
TestElementwiseOutputScaleMissing("elementwise_mul", "ElementwiseMul");
TEST_P(TestElementwises, elementwise_unsigned_and_signed_input) {
TestElementwiseUnsignedAndSignedInput(GetParam());
}
TEST(CpuQuantizePass, elementwise_mul_unsigned_and_signed_input) {
TestElementwiseUnsignedAndSignedInput("elementwise_mul", "ElementwiseMul");
}
INSTANTIATE_TEST_CASE_P(
Elementwises, TestElementwises, testing::ValuesIn(elementwises),
[](const ::testing::TestParamInfo<TestElementwises::ParamType>& info) {
std::string name = info.param[0];
return name;
});
const std::vector<std::string> churn_out_vars(ProgramDesc* prog,
const std::string& prefix,
......
......@@ -27,9 +27,10 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>(
{"concat", "conv2d", "depthwise_conv2d", "elementwise_add",
"elementwise_mul", "fc", "matmul", "nearest_interp",
"nearest_interp_v2", "pool2d", "prior_box", "reshape2", "transpose2",
"fusion_gru", "fusion_lstm", "multi_gru", "slice"});
"elementwise_mul", "elementwise_sub", "fc", "matmul",
"nearest_interp", "nearest_interp_v2", "pool2d", "prior_box",
"reshape2", "transpose2", "fusion_gru", "fusion_lstm", "multi_gru",
"slice"});
const auto& excluded_ids_list =
Get<std::unordered_set<int>>("quantize_excluded_op_ids");
const auto& op_types_list =
......
......@@ -14,10 +14,6 @@
#pragma once
// #include <memory>
// #include <string>
// #include <unordered_map>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册