未验证 提交 ff70a269 编写于 作者: G Guanghua Yu 提交者: GitHub

[cherry-pick]Update quantization round and clip calculation methods (#43829)

* update quantization clip and round

* fix quantization clip and round Attribute

* fix typo
上级 9e776f62
...@@ -45,6 +45,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() { ...@@ -45,6 +45,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
.End() .End()
.AddAttr("bit_length") .AddAttr("bit_length")
.IsIntIn({8, 16}) .IsIntIn({8, 16})
.End()
.AddAttr("round_type")
.IsOptional()
.IsIntIn({0, 1})
.End(); .End();
AddOpCompat(OpCompat("fake_channel_wise_quantize_dequantize_abs_max")) AddOpCompat(OpCompat("fake_channel_wise_quantize_dequantize_abs_max"))
.AddInput("X") .AddInput("X")
...@@ -61,6 +65,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() { ...@@ -61,6 +65,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
.End() .End()
.AddAttr("quant_axis") .AddAttr("quant_axis")
.IsIntIn({0, 1}) .IsIntIn({0, 1})
.End()
.AddAttr("round_type")
.IsOptional()
.IsIntIn({0, 1})
.End(); .End();
} }
// Delete quant_dequant_op, then quantize and dequantize weight // Delete quant_dequant_op, then quantize and dequantize weight
...@@ -96,15 +104,18 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -96,15 +104,18 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
auto var_map = any_op2_desc->Inputs(); auto var_map = any_op2_desc->Inputs();
std::string arg_name = ""; std::string arg_name = "";
for (auto& name_m : var_map) { for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(), if (std::find(name_m.second.begin(),
name_m.second.end(),
quant_dequant_op_out_name) != name_m.second.end()) { quant_dequant_op_out_name) != name_m.second.end()) {
arg_name = name_m.first; arg_name = name_m.first;
break; break;
} }
} }
PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(
"can not find the input %s.", arg_name.size(),
quant_dequant_op_out_name)); 0,
platform::errors::InvalidArgument("can not find the input %s.",
quant_dequant_op_out_name));
// any_op2_desc->SetAttr("enable_int8", true); // any_op2_desc->SetAttr("enable_int8", true);
any_op2_desc->SetAttr("bit_length", bit_length); any_op2_desc->SetAttr("bit_length", bit_length);
...@@ -123,7 +134,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -123,7 +134,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
int quant_axis = int quant_axis =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis")); BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but " "'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
...@@ -176,7 +188,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -176,7 +188,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
} }
} }
for (int i = 0; i < channel; i++) { for (int i = 0; i < channel; i++) {
PADDLE_ENFORCE_NE(weight_scale[i], 0, PADDLE_ENFORCE_NE(weight_scale[i],
0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero.")); "Weight scale should be nonzero, but get zero."));
weight_scale[i] = weight_scale[i] / range; weight_scale[i] = weight_scale[i] / range;
...@@ -188,7 +201,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -188,7 +201,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
abs_max_weight = abs_max_weight =
std::max(abs_max_weight, std::abs(quantized_weight_data[j])); std::max(abs_max_weight, std::abs(quantized_weight_data[j]));
} }
PADDLE_ENFORCE_NE(abs_max_weight, 0, PADDLE_ENFORCE_NE(abs_max_weight,
0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero")); "Weight scale should be nonzero, but get zero"));
weight_scale.push_back(abs_max_weight / range); weight_scale.push_back(abs_max_weight / range);
......
...@@ -54,6 +54,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { ...@@ -54,6 +54,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
.End() .End()
.AddAttr("quant_axis") .AddAttr("quant_axis")
.IsType<int>() .IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End(); .End();
AddOpCompat(OpCompat("dequantize_linear")) AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X") .AddInput("X")
...@@ -74,6 +78,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { ...@@ -74,6 +78,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
.End() .End()
.AddAttr("quant_axis") .AddAttr("quant_axis")
.IsType<int>() .IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End(); .End();
} }
// Delete quantize_linear_op dequantize_linear_op, then add input_scales // Delete quantize_linear_op dequantize_linear_op, then add input_scales
...@@ -112,7 +120,8 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -112,7 +120,8 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
const LoDTensor& input_scale_tensor = const LoDTensor& input_scale_tensor =
scope->GetVar(quantize_linear_op_scale->Name())->Get<LoDTensor>(); scope->GetVar(quantize_linear_op_scale->Name())->Get<LoDTensor>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(input_scale_tensor.place()), true, paddle::platform::is_cpu_place(input_scale_tensor.place()),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU.")); "Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>(); const float* input_scale_data = input_scale_tensor.data<float>();
......
...@@ -52,6 +52,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() { ...@@ -52,6 +52,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
.End() .End()
.AddAttr("quant_axis") .AddAttr("quant_axis")
.IsType<int>() .IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End(); .End();
AddOpCompat(OpCompat("dequantize_linear")) AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X") .AddInput("X")
...@@ -72,6 +76,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() { ...@@ -72,6 +76,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
.End() .End()
.AddAttr("quant_axis") .AddAttr("quant_axis")
.IsType<int>() .IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End(); .End();
AddOpCompat(OpCompat("conv2d")) AddOpCompat(OpCompat("conv2d"))
.AddInput("Input") .AddInput("Input")
...@@ -322,7 +330,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -322,7 +330,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
int quant_axis = BOOST_GET_CONST( int quant_axis = BOOST_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis")); int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
if (quant_axis == -1) { // per_layer quant_dequant: all OP if (quant_axis == -1) { // per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ(weight_scale_nums, 1, PADDLE_ENFORCE_EQ(weight_scale_nums,
1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When quant_axis == -1 means use per_layer " "When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1.")); "quant_dequant, weight_scale'number should be 1."));
...@@ -335,11 +344,13 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -335,11 +344,13 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
} else if (quant_axis == 0) { // per_channel quant_dequant: conv2d, } else if (quant_axis == 0) { // per_channel quant_dequant: conv2d,
// depthwise_conv2d, conv2d_fusion // depthwise_conv2d, conv2d_fusion
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_scale_nums, w_dims[quant_axis], weight_scale_nums,
w_dims[quant_axis],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel quant_dequant, " "When quant_axis == 0 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels.")); "weight_scale'numbers should be equal channels."));
PADDLE_ENFORCE_EQ(w_dims.size(), 4, PADDLE_ENFORCE_EQ(w_dims.size(),
4,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel " "When quant_axis == 0 means use per_channel "
"quant_dequant, (conv2d, depthwise_conv2d, " "quant_dequant, (conv2d, depthwise_conv2d, "
...@@ -352,7 +363,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -352,7 +363,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
} }
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_scale_nums, w_dims[quant_axis], weight_scale_nums,
w_dims[quant_axis],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, " "When quant_axis == 1 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels.")); "weight_scale'numbers should be equal channels."));
...@@ -360,7 +372,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -360,7 +372,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
if (w_dims.size() == 4) { // conv2d_transpose if (w_dims.size() == 4) { // conv2d_transpose
std::string quantized_op_type = any_op2->Op()->Type(); std::string quantized_op_type = any_op2->Op()->Type();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quantized_op_type, "conv2d_transpose", quantized_op_type,
"conv2d_transpose",
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, " "When quant_axis == 1 means use per_channel quant_dequant, "
"only conv2d_transpose weight dims equal 4.")); "only conv2d_transpose weight dims equal 4."));
...@@ -388,7 +401,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -388,7 +401,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims))); weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims)));
float* new_quantized_weight_data = float* new_quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace()); weight_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_quantized_weight_data, weight_data_tmp.data(), memcpy(new_quantized_weight_data,
weight_data_tmp.data(),
weight_tensor->numel() * sizeof(float)); weight_tensor->numel() * sizeof(float));
nodes2rm.insert(weight_dequantize_linear_op_scale); nodes2rm.insert(weight_dequantize_linear_op_scale);
......
...@@ -49,6 +49,10 @@ QuantDequantFusePass::QuantDequantFusePass() { ...@@ -49,6 +49,10 @@ QuantDequantFusePass::QuantDequantFusePass() {
.End() .End()
.AddAttr("bit_length") .AddAttr("bit_length")
.IsIntIn({8, 16}) .IsIntIn({8, 16})
.End()
.AddAttr("round_type")
.IsOptional()
.IsIntIn({0, 1})
.End(); .End();
AddOpCompat(OpCompat("fake_quantize_moving_average_abs_max")) AddOpCompat(OpCompat("fake_quantize_moving_average_abs_max"))
.AddInput("X") .AddInput("X")
...@@ -85,6 +89,10 @@ QuantDequantFusePass::QuantDequantFusePass() { ...@@ -85,6 +89,10 @@ QuantDequantFusePass::QuantDequantFusePass() {
.End() .End()
.AddAttr("bit_length") .AddAttr("bit_length")
.IsIntIn({8, 16}) .IsIntIn({8, 16})
.End()
.AddAttr("round_type")
.IsOptional()
.IsIntIn({0, 1})
.End(); .End();
AddOpCompat(OpCompat("fake_dequantize_max_abs")) AddOpCompat(OpCompat("fake_dequantize_max_abs"))
.AddInput("X") .AddInput("X")
...@@ -309,7 +317,8 @@ QuantDequantFusePass::QuantDequantFusePass() { ...@@ -309,7 +317,8 @@ QuantDequantFusePass::QuantDequantFusePass() {
} }
// Delete quant op before quantized ops, and set input scale in the attr of // Delete quant op before quantized ops, and set input scale in the attr of
// quantized ops // quantized ops
void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, void QuantDequantFusePass::DeleteQuant(ir::Graph* graph,
Scope* scope,
const std::string& quant_type) const { const std::string& quant_type) const {
const std::string pattern_name = "delete_quant_fuse"; const std::string pattern_name = "delete_quant_fuse";
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -331,7 +340,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -331,7 +340,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
return; return;
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
subgraph.count(input_act_node), true, subgraph.count(input_act_node),
true,
platform::errors::NotFound( platform::errors::NotFound(
"Input act node(%s) not found in QuantDequantFuse pass.", "Input act node(%s) not found in QuantDequantFuse pass.",
input_act_node->name())); input_act_node->name()));
...@@ -345,12 +355,14 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -345,12 +355,14 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
// Get input scale from tensor // Get input scale from tensor
std::string input_scale_var_name = quant->Op()->Input("InScale").front(); std::string input_scale_var_name = quant->Op()->Input("InScale").front();
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument( scope,
"Scope in QuantDequantFuse pass should not be null.")); platform::errors::InvalidArgument(
"Scope in QuantDequantFuse pass should not be null."));
const LoDTensor& input_scale_tensor = const LoDTensor& input_scale_tensor =
scope->FindVar(input_scale_var_name)->Get<LoDTensor>(); scope->FindVar(input_scale_var_name)->Get<LoDTensor>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(input_scale_tensor.place()), true, paddle::platform::is_cpu_place(input_scale_tensor.place()),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU.")); "Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>(); const float* input_scale_data = input_scale_tensor.data<float>();
...@@ -382,8 +394,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -382,8 +394,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
IR_NODE_LINK_TO(input_act, quantized_node); IR_NODE_LINK_TO(input_act, quantized_node);
} }
// Delete nodes and edges // Delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale, quant, std::unordered_set<const Node*> nodes2rm = {
output_scale, output_act}; input_scale, quant, output_scale, output_act};
GraphSafeRemoveNodes(graph, nodes2rm); GraphSafeRemoveNodes(graph, nodes2rm);
}; };
gpd(graph, handler); gpd(graph, handler);
...@@ -391,7 +403,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -391,7 +403,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
// Delete dequant op after quantized ops, and convert weight from fp32 range to // Delete dequant op after quantized ops, and convert weight from fp32 range to
// int8 range // int8 range
void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
Scope* scope,
const std::string& quantized_op_type, const std::string& quantized_op_type,
const std::string& dequant_type) const { const std::string& dequant_type) const {
std::string weight_name = ""; std::string weight_name = "";
...@@ -436,7 +449,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -436,7 +449,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
return; return;
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
subgraph.count(quantized_op_input), true, subgraph.count(quantized_op_input),
true,
platform::errors::NotFound("Quantized op input node(%s) did not find " platform::errors::NotFound("Quantized op input node(%s) did not find "
"in QuantDequantFuse pass.", "in QuantDequantFuse pass.",
quantized_op_input->name())); quantized_op_input->name()));
...@@ -464,14 +478,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -464,14 +478,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
subgraph.at(pattern.GetPDNode("dequant_channel_scale")); subgraph.at(pattern.GetPDNode("dequant_channel_scale"));
auto scales_name = dequant_op_node->Op()->Input("Scales"); auto scales_name = dequant_op_node->Op()->Input("Scales");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scales_name.size(), 2, scales_name.size(),
2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Scales size in channel-wise dequantize op should be 2, got %d.", "Scales size in channel-wise dequantize op should be 2, got %d.",
scales_name.size())); scales_name.size()));
const LoDTensor& channel_scale_tensor = const LoDTensor& channel_scale_tensor =
scope->FindVar(scales_name[0])->Get<LoDTensor>(); scope->FindVar(scales_name[0])->Get<LoDTensor>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(channel_scale_tensor.place()), true, paddle::platform::is_cpu_place(channel_scale_tensor.place()),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Channel scale tensor's place should be CPU.")); "Channel scale tensor's place should be CPU."));
const float* channel_scale_data = channel_scale_tensor.data<float>(); const float* channel_scale_data = channel_scale_tensor.data<float>();
...@@ -497,7 +513,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -497,7 +513,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
if (quantized_op_type == "mul" || quantized_op_type == "matmul" || if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2" || quantized_op_type == "fc") { quantized_op_type == "matmul_v2" || quantized_op_type == "fc") {
if (dequant_type == "fake_dequantize_max_abs") { if (dequant_type == "fake_dequantize_max_abs") {
PADDLE_ENFORCE_EQ(weight_scale.size(), 1, PADDLE_ENFORCE_EQ(weight_scale.size(),
1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"mul/matmul/matmul_v2 op weight dequantized by " "mul/matmul/matmul_v2 op weight dequantized by "
"[fake_dequantize_max_abs] " "[fake_dequantize_max_abs] "
...@@ -511,7 +528,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -511,7 +528,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
if (quant_axis == 0) { if (quant_axis == 0) {
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 1, true, quant_axis == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' of mul/matmul/fc/matmul_v2 op weight " "'quant_axis' of mul/matmul/fc/matmul_v2 op weight "
"dequantized by " "dequantized by "
...@@ -520,14 +538,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -520,14 +538,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
quant_axis)); quant_axis));
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]), weight_scale.size(),
static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"mul/matmul/matmul_v2 op weight dequantized by " "mul/matmul/matmul_v2 op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs] requires weight scale " "[fake_channel_wise_dequantize_max_abs] requires weight scale "
"size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, " "size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, "
"but got " "but got "
"%d.", "%d.",
static_cast<size_t>(w_dims[1]), weight_scale.size())); static_cast<size_t>(w_dims[1]),
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
quantized_weight_data[j] *= weight_scale[j % w_dims[1]]; quantized_weight_data[j] *= weight_scale[j % w_dims[1]];
} }
...@@ -535,7 +555,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -535,7 +555,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
} else if (quantized_op_type == "conv2d" || } else if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d") { quantized_op_type == "depthwise_conv2d") {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dequant_type, "fake_channel_wise_dequantize_max_abs", dequant_type,
"fake_channel_wise_dequantize_max_abs",
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"conv2d op must be dequantized by " "conv2d op must be dequantized by "
"[fake_channel_wise_dequantize_max_abs], but got %s. " "[fake_channel_wise_dequantize_max_abs], but got %s. "
...@@ -546,7 +567,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -546,7 +567,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
if (quant_axis == 0) { if (quant_axis == 0) {
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0, true, quant_axis == 0,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' of conv2d/depthwise_conv2d op weight dequantized " "'quant_axis' of conv2d/depthwise_conv2d op weight dequantized "
"by [fake_channel_wise_dequantize_max_abs]should be 0, but " "by [fake_channel_wise_dequantize_max_abs]should be 0, but "
...@@ -554,18 +576,21 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -554,18 +576,21 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
quant_axis)); quant_axis));
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]), weight_scale.size(),
static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"conv2d op requires weight scale size = channel size of the " "conv2d op requires weight scale size = channel size of the "
"weight, which is %d, but got %d.", "weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[0]), weight_scale.size())); static_cast<size_t>(w_dims[0]),
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[j / inner_size]; quantized_weight_data[j] *= weight_scale[j / inner_size];
} }
} else if (quantized_op_type == "conv2d_transpose") { } else if (quantized_op_type == "conv2d_transpose") {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dequant_type, "fake_channel_wise_dequantize_max_abs", dequant_type,
"fake_channel_wise_dequantize_max_abs",
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"conv2d_transpose must be dequantized by " "conv2d_transpose must be dequantized by "
"[fake_channel_wise_dequantize_max_abs], but got %s", "[fake_channel_wise_dequantize_max_abs], but got %s",
...@@ -573,7 +598,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -573,7 +598,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
if (quant_axis == 0) { if (quant_axis == 0) {
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 1, true, quant_axis == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' of conv2d_transpose op weight dequantized by " "'quant_axis' of conv2d_transpose op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs]should be 1, but " "[fake_channel_wise_dequantize_max_abs]should be 1, but "
...@@ -581,11 +607,13 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -581,11 +607,13 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
quant_axis)); quant_axis));
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]), weight_scale.size(),
static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"conv2d_transpose op requires weight scale size = channel size " "conv2d_transpose op requires weight scale size = channel size "
"of the weight, which is %d, but got %d.", "of the weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[1]), weight_scale.size())); static_cast<size_t>(w_dims[1]),
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
int inner_size = w_dims[2] * w_dims[3]; int inner_size = w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1]]; quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1]];
...@@ -639,8 +667,13 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -639,8 +667,13 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> quant_types = { std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = { std::unordered_set<std::string> quantized_op_types = {
"conv2d", "mul", "matmul", "depthwise_conv2d", "conv2d",
"conv2d_transpose", "fc", "matmul_v2", "mul",
"matmul",
"depthwise_conv2d",
"conv2d_transpose",
"fc",
"matmul_v2",
}; };
auto* scope = param_scope(); auto* scope = param_scope();
......
...@@ -10,9 +10,11 @@ See the License for the specific language governing permissions and ...@@ -10,9 +10,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/quantize_linear_op.h" #include "paddle/fluid/operators/quantize_linear_op.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
...@@ -24,14 +26,17 @@ namespace operators { ...@@ -24,14 +26,17 @@ namespace operators {
template <typename T> template <typename T>
struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> { struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx, void operator()(const platform::CPUDeviceContext &dev_ctx,
const framework::Tensor* in, const framework::Tensor* scale, const framework::Tensor *in,
T max_range, const int quant_axis, framework::Tensor* out) { const framework::Tensor *scale,
T max_range,
const int quant_axis,
framework::Tensor *out) {
// Dequant op is before quantized op // Dequant op is before quantized op
// Dequantize the weight of quantized op // Dequantize the weight of quantized op
auto in_dims = in->dims(); auto in_dims = in->dims();
const int64_t channel = in_dims[quant_axis]; const int64_t channel = in_dims[quant_axis];
const T* scale_factor = scale->data<T>(); const T *scale_factor = scale->data<T>();
if (quant_axis == 0) { if (quant_axis == 0) {
for (int64_t i = 0; i < channel; i++) { for (int64_t i = 0; i < channel; i++) {
T s = scale_factor[i]; T s = scale_factor[i];
...@@ -39,7 +44,7 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> { ...@@ -39,7 +44,7 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
framework::Tensor one_channel_out = out->Slice(i, i + 1); framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in); auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out); auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device(); auto &dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * s / max_range; out_e.device(dev) = in_e * s / max_range;
} }
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
...@@ -49,12 +54,12 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> { ...@@ -49,12 +54,12 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
} }
int64_t step_i = in->numel() / out_iter; int64_t step_i = in->numel() / out_iter;
int64_t step_j = in->numel() / (out_iter * channel); int64_t step_j = in->numel() / (out_iter * channel);
auto* in_data = in->data<T>(); auto *in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(dev_ctx.GetPlace()); auto *out_data = out->mutable_data<T>(dev_ctx.GetPlace());
for (int64_t i = 0; i < out_iter; i++) { for (int64_t i = 0; i < out_iter; i++) {
for (int64_t j = 0; j < channel; j++) { for (int64_t j = 0; j < channel; j++) {
auto* cur_in = in_data + i * step_i + j * step_j; auto *cur_in = in_data + i * step_i + j * step_j;
auto* cur_out = out_data + i * step_i + j * step_j; auto *cur_out = out_data + i * step_i + j * step_j;
T s = scale_factor[j]; T s = scale_factor[j];
for (int64_t k = 0; k < step_j; k++) { for (int64_t k = 0; k < step_j; k++) {
*cur_out = (*cur_in) * s / max_range; *cur_out = (*cur_in) * s / max_range;
...@@ -67,19 +72,17 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> { ...@@ -67,19 +72,17 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
} }
}; };
template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
template struct DequantizeFunctor<platform::CPUDeviceContext, double>;
template struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, float>; template struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, float>;
template struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, double>; template struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, double>;
class QuantizeLinearOp : public framework::OperatorWithKernel { class QuantizeLinearOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "QuantizeLinear"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "QuantizeLinear"); OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint", OP_INOUT_CHECK(
"QuantizeLinear"); ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "QuantizeLinear"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "QuantizeLinear");
ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
int quant_axis = ctx->Attrs().Get<int>("quant_axis"); int quant_axis = ctx->Attrs().Get<int>("quant_axis");
...@@ -95,7 +98,7 @@ class QuantizeLinearOp : public framework::OperatorWithKernel { ...@@ -95,7 +98,7 @@ class QuantizeLinearOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
...@@ -116,9 +119,10 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -116,9 +119,10 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"For conv2d, depthwise_conv2d, conv2d_transpose " "For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.") "and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0) .SetDefault(0)
.AddCustomChecker([](const int& quant_axis) { .AddCustomChecker([](const int &quant_axis) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1 || quant_axis == -1, true, quant_axis == 0 || quant_axis == 1 || quant_axis == -1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but " "'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
...@@ -126,13 +130,32 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -126,13 +130,32 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
}); });
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but " "'bit_length' should be between 1 and 16, but "
"the received is %d", "the received is %d",
bit_length)); bit_length));
}); });
AddAttr<int>(
"round_type",
"(int, default 0) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3")
.SetDefault(0)
.AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument(
"'round_type' should be 0 or 1, 0 rounding to "
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type));
})
.AsExtra();
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
...@@ -156,14 +179,18 @@ namespace ops = paddle::operators; ...@@ -156,14 +179,18 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR( REGISTER_OPERATOR(
quantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker, quantize_linear,
ops::QuantizeLinearOp,
ops::QuantizeLinearOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(quantize_linear, ops::QuantizeLinearKernel<CPU, float>); REGISTER_OP_CPU_KERNEL(quantize_linear, ops::QuantizeLinearKernel<CPU, float>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
dequantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker, dequantize_linear,
ops::QuantizeLinearOp,
ops::QuantizeLinearOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
...@@ -29,9 +29,13 @@ namespace operators { ...@@ -29,9 +29,13 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctorV2 { struct ChannelDequantizeFunctorV2 {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, void operator()(const DeviceContext& dev_ctx,
const framework::Tensor** scales, const int scale_num, const framework::Tensor* in,
T max_range, const int quant_axis, framework::Tensor* out); const framework::Tensor** scales,
const int scale_num,
T max_range,
const int quant_axis,
framework::Tensor* out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -44,6 +48,7 @@ class QuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -44,6 +48,7 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
auto* out = context.Output<framework::Tensor>("Y"); auto* out = context.Output<framework::Tensor>("Y");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int round_type = context.Attr<int>("round_type");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
int quant_axis = context.Attr<int>("quant_axis"); int quant_axis = context.Attr<int>("quant_axis");
bool is_test = context.Attr<bool>("is_test"); bool is_test = context.Attr<bool>("is_test");
...@@ -53,25 +58,25 @@ class QuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -53,25 +58,25 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
if (!is_test) { if (!is_test) {
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_s = out_scale->mutable_data<T>(context.GetPlace()); T* out_s = out_scale->mutable_data<T>(context.GetPlace());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), FindAbsMaxFunctor<DeviceContext, T>()(
in->numel(), out_s); dev_ctx, in->data<T>(), in->numel(), out_s);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale, ClipAndFakeQuantFunctor<DeviceContext, T>()(
bin_cnt, out); dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
} else { } else {
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *in_scale, ClipAndFakeQuantFunctor<DeviceContext, T>()(
bin_cnt, out); dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
} }
} else { } else {
if (!is_test) { if (!is_test) {
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace()); T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
FindChannelAbsMaxFunctor<DeviceContext, T>()(dev_ctx, *in, quant_axis, FindChannelAbsMaxFunctor<DeviceContext, T>()(
out_scale_data); dev_ctx, *in, quant_axis, out_scale_data);
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()( ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
} else { } else {
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()( ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *in_scale, bin_cnt, quant_axis, out); dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out);
} }
} }
} }
...@@ -87,7 +92,8 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -87,7 +92,8 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> {
auto in_tmp = phi::Cast<T>( auto in_tmp = phi::Cast<T>(
static_cast<const typename paddle::framework::ConvertToPhiContext< static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx), DeviceContext>::TYPE&>(dev_ctx),
*in, experimental::CppTypeToDataType<D>::Type()); *in,
experimental::CppTypeToDataType<D>::Type());
auto* scale = context.Input<framework::Tensor>("Scale"); auto* scale = context.Input<framework::Tensor>("Scale");
auto* out = context.Output<framework::Tensor>("Y"); auto* out = context.Output<framework::Tensor>("Y");
...@@ -97,16 +103,18 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -97,16 +103,18 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> {
if (quant_axis < 0) { if (quant_axis < 0) {
float max_range = (std::pow(2, bit_length - 1) - 1); float max_range = (std::pow(2, bit_length - 1) - 1);
DequantizeFunctor<DeviceContext, D>()(dev_ctx, &in_tmp, scale, DequantizeFunctor<DeviceContext, D>()(
static_cast<D>(max_range), out); dev_ctx, &in_tmp, scale, static_cast<D>(max_range), out);
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale->numel(), in_tmp.dims()[quant_axis], scale->numel(),
in_tmp.dims()[quant_axis],
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The number of first scale values must be the same with " "The number of first scale values must be the same with "
"quant_axis dimension value of Input(X) when the `scale` has " "quant_axis dimension value of Input(X) when the `scale` has "
"only one element, but %ld != %ld here.", "only one element, but %ld != %ld here.",
scale->numel(), in_tmp.dims()[quant_axis])); scale->numel(),
in_tmp.dims()[quant_axis]));
int max_range = (std::pow(2, bit_length - 1) - 1); int max_range = (std::pow(2, bit_length - 1) - 1);
ChannelDequantizeFunctorV2<DeviceContext, D>()( ChannelDequantizeFunctorV2<DeviceContext, D>()(
......
...@@ -20,26 +20,31 @@ import logging ...@@ -20,26 +20,31 @@ import logging
import paddle.fluid as fluid import paddle.fluid as fluid
from ....log_helper import get_logger from ....log_helper import get_logger
from .utils import load_variable_data, set_variable_data, stable_sigmoid, quant_tensor, dequant_tensor, _channelwise_quant_axis1_ops, calculate_quant_cos_error from .utils import load_variable_data, set_variable_data, stable_sigmoid, quant_tensor, dequant_tensor, _channelwise_quant_axis1_ops, calculate_quant_cos_error, bias_correction_w
_logger = get_logger( _logger = get_logger(__name__,
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
GAMMA = -0.1 GAMMA = -0.1
ZETA = 1.1 ZETA = 1.1
def compute_soft_rounding(alpha_v): def compute_soft_rounding(alpha_v):
return fluid.layers.clip( return fluid.layers.clip(fluid.layers.sigmoid(alpha_v) * (ZETA - GAMMA) +
fluid.layers.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, min=0, max=1) GAMMA,
min=0,
max=1)
def compute_soft_rounding_np(alpha_v): def compute_soft_rounding_np(alpha_v):
return np.clip( return np.clip(stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA,
stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, a_min=0, a_max=1) a_min=0,
a_max=1)
class AdaRoundLoss(object): class AdaRoundLoss(object):
def __init__(self, reg_param=0.01, default_beta_range=(20, 2)): def __init__(self, reg_param=0.01, default_beta_range=(20, 2)):
self.default_reg_param = reg_param self.default_reg_param = reg_param
self.default_beta_range = default_beta_range self.default_beta_range = default_beta_range
...@@ -48,26 +53,29 @@ class AdaRoundLoss(object): ...@@ -48,26 +53,29 @@ class AdaRoundLoss(object):
square_cost = fluid.layers.square_error_cost(ada_quantized_output, square_cost = fluid.layers.square_error_cost(ada_quantized_output,
orig_output) orig_output)
recon_loss = fluid.layers.reduce_mean( recon_loss = fluid.layers.reduce_mean(
fluid.layers.reduce_sum( fluid.layers.reduce_sum(square_cost, dim=-1))
square_cost, dim=-1))
return recon_loss return recon_loss
def compute_round_loss(self, alpha_v, warm_start, beta): def compute_round_loss(self, alpha_v, warm_start, beta):
def round_loss_fn(): def round_loss_fn():
# compute rectified sigmoid of parameter 'alpha' which maps it between zero and one # compute rectified sigmoid of parameter 'alpha' which maps it between zero and one
h_v = compute_soft_rounding(alpha_v) h_v = compute_soft_rounding(alpha_v)
# calculate regularization term - which ensures parameter to converge to exactly zeros and ones # calculate regularization term - which ensures parameter to converge to exactly zeros and ones
# at the end of optimization # at the end of optimization
reg_term = fluid.layers.reduce_sum(-fluid.layers.pow( reg_term = fluid.layers.reduce_sum(
fluid.layers.abs(2 * h_v - 1), factor=beta) + 1) -fluid.layers.pow(fluid.layers.abs(2 * h_v - 1), factor=beta) +
1)
# calculate the rounding loss # calculate the rounding loss
round_loss = self.default_reg_param * reg_term round_loss = self.default_reg_param * reg_term
return round_loss return round_loss
round_loss = fluid.layers.cond(warm_start, lambda: fluid.layers.fill_constant(shape=[1], dtype='float32', value=0.0), round_loss_fn) round_loss = fluid.layers.cond(
warm_start, lambda: fluid.layers.fill_constant(
shape=[1], dtype='float32', value=0.0), round_loss_fn)
return round_loss return round_loss
...@@ -80,15 +88,16 @@ class AdaRoundLoss(object): ...@@ -80,15 +88,16 @@ class AdaRoundLoss(object):
warm_start_end_iter = warm_start * max_iter warm_start_end_iter = warm_start * max_iter
# compute relative iteration of current iteration # compute relative iteration of current iteration
rel_iter = (cur_iter - warm_start_end_iter) / ( rel_iter = (cur_iter - warm_start_end_iter) / (max_iter -
max_iter - warm_start_end_iter) warm_start_end_iter)
beta = end_beta + 0.5 * (start_beta - end_beta) * (1 + np.cos(rel_iter * beta = end_beta + 0.5 * (start_beta -
np.pi)) end_beta) * (1 + np.cos(rel_iter * np.pi))
return beta return beta
class AdaRound(object): class AdaRound(object):
def __init__(self, def __init__(self,
scale, scale,
weight_tensor, weight_tensor,
...@@ -145,10 +154,9 @@ class AdaRound(object): ...@@ -145,10 +154,9 @@ class AdaRound(object):
h_alpha = compute_soft_rounding_np(np_alpha) h_alpha = compute_soft_rounding_np(np_alpha)
# Scale the tensor # Scale the tensor
tensor_scale = quant_tensor( tensor_scale = quant_tensor(self.ori_weight_tensor.copy(),
self.ori_weight_tensor.copy(), self.scale,
self.scale, quant_axis=self.quant_axis)
quant_axis=self.quant_axis)
weight_tensor = np.floor(tensor_scale) weight_tensor = np.floor(tensor_scale)
...@@ -160,10 +168,10 @@ class AdaRound(object): ...@@ -160,10 +168,10 @@ class AdaRound(object):
weight_tensor_quant = self._calculate_quant_weight() weight_tensor_quant = self._calculate_quant_weight()
# Dequantize the tensor # Dequantize the tensor
weight_tensor_dequant = dequant_tensor( weight_tensor_dequant = dequant_tensor(weight_tensor_quant +
weight_tensor_quant + self.offset, self.offset,
self.scale, self.scale,
quant_axis=self.quant_axis) quant_axis=self.quant_axis)
return weight_tensor_dequant return weight_tensor_dequant
def update_final_weights(self): def update_final_weights(self):
...@@ -171,10 +179,10 @@ class AdaRound(object): ...@@ -171,10 +179,10 @@ class AdaRound(object):
return weight_tensor_quant return weight_tensor_quant
def get_loss(self, beta, warm_start, adaround_out_tensor, orig_out_tensor): def get_loss(self, beta, warm_start, adaround_out_tensor, orig_out_tensor):
round_loss = self.adaround_loss.compute_round_loss(self.alpha_v, round_loss = self.adaround_loss.compute_round_loss(
warm_start, beta) self.alpha_v, warm_start, beta)
recon_loss = self.adaround_loss.compute_recon_loss(adaround_out_tensor, recon_loss = self.adaround_loss.compute_recon_loss(
orig_out_tensor) adaround_out_tensor, orig_out_tensor)
loss = round_loss + recon_loss loss = round_loss + recon_loss
losses = { losses = {
'loss': loss, 'loss': loss,
...@@ -201,6 +209,7 @@ def run_adaround(data_loader, ...@@ -201,6 +209,7 @@ def run_adaround(data_loader,
scale_dict, scale_dict,
num_iterations=1000, num_iterations=1000,
lr=0.001, lr=0.001,
bias_correction=False,
fast_mode=True): fast_mode=True):
fetch_op_name = fetch_list[0].name fetch_op_name = fetch_list[0].name
final_weight_tensor_quant_dict = {} final_weight_tensor_quant_dict = {}
...@@ -226,29 +235,29 @@ def run_adaround(data_loader, ...@@ -226,29 +235,29 @@ def run_adaround(data_loader,
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
# initialize adaround # initialize adaround
adaround = AdaRound( adaround = AdaRound(scale,
scale, weight_var_tensor,
weight_var_tensor, scope=scope,
scope=scope, weight_var_name=weight_var_name,
weight_var_name=weight_var_name, weight_op_type=weight_op_type,
weight_op_type=weight_op_type, num_iterations=num_iterations)
num_iterations=num_iterations) orig_out_tensor = fluid.data(name='orig_out_tensor',
orig_out_tensor = fluid.data( shape=fp32_fetch_list.shape,
name='orig_out_tensor', dtype='float32')
shape=fp32_fetch_list.shape, adaround_out_tensor = fluid.data(name='adaround_out_tensor',
dtype='float32') shape=fp32_fetch_list.shape,
adaround_out_tensor = fluid.data( dtype='float32')
name='adaround_out_tensor', beta_tensor = fluid.data(name='beta',
shape=fp32_fetch_list.shape, shape=[1],
dtype='float32') dtype='float32')
beta_tensor = fluid.data( warm_start_tensor = fluid.data(name='warm_start',
name='beta', shape=[1], dtype='float32') shape=[1],
warm_start_tensor = fluid.data( dtype='bool')
name='warm_start', shape=[1], dtype='bool')
train_fetches_loss = adaround.get_loss(beta_tensor,
train_fetches_loss = adaround.get_loss( warm_start_tensor,
beta_tensor, warm_start_tensor, adaround_out_tensor, adaround_out_tensor,
orig_out_tensor) orig_out_tensor)
optimizer = fluid.optimizer.Adam(learning_rate=lr) optimizer = fluid.optimizer.Adam(learning_rate=lr)
loss = train_fetches_loss['loss'] loss = train_fetches_loss['loss']
optimizer.minimize(loss) optimizer.minimize(loss)
...@@ -291,16 +300,23 @@ def run_adaround(data_loader, ...@@ -291,16 +300,23 @@ def run_adaround(data_loader,
fetch_list=[v.name for v in train_fetches_loss.values()], fetch_list=[v.name for v in train_fetches_loss.values()],
return_numpy=True) return_numpy=True)
_logger.info( _logger.info(
"Iter {:d}, lr {:.5f}, loss {:.5f}, loss_round {:.5f}, loss_recon {:.5f}, time {:.5f}s". "Iter {:d}, lr {:.5f}, loss {:.5f}, loss_round {:.5f}, loss_recon {:.5f}, time {:.5f}s"
format(i, lr, .format(i, lr, np.mean(out[0]), np.mean(out[1]),
np.mean(out[0]), np.mean(out[2]), start_time - prev_start_time))
np.mean(out[1]),
np.mean(out[2]), start_time - prev_start_time))
sys.stdout.flush() sys.stdout.flush()
if i == num_iterations: if i == num_iterations:
break break
final_weight_tensor_quant_dict[ final_weight_tensor_quant_dict[
weight_var_name] = adaround.update_final_weights() weight_var_name] = adaround.update_final_weights()
if bias_correction:
final_weight_tensor_quant_dict[weight_var_name] = bias_correction_w(
weight_var_tensor,
final_weight_tensor_quant_dict[weight_var_name],
scale,
adaround.quant_axis,
weight_bits=adaround.weight_bits)
del adaround del adaround
# update adarounded calibrated weights # update adarounded calibrated weights
......
...@@ -36,8 +36,9 @@ from . import utils ...@@ -36,8 +36,9 @@ from . import utils
__all__ = ['PostTrainingQuantization', 'WeightQuantization'] __all__ = ['PostTrainingQuantization', 'WeightQuantization']
_logger = get_logger( _logger = get_logger(__name__,
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
def _all_persistable_var_names(program): def _all_persistable_var_names(program):
...@@ -88,7 +89,8 @@ def _apply_pass(scope, ...@@ -88,7 +89,8 @@ def _apply_pass(scope,
cpp_graph.set_not_owned('__param_scope__', scope) cpp_graph.set_not_owned('__param_scope__', scope)
if attrs: if attrs:
assert attr_values and len(attrs) == len( assert attr_values and len(attrs) == len(
attr_values), "Different number of pass attributes and their values." attr_values
), "Different number of pass attributes and their values."
for attr, value in zip(attrs, attr_values): for attr, value in zip(attrs, attr_values):
ir_pass.set(attr, value) ir_pass.set(attr, value)
ir_pass.apply(cpp_graph) ir_pass.apply(cpp_graph)
...@@ -180,7 +182,8 @@ class PostTrainingQuantization(object): ...@@ -180,7 +182,8 @@ class PostTrainingQuantization(object):
"mul"]. "mul"].
round_type(str, optional): The method of converting the quantized weights round_type(str, optional): The method of converting the quantized weights
value float->int. Currently supports ['round', 'adaround'] methods. value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the nearest whole number. Default is `round`, which is rounding nearest to the integer.
'adaround' is refer to https://arxiv.org/abs/2004.10568.
learning_rate(float, optional): The learning rate of adaround method. learning_rate(float, optional): The learning rate of adaround method.
is_full_quantized(bool, optional): If set is_full_quantized as True, is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set apply quantization to all supported quantizable op type. If set
...@@ -364,7 +367,8 @@ class PostTrainingQuantization(object): ...@@ -364,7 +367,8 @@ class PostTrainingQuantization(object):
batch_id = 0 batch_id = 0
with tqdm( with tqdm(
total=self._batch_nums, total=self._batch_nums,
bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', bar_format=
'Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t: ncols=80) as t:
for data in self._data_loader(): for data in self._data_loader():
self._executor.run(program=self._program, self._executor.run(program=self._program,
...@@ -380,10 +384,10 @@ class PostTrainingQuantization(object): ...@@ -380,10 +384,10 @@ class PostTrainingQuantization(object):
self._init_sampling_act_histogram() self._init_sampling_act_histogram()
batch_id = 0 batch_id = 0
with tqdm( with tqdm(total=self._batch_nums,
total=self._batch_nums, bar_format=
bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', 'Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t: ncols=80) as t:
for data in self._data_loader(): for data in self._data_loader():
self._executor.run(program=self._program, self._executor.run(program=self._program,
feed=data, feed=data,
...@@ -446,18 +450,18 @@ class PostTrainingQuantization(object): ...@@ -446,18 +450,18 @@ class PostTrainingQuantization(object):
scale_dict = self._quantized_var_threshold scale_dict = self._quantized_var_threshold
else: else:
scale_dict = self._quantized_threshold scale_dict = self._quantized_threshold
run_adaround( run_adaround(self._data_loader,
self._data_loader, self._program,
self._program, self._fetch_list,
self._fetch_list, self._executor,
self._executor, self._scope,
self._scope, self._place,
self._place, self._quantized_op_pairs,
self._quantized_op_pairs, self._weight_op_pairs,
self._weight_op_pairs, scale_dict,
scale_dict, num_iterations=self._batch_nums,
num_iterations=self._batch_nums, bias_correction=self._bias_correction,
lr=self._learning_rate) lr=self._learning_rate)
def save_quantized_model(self, def save_quantized_model(self,
save_model_path, save_model_path,
...@@ -478,15 +482,14 @@ class PostTrainingQuantization(object): ...@@ -478,15 +482,14 @@ class PostTrainingQuantization(object):
None None
''' '''
clip_extra = True if self._onnx_format else False clip_extra = True if self._onnx_format else False
io.save_inference_model( io.save_inference_model(dirname=save_model_path,
dirname=save_model_path, model_filename=model_filename,
model_filename=model_filename, params_filename=params_filename,
params_filename=params_filename, feeded_var_names=self._feed_list,
feeded_var_names=self._feed_list, target_vars=self._fetch_list,
target_vars=self._fetch_list, executor=self._executor,
executor=self._executor, main_program=self._program,
main_program=self._program, clip_extra=clip_extra)
clip_extra=clip_extra)
_logger.info("The quantized model is saved in " + save_model_path) _logger.info("The quantized model is saved in " + save_model_path)
def _load_model_data(self): def _load_model_data(self):
...@@ -508,17 +511,18 @@ class PostTrainingQuantization(object): ...@@ -508,17 +511,18 @@ class PostTrainingQuantization(object):
if self._data_loader is not None: if self._data_loader is not None:
return return
self._data_loader = io.DataLoader.from_generator( self._data_loader = io.DataLoader.from_generator(feed_list=feed_vars,
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True) capacity=3 *
self._batch_size,
iterable=True)
if self._sample_generator is not None: if self._sample_generator is not None:
self._data_loader.set_sample_generator( self._data_loader.set_sample_generator(self._sample_generator,
self._sample_generator, batch_size=self._batch_size,
batch_size=self._batch_size, drop_last=True,
drop_last=True, places=self._place)
places=self._place)
elif self._batch_generator is not None: elif self._batch_generator is not None:
self._data_loader.set_batch_generator( self._data_loader.set_batch_generator(self._batch_generator,
self._batch_generator, places=self._place) places=self._place)
def _optimize_fp32_model(self): def _optimize_fp32_model(self):
''' '''
...@@ -569,12 +573,10 @@ class PostTrainingQuantization(object): ...@@ -569,12 +573,10 @@ class PostTrainingQuantization(object):
" is not supported for quantization.") " is not supported for quantization.")
# For quantized ops, sample inputs and outputs # For quantized ops, sample inputs and outputs
if op_type in self._quantizable_op_type: if op_type in self._quantizable_op_type:
collect_var_name( collect_var_name(utils._get_op_input_var_names(op),
utils._get_op_input_var_names(op), persistable_var_names, op_type)
persistable_var_names, op_type) collect_var_name(utils._get_op_output_var_names(op),
collect_var_name( persistable_var_names, op_type)
utils._get_op_output_var_names(op),
persistable_var_names, op_type)
# collect quanted op output var name # collect quanted op output var name
for out_var_name in utils._get_op_output_var_names(op): for out_var_name in utils._get_op_output_var_names(op):
for in_var_name in utils._get_op_input_var_names(op): for in_var_name in utils._get_op_input_var_names(op):
...@@ -583,9 +585,8 @@ class PostTrainingQuantization(object): ...@@ -583,9 +585,8 @@ class PostTrainingQuantization(object):
in_var_name] = out_var_name in_var_name] = out_var_name
# For other op, only sample output scale # For other op, only sample output scale
elif op_type in self._out_scale_op_list: elif op_type in self._out_scale_op_list:
collect_var_name( collect_var_name(utils._get_op_output_var_names(op),
utils._get_op_output_var_names(op), persistable_var_names, op_type)
persistable_var_names, op_type)
def _set_activation_persistable(self): def _set_activation_persistable(self):
''' '''
...@@ -655,9 +656,14 @@ class PostTrainingQuantization(object): ...@@ -655,9 +656,14 @@ class PostTrainingQuantization(object):
scale = s * abs_max_value scale = s * abs_max_value
s += 0.02 s += 0.02
bins = 2**(self._activation_bits - 1) - 1 bins = 2**(self._activation_bits - 1) - 1
quant_dequant_var = np.round( if self._onnx_format:
np.clip(var_tensor, 0.0, scale) / scale * quant_var = np.clip(np.round(var_tensor / scale * bins),
bins) / bins * scale -bins - 1, bins)
quant_dequant_var = quant_var / bins * scale
else:
quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
mse_loss = ((var_tensor - quant_dequant_var)**2).mean() mse_loss = ((var_tensor - quant_dequant_var)**2).mean()
if mse_loss <= self._best_calibration_loss[var_name]: if mse_loss <= self._best_calibration_loss[var_name]:
self._best_calibration_loss[var_name] = mse_loss self._best_calibration_loss[var_name] = mse_loss
...@@ -694,9 +700,14 @@ class PostTrainingQuantization(object): ...@@ -694,9 +700,14 @@ class PostTrainingQuantization(object):
scale = s * abs_max_value scale = s * abs_max_value
s += 0.02 s += 0.02
bins = 2**(self._activation_bits - 1) - 1 bins = 2**(self._activation_bits - 1) - 1
quant_dequant_var = np.round( if self._onnx_format:
np.clip(var_tensor, 0.0, scale) / scale * quant_var = np.clip(np.round(var_tensor / scale * bins),
bins) / bins * scale -bins - 1, bins)
quant_dequant_var = quant_var / bins * scale
else:
quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
emd_loss = np.abs( emd_loss = np.abs(
np.mean(var_tensor) - np.mean(quant_dequant_var)) + np.abs( np.mean(var_tensor) - np.mean(quant_dequant_var)) + np.abs(
np.std(var_tensor) - np.std(quant_dequant_var)) np.std(var_tensor) - np.std(quant_dequant_var))
...@@ -846,8 +857,9 @@ class PostTrainingQuantization(object): ...@@ -846,8 +857,9 @@ class PostTrainingQuantization(object):
if var_name not in self._sampling_act_histogram: if var_name not in self._sampling_act_histogram:
min_val = self._sampling_act_abs_min_max[var_name][0] min_val = self._sampling_act_abs_min_max[var_name][0]
max_val = self._sampling_act_abs_min_max[var_name][1] max_val = self._sampling_act_abs_min_max[var_name][1]
hist, hist_edeges = np.histogram( hist, hist_edeges = np.histogram([],
[], bins=self._histogram_bins, range=(min_val, max_val)) bins=self._histogram_bins,
range=(min_val, max_val))
self._sampling_act_histogram[var_name] = [hist, hist_edeges] self._sampling_act_histogram[var_name] = [hist, hist_edeges]
def _calculate_kl_hist_threshold(self): def _calculate_kl_hist_threshold(self):
...@@ -951,18 +963,11 @@ class PostTrainingQuantization(object): ...@@ -951,18 +963,11 @@ class PostTrainingQuantization(object):
else: else:
scale_dict = self._quantized_threshold scale_dict = self._quantized_threshold
for key, val in scale_dict.items(): for key, val in scale_dict.items():
utils.set_variable_data( utils.set_variable_data(self._scope, self._place, key + ".scale",
self._scope, np.array([val], dtype=np.float32))
self._place, utils.set_variable_data(self._scope, self._place,
key + ".scale", key + ".quant_dequant.scale",
np.array( np.array([val], dtype=np.float32))
[val], dtype=np.float32))
utils.set_variable_data(
self._scope,
self._place,
key + ".quant_dequant.scale",
np.array(
[val], dtype=np.float32))
if not self._onnx_format: if not self._onnx_format:
# apply QuantizationFreezePass, and obtain the final quant model # apply QuantizationFreezePass, and obtain the final quant model
...@@ -1038,8 +1043,8 @@ class PostTrainingQuantization(object): ...@@ -1038,8 +1043,8 @@ class PostTrainingQuantization(object):
for block_id in range(len(self._program.blocks)): for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops: for op in self._program.blocks[block_id].ops:
if op.type in ( if op.type in (self._quantizable_op_type +
self._quantizable_op_type + self._out_scale_op_list): self._out_scale_op_list):
out_var_names = utils._get_op_output_var_names(op) out_var_names = utils._get_op_output_var_names(op)
for var_name in out_var_names: for var_name in out_var_names:
analysis_and_save_info(op, var_name) analysis_and_save_info(op, var_name)
...@@ -1175,10 +1180,11 @@ class WeightQuantization(object): ...@@ -1175,10 +1180,11 @@ class WeightQuantization(object):
if generate_test_model: if generate_test_model:
test_model_dir = os.path.join(save_model_dir, "test_model") test_model_dir = os.path.join(save_model_dir, "test_model")
self._quantize_weight_to_int( self._quantize_weight_to_int(test_model_dir, save_model_filename,
test_model_dir, save_model_filename, save_params_filename, save_params_filename,
quantizable_op_type, weight_bits, weight_quantize_type, True, quantizable_op_type, weight_bits,
threshold_rate) weight_quantize_type, True,
threshold_rate)
def convert_weight_to_fp16(self, save_model_dir): def convert_weight_to_fp16(self, save_model_dir):
""" """
...@@ -1216,16 +1222,17 @@ class WeightQuantization(object): ...@@ -1216,16 +1222,17 @@ class WeightQuantization(object):
if self._params_filename is not None: if self._params_filename is not None:
save_var_map[new_var.name] = new_var save_var_map[new_var.name] = new_var
else: else:
save_file_path = os.path.join( save_file_path = os.path.join(os.path.normpath(save_model_dir),
os.path.normpath(save_model_dir), new_var.name) new_var.name)
save_block.append_op( save_block.append_op(type='save',
type='save', inputs={'X': [new_var]},
inputs={'X': [new_var]}, outputs={},
outputs={}, attrs={
attrs={ 'file_path':
'file_path': os.path.normpath(save_file_path), os.path.normpath(save_file_path),
'save_as_fp16': True 'save_as_fp16':
}) True
})
if self._params_filename is not None: if self._params_filename is not None:
save_var_list = [] save_var_list = []
...@@ -1237,14 +1244,15 @@ class WeightQuantization(object): ...@@ -1237,14 +1244,15 @@ class WeightQuantization(object):
name=unique_name.generate("saved_params")) name=unique_name.generate("saved_params"))
saved_params_var.desc.set_persistable(True) saved_params_var.desc.set_persistable(True)
save_path = os.path.join( save_path = os.path.join(os.path.normpath(save_model_dir),
os.path.normpath(save_model_dir), self._params_filename) self._params_filename)
save_block.append_op( save_block.append_op(type='save_combine',
type='save_combine', inputs={'X': save_var_list},
inputs={'X': save_var_list}, outputs={'Y': saved_params_var},
outputs={'Y': saved_params_var}, attrs={
attrs={'file_path': save_path, 'file_path': save_path,
'save_as_fp16': True}) 'save_as_fp16': True
})
save_program._sync_with_cpp() save_program._sync_with_cpp()
exe.run(save_program) exe.run(save_program)
...@@ -1293,14 +1301,13 @@ class WeightQuantization(object): ...@@ -1293,14 +1301,13 @@ class WeightQuantization(object):
self._weight_channel_wise_abs_max_quantization( self._weight_channel_wise_abs_max_quantization(
scope, place, weight_bits, op, var_name, for_test) scope, place, weight_bits, op, var_name, for_test)
io.save_inference_model( io.save_inference_model(dirname=save_model_dir,
dirname=save_model_dir, feeded_var_names=feed_list,
feeded_var_names=feed_list, target_vars=fetch_list,
target_vars=fetch_list, executor=exe,
executor=exe, main_program=program,
main_program=program, model_filename=save_model_filename,
model_filename=save_model_filename, params_filename=save_params_filename)
params_filename=save_params_filename)
def _weight_abs_max_quantization(self, scope, place, weight_bits, def _weight_abs_max_quantization(self, scope, place, weight_bits,
threshold_rate, op, var_name, for_test): threshold_rate, op, var_name, for_test):
...@@ -1339,8 +1346,9 @@ class WeightQuantization(object): ...@@ -1339,8 +1346,9 @@ class WeightQuantization(object):
op._set_attr(var_name + "_quant_scale", [scale]) # Save as list op._set_attr(var_name + "_quant_scale", [scale]) # Save as list
op._set_attr("with_quant_attr", True) op._set_attr("with_quant_attr", True)
def _weight_channel_wise_abs_max_quantization( def _weight_channel_wise_abs_max_quantization(self, scope, place,
self, scope, place, weight_bits, op, var_name, for_test): weight_bits, op, var_name,
for_test):
''' '''
Use channel_wise_abs_max method to quantize weight. Use channel_wise_abs_max method to quantize weight.
''' '''
...@@ -1390,8 +1398,8 @@ class WeightQuantization(object): ...@@ -1390,8 +1398,8 @@ class WeightQuantization(object):
and quantize the weights. and quantize the weights.
''' '''
scales = [] scales = []
quantized_weight_data = np.zeros_like( quantized_weight_data = np.zeros_like(weight_data,
weight_data, dtype=save_weight_dtype) dtype=save_weight_dtype)
channel_num = weight_data.shape[0] channel_num = weight_data.shape[0]
for i in range(channel_num): for i in range(channel_num):
scale = np.max(np.abs(weight_data[i])) / quantize_range scale = np.max(np.abs(weight_data[i])) / quantize_range
...@@ -1404,8 +1412,8 @@ class WeightQuantization(object): ...@@ -1404,8 +1412,8 @@ class WeightQuantization(object):
''' '''
For conv2d and depthwise_conv2d, dequantize the weights to fp32. For conv2d and depthwise_conv2d, dequantize the weights to fp32.
''' '''
dequantized_weight_data = np.zeros_like( dequantized_weight_data = np.zeros_like(quantized_weight_data,
quantized_weight_data, dtype=np.float32) dtype=np.float32)
for i in range(len(scales)): for i in range(len(scales)):
dequantized_weight_data[i] = \ dequantized_weight_data[i] = \
(quantized_weight_data[i] * scales[i]).astype(np.float32) (quantized_weight_data[i] * scales[i]).astype(np.float32)
...@@ -1418,8 +1426,8 @@ class WeightQuantization(object): ...@@ -1418,8 +1426,8 @@ class WeightQuantization(object):
and quantize the weights. and quantize the weights.
''' '''
scales = [] scales = []
quantized_weight_data = np.zeros_like( quantized_weight_data = np.zeros_like(weight_data,
weight_data, dtype=save_weight_dtype) dtype=save_weight_dtype)
channel_num = weight_data.shape[-1] channel_num = weight_data.shape[-1]
for i in range(channel_num): for i in range(channel_num):
scale = np.max(np.abs(weight_data[:, i])) / quantize_range scale = np.max(np.abs(weight_data[:, i])) / quantize_range
...@@ -1432,8 +1440,8 @@ class WeightQuantization(object): ...@@ -1432,8 +1440,8 @@ class WeightQuantization(object):
''' '''
For mul, dequantize the weights to fp32. For mul, dequantize the weights to fp32.
''' '''
dequantized_weight_data = np.zeros_like( dequantized_weight_data = np.zeros_like(quantized_weight_data,
quantized_weight_data, dtype=np.float32) dtype=np.float32)
for i in range(len(scales)): for i in range(len(scales)):
dequantized_weight_data[:, i] = \ dequantized_weight_data[:, i] = \
(quantized_weight_data[:, i] * scales[i]).astype(np.float32) (quantized_weight_data[:, i] * scales[i]).astype(np.float32)
...@@ -1441,8 +1449,9 @@ class WeightQuantization(object): ...@@ -1441,8 +1449,9 @@ class WeightQuantization(object):
def _calculate_threshold(self, input, threshold_rate, histogram_bins=5000): def _calculate_threshold(self, input, threshold_rate, histogram_bins=5000):
input_abs = np.abs(input) input_abs = np.abs(input)
hist, hist_edeges = np.histogram( hist, hist_edeges = np.histogram(input_abs,
input_abs, bins=histogram_bins, range=(0, np.max(input_abs))) bins=histogram_bins,
range=(0, np.max(input_abs)))
hist = hist / float(sum(hist)) hist = hist / float(sum(hist))
hist_sum = 0 hist_sum = 0
hist_index = 0 hist_index = 0
......
...@@ -321,7 +321,7 @@ def set_variable_data(scope, place, var_name, np_value): ...@@ -321,7 +321,7 @@ def set_variable_data(scope, place, var_name, np_value):
tensor.set(np_value, place) tensor.set(np_value, place)
def quant_tensor(x, scale, quant_axis=0, weight_bits=8): def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False):
# symmetry quant # symmetry quant
def _clip(x, scale): def _clip(x, scale):
x[x > scale] = scale x[x > scale] = scale
...@@ -335,15 +335,27 @@ def quant_tensor(x, scale, quant_axis=0, weight_bits=8): ...@@ -335,15 +335,27 @@ def quant_tensor(x, scale, quant_axis=0, weight_bits=8):
if s == 0.0: if s == 0.0:
s = 1e-8 s = 1e-8
if quant_axis == 0: if quant_axis == 0:
x[i] = _clip(x[i], s) if onnx_format:
x[i] = x[i] / s * bnt x[i] = np.round(x[i] / s * bnt)
x[i] = np.clip(x[i], -bnt - 1, bnt)
else:
x[i] = _clip(x[i], s)
x[i] = x[i] / s * bnt
else: else:
x[:, i] = _clip(x[:, i], s) if onnx_format:
x[:, i] = x[:, i] / s * bnt x[:, i] = np.round(x[:, i] / s * bnt)
x[:, i] = np.clip(x[:, i], -bnt - 1, bnt)
else:
x[:, i] = _clip(x[:, i], s)
x[:, i] = x[:, i] / s * bnt
else: else:
scale = 1e-8 if scale == 0.0 else scale scale = 1e-8 if scale == 0.0 else scale
x = _clip(x, scale) if onnx_format:
x = x / scale * bnt x = np.round(x / scale * bnt)
x = np.clip(x, -bnt - 1, bnt)
else:
x = _clip(x, scale)
x = x / scale * bnt
return x return x
...@@ -416,6 +428,7 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor): ...@@ -416,6 +428,7 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor):
class tqdm(object): class tqdm(object):
def __init__(self, total, bar_format='Loading|{bar}', ncols=80): def __init__(self, total, bar_format='Loading|{bar}', ncols=80):
self.total = total self.total = total
self.bar_format = bar_format self.bar_format = bar_format
......
...@@ -35,8 +35,9 @@ from paddle.fluid.framework import _test_eager_guard ...@@ -35,8 +35,9 @@ from paddle.fluid.framework import _test_eager_guard
from imperative_test_utils import fix_model_dict, ImperativeLenet, ImperativeLinearBn from imperative_test_utils import fix_model_dict, ImperativeLenet, ImperativeLinearBn
from imperative_test_utils import ImperativeLinearBn_hook from imperative_test_utils import ImperativeLinearBn_hook
_logger = get_logger( _logger = get_logger(__name__,
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
class TestFuseLinearBn(unittest.TestCase): class TestFuseLinearBn(unittest.TestCase):
...@@ -55,15 +56,15 @@ class TestFuseLinearBn(unittest.TestCase): ...@@ -55,15 +56,15 @@ class TestFuseLinearBn(unittest.TestCase):
quant_h = ptq.quantize(model_h, fuse=True, fuse_list=f_l) quant_h = ptq.quantize(model_h, fuse=True, fuse_list=f_l)
for name, layer in quant_model.named_sublayers(): for name, layer in quant_model.named_sublayers():
if name in f_l: if name in f_l:
assert not (isinstance(layer, nn.BatchNorm1D) or assert not (isinstance(layer, nn.BatchNorm1D)
isinstance(layer, nn.BatchNorm2D)) or isinstance(layer, nn.BatchNorm2D))
out = model(inputs) out = model(inputs)
out_h = model_h(inputs) out_h = model_h(inputs)
out_quant = quant_model(inputs) out_quant = quant_model(inputs)
out_quant_h = quant_h(inputs) out_quant_h = quant_h(inputs)
cos_sim_func = nn.CosineSimilarity(axis=0) cos_sim_func = nn.CosineSimilarity(axis=0)
print('fuse linear+bn', print('fuse linear+bn', cos_sim_func(out.flatten(),
cos_sim_func(out.flatten(), out_quant.flatten())) out_quant.flatten()))
print(cos_sim_func(out_h.flatten(), out_quant_h.flatten())) print(cos_sim_func(out_h.flatten(), out_quant_h.flatten()))
...@@ -87,8 +88,8 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -87,8 +88,8 @@ class TestImperativePTQ(unittest.TestCase):
def cache_unzipping(self, target_folder, zip_path): def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder): if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(
zip_path) target_folder, zip_path)
os.system(cmd) os.system(cmd)
def download_model(self, data_url, data_md5, folder_name): def download_model(self, data_url, data_md5, folder_name):
...@@ -123,8 +124,8 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -123,8 +124,8 @@ class TestImperativePTQ(unittest.TestCase):
def model_test(self, model, batch_num=-1, batch_size=8): def model_test(self, model, batch_num=-1, batch_size=8):
model.eval() model.eval()
test_reader = paddle.batch( test_reader = paddle.batch(paddle.dataset.mnist.test(),
paddle.dataset.mnist.test(), batch_size=batch_size) batch_size=batch_size)
eval_acc_top1_list = [] eval_acc_top1_list = []
for batch_id, data in enumerate(test_reader()): for batch_id, data in enumerate(test_reader()):
...@@ -157,8 +158,8 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -157,8 +158,8 @@ class TestImperativePTQ(unittest.TestCase):
[inference_program, feed_target_names, fetch_targets [inference_program, feed_target_names, fetch_targets
] = (paddle.static.load_inference_model(program_path, exe)) ] = (paddle.static.load_inference_model(program_path, exe))
test_reader = paddle.batch( test_reader = paddle.batch(paddle.dataset.mnist.test(),
paddle.dataset.mnist.test(), batch_size=batch_size) batch_size=batch_size)
top1_correct_num = 0. top1_correct_num = 0.
total_num = 0. total_num = 0.
...@@ -203,13 +204,13 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -203,13 +204,13 @@ class TestImperativePTQ(unittest.TestCase):
self.batch_size) self.batch_size)
input_spec = [ input_spec = [
paddle.static.InputSpec( paddle.static.InputSpec(shape=[None, 1, 28, 28], dtype='float32')
shape=[None, 1, 28, 28], dtype='float32')
] ]
with tempfile.TemporaryDirectory(prefix="imperative_ptq_") as tmpdir: with tempfile.TemporaryDirectory(prefix="imperative_ptq_") as tmpdir:
save_path = os.path.join(tmpdir, "model") save_path = os.path.join(tmpdir, "model")
self.ptq.save_quantized_model( self.ptq.save_quantized_model(model=quant_model,
model=quant_model, path=save_path, input_spec=input_spec) path=save_path,
input_spec=input_spec)
print('Quantized model saved in {%s}' % save_path) print('Quantized model saved in {%s}' % save_path)
after_acc_top1 = self.model_test(quant_model, self.batch_num, after_acc_top1 = self.model_test(quant_model, self.batch_num,
...@@ -225,13 +226,11 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -225,13 +226,11 @@ class TestImperativePTQ(unittest.TestCase):
print('After converted acc_top1: %s' % after_acc_top1) print('After converted acc_top1: %s' % after_acc_top1)
print('Infer acc_top1: %s' % infer_acc_top1) print('Infer acc_top1: %s' % infer_acc_top1)
self.assertTrue( self.assertTrue(after_acc_top1 >= self.eval_acc_top1,
after_acc_top1 >= self.eval_acc_top1, msg="The test acc {%f} is less than {%f}." %
msg="The test acc {%f} is less than {%f}." % (after_acc_top1, self.eval_acc_top1))
(after_acc_top1, self.eval_acc_top1)) self.assertTrue(infer_acc_top1 >= after_acc_top1,
self.assertTrue( msg='The acc is lower after converting model.')
infer_acc_top1 >= after_acc_top1,
msg='The acc is lower after converting model.')
end_time = time.time() end_time = time.time()
print("total time: %ss \n" % (end_time - start_time)) print("total time: %ss \n" % (end_time - start_time))
...@@ -243,6 +242,7 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -243,6 +242,7 @@ class TestImperativePTQ(unittest.TestCase):
class TestImperativePTQfuse(TestImperativePTQ): class TestImperativePTQfuse(TestImperativePTQ):
def func_ptq(self): def func_ptq(self):
start_time = time.time() start_time = time.time()
...@@ -261,19 +261,19 @@ class TestImperativePTQfuse(TestImperativePTQ): ...@@ -261,19 +261,19 @@ class TestImperativePTQfuse(TestImperativePTQ):
quant_model = self.ptq.quantize(model, fuse=True, fuse_list=f_l) quant_model = self.ptq.quantize(model, fuse=True, fuse_list=f_l)
for name, layer in quant_model.named_sublayers(): for name, layer in quant_model.named_sublayers():
if name in f_l: if name in f_l:
assert not (isinstance(layer, nn.BatchNorm1D) or assert not (isinstance(layer, nn.BatchNorm1D)
isinstance(layer, nn.BatchNorm2D)) or isinstance(layer, nn.BatchNorm2D))
before_acc_top1 = self.model_test(quant_model, self.batch_num, before_acc_top1 = self.model_test(quant_model, self.batch_num,
self.batch_size) self.batch_size)
input_spec = [ input_spec = [
paddle.static.InputSpec( paddle.static.InputSpec(shape=[None, 1, 28, 28], dtype='float32')
shape=[None, 1, 28, 28], dtype='float32')
] ]
with tempfile.TemporaryDirectory(prefix="imperative_ptq_") as tmpdir: with tempfile.TemporaryDirectory(prefix="imperative_ptq_") as tmpdir:
save_path = os.path.join(tmpdir, "model") save_path = os.path.join(tmpdir, "model")
self.ptq.save_quantized_model( self.ptq.save_quantized_model(model=quant_model,
model=quant_model, path=save_path, input_spec=input_spec) path=save_path,
input_spec=input_spec)
print('Quantized model saved in {%s}' % save_path) print('Quantized model saved in {%s}' % save_path)
after_acc_top1 = self.model_test(quant_model, self.batch_num, after_acc_top1 = self.model_test(quant_model, self.batch_num,
...@@ -291,15 +291,13 @@ class TestImperativePTQfuse(TestImperativePTQ): ...@@ -291,15 +291,13 @@ class TestImperativePTQfuse(TestImperativePTQ):
#Check whether the quant_model is correct after converting. #Check whether the quant_model is correct after converting.
#The acc of quantized model should be higher than 0.95. #The acc of quantized model should be higher than 0.95.
self.assertTrue( self.assertTrue(after_acc_top1 >= self.eval_acc_top1,
after_acc_top1 >= self.eval_acc_top1, msg="The test acc {%f} is less than {%f}." %
msg="The test acc {%f} is less than {%f}." % (after_acc_top1, self.eval_acc_top1))
(after_acc_top1, self.eval_acc_top1))
#Check the saved infer_model.The acc of infer model #Check the saved infer_model.The acc of infer model
#should not be lower than the one of dygraph model. #should not be lower than the one of dygraph model.
self.assertTrue( self.assertTrue(infer_acc_top1 >= after_acc_top1,
infer_acc_top1 >= after_acc_top1, msg='The acc is lower after converting model.')
msg='The acc is lower after converting model.')
end_time = time.time() end_time = time.time()
print("total time: %ss \n" % (end_time - start_time)) print("total time: %ss \n" % (end_time - start_time))
...@@ -311,6 +309,7 @@ class TestImperativePTQfuse(TestImperativePTQ): ...@@ -311,6 +309,7 @@ class TestImperativePTQfuse(TestImperativePTQ):
class TestImperativePTQHist(TestImperativePTQ): class TestImperativePTQHist(TestImperativePTQ):
def set_vars(self): def set_vars(self):
config = PTQConfig(HistQuantizer(), AbsmaxQuantizer()) config = PTQConfig(HistQuantizer(), AbsmaxQuantizer())
self.ptq = ImperativePTQ(config) self.ptq = ImperativePTQ(config)
...@@ -332,13 +331,14 @@ class TestImperativePTQHist(TestImperativePTQ): ...@@ -332,13 +331,14 @@ class TestImperativePTQHist(TestImperativePTQ):
class TestImperativePTQKL(TestImperativePTQ): class TestImperativePTQKL(TestImperativePTQ):
def set_vars(self): def set_vars(self):
config = PTQConfig(KLQuantizer(), PerChannelAbsmaxQuantizer()) config = PTQConfig(KLQuantizer(), PerChannelAbsmaxQuantizer())
self.ptq = ImperativePTQ(config) self.ptq = ImperativePTQ(config)
self.batch_num = 10 self.batch_num = 10
self.batch_size = 10 self.batch_size = 10
self.eval_acc_top1 = 1.0 self.eval_acc_top1 = 0.98
conv2d_1_wt_thresholds = [ conv2d_1_wt_thresholds = [
0.18116560578346252, 0.17079241573810577, 0.1702047884464264, 0.18116560578346252, 0.17079241573810577, 0.1702047884464264,
......
...@@ -34,6 +34,7 @@ np.random.seed(0) ...@@ -34,6 +34,7 @@ np.random.seed(0)
class TestPostTrainingQuantization(unittest.TestCase): class TestPostTrainingQuantization(unittest.TestCase):
def setUp(self): def setUp(self):
self.download_path = 'int8/download' self.download_path = 'int8/download'
self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
...@@ -44,8 +45,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -44,8 +45,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
try: try:
os.system("mkdir -p " + self.int8_model_path) os.system("mkdir -p " + self.int8_model_path)
except Exception as e: except Exception as e:
print("Failed to create {} due to {}".format(self.int8_model_path, print("Failed to create {} due to {}".format(
str(e))) self.int8_model_path, str(e)))
sys.exit(-1) sys.exit(-1)
def tearDown(self): def tearDown(self):
...@@ -53,8 +54,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -53,8 +54,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
def cache_unzipping(self, target_folder, zip_path): def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder): if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(
zip_path) target_folder, zip_path)
os.system(cmd) os.system(cmd)
def download_model(self, data_url, data_md5, folder_name): def download_model(self, data_url, data_md5, folder_name):
...@@ -68,6 +69,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -68,6 +69,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
return data_cache_folder return data_cache_folder
def get_batch_reader(self, data_path, place): def get_batch_reader(self, data_path, place):
def reader(): def reader():
with open(data_path, 'rb') as in_file: with open(data_path, 'rb') as in_file:
while True: while True:
...@@ -80,15 +82,14 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -80,15 +82,14 @@ class TestPostTrainingQuantization(unittest.TestCase):
seq_len = (alllen >> 16) & 0xFFFF seq_len = (alllen >> 16) & 0xFFFF
label = in_file.read(4 * label_len) label = in_file.read(4 * label_len)
label = np.frombuffer( label = np.frombuffer(label, dtype=np.int32).reshape(
label, dtype=np.int32).reshape([len(label) // 4]) [len(label) // 4])
if label.shape[0] != 1 or label[0] > 6350: if label.shape[0] != 1 or label[0] > 6350:
continue continue
feat = in_file.read(4 * seq_len * 8) feat = in_file.read(4 * seq_len * 8)
feat = np.frombuffer( feat = np.frombuffer(feat, dtype=np.float32).reshape(
feat, [len(feat) // 4 // 8, 8])
dtype=np.float32).reshape([len(feat) // 4 // 8, 8])
lod_feat = [feat.shape[0]] lod_feat = [feat.shape[0]]
minputs = fluid.create_lod_tensor(feat, [lod_feat], place) minputs = fluid.create_lod_tensor(feat, [lod_feat], place)
...@@ -97,6 +98,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -97,6 +98,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
return reader return reader
def get_simple_reader(self, data_path, place): def get_simple_reader(self, data_path, place):
def reader(): def reader():
with open(data_path, 'rb') as in_file: with open(data_path, 'rb') as in_file:
while True: while True:
...@@ -109,15 +111,14 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -109,15 +111,14 @@ class TestPostTrainingQuantization(unittest.TestCase):
seq_len = (alllen >> 16) & 0xFFFF seq_len = (alllen >> 16) & 0xFFFF
label = in_file.read(4 * label_len) label = in_file.read(4 * label_len)
label = np.frombuffer( label = np.frombuffer(label, dtype=np.int32).reshape(
label, dtype=np.int32).reshape([len(label) // 4]) [len(label) // 4])
if label.shape[0] != 1 or label[0] > 6350: if label.shape[0] != 1 or label[0] > 6350:
continue continue
feat = in_file.read(4 * seq_len * 8) feat = in_file.read(4 * seq_len * 8)
feat = np.frombuffer( feat = np.frombuffer(feat, dtype=np.float32).reshape(
feat, [len(feat) // 4 // 8, 8])
dtype=np.float32).reshape([len(feat) // 4 // 8, 8])
lod_feat = [feat.shape[0]] lod_feat = [feat.shape[0]]
minputs = fluid.create_lod_tensor(feat, [lod_feat], place) minputs = fluid.create_lod_tensor(feat, [lod_feat], place)
...@@ -178,18 +179,17 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -178,18 +179,17 @@ class TestPostTrainingQuantization(unittest.TestCase):
scope = fluid.global_scope() scope = fluid.global_scope()
batch_generator = self.get_batch_reader(data_path, place) batch_generator = self.get_batch_reader(data_path, place)
ptq = PostTrainingQuantization( ptq = PostTrainingQuantization(executor=exe,
executor=exe, model_dir=model_path,
model_dir=model_path, batch_generator=batch_generator,
batch_generator=batch_generator, batch_nums=batch_nums,
batch_nums=batch_nums, algo=algo,
algo=algo, quantizable_op_type=quantizable_op_type,
quantizable_op_type=quantizable_op_type, round_type=round_type,
round_type=round_type, is_full_quantize=is_full_quantize,
is_full_quantize=is_full_quantize, optimize_model=is_optimize_model,
optimize_model=is_optimize_model, onnx_format=onnx_format,
onnx_format=onnx_format, is_use_cache_file=is_use_cache_file)
is_use_cache_file=is_use_cache_file)
ptq.quantize() ptq.quantize()
ptq.save_quantized_model(self.int8_model_path) ptq.save_quantized_model(self.int8_model_path)
...@@ -223,10 +223,11 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -223,10 +223,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start post training quantization for {0} on {1} samples ...". print("Start post training quantization for {0} on {1} samples ...".
format(model_name, quant_iterations)) format(model_name, quant_iterations))
self.generate_quantized_model( self.generate_quantized_model(fp32_model_path, data_path, algo,
fp32_model_path, data_path, algo, round_type, quantizable_op_type, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model, is_full_quantize, is_use_cache_file,
quant_iterations, onnx_format) is_optimize_model, quant_iterations,
onnx_format)
print("Start INT8 inference for {0} on {1} samples ...".format( print("Start INT8 inference for {0} on {1} samples ...".format(
model_name, infer_iterations)) model_name, infer_iterations))
...@@ -245,6 +246,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -245,6 +246,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization): class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
def test_post_training_avg(self): def test_post_training_avg(self):
model_name = "nlp_lstm_fp32_model" model_name = "nlp_lstm_fp32_model"
model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz" model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz"
...@@ -268,6 +270,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization): ...@@ -268,6 +270,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization): class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
def test_post_training_avg_onnx_format(self): def test_post_training_avg_onnx_format(self):
model_name = "nlp_lstm_fp32_model" model_name = "nlp_lstm_fp32_model"
model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz" model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz"
...@@ -285,23 +288,22 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization): ...@@ -285,23 +288,22 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
infer_iterations = 100 infer_iterations = 100
quant_iterations = 10 quant_iterations = 10
onnx_format = True onnx_format = True
self.run_test( self.run_test(model_name,
model_name, model_url,
model_url, model_md5,
model_md5, data_name,
data_name, data_url,
data_url, data_md5,
data_md5, algo,
algo, round_type,
round_type, quantizable_op_type,
quantizable_op_type, is_full_quantize,
is_full_quantize, is_use_cache_file,
is_use_cache_file, is_optimize_model,
is_optimize_model, diff_threshold,
diff_threshold, infer_iterations,
infer_iterations, quant_iterations,
quant_iterations, onnx_format=onnx_format)
onnx_format=onnx_format)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -33,6 +33,7 @@ np.random.seed(0) ...@@ -33,6 +33,7 @@ np.random.seed(0)
class TestPostTrainingQuantization(unittest.TestCase): class TestPostTrainingQuantization(unittest.TestCase):
def setUp(self): def setUp(self):
self.root_path = tempfile.TemporaryDirectory() self.root_path = tempfile.TemporaryDirectory()
self.int8_model_path = os.path.join(self.root_path.name, self.int8_model_path = os.path.join(self.root_path.name,
...@@ -43,8 +44,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -43,8 +44,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
try: try:
os.system("mkdir -p " + self.int8_model_path) os.system("mkdir -p " + self.int8_model_path)
except Exception as e: except Exception as e:
print("Failed to create {} due to {}".format(self.int8_model_path, print("Failed to create {} due to {}".format(
str(e))) self.int8_model_path, str(e)))
sys.exit(-1) sys.exit(-1)
def tearDown(self): def tearDown(self):
...@@ -52,8 +53,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -52,8 +53,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
def cache_unzipping(self, target_folder, zip_path): def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder): if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(
zip_path) target_folder, zip_path)
os.system(cmd) os.system(cmd)
def download_model(self, data_url, data_md5, folder_name): def download_model(self, data_url, data_md5, folder_name):
...@@ -115,26 +116,27 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -115,26 +116,27 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_size=10, batch_size=10,
batch_nums=10, batch_nums=10,
onnx_format=False, onnx_format=False,
skip_tensor_list=None): skip_tensor_list=None,
bias_correction=False):
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
val_reader = paddle.dataset.mnist.train() val_reader = paddle.dataset.mnist.train()
ptq = PostTrainingQuantization( ptq = PostTrainingQuantization(executor=exe,
executor=exe, model_dir=model_path,
model_dir=model_path, sample_generator=val_reader,
sample_generator=val_reader, batch_size=batch_size,
batch_size=batch_size, batch_nums=batch_nums,
batch_nums=batch_nums, algo=algo,
algo=algo, quantizable_op_type=quantizable_op_type,
quantizable_op_type=quantizable_op_type, round_type=round_type,
round_type=round_type, is_full_quantize=is_full_quantize,
is_full_quantize=is_full_quantize, optimize_model=is_optimize_model,
optimize_model=is_optimize_model, bias_correction=bias_correction,
onnx_format=onnx_format, onnx_format=onnx_format,
skip_tensor_list=skip_tensor_list, skip_tensor_list=skip_tensor_list,
is_use_cache_file=is_use_cache_file) is_use_cache_file=is_use_cache_file)
ptq.quantize() ptq.quantize()
ptq.save_quantized_model(self.int8_model_path) ptq.save_quantized_model(self.int8_model_path)
...@@ -152,6 +154,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -152,6 +154,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_size=10, batch_size=10,
infer_iterations=10, infer_iterations=10,
quant_iterations=5, quant_iterations=5,
bias_correction=False,
onnx_format=False, onnx_format=False,
skip_tensor_list=None): skip_tensor_list=None):
...@@ -160,20 +163,23 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -160,20 +163,23 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start FP32 inference for {0} on {1} images ...".format( print("Start FP32 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size)) model_name, infer_iterations * batch_size))
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program( (fp32_throughput, fp32_latency,
origin_model_path, batch_size, infer_iterations) fp32_acc1) = self.run_program(origin_model_path, batch_size,
infer_iterations)
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model_name, quant_iterations * batch_size)) format(model_name, quant_iterations * batch_size))
self.generate_quantized_model( self.generate_quantized_model(origin_model_path, algo, round_type,
origin_model_path, algo, round_type, quantizable_op_type, quantizable_op_type, is_full_quantize,
is_full_quantize, is_use_cache_file, is_optimize_model, batch_size, is_use_cache_file, is_optimize_model,
quant_iterations, onnx_format, skip_tensor_list) batch_size, quant_iterations, onnx_format,
skip_tensor_list, bias_correction)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size)) model_name, infer_iterations * batch_size))
(int8_throughput, int8_latency, int8_acc1) = self.run_program( (int8_throughput, int8_latency,
self.int8_model_path, batch_size, infer_iterations) int8_acc1) = self.run_program(self.int8_model_path, batch_size,
infer_iterations)
print("---Post training quantization of {} method---".format(algo)) print("---Post training quantization of {} method---".format(algo))
print( print(
...@@ -191,6 +197,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -191,6 +197,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
class TestPostTrainingKLForMnist(TestPostTrainingQuantization): class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
def test_post_training_kl(self): def test_post_training_kl(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -212,6 +219,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization): ...@@ -212,6 +219,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
class TestPostTraininghistForMnist(TestPostTrainingQuantization): class TestPostTraininghistForMnist(TestPostTrainingQuantization):
def test_post_training_hist(self): def test_post_training_hist(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -233,6 +241,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization): ...@@ -233,6 +241,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
class TestPostTrainingmseForMnist(TestPostTrainingQuantization): class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self): def test_post_training_mse(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -254,6 +263,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization): ...@@ -254,6 +263,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
class TestPostTrainingemdForMnist(TestPostTrainingQuantization): class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self): def test_post_training_mse(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -275,6 +285,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization): ...@@ -275,6 +285,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
class TestPostTrainingavgForMnist(TestPostTrainingQuantization): class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
def test_post_training_avg(self): def test_post_training_avg(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -296,6 +307,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization): ...@@ -296,6 +307,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
def test_post_training_abs_max(self): def test_post_training_abs_max(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -317,6 +329,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): ...@@ -317,6 +329,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization): class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self): def test_post_training_mse(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -331,13 +344,25 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization): ...@@ -331,13 +344,25 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
batch_size = 10 batch_size = 10
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, round_type, bias_correction = True
quantizable_op_type, is_full_quantize, is_use_cache_file, self.run_test(model_name,
is_optimize_model, diff_threshold, batch_size, data_url,
infer_iterations, quant_iterations) data_md5,
algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
batch_size,
infer_iterations,
quant_iterations,
bias_correction=bias_correction)
class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization): class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
def test_post_training_kl(self): def test_post_training_kl(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -359,6 +384,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization): ...@@ -359,6 +384,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization): class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
def test_post_training_mse_onnx_format(self): def test_post_training_mse_onnx_format(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -374,25 +400,25 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization): ...@@ -374,25 +400,25 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
batch_size = 10 batch_size = 10
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
self.run_test( self.run_test(model_name,
model_name, data_url,
data_url, data_md5,
data_md5, algo,
algo, round_type,
round_type, quantizable_op_type,
quantizable_op_type, is_full_quantize,
is_full_quantize, is_use_cache_file,
is_use_cache_file, is_optimize_model,
is_optimize_model, diff_threshold,
diff_threshold, batch_size,
batch_size, infer_iterations,
infer_iterations, quant_iterations,
quant_iterations, onnx_format=onnx_format)
onnx_format=onnx_format)
class TestPostTrainingmseForMnistONNXFormatFullQuant( class TestPostTrainingmseForMnistONNXFormatFullQuant(
TestPostTrainingQuantization): TestPostTrainingQuantization):
def test_post_training_mse_onnx_format_full_quant(self): def test_post_training_mse_onnx_format_full_quant(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -408,24 +434,24 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant( ...@@ -408,24 +434,24 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
batch_size = 10 batch_size = 10
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
self.run_test( self.run_test(model_name,
model_name, data_url,
data_url, data_md5,
data_md5, algo,
algo, round_type,
round_type, quantizable_op_type,
quantizable_op_type, is_full_quantize,
is_full_quantize, is_use_cache_file,
is_use_cache_file, is_optimize_model,
is_optimize_model, diff_threshold,
diff_threshold, batch_size,
batch_size, infer_iterations,
infer_iterations, quant_iterations,
quant_iterations, onnx_format=onnx_format)
onnx_format=onnx_format)
class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization): class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
def test_post_training_avg_skip_op(self): def test_post_training_avg_skip_op(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -441,21 +467,20 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization): ...@@ -441,21 +467,20 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
skip_tensor_list = ["fc_0.w_0"] skip_tensor_list = ["fc_0.w_0"]
self.run_test( self.run_test(model_name,
model_name, data_url,
data_url, data_md5,
data_md5, algo,
algo, round_type,
round_type, quantizable_op_type,
quantizable_op_type, is_full_quantize,
is_full_quantize, is_use_cache_file,
is_use_cache_file, is_optimize_model,
is_optimize_model, diff_threshold,
diff_threshold, batch_size,
batch_size, infer_iterations,
infer_iterations, quant_iterations,
quant_iterations, skip_tensor_list=skip_tensor_list)
skip_tensor_list=skip_tensor_list)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -21,6 +21,7 @@ paddle.enable_static() ...@@ -21,6 +21,7 @@ paddle.enable_static()
class TestPostTrainingForResnet50(TestPostTrainingQuantization): class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def test_post_training_resnet50(self): def test_post_training_resnet50(self):
model = "ResNet-50" model = "ResNet-50"
algo = "min_max" algo = "min_max"
...@@ -40,6 +41,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): ...@@ -40,6 +41,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization): class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
def test_post_training_resnet50(self): def test_post_training_resnet50(self):
model = "ResNet-50" model = "ResNet-50"
algo = "min_max" algo = "min_max"
...@@ -54,18 +56,17 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization): ...@@ -54,18 +56,17 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
is_optimize_model = False is_optimize_model = False
diff_threshold = 0.025 diff_threshold = 0.025
onnx_format = True onnx_format = True
self.run_test( self.run_test(model,
model, algo,
algo, round_type,
round_type, data_urls,
data_urls, data_md5s,
data_md5s, quantizable_op_type,
quantizable_op_type, is_full_quantize,
is_full_quantize, is_use_cache_file,
is_use_cache_file, is_optimize_model,
is_optimize_model, diff_threshold,
diff_threshold, onnx_format=onnx_format)
onnx_format=onnx_format)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册