未验证 提交 561fd8c8 编写于 作者: Y yeliang2258 提交者: GitHub

Fix quantize model deploy bugs when using MKLDNN (#45920)

* fix immutable op quantize bugs

* fix

* fix build bug

* fix test

* notest,test=inference

* fix ppyoloe acc drop bugs

* fix test

* fix test

* add test

* fix

* fix

* fix test

* fix refined name bug

* fix test

* bias fix

* fix matmul weight dequant bug

* re-ci

* fix tester

* fix test

* fix tester

* update weight dequantize func

* update code

* update test for converage

* update test

* update cmake

* update cmakelist

* update code

* rerun ci

* remove useless code
上级 910e1b6a
......@@ -67,7 +67,7 @@ std::vector<float> ComputePropagateScalesMkldnnPass::GetScales(
for (int i = 0; i < columns; i++) {
float max_value = FLT_MIN;
for (int j = 0; j < rows; j++) {
max_value = std::max(max_value, std::abs(data[i + j * columns]));
max_value = std::max(max_value, std::abs(data[j + i * rows]));
}
max_value = 1.0 / max_value;
if (std::isinf(max_value) || std::isnan(max_value)) {
......
......@@ -411,7 +411,16 @@ void CPUQuantizePass::QuantizeConv(Graph* graph,
auto filter_scale_tensor = GetScaleTensorForNode(conv_filter);
EigenVectorArrayMap eigen_tensor{filter_scale_tensor.data<double>(),
filter_scale_tensor.numel()};
eigen_tensor *= static_cast<double>(S8_MAX);
// If the scale value of a weight is already multiplied by S8_MAX, it does
// not need to be multiplied again
if (std::find(change_weight_->begin(),
change_weight_->end(),
conv_filter->Name()) == change_weight_->end()) {
eigen_tensor *= static_cast<double>(S8_MAX);
change_weight_->push_back(conv_filter->Name());
}
std::vector<float> filter_scale{
filter_scale_tensor.data<double>(),
filter_scale_tensor.data<double>() + filter_scale_tensor.numel()};
......@@ -693,6 +702,13 @@ void CPUQuantizePass::QuantizeImmutable(Graph* graph,
return;
}
// skip if the dtype of immutable_in is not float32
auto dtype = immutable_in->Var()->GetDataType();
if (dtype != proto::VarType::FP32) {
MarkAndLogCannotQuantizeOp(immutable_op, "The input dtype is not float.");
return;
}
if (!AreScalesPresentForNodes({immutable_out})) {
MarkAndLogCannotQuantizeOp(immutable_op,
"No scale available for the operator");
......@@ -1164,7 +1180,6 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeImmutable(graph, "reshape2", "X");
QuantizeImmutable(graph, "transpose2", "X");
QuantizeImmutable(graph, "slice", "Input");
QuantizeImmutable(graph, "shape", "Input");
QuantizeImmutable(graph, "nearest_interp", "X");
QuantizeImmutable(graph, "nearest_interp_v2", "X");
QuantizeElementwise(graph, "elementwise_add");
......
......@@ -110,6 +110,11 @@ class CPUQuantizePass : public FusePassBase {
VarQuantScale string_pair_map = {};
VarQuantScale* const var_quant_scales_ = &string_pair_map;
// Save the scale values of which weights have been processed to avoid
// secondary processing
std::vector<std::string> change_weight = {};
std::vector<std::string>* const change_weight_ = &change_weight;
void GetQuantInfo(Graph* graph) const;
};
......
......@@ -66,7 +66,7 @@ void SetOp(ProgramDesc* prog,
type == "nearest_interp" || type == "nearest_interp_v2") {
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
} else if (type == "slice" || type == "shape") {
} else if (type == "slice") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
} else if (type == "dropout") {
......@@ -467,7 +467,7 @@ static const std::initializer_list<std::string> variable_names_immutable_ops = {
void TestImmutableOp(const std::string tested_op) {
ProgramDesc prog;
for (auto& v : variable_names_immutable_ops) {
prog.MutableBlock(0)->Var(v);
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
}
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
SetOp(&prog, tested_op, tested_op, {"b"}, {"c"}, true, "int8");
......@@ -520,7 +520,7 @@ void TestImmutableOpBetweenNonQuantizedOp(const std::string tested_op) {
void TestImmutableOpWithManyOutputs(const std::string tested_op) {
ProgramDesc prog;
for (auto& v : variable_names_immutable_ops) {
prog.MutableBlock(0)->Var(v);
prog.MutableBlock(0)->Var(v)->SetDataType(proto::VarType::FP32);
}
SetOp(&prog, "dropout", "Dropout1", {"a"}, {"b"}, true, "float32");
......@@ -556,12 +556,8 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) {
SCALE * S8_MAX);
}
const std::vector<std::string> immutables = {"reshape2",
"transpose2",
"slice",
"shape",
"nearest_interp",
"nearest_interp_v2"};
const std::vector<std::string> immutables = {
"reshape2", "transpose2", "slice", "nearest_interp", "nearest_interp_v2"};
class TestImmutables : public testing::TestWithParam<std::string> {};
......
......@@ -52,36 +52,25 @@ bool HasBias(ir::Node* conv_op) {
conv_op->Op()->Input("Bias").size() > 0;
}
bool ShouldSkipConv(ir::Node* conv_op, Scope* scope, ir::Node* conv_filter) {
if (!platform::HasOpINT8DataType(conv_op->Op())) {
VLOG(4) << "Skipping non-int8 convolution (id: " << conv_op->id() << ").";
return true;
}
auto filter_var = scope->GetVar(conv_filter->Name());
if (filter_var->Get<LoDTensor>().dtype() != phi::DataType::FLOAT32) {
VLOG(4) << "Skipping convolution (id: " << conv_op->id()
<< ") because it's a bug that it is detected again.";
return true;
}
VLOG(4) << "Not skipping convolution (id: " << conv_op->id() << ")";
return false;
}
template <typename T>
void QuantizeConvInput(Scope* scope,
ir::Graph* g,
ir::Node* conv_op,
const std::string& input_name,
const std::string& scales_attr_name) {
const auto scales =
conv_op->Op()->GetAttrIfExists<std::vector<float>>(scales_attr_name);
auto* tensor = scope->GetVar(input_name)->GetMutable<LoDTensor>();
QuantizeParams<T>(tensor, scales);
conv_op->Op()->SetAttr(scales_attr_name, std::vector<float>(1, 1));
auto var = scope->GetVar(input_name);
if (var->Get<LoDTensor>().dtype() != phi::DataType::FLOAT32) {
VLOG(0) << "Skipping convolution filter: " << input_name
<< " because it is detected again.";
conv_op->Op()->SetAttr(scales_attr_name, std::vector<float>(1, 1));
} else {
const auto scales =
conv_op->Op()->GetAttrIfExists<std::vector<float>>(scales_attr_name);
auto* tensor = scope->GetVar(input_name)->GetMutable<LoDTensor>();
QuantizeParams<T>(tensor, scales);
conv_op->Op()->SetAttr(scales_attr_name, std::vector<float>(1, 1));
}
}
} // namespace
......@@ -151,7 +140,8 @@ void ParamsQuantizationMkldnnPass::QuantizeConv(ir::Graph* graph,
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
if (ShouldSkipConv(conv_op, scope, conv_filter)) {
// If not a quantized OP
if (!platform::HasOpINT8DataType(conv_op->Op())) {
return;
}
......
......@@ -89,8 +89,14 @@ struct ProgramStrategy {
virtual void CheckOp(const OpDesc& op) const = 0;
VarDesc* AddInput(OpDesc* op, std::string input_name, const Data& data) {
const std::string var_name = input_name + "_var";
VarDesc* AddInput(OpDesc* op,
std::string input_name,
const Data& data,
const std::string user_var_name = "") {
std::string var_name = user_var_name;
if (var_name.empty()) {
var_name = input_name + "_var";
}
op->SetInput(input_name, {var_name});
auto var = program.MutableBlock(0)->Var(var_name);
var->SetShape(data.getShape());
......@@ -98,8 +104,14 @@ struct ProgramStrategy {
return var;
}
void AddOutput(OpDesc* op, std::string output_name, const Data& data) {
const std::string var_name = output_name + "_var";
void AddOutput(OpDesc* op,
std::string output_name,
const Data& data,
const std::string user_var_name = "") {
std::string var_name = user_var_name;
if (var_name.empty()) {
var_name = output_name + "_var";
}
op->SetOutput(output_name, {var_name});
program.MutableBlock(0)->Var(var_name);
test_scope.CreateTensor(var_name, data);
......@@ -117,21 +129,23 @@ struct ConvProgramStrategy : public ProgramStrategy {
std::vector<float>&& scale_weights,
int groups = 1,
Data&& bias = Data(),
std::vector<float>&& scale_bias = {})
std::vector<float>&& scale_bias = {},
bool share_weight = false)
: input(std::move(input)),
filter(std::move(filter)),
output(std::move(output)),
scale_weights(std::move(scale_weights)),
groups(std::move(groups)),
bias(std::move(bias)),
scale_bias(std::move(scale_bias)) {}
scale_bias(std::move(scale_bias)),
share_weight(std::move(share_weight)) {}
protected:
OpDesc* CreateBasicConvOp() {
OpDesc* CreateBasicConvOp(const std::string conv_name = "Conv1") {
auto op = program.MutableBlock(0)->AppendOp();
op->SetType("conv2d");
op->SetAttr("use_mkldnn", true);
op->SetAttr("name", std::string{"Conv1"});
op->SetAttr("name", conv_name);
op->SetAttr("mkldnn_data_type", std::string{"int8"});
op->SetAttr("data_format", std::string{"NCHW"});
op->SetAttr("dilations", std::vector<int>({1, 1}));
......@@ -155,6 +169,20 @@ struct ConvProgramStrategy : public ProgramStrategy {
AddInput(op, "Bias", bias);
op->SetAttr("Bias_scales", scale_bias);
}
if (share_weight) {
OpDesc* op2 = CreateBasicConvOp("Conv2");
AddInput(op2, "Input", input);
AddInput(op2, "Filter", filter)->SetPersistable(true);
AddOutput(op2, "Output", output, "output2");
op2->SetAttr("Scale_weights", scale_weights);
op2->SetAttr("Scale_in", 1.0f);
op2->SetAttr("groups", groups);
if (HasBias()) {
AddInput(op2, "Bias", bias, "Bias2");
op2->SetAttr("Bias_scales", scale_bias);
}
}
}
void CheckOp(const OpDesc& op) const override {
......@@ -210,9 +238,9 @@ struct ConvProgramStrategy : public ProgramStrategy {
const Data output;
const std::vector<float> scale_weights;
const int groups;
const Data bias;
const std::vector<float> scale_bias;
const bool share_weight;
};
struct ParamsQuantizationMkldnnPassTestFixture : public ::testing::Test {
......@@ -340,6 +368,19 @@ TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_with_bias_2g2o2i1h1w) {
RunPassTest(std::move(program));
}
TEST_F(ParamsQuantizationMkldnnPassTestFixture, conv_with_bias_2g2o2i1h1ws) {
auto program = std::make_unique<ConvProgramStrategy>(
GenericInput(),
Data({2, 2, 2, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f}),
GenericOutput(),
std::vector<float>{2.f, 2.f, 4.f, 4.f},
2,
Data({2, 2, 1, 1, 1}, {1.5f, 1.5f, 1.5f, 1.5f}),
std::vector<float>{2.f, 2.f, 4.f, 4.f},
true);
RunPassTest(std::move(program));
}
} // namespace
} // namespace ir
} // namespace framework
......
......@@ -109,27 +109,34 @@ void QuantDequantMkldnnPass::CollectWeightScalesInfoFromONNXFormatDequantize(
if (op_node->Name() == "dequantize_linear") {
auto* op_desc = op_node->Op();
auto scale_name = op_desc->Input("Scale")[0];
auto* var = scope->FindVar(scale_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The Scales variable [%s] of dequantize op is not found.", var));
auto* scale_tensor = var->GetMutable<LoDTensor>();
auto* scale_data = scale_tensor->data<float>();
auto x_var_name = op_desc->Input("X")[0];
auto* weight_var = scope->FindVar(x_var_name);
if (!weight_var) {
auto out_var_name = op_desc->Output("Y")[0];
if (var_quant_scales->count(x_var_name) &&
!var_quant_scales->count(out_var_name)) {
std::vector<float> scale_v = var_quant_scales->at(x_var_name);
float scale = 1.0 / scale_data[0];
if (std::isinf(scale) || std::isnan(scale)) {
scale = 0.0;
}
std::vector<float> scale_v = {scale};
if (!var_quant_scales->count(out_var_name)) {
var_quant_scales->insert(std::make_pair(out_var_name, scale_v));
}
if (!var_quant_scales->count(x_var_name)) {
var_quant_scales->insert(std::make_pair(x_var_name, scale_v));
}
} else {
*onnx_format_quantize_model = true;
auto scale_name = op_desc->Input("Scale")[0];
auto* var = scope->FindVar(scale_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"The Scales variable [%s] of dequantize op is not found.",
var));
auto* scale_tensor = var->GetMutable<LoDTensor>();
auto* scale_data = scale_tensor->data<float>();
std::vector<float> thresholds(scale_data,
scale_data + scale_tensor->numel());
weight_thresholds->insert(std::make_pair(x_var_name, thresholds));
......@@ -182,7 +189,7 @@ void QuantDequantMkldnnPass::CollectInputScalesFromQuantize(
auto* scale_data = scale_tensor->data<float>();
float scale = 1.0 / scale_data[0];
if (std::isinf(scale) || std::isnan(scale)) {
scale = 0.0;
continue;
}
if (!var_quant_scales->count(x_var_name)) {
......@@ -520,12 +527,10 @@ void QuantDequantMkldnnPass::ConvertFromINT8ToFP32(
int step_c = step_n / size;
for (int i = 0; i < weight_dims[0]; i++) {
int begin_n = i * step_n;
for (int j = begin_n; j < begin_n + step_n; j++) {
for (int k = 0; k < size; k++) {
int begin_c = k * step_c;
for (int m = begin_c; m < begin_c + step_c; m++) {
weight_data[m] *= scales[k];
}
for (int j = 0; j < size; j++) {
int begin_c = begin_n + j * step_c;
for (int k = 0; k < step_c; k++) {
weight_data[begin_c + k] *= scales[j];
}
}
}
......@@ -588,7 +593,8 @@ void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat(
Scope* scope,
const std::string& weight_name,
const std::unordered_map<std::string, std::vector<float>>&
weight_thresholds) const {
weight_thresholds,
std::vector<std::string>* dequantized_weights_names) const {
auto* op_desc = op_node->Op();
std::string weight_var_name = op_desc->Input(weight_name)[0];
......@@ -596,6 +602,13 @@ void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat(
auto iter = weight_thresholds.find(weight_var_name);
if (iter != weight_thresholds.end()) {
scales = iter->second;
auto name_iter = std::find(dequantized_weights_names->begin(),
dequantized_weights_names->end(),
weight_var_name);
// Has been dequantized
if (name_iter != dequantized_weights_names->end()) {
return;
}
} else {
if (!IsInt8Weight(op_node, scope, weight_name)) {
return;
......@@ -605,7 +618,7 @@ void QuantDequantMkldnnPass::DequantizeOpWeightsFromONNXFormat(
"the model is correct.",
weight_var_name));
}
dequantized_weights_names->push_back(weight_var_name);
auto* var = scope->FindVar(weight_var_name);
PADDLE_ENFORCE_NOT_NULL(
var,
......@@ -634,14 +647,17 @@ void QuantDequantMkldnnPass::DequantizeWeights(
<< "No need to dequantize weights because weight_thresholds is empty.";
return;
}
std::vector<std::string> dequantized_weights_names;
for (auto* op_node :
ir::TopologyVarientSort(*graph, static_cast<ir::SortKind>(0))) {
if (!op_node->IsOp()) continue;
if (op_node->Name() == "conv2d" || op_node->Name() == "depthwise_conv2d") {
if (onnx_format_quantize_model) {
DequantizeOpWeightsFromONNXFormat(
op_node, scope, "Filter", weight_thresholds);
DequantizeOpWeightsFromONNXFormat(op_node,
scope,
"Filter",
weight_thresholds,
&dequantized_weights_names);
} else if (IsInt8Weight(op_node, scope, "Filter")) {
DequantizeOpWeights(
op_node, scope, "Filter", "Output", weight_thresholds);
......@@ -650,7 +666,7 @@ void QuantDequantMkldnnPass::DequantizeWeights(
op_node->Name() == "matmul_v2") {
if (onnx_format_quantize_model) {
DequantizeOpWeightsFromONNXFormat(
op_node, scope, "Y", weight_thresholds);
op_node, scope, "Y", weight_thresholds, &dequantized_weights_names);
} else if (IsInt8Weight(op_node, scope, "Y")) {
DequantizeOpWeights(op_node, scope, "Y", "Out", weight_thresholds);
}
......
......@@ -125,7 +125,8 @@ class QuantDequantMkldnnPass : public FusePassBase {
Scope* scope,
const std::string& weight_name,
const std::unordered_map<std::string, std::vector<float>>&
weight_thresholds) const;
weight_thresholds,
std::vector<std::string>* dequantized_weights_names) const;
void DequantizeWeights(
ir::Graph* graph,
......
......@@ -17,7 +17,6 @@ import time
import sys
import random
import functools
import tempfile
import numpy as np
from PIL import Image
import paddle
......@@ -149,13 +148,13 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.infer_iterations = 50000 if os.environ.get(
'DATASET') == 'full' else 2
self.root_path = tempfile.TemporaryDirectory()
self.int8_model = os.path.join(self.root_path.name,
"post_training_quantization")
self.int8_model = "post_training_quantization"
print("self.int8_model: ", self.int8_model)
def tearDown(self):
self.root_path.cleanup()
cmd = 'rm -rf post_training_quantization'
os.system(cmd)
pass
def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder):
......@@ -262,16 +261,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_use_cache_file=False,
is_optimize_model=False,
onnx_format=False):
try:
os.system("mkdir " + self.int8_model)
except Exception as e:
print("Failed to create {} due to {}".format(
self.int8_model, str(e)))
sys.exit(-1)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.global_scope()
val_reader = val()
ptq = PostTrainingQuantization(executor=exe,
......@@ -305,12 +296,6 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_cache_folder = self.download_data(data_urls, data_md5s, model)
print("Start FP32 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(
os.path.join(model_cache_folder, "model"), batch_size,
infer_iterations)
print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size))
self.generate_quantized_model(os.path.join(model_cache_folder, "model"),
......@@ -318,6 +303,12 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_full_quantize, is_use_cache_file,
is_optimize_model, onnx_format)
print("Start FP32 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(
os.path.join(model_cache_folder, "model"), batch_size,
infer_iterations)
print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
(int8_throughput, int8_latency,
......@@ -341,10 +332,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.assertLess(delta_value, diff_threshold)
class TestMKLDNNInt8ForMobilenetv1AvgONNXFormat(TestPostTrainingQuantization):
class TestMKLDNNInt8ForResnet50AvgONNXFormat(TestPostTrainingQuantization):
def test_onnx_format_avg_mobilenetv1(self):
model = "MobileNet-V1"
def test_onnx_format_avg_resnet50(self):
model = "resnet50"
algo = "avg"
round_type = "round"
data_urls = [
......@@ -373,66 +364,5 @@ class TestMKLDNNInt8ForMobilenetv1AvgONNXFormat(TestPostTrainingQuantization):
onnx_format=True)
class TestMKLDNNInt8ForMobilenetv1Avg(TestPostTrainingQuantization):
def test_avg_mobilenetv1(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
diff_threshold = 0
self.run_test(model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=False)
class TestMKLDNNInt8ForMobilenetv1AbsMax(TestPostTrainingQuantization):
def test_abs_max_mobilenetv1(self):
model = "MobileNet-V1"
algo = "abs_max"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = False
# The accuracy diff of post-training quantization (abs_max) maybe bigger
diff_threshold = 0
self.run_test(model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=False)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册