未验证 提交 ff70a269 编写于 作者: G Guanghua Yu 提交者: GitHub

[cherry-pick]Update quantization round and clip calculation methods (#43829)

* update quantization clip and round

* fix quantization clip and round Attribute

* fix typo
上级 9e776f62
...@@ -45,6 +45,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() { ...@@ -45,6 +45,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
.End() .End()
.AddAttr("bit_length") .AddAttr("bit_length")
.IsIntIn({8, 16}) .IsIntIn({8, 16})
.End()
.AddAttr("round_type")
.IsOptional()
.IsIntIn({0, 1})
.End(); .End();
AddOpCompat(OpCompat("fake_channel_wise_quantize_dequantize_abs_max")) AddOpCompat(OpCompat("fake_channel_wise_quantize_dequantize_abs_max"))
.AddInput("X") .AddInput("X")
...@@ -61,6 +65,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() { ...@@ -61,6 +65,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
.End() .End()
.AddAttr("quant_axis") .AddAttr("quant_axis")
.IsIntIn({0, 1}) .IsIntIn({0, 1})
.End()
.AddAttr("round_type")
.IsOptional()
.IsIntIn({0, 1})
.End(); .End();
} }
// Delete quant_dequant_op, then quantize and dequantize weight // Delete quant_dequant_op, then quantize and dequantize weight
...@@ -96,14 +104,17 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -96,14 +104,17 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
auto var_map = any_op2_desc->Inputs(); auto var_map = any_op2_desc->Inputs();
std::string arg_name = ""; std::string arg_name = "";
for (auto& name_m : var_map) { for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(), if (std::find(name_m.second.begin(),
name_m.second.end(),
quant_dequant_op_out_name) != name_m.second.end()) { quant_dequant_op_out_name) != name_m.second.end()) {
arg_name = name_m.first; arg_name = name_m.first;
break; break;
} }
} }
PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(
"can not find the input %s.", arg_name.size(),
0,
platform::errors::InvalidArgument("can not find the input %s.",
quant_dequant_op_out_name)); quant_dequant_op_out_name));
// any_op2_desc->SetAttr("enable_int8", true); // any_op2_desc->SetAttr("enable_int8", true);
any_op2_desc->SetAttr("bit_length", bit_length); any_op2_desc->SetAttr("bit_length", bit_length);
...@@ -123,7 +134,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -123,7 +134,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
int quant_axis = int quant_axis =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis")); BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but " "'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
...@@ -176,7 +188,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -176,7 +188,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
} }
} }
for (int i = 0; i < channel; i++) { for (int i = 0; i < channel; i++) {
PADDLE_ENFORCE_NE(weight_scale[i], 0, PADDLE_ENFORCE_NE(weight_scale[i],
0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero.")); "Weight scale should be nonzero, but get zero."));
weight_scale[i] = weight_scale[i] / range; weight_scale[i] = weight_scale[i] / range;
...@@ -188,7 +201,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -188,7 +201,8 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
abs_max_weight = abs_max_weight =
std::max(abs_max_weight, std::abs(quantized_weight_data[j])); std::max(abs_max_weight, std::abs(quantized_weight_data[j]));
} }
PADDLE_ENFORCE_NE(abs_max_weight, 0, PADDLE_ENFORCE_NE(abs_max_weight,
0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero")); "Weight scale should be nonzero, but get zero"));
weight_scale.push_back(abs_max_weight / range); weight_scale.push_back(abs_max_weight / range);
......
...@@ -54,6 +54,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { ...@@ -54,6 +54,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
.End() .End()
.AddAttr("quant_axis") .AddAttr("quant_axis")
.IsType<int>() .IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End(); .End();
AddOpCompat(OpCompat("dequantize_linear")) AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X") .AddInput("X")
...@@ -74,6 +78,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { ...@@ -74,6 +78,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
.End() .End()
.AddAttr("quant_axis") .AddAttr("quant_axis")
.IsType<int>() .IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End(); .End();
} }
// Delete quantize_linear_op dequantize_linear_op, then add input_scales // Delete quantize_linear_op dequantize_linear_op, then add input_scales
...@@ -112,7 +120,8 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -112,7 +120,8 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
const LoDTensor& input_scale_tensor = const LoDTensor& input_scale_tensor =
scope->GetVar(quantize_linear_op_scale->Name())->Get<LoDTensor>(); scope->GetVar(quantize_linear_op_scale->Name())->Get<LoDTensor>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(input_scale_tensor.place()), true, paddle::platform::is_cpu_place(input_scale_tensor.place()),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU.")); "Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>(); const float* input_scale_data = input_scale_tensor.data<float>();
......
...@@ -52,6 +52,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() { ...@@ -52,6 +52,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
.End() .End()
.AddAttr("quant_axis") .AddAttr("quant_axis")
.IsType<int>() .IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End(); .End();
AddOpCompat(OpCompat("dequantize_linear")) AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X") .AddInput("X")
...@@ -72,6 +76,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() { ...@@ -72,6 +76,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
.End() .End()
.AddAttr("quant_axis") .AddAttr("quant_axis")
.IsType<int>() .IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End(); .End();
AddOpCompat(OpCompat("conv2d")) AddOpCompat(OpCompat("conv2d"))
.AddInput("Input") .AddInput("Input")
...@@ -322,7 +330,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -322,7 +330,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
int quant_axis = BOOST_GET_CONST( int quant_axis = BOOST_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis")); int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
if (quant_axis == -1) { // per_layer quant_dequant: all OP if (quant_axis == -1) { // per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ(weight_scale_nums, 1, PADDLE_ENFORCE_EQ(weight_scale_nums,
1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When quant_axis == -1 means use per_layer " "When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1.")); "quant_dequant, weight_scale'number should be 1."));
...@@ -335,11 +344,13 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -335,11 +344,13 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
} else if (quant_axis == 0) { // per_channel quant_dequant: conv2d, } else if (quant_axis == 0) { // per_channel quant_dequant: conv2d,
// depthwise_conv2d, conv2d_fusion // depthwise_conv2d, conv2d_fusion
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_scale_nums, w_dims[quant_axis], weight_scale_nums,
w_dims[quant_axis],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel quant_dequant, " "When quant_axis == 0 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels.")); "weight_scale'numbers should be equal channels."));
PADDLE_ENFORCE_EQ(w_dims.size(), 4, PADDLE_ENFORCE_EQ(w_dims.size(),
4,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel " "When quant_axis == 0 means use per_channel "
"quant_dequant, (conv2d, depthwise_conv2d, " "quant_dequant, (conv2d, depthwise_conv2d, "
...@@ -352,7 +363,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -352,7 +363,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
} }
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_scale_nums, w_dims[quant_axis], weight_scale_nums,
w_dims[quant_axis],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, " "When quant_axis == 1 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels.")); "weight_scale'numbers should be equal channels."));
...@@ -360,7 +372,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -360,7 +372,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
if (w_dims.size() == 4) { // conv2d_transpose if (w_dims.size() == 4) { // conv2d_transpose
std::string quantized_op_type = any_op2->Op()->Type(); std::string quantized_op_type = any_op2->Op()->Type();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quantized_op_type, "conv2d_transpose", quantized_op_type,
"conv2d_transpose",
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, " "When quant_axis == 1 means use per_channel quant_dequant, "
"only conv2d_transpose weight dims equal 4.")); "only conv2d_transpose weight dims equal 4."));
...@@ -388,7 +401,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -388,7 +401,8 @@ void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims))); weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims)));
float* new_quantized_weight_data = float* new_quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace()); weight_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_quantized_weight_data, weight_data_tmp.data(), memcpy(new_quantized_weight_data,
weight_data_tmp.data(),
weight_tensor->numel() * sizeof(float)); weight_tensor->numel() * sizeof(float));
nodes2rm.insert(weight_dequantize_linear_op_scale); nodes2rm.insert(weight_dequantize_linear_op_scale);
......
...@@ -49,6 +49,10 @@ QuantDequantFusePass::QuantDequantFusePass() { ...@@ -49,6 +49,10 @@ QuantDequantFusePass::QuantDequantFusePass() {
.End() .End()
.AddAttr("bit_length") .AddAttr("bit_length")
.IsIntIn({8, 16}) .IsIntIn({8, 16})
.End()
.AddAttr("round_type")
.IsOptional()
.IsIntIn({0, 1})
.End(); .End();
AddOpCompat(OpCompat("fake_quantize_moving_average_abs_max")) AddOpCompat(OpCompat("fake_quantize_moving_average_abs_max"))
.AddInput("X") .AddInput("X")
...@@ -85,6 +89,10 @@ QuantDequantFusePass::QuantDequantFusePass() { ...@@ -85,6 +89,10 @@ QuantDequantFusePass::QuantDequantFusePass() {
.End() .End()
.AddAttr("bit_length") .AddAttr("bit_length")
.IsIntIn({8, 16}) .IsIntIn({8, 16})
.End()
.AddAttr("round_type")
.IsOptional()
.IsIntIn({0, 1})
.End(); .End();
AddOpCompat(OpCompat("fake_dequantize_max_abs")) AddOpCompat(OpCompat("fake_dequantize_max_abs"))
.AddInput("X") .AddInput("X")
...@@ -309,7 +317,8 @@ QuantDequantFusePass::QuantDequantFusePass() { ...@@ -309,7 +317,8 @@ QuantDequantFusePass::QuantDequantFusePass() {
} }
// Delete quant op before quantized ops, and set input scale in the attr of // Delete quant op before quantized ops, and set input scale in the attr of
// quantized ops // quantized ops
void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, void QuantDequantFusePass::DeleteQuant(ir::Graph* graph,
Scope* scope,
const std::string& quant_type) const { const std::string& quant_type) const {
const std::string pattern_name = "delete_quant_fuse"; const std::string pattern_name = "delete_quant_fuse";
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -331,7 +340,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -331,7 +340,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
return; return;
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
subgraph.count(input_act_node), true, subgraph.count(input_act_node),
true,
platform::errors::NotFound( platform::errors::NotFound(
"Input act node(%s) not found in QuantDequantFuse pass.", "Input act node(%s) not found in QuantDequantFuse pass.",
input_act_node->name())); input_act_node->name()));
...@@ -345,12 +355,14 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -345,12 +355,14 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
// Get input scale from tensor // Get input scale from tensor
std::string input_scale_var_name = quant->Op()->Input("InScale").front(); std::string input_scale_var_name = quant->Op()->Input("InScale").front();
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument( scope,
platform::errors::InvalidArgument(
"Scope in QuantDequantFuse pass should not be null.")); "Scope in QuantDequantFuse pass should not be null."));
const LoDTensor& input_scale_tensor = const LoDTensor& input_scale_tensor =
scope->FindVar(input_scale_var_name)->Get<LoDTensor>(); scope->FindVar(input_scale_var_name)->Get<LoDTensor>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(input_scale_tensor.place()), true, paddle::platform::is_cpu_place(input_scale_tensor.place()),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU.")); "Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>(); const float* input_scale_data = input_scale_tensor.data<float>();
...@@ -382,8 +394,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -382,8 +394,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
IR_NODE_LINK_TO(input_act, quantized_node); IR_NODE_LINK_TO(input_act, quantized_node);
} }
// Delete nodes and edges // Delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale, quant, std::unordered_set<const Node*> nodes2rm = {
output_scale, output_act}; input_scale, quant, output_scale, output_act};
GraphSafeRemoveNodes(graph, nodes2rm); GraphSafeRemoveNodes(graph, nodes2rm);
}; };
gpd(graph, handler); gpd(graph, handler);
...@@ -391,7 +403,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -391,7 +403,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
// Delete dequant op after quantized ops, and convert weight from fp32 range to // Delete dequant op after quantized ops, and convert weight from fp32 range to
// int8 range // int8 range
void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, void QuantDequantFusePass::FuseDequant(ir::Graph* graph,
Scope* scope,
const std::string& quantized_op_type, const std::string& quantized_op_type,
const std::string& dequant_type) const { const std::string& dequant_type) const {
std::string weight_name = ""; std::string weight_name = "";
...@@ -436,7 +449,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -436,7 +449,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
return; return;
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
subgraph.count(quantized_op_input), true, subgraph.count(quantized_op_input),
true,
platform::errors::NotFound("Quantized op input node(%s) did not find " platform::errors::NotFound("Quantized op input node(%s) did not find "
"in QuantDequantFuse pass.", "in QuantDequantFuse pass.",
quantized_op_input->name())); quantized_op_input->name()));
...@@ -464,14 +478,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -464,14 +478,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
subgraph.at(pattern.GetPDNode("dequant_channel_scale")); subgraph.at(pattern.GetPDNode("dequant_channel_scale"));
auto scales_name = dequant_op_node->Op()->Input("Scales"); auto scales_name = dequant_op_node->Op()->Input("Scales");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scales_name.size(), 2, scales_name.size(),
2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Scales size in channel-wise dequantize op should be 2, got %d.", "Scales size in channel-wise dequantize op should be 2, got %d.",
scales_name.size())); scales_name.size()));
const LoDTensor& channel_scale_tensor = const LoDTensor& channel_scale_tensor =
scope->FindVar(scales_name[0])->Get<LoDTensor>(); scope->FindVar(scales_name[0])->Get<LoDTensor>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(channel_scale_tensor.place()), true, paddle::platform::is_cpu_place(channel_scale_tensor.place()),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Channel scale tensor's place should be CPU.")); "Channel scale tensor's place should be CPU."));
const float* channel_scale_data = channel_scale_tensor.data<float>(); const float* channel_scale_data = channel_scale_tensor.data<float>();
...@@ -497,7 +513,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -497,7 +513,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
if (quantized_op_type == "mul" || quantized_op_type == "matmul" || if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2" || quantized_op_type == "fc") { quantized_op_type == "matmul_v2" || quantized_op_type == "fc") {
if (dequant_type == "fake_dequantize_max_abs") { if (dequant_type == "fake_dequantize_max_abs") {
PADDLE_ENFORCE_EQ(weight_scale.size(), 1, PADDLE_ENFORCE_EQ(weight_scale.size(),
1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"mul/matmul/matmul_v2 op weight dequantized by " "mul/matmul/matmul_v2 op weight dequantized by "
"[fake_dequantize_max_abs] " "[fake_dequantize_max_abs] "
...@@ -511,7 +528,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -511,7 +528,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
if (quant_axis == 0) { if (quant_axis == 0) {
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 1, true, quant_axis == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' of mul/matmul/fc/matmul_v2 op weight " "'quant_axis' of mul/matmul/fc/matmul_v2 op weight "
"dequantized by " "dequantized by "
...@@ -520,14 +538,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -520,14 +538,16 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
quant_axis)); quant_axis));
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]), weight_scale.size(),
static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"mul/matmul/matmul_v2 op weight dequantized by " "mul/matmul/matmul_v2 op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs] requires weight scale " "[fake_channel_wise_dequantize_max_abs] requires weight scale "
"size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, " "size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, "
"but got " "but got "
"%d.", "%d.",
static_cast<size_t>(w_dims[1]), weight_scale.size())); static_cast<size_t>(w_dims[1]),
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
quantized_weight_data[j] *= weight_scale[j % w_dims[1]]; quantized_weight_data[j] *= weight_scale[j % w_dims[1]];
} }
...@@ -535,7 +555,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -535,7 +555,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
} else if (quantized_op_type == "conv2d" || } else if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d") { quantized_op_type == "depthwise_conv2d") {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dequant_type, "fake_channel_wise_dequantize_max_abs", dequant_type,
"fake_channel_wise_dequantize_max_abs",
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"conv2d op must be dequantized by " "conv2d op must be dequantized by "
"[fake_channel_wise_dequantize_max_abs], but got %s. " "[fake_channel_wise_dequantize_max_abs], but got %s. "
...@@ -546,7 +567,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -546,7 +567,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
if (quant_axis == 0) { if (quant_axis == 0) {
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0, true, quant_axis == 0,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' of conv2d/depthwise_conv2d op weight dequantized " "'quant_axis' of conv2d/depthwise_conv2d op weight dequantized "
"by [fake_channel_wise_dequantize_max_abs]should be 0, but " "by [fake_channel_wise_dequantize_max_abs]should be 0, but "
...@@ -554,18 +576,21 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -554,18 +576,21 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
quant_axis)); quant_axis));
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]), weight_scale.size(),
static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"conv2d op requires weight scale size = channel size of the " "conv2d op requires weight scale size = channel size of the "
"weight, which is %d, but got %d.", "weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[0]), weight_scale.size())); static_cast<size_t>(w_dims[0]),
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[j / inner_size]; quantized_weight_data[j] *= weight_scale[j / inner_size];
} }
} else if (quantized_op_type == "conv2d_transpose") { } else if (quantized_op_type == "conv2d_transpose") {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dequant_type, "fake_channel_wise_dequantize_max_abs", dequant_type,
"fake_channel_wise_dequantize_max_abs",
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"conv2d_transpose must be dequantized by " "conv2d_transpose must be dequantized by "
"[fake_channel_wise_dequantize_max_abs], but got %s", "[fake_channel_wise_dequantize_max_abs], but got %s",
...@@ -573,7 +598,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -573,7 +598,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
if (quant_axis == 0) { if (quant_axis == 0) {
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 1, true, quant_axis == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' of conv2d_transpose op weight dequantized by " "'quant_axis' of conv2d_transpose op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs]should be 1, but " "[fake_channel_wise_dequantize_max_abs]should be 1, but "
...@@ -581,11 +607,13 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -581,11 +607,13 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
quant_axis)); quant_axis));
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]), weight_scale.size(),
static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"conv2d_transpose op requires weight scale size = channel size " "conv2d_transpose op requires weight scale size = channel size "
"of the weight, which is %d, but got %d.", "of the weight, which is %d, but got %d.",
static_cast<size_t>(w_dims[1]), weight_scale.size())); static_cast<size_t>(w_dims[1]),
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
int inner_size = w_dims[2] * w_dims[3]; int inner_size = w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1]]; quantized_weight_data[j] *= weight_scale[(j / inner_size) % w_dims[1]];
...@@ -639,8 +667,13 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -639,8 +667,13 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> quant_types = { std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"}; "fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = { std::unordered_set<std::string> quantized_op_types = {
"conv2d", "mul", "matmul", "depthwise_conv2d", "conv2d",
"conv2d_transpose", "fc", "matmul_v2", "mul",
"matmul",
"depthwise_conv2d",
"conv2d_transpose",
"fc",
"matmul_v2",
}; };
auto* scope = param_scope(); auto* scope = param_scope();
......
...@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/operators/fake_quantize_op.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
...@@ -31,8 +33,10 @@ struct Compare { ...@@ -31,8 +33,10 @@ struct Compare {
template <typename T> template <typename T>
struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* in, void operator()(const platform::CPUDeviceContext &ctx,
const int num, T* out) { const T *in,
const int num,
T *out) {
*out = std::abs(*(std::max_element(in + 0, in + num, Compare<T>()))); *out = std::abs(*(std::max_element(in + 0, in + num, Compare<T>())));
} }
}; };
...@@ -41,24 +45,26 @@ template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>; ...@@ -41,24 +45,26 @@ template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T> template <typename T>
struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& in_tensor, const int quant_axis, const framework::Tensor &in_tensor,
T* out_abs_max) { const int quant_axis,
T *out_abs_max) {
// At present, channelwise quantization supports conv2d, depthwise_conv2d // At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul // conv2d_transpose and mul
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true, quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
auto* in_data = in_tensor.data<T>(); auto *in_data = in_tensor.data<T>();
auto in_dims = in_tensor.dims(); auto in_dims = in_tensor.dims();
const int64_t channel = in_dims[quant_axis]; const int64_t channel = in_dims[quant_axis];
if (quant_axis == 0) { if (quant_axis == 0) {
const int64_t channel_size = in_tensor.numel() / channel; const int64_t channel_size = in_tensor.numel() / channel;
for (int64_t i = 0; i < channel; i++) { for (int64_t i = 0; i < channel; i++) {
auto* start = in_data + i * channel_size; auto *start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size; auto *end = in_data + (i + 1) * channel_size;
out_abs_max[i] = out_abs_max[i] =
std::abs(*(std::max_element(start, end, Compare<T>()))); std::abs(*(std::max_element(start, end, Compare<T>())));
} }
...@@ -70,8 +76,8 @@ struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> { ...@@ -70,8 +76,8 @@ struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
const int64_t step_j = in_tensor.numel() / (in_dims[0] * in_dims[1]); const int64_t step_j = in_tensor.numel() / (in_dims[0] * in_dims[1]);
for (int64_t i = 0; i < in_dims[0]; i++) { for (int64_t i = 0; i < in_dims[0]; i++) {
for (int64_t j = 0; j < in_dims[1]; j++) { for (int64_t j = 0; j < in_dims[1]; j++) {
auto* start = in_data + i * step_i + j * step_j; auto *start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j; auto *end = in_data + i * step_i + (j + 1) * step_j;
T abs_max = std::abs(*(std::max_element(start, end, Compare<T>()))); T abs_max = std::abs(*(std::max_element(start, end, Compare<T>())));
out_abs_max[j] = std::max(out_abs_max[j], abs_max); out_abs_max[j] = std::max(out_abs_max[j], abs_max);
} }
...@@ -84,56 +90,90 @@ template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>; ...@@ -84,56 +90,90 @@ template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T> template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, framework::Tensor* out) { const framework::Tensor &scale,
const int bin_cnt,
const int round_type,
framework::Tensor *out) {
T s = scale.data<T>()[0]; T s = scale.data<T>()[0];
T inv_s = inverse(s); T inv_s = inverse(s);
platform::Transform<platform::CPUDeviceContext> trans; platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(), if (round_type == 0) {
out->mutable_data<T>(ctx.GetPlace()), phi::ClipFunctor<T>(-s, s)); trans(ctx,
in.data<T>(),
in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()),
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), inv_s));
} else {
trans(ctx,
in.data<T>(),
in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()),
phi::ClipFunctor<T>(-s, s));
auto out_e = framework::EigenVector<T>::Flatten(*out); auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
} }
}
}; };
template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>; template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
template <typename T> template <typename T>
struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> { struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, framework::Tensor* out) { const framework::Tensor &scale,
const int bin_cnt,
const int round_type,
framework::Tensor *out) {
T s = scale.data<T>()[0]; T s = scale.data<T>()[0];
T inv_s = inverse(s); T inv_s = inverse(s);
platform::Transform<platform::CPUDeviceContext> trans; platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(), if (round_type == 0) {
out->mutable_data<T>(ctx.GetPlace()), phi::ClipFunctor<T>(-s, s)); trans(ctx,
in.data<T>(),
in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()),
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), inv_s));
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) = out_e * s / static_cast<T>(bin_cnt);
} else {
trans(ctx,
in.data<T>(),
in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()),
phi::ClipFunctor<T>(-s, s));
auto out_e = framework::EigenVector<T>::Flatten(*out); auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) = out_e.device(*ctx.eigen_device()) =
(bin_cnt * inv_s * out_e).round() * s / static_cast<T>(bin_cnt); (bin_cnt * inv_s * out_e).round() * s / static_cast<T>(bin_cnt);
} }
}
}; };
template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext,
float>; float>;
template <typename T> template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, const int quant_axis, const framework::Tensor &scale,
framework::Tensor* out) { const int bin_cnt,
const int round_type,
const int quant_axis,
framework::Tensor *out) {
// At present, channelwise quantization supports conv2d, depthwise_conv2d // At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul // conv2d_transpose and mul
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true, quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
auto* scale_data = scale.data<T>(); auto *scale_data = scale.data<T>();
auto* in_data = in.data<T>(); auto *in_data = in.data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace()); auto *out_data = out->mutable_data<T>(ctx.GetPlace());
auto in_dims = in.dims(); auto in_dims = in.dims();
const int64_t channel = in_dims[quant_axis]; const int64_t channel = in_dims[quant_axis];
platform::Transform<platform::CPUDeviceContext> trans; platform::Transform<platform::CPUDeviceContext> trans;
...@@ -141,11 +181,24 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { ...@@ -141,11 +181,24 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
const int64_t channel_size = in.numel() / channel; const int64_t channel_size = in.numel() / channel;
for (int64_t i = 0; i < channel; i++) { for (int64_t i = 0; i < channel; i++) {
T s = scale_data[i]; T s = scale_data[i];
auto* start = in_data + i * channel_size; auto *start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size; auto *end = in_data + (i + 1) * channel_size;
trans(ctx, start, end, out_data + i * channel_size, T inv_s = inverse(s);
if (round_type == 0) {
trans(ctx,
start,
end,
out_data + i * channel_size,
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), inv_s));
} else {
trans(ctx,
start,
end,
out_data + i * channel_size,
phi::ClipFunctor<T>(-s, s)); phi::ClipFunctor<T>(-s, s));
} }
}
if (round_type == 1) {
for (int64_t i = 0; i < channel; i++) { for (int64_t i = 0; i < channel; i++) {
T s = scale_data[i]; T s = scale_data[i];
T inv_s = inverse(s); T inv_s = inverse(s);
...@@ -153,6 +206,7 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { ...@@ -153,6 +206,7 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out); auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
} }
}
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
const int64_t step_i = in.numel() / in_dims[0]; const int64_t step_i = in.numel() / in_dims[0];
const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]); const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]);
...@@ -160,9 +214,16 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { ...@@ -160,9 +214,16 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
for (int j = 0; j < in_dims[1]; j++) { for (int j = 0; j < in_dims[1]; j++) {
T s = scale_data[j]; T s = scale_data[j];
T inv_s = inverse(s); T inv_s = inverse(s);
auto* start = in_data + i * step_i + j * step_j; auto *start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j; auto *end = in_data + i * step_i + (j + 1) * step_j;
auto* cur_out_data = out_data + i * step_i + j * step_j; auto *cur_out_data = out_data + i * step_i + j * step_j;
if (round_type == 0) {
trans(ctx,
start,
end,
cur_out_data,
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), inv_s));
} else {
trans(ctx, start, end, cur_out_data, phi::ClipFunctor<T>(-s, s)); trans(ctx, start, end, cur_out_data, phi::ClipFunctor<T>(-s, s));
for (int k = 0; k < step_j; k++) { for (int k = 0; k < step_j; k++) {
cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]); cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]);
...@@ -171,25 +232,30 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { ...@@ -171,25 +232,30 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
} }
} }
} }
}
}; };
template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
float>; float>;
template <typename T> template <typename T>
struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> { struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, const int quant_axis, const framework::Tensor &scale,
framework::Tensor* out) { const int bin_cnt,
const int round_type,
const int quant_axis,
framework::Tensor *out) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true, quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
auto* scale_data = scale.data<T>(); auto *scale_data = scale.data<T>();
auto* in_data = in.data<T>(); auto *in_data = in.data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace()); auto *out_data = out->mutable_data<T>(ctx.GetPlace());
auto in_dims = in.dims(); auto in_dims = in.dims();
const int64_t channel = in_dims[quant_axis]; const int64_t channel = in_dims[quant_axis];
platform::Transform<platform::CPUDeviceContext> trans; platform::Transform<platform::CPUDeviceContext> trans;
...@@ -197,19 +263,36 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> { ...@@ -197,19 +263,36 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
const int64_t channel_size = in.numel() / channel; const int64_t channel_size = in.numel() / channel;
for (int i = 0; i < channel; i++) { for (int i = 0; i < channel; i++) {
T s = scale_data[i]; T s = scale_data[i];
auto* start = in_data + i * channel_size; auto *start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size; auto *end = in_data + (i + 1) * channel_size;
trans(ctx, start, end, out_data + i * channel_size, if (round_type == 0) {
T inv_s = inverse(s);
trans(ctx,
start,
end,
out_data + i * channel_size,
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), inv_s));
} else {
trans(ctx,
start,
end,
out_data + i * channel_size,
phi::ClipFunctor<T>(-s, s)); phi::ClipFunctor<T>(-s, s));
} }
}
for (int i = 0; i < channel; i++) { for (int i = 0; i < channel; i++) {
T s = scale_data[i]; T s = scale_data[i];
T inv_s = inverse(s);
framework::Tensor one_channel_out = out->Slice(i, i + 1); framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out); auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
if (round_type == 0) {
out_e.device(*ctx.eigen_device()) =
out_e * s / static_cast<T>(bin_cnt);
} else {
T inv_s = inverse(s);
out_e.device(*ctx.eigen_device()) = out_e.device(*ctx.eigen_device()) =
(bin_cnt * inv_s * out_e).round() * s / static_cast<T>(bin_cnt); (bin_cnt * inv_s * out_e).round() * s / static_cast<T>(bin_cnt);
} }
}
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
const int64_t step_i = in.numel() / in_dims[0]; const int64_t step_i = in.numel() / in_dims[0];
const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]); const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]);
...@@ -217,11 +300,22 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> { ...@@ -217,11 +300,22 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
for (int j = 0; j < in_dims[1]; j++) { for (int j = 0; j < in_dims[1]; j++) {
T s = scale_data[j]; T s = scale_data[j];
T inv_s = inverse(s); T inv_s = inverse(s);
auto* start = in_data + i * step_i + j * step_j; auto *start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j; auto *end = in_data + i * step_i + (j + 1) * step_j;
auto* cur_out_data = out_data + i * step_i + j * step_j; auto *cur_out_data = out_data + i * step_i + j * step_j;
if (round_type == 0) {
trans(ctx,
start,
end,
cur_out_data,
QuantTensorFunctor<T>(static_cast<T>(bin_cnt), inv_s));
} else {
trans(ctx, start, end, cur_out_data, phi::ClipFunctor<T>(-s, s)); trans(ctx, start, end, cur_out_data, phi::ClipFunctor<T>(-s, s));
}
for (int k = 0; k < step_j; k++) { for (int k = 0; k < step_j; k++) {
if (round_type == 0) {
cur_out_data[k] = cur_out_data[k] * s / static_cast<T>(bin_cnt);
} else {
cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]) * cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]) *
s / static_cast<T>(bin_cnt); s / static_cast<T>(bin_cnt);
} }
...@@ -229,18 +323,21 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> { ...@@ -229,18 +323,21 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
} }
} }
} }
}
}; };
template struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, template struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext,
float>; float>;
template <typename T> template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& cur_scale, const framework::Tensor &cur_scale,
const framework::Tensor& last_scale, const framework::Tensor &last_scale,
const framework::Tensor& iter, const int window_size, const framework::Tensor &iter,
framework::Tensor* scales_arr, framework::Tensor* out_scale) { const int window_size,
T* scale_arr = scales_arr->mutable_data<T>(ctx.GetPlace()); framework::Tensor *scales_arr,
framework::Tensor *out_scale) {
T *scale_arr = scales_arr->mutable_data<T>(ctx.GetPlace());
int64_t it = iter.data<int64_t>()[0]; int64_t it = iter.data<int64_t>()[0];
int idx = it % window_size; int idx = it % window_size;
T removed = scale_arr[idx]; T removed = scale_arr[idx];
...@@ -252,8 +349,8 @@ struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> { ...@@ -252,8 +349,8 @@ struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
max = cur; max = cur;
} else if (fabs(removed - max) < 1e-6) { } else if (fabs(removed - max) < 1e-6) {
int size = (it > window_size) ? window_size : it; int size = (it > window_size) ? window_size : it;
FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(ctx, scale_arr, size, FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(
&max); ctx, scale_arr, size, &max);
} }
out_scale->mutable_data<T>(ctx.GetPlace())[0] = max; out_scale->mutable_data<T>(ctx.GetPlace())[0] = max;
} }
...@@ -263,11 +360,14 @@ template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>; ...@@ -263,11 +360,14 @@ template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>;
template <typename T> template <typename T>
struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext &ctx,
const framework::Tensor& in_accum, const framework::Tensor &in_accum,
const framework::Tensor& in_state, const T* cur_scale, const framework::Tensor &in_state,
const float rate, framework::Tensor* out_state, const T *cur_scale,
framework::Tensor* out_accum, framework::Tensor* out_scale) { const float rate,
framework::Tensor *out_state,
framework::Tensor *out_accum,
framework::Tensor *out_scale) {
T accum = in_accum.data<T>()[0]; T accum = in_accum.data<T>()[0];
T state = in_state.data<T>()[0]; T state = in_state.data<T>()[0];
T scale = cur_scale[0]; T scale = cur_scale[0];
...@@ -287,18 +387,22 @@ template struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, ...@@ -287,18 +387,22 @@ template struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext,
class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel { class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel {
public: public:
FakeQuantOrWithDequantAbsMaxOp(const std::string& type, FakeQuantOrWithDequantAbsMaxOp(const std::string &type,
const framework::VariableNameMap& inputs, const framework::VariableNameMap &inputs,
const framework::VariableNameMap& outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", OP_INOUT_CHECK(
"FakeQuantOrWithDequantAbsMaxOp"); ctx->HasInput("X"), "Input", "X", "FakeQuantOrWithDequantAbsMaxOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", OP_INOUT_CHECK(ctx->HasOutput("Out"),
"Output",
"Out",
"FakeQuantOrWithDequantAbsMaxOp"); "FakeQuantOrWithDequantAbsMaxOp");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"),
"Output",
"OutScale",
"FakeQuantOrWithDequantAbsMaxOp"); "FakeQuantOrWithDequantAbsMaxOp");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1}); ctx->SetOutputDim("OutScale", {1});
...@@ -307,7 +411,7 @@ class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel { ...@@ -307,7 +411,7 @@ class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context()); ctx.device_context());
...@@ -325,13 +429,32 @@ class FakeQuantOrWithDequantAbsMaxOpMaker ...@@ -325,13 +429,32 @@ class FakeQuantOrWithDequantAbsMaxOpMaker
AddOutput("OutScale", "(Tensor) Current scale"); AddOutput("OutScale", "(Tensor) Current scale");
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but " "'bit_length' should be between 1 and 16, but "
"the received is %d", "the received is %d",
bit_length)); bit_length));
}); });
AddAttr<int>(
"round_type",
"(int, default 1) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3")
.SetDefault(1)
.AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument(
"'round_type' should be 0 or 1, 0 rounding to "
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type));
})
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
This is a Base Op which supports FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker. This is a Base Op which supports FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker.
FakeQuantAbsMaxOp operator is used in the dynamic quantization. FakeQuantAbsMaxOp operator is used in the dynamic quantization.
...@@ -354,12 +477,16 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -354,12 +477,16 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", OP_INOUT_CHECK(
"FakeChannelWiseQuantizeAbsMax"); ctx->HasInput("X"), "Input", "X", "FakeChannelWiseQuantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", OP_INOUT_CHECK(ctx->HasOutput("Out"),
"Output",
"Out",
"FakeChannelWiseQuantizeAbsMax"); "FakeChannelWiseQuantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"),
"Output",
"OutScale",
"FakeChannelWiseQuantizeAbsMax"); "FakeChannelWiseQuantizeAbsMax");
int quant_axis = ctx->Attrs().Get<int>("quant_axis"); int quant_axis = ctx->Attrs().Get<int>("quant_axis");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
...@@ -369,7 +496,7 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -369,7 +496,7 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
...@@ -389,8 +516,9 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker ...@@ -389,8 +516,9 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
"For conv2d, depthwise_conv2d, conv2d_transpose " "For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.") "and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0) .SetDefault(0)
.AddCustomChecker([](const int& quant_axis) { .AddCustomChecker([](const int &quant_axis) {
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but " "'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
...@@ -398,13 +526,32 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker ...@@ -398,13 +526,32 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
}); });
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but " "'bit_length' should be between 1 and 16, but "
"the received is %d", "the received is %d",
bit_length)); bit_length));
}); });
AddAttr<int>(
"round_type",
"(int, default 1) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3")
.SetDefault(1)
.AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument(
"'round_type' should be 0 or 1, 0 rounding to "
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type));
})
.AsExtra();
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
...@@ -427,12 +574,18 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp ...@@ -427,12 +574,18 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", OP_INOUT_CHECK(ctx->HasInput("X"),
"Input",
"X",
"FakeChannelWiseQuantizeDequantizeAbsMax"); "FakeChannelWiseQuantizeDequantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", OP_INOUT_CHECK(ctx->HasOutput("Out"),
"Output",
"Out",
"FakeChannelWiseQuantizeDequantizeAbsMax"); "FakeChannelWiseQuantizeDequantizeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"),
"Output",
"OutScale",
"FakeChannelWiseQuantizeDequantizeAbsMax"); "FakeChannelWiseQuantizeDequantizeAbsMax");
int quant_axis = ctx->Attrs().Get<int>("quant_axis"); int quant_axis = ctx->Attrs().Get<int>("quant_axis");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
...@@ -442,7 +595,7 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp ...@@ -442,7 +595,7 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
...@@ -462,8 +615,9 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker ...@@ -462,8 +615,9 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
"For conv2d, depthwise_conv2d, conv2d_transpose " "For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.") "and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0) .SetDefault(0)
.AddCustomChecker([](const int& quant_axis) { .AddCustomChecker([](const int &quant_axis) {
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but " "'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
...@@ -471,13 +625,32 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker ...@@ -471,13 +625,32 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
}); });
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but " "'bit_length' should be between 1 and 16, but "
"the received is %d", "the received is %d",
bit_length)); bit_length));
}); });
AddAttr<int>(
"round_type",
"(int, default 1) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3")
.SetDefault(1)
.AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument(
"'round_type' should be 0 or 1, 0 rounding to "
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type));
})
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector. The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value. In detail, each channel of the input X has a scale value.
...@@ -493,17 +666,19 @@ $$0 \leq c \lt \ the\ channel\ number\ of\ X$$ ...@@ -493,17 +666,19 @@ $$0 \leq c \lt \ the\ channel\ number\ of\ X$$
class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
public: public:
FakeQuantizeRangeAbsMaxOp(const std::string& type, FakeQuantizeRangeAbsMaxOp(const std::string &type,
const framework::VariableNameMap& inputs, const framework::VariableNameMap &inputs,
const framework::VariableNameMap& outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeRangeAbsMax"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeRangeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", OP_INOUT_CHECK(
"FakeQuantizeRangeAbsMax"); ctx->HasOutput("Out"), "Output", "Out", "FakeQuantizeRangeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"),
"Output",
"OutScale",
"FakeQuantizeRangeAbsMax"); "FakeQuantizeRangeAbsMax");
if (ctx->HasOutput("OutScales")) { if (ctx->HasOutput("OutScales")) {
int window_size = ctx->Attrs().Get<int>("window_size"); int window_size = ctx->Attrs().Get<int>("window_size");
...@@ -516,7 +691,7 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -516,7 +691,7 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context()); ctx.device_context());
...@@ -537,13 +712,32 @@ class FakeQuantizeRangeAbsMaxOpMaker ...@@ -537,13 +712,32 @@ class FakeQuantizeRangeAbsMaxOpMaker
.SetDefault(10000); .SetDefault(10000);
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.") AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but " "'bit_length' should be between 1 and 16, but "
"the received is %d", "the received is %d",
bit_length)); bit_length));
}); });
AddAttr<int>(
"round_type",
"(int, default 1) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3")
.SetDefault(1)
.AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument(
"'round_type' should be 0 or 1, 0 rounding to "
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type));
})
.AsExtra();
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
...@@ -563,17 +757,24 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp ...@@ -563,17 +757,24 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp
: public framework::OperatorWithKernel { : public framework::OperatorWithKernel {
public: public:
FakeQuantOrWithDequantMovingAverageAbsMaxOp( FakeQuantOrWithDequantMovingAverageAbsMaxOp(
const std::string& type, const framework::VariableNameMap& inputs, const std::string &type,
const framework::VariableNameMap& outputs, const framework::VariableNameMap &inputs,
const framework::AttributeMap& attrs) const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", OP_INOUT_CHECK(ctx->HasInput("X"),
"Input",
"X",
"FakeQuantOrWithDequantMovingAverageAbsMax"); "FakeQuantOrWithDequantMovingAverageAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", OP_INOUT_CHECK(ctx->HasOutput("Out"),
"Output",
"Out",
"FakeQuantOrWithDequantMovingAverageAbsMax"); "FakeQuantOrWithDequantMovingAverageAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"),
"Output",
"OutScale",
"FakeQuantOrWithDequantMovingAverageAbsMax"); "FakeQuantOrWithDequantMovingAverageAbsMax");
if (ctx->HasOutput("OutState")) { if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1}); ctx->SetOutputDim("OutState", {1});
...@@ -588,7 +789,7 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp ...@@ -588,7 +789,7 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context()); ctx.device_context());
...@@ -611,13 +812,32 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker ...@@ -611,13 +812,32 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
.SetDefault(0.9); .SetDefault(0.9);
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.") AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but " "'bit_length' should be between 1 and 16, but "
"the received is %d", "the received is %d",
bit_length)); bit_length));
}); });
AddAttr<int>(
"round_type",
"(int, default 1) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3")
.SetDefault(1)
.AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument(
"'round_type' should be 0 or 1, 0 rounding to "
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type));
})
.AsExtra();
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
...@@ -644,10 +864,12 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { ...@@ -644,10 +864,12 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", OP_INOUT_CHECK(
"MovingAverageAbsMaxScale"); ctx->HasInput("X"), "Input", "X", "MovingAverageAbsMaxScale");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"),
"Output",
"OutScale",
"MovingAverageAbsMaxScale"); "MovingAverageAbsMaxScale");
if (ctx->HasOutput("OutState")) { if (ctx->HasOutput("OutState")) {
...@@ -665,7 +887,7 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { ...@@ -665,7 +887,7 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
...@@ -705,19 +927,23 @@ class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel { ...@@ -705,19 +927,23 @@ class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
auto out_grad_name = framework::GradVarName("Out"); auto out_grad_name = framework::GradVarName("Out");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name, OP_INOUT_CHECK(ctx->HasInput(out_grad_name),
"Input",
out_grad_name,
"StrightThroughEstimatorGradOp"); "StrightThroughEstimatorGradOp");
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name, OP_INOUT_CHECK(ctx->HasOutput(x_grad_name),
"Output",
x_grad_name,
"StrightThroughEstimatorGradOp"); "StrightThroughEstimatorGradOp");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name)); ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType( auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")); ctx, framework::GradVarName("Out"));
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
...@@ -745,7 +971,8 @@ namespace ops = paddle::operators; ...@@ -745,7 +971,8 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_quantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp, fake_quantize_abs_max,
ops::FakeQuantOrWithDequantAbsMaxOp,
ops::FakeQuantOrWithDequantAbsMaxOpMaker, ops::FakeQuantOrWithDequantAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -753,7 +980,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max, ...@@ -753,7 +980,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CPU, float>); ops::FakeQuantizeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_quantize_dequantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp, fake_quantize_dequantize_abs_max,
ops::FakeQuantOrWithDequantAbsMaxOp,
ops::FakeQuantOrWithDequantAbsMaxOpMaker, ops::FakeQuantOrWithDequantAbsMaxOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>, ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>); ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
...@@ -761,7 +989,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max, ...@@ -761,7 +989,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max,
ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>); ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp, fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxOp,
ops::FakeQuantizeRangeAbsMaxOpMaker, ops::FakeQuantizeRangeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -788,7 +1017,8 @@ REGISTER_OP_CPU_KERNEL( ...@@ -788,7 +1017,8 @@ REGISTER_OP_CPU_KERNEL(
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>); ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_channel_wise_quantize_abs_max, ops::FakeChannelWiseQuantizeAbsMaxOp, fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeAbsMaxOpMaker, ops::FakeChannelWiseQuantizeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -796,7 +1026,8 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max, ...@@ -796,7 +1026,8 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CPU, float>); ops::FakeChannelWiseQuantizeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp, moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleOp,
ops::MovingAverageAbsMaxScaleOpMaker, ops::MovingAverageAbsMaxScaleOpMaker,
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>, ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>); ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
...@@ -832,7 +1063,7 @@ REGISTER_OP_VERSION(moving_average_abs_max_scale) ...@@ -832,7 +1063,7 @@ REGISTER_OP_VERSION(moving_average_abs_max_scale)
"Delete output in order to make the inference model not " "Delete output in order to make the inference model not "
"save moving_average_abs_max_scale operator. This will " "save moving_average_abs_max_scale operator. This will "
"make the quantitative model be correctly applied in inference.")) "make the quantitative model be correctly applied in inference."))
.AddCheckpoint( .AddCheckpoint(R"ROC(Incompatible upgrade of output [Out])ROC",
R"ROC(Incompatible upgrade of output [Out])ROC",
paddle::framework::compatible::OpVersionDesc().NewOutput( paddle::framework::compatible::OpVersionDesc().NewOutput(
"Out", "In order to support dygraph qat, add output again.")); "Out",
"In order to support dygraph qat, add output again."));
...@@ -36,12 +36,12 @@ struct QuantizeDataType<paddle::platform::float16> { ...@@ -36,12 +36,12 @@ struct QuantizeDataType<paddle::platform::float16> {
}; };
template <typename T> template <typename T>
__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { __global__ void FindAbsMaxKernel(const T *in, const int n, T *out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x; int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x; int tid = threadIdx.x;
extern __shared__ char* shared_max_data_tmp[]; extern __shared__ char *shared_max_data_tmp[];
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp); auto shared_max_data = reinterpret_cast<T *>(shared_max_data_tmp);
if (gridDim.x > 1) { if (gridDim.x > 1) {
T local_max_data = T(0); T local_max_data = T(0);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
...@@ -73,18 +73,20 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { ...@@ -73,18 +73,20 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
template <typename T> template <typename T>
struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> { struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const T* in, void operator()(const platform::CUDADeviceContext &ctx,
const int num, T* out) { const T *in,
const int num,
T *out) {
int block = 1024; int block = 1024;
int grid = (block - 1 + num) / block; int grid = (block - 1 + num) / block;
grid = (grid > block) ? block : grid; grid = (grid > block) ? block : grid;
framework::Tensor max; framework::Tensor max;
T* max_data = max.mutable_data<T>(phi::make_ddim({grid}), ctx.GetPlace()); T *max_data = max.mutable_data<T>(phi::make_ddim({grid}), ctx.GetPlace());
FindAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>( FindAbsMaxKernel<T>
in, num, max_data); <<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(in, num, max_data);
FindAbsMaxKernel<T><<<1, block, 1024 * sizeof(T), ctx.stream()>>>( FindAbsMaxKernel<T>
max_data, grid, out); <<<1, block, 1024 * sizeof(T), ctx.stream()>>>(max_data, grid, out);
} }
}; };
...@@ -93,13 +95,15 @@ template struct FindAbsMaxFunctor<platform::CUDADeviceContext, ...@@ -93,13 +95,15 @@ template struct FindAbsMaxFunctor<platform::CUDADeviceContext,
paddle::platform::float16>; paddle::platform::float16>;
template <typename T> template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, __global__ void FindChannelAbsMaxKernelQuantAxis0(const T *in,
const int c, T* out) { const int n,
const int c,
T *out) {
int tid = threadIdx.x; int tid = threadIdx.x;
int channel_size = n / c; int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size; const T *in_c = in + blockIdx.x * channel_size;
extern __shared__ char* shared_max_data_tmp[]; extern __shared__ char *shared_max_data_tmp[];
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp); auto shared_max_data = reinterpret_cast<T *>(shared_max_data_tmp);
T local_max_data = T(0); T local_max_data = T(0);
for (int i = tid; i < channel_size; i += blockDim.x) { for (int i = tid; i < channel_size; i += blockDim.x) {
T tmp = static_cast<T>( T tmp = static_cast<T>(
...@@ -122,17 +126,16 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, ...@@ -122,17 +126,16 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
} }
template <typename T> template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, __global__ void FindChannelAbsMaxKernelQuantAxis1(
const int cin, const int cout, const T *in, const int n, const int cin, const int cout, T *out) {
T* out) { extern __shared__ char *shared_max_data_tmp[];
extern __shared__ char* shared_max_data_tmp[]; auto shared_max_data = reinterpret_cast<T *>(shared_max_data_tmp);
auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
int cout_wh_size = n / cin; int cout_wh_size = n / cin;
int wh_size = n / (cin * cout); int wh_size = n / (cin * cout);
int tid = threadIdx.x; int tid = threadIdx.x;
int bid = blockIdx.x; int bid = blockIdx.x;
const T* in_current = in + tid * cout_wh_size + bid * wh_size; const T *in_current = in + tid * cout_wh_size + bid * wh_size;
T local_max_data = T(0); T local_max_data = T(0);
for (int i = 0; i < wh_size; i++) { for (int i = 0; i < wh_size; i++) {
T tmp = static_cast<T>( T tmp = static_cast<T>(
...@@ -162,23 +165,25 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, ...@@ -162,23 +165,25 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
template <typename T> template <typename T>
struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> { struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& in_tensor, const int quant_axis, const framework::Tensor &in_tensor,
T* out_abs_max) { const int quant_axis,
T *out_abs_max) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true, quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
const int num = in_tensor.numel(); const int num = in_tensor.numel();
auto in_dims = in_tensor.dims(); auto in_dims = in_tensor.dims();
const T* in_data = in_tensor.data<T>(); const T *in_data = in_tensor.data<T>();
if (quant_axis == 0) { if (quant_axis == 0) {
int cout = in_dims[0]; int cout = in_dims[0];
int grid = cout; int grid = cout;
int block = 1024; int block = 1024;
FindChannelAbsMaxKernelQuantAxis0< FindChannelAbsMaxKernelQuantAxis0<T>
T><<<grid, block, block * sizeof(T), ctx.stream()>>>( <<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, cout, out_abs_max); in_data, num, cout, out_abs_max);
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
int cin = in_dims[0]; int cin = in_dims[0];
...@@ -194,16 +199,16 @@ struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -194,16 +199,16 @@ struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
for (int i = 0; i < cin / max_threads; i++) { for (int i = 0; i < cin / max_threads; i++) {
int block = max_threads; int block = max_threads;
FindChannelAbsMaxKernelQuantAxis1< FindChannelAbsMaxKernelQuantAxis1<T>
T><<<grid, block, block * sizeof(T), ctx.stream()>>>( <<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, cin, cout, out_abs_max); in_data, num, cin, cout, out_abs_max);
in_data += num / cin; in_data += num / cin;
} }
int block = cin % max_threads; int block = cin % max_threads;
if (block > 0) { if (block > 0) {
FindChannelAbsMaxKernelQuantAxis1< FindChannelAbsMaxKernelQuantAxis1<T>
T><<<grid, block, block * sizeof(T), ctx.stream()>>>( <<<grid, block, block * sizeof(T), ctx.stream()>>>(
in_data, num, in_dims[0], in_dims[1], out_abs_max); in_data, num, in_dims[0], in_dims[1], out_abs_max);
} }
} }
...@@ -213,8 +218,12 @@ struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -213,8 +218,12 @@ struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>; template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T> template <typename T>
__global__ void ClipAndQuantKernel(const T* in, const T* scale, __global__ void ClipAndQuantKernel(const T *in,
const int bin_cnt, const int n, T* out) { const T *scale,
const int bin_cnt,
const int round_type,
const int n,
T *out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x; int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x; int tid = threadIdx.x;
...@@ -226,17 +235,30 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, ...@@ -226,17 +235,30 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
ComputeDataType x = static_cast<ComputeDataType>(in[i]); ComputeDataType x = static_cast<ComputeDataType>(in[i]);
if (round_type == 0) {
x = bin_cnt_t * inv_s * x;
x = roundWithTiesToEven(x);
ComputeDataType max_bound = bin_cnt_t;
ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1);
x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x;
out[i] = static_cast<T>(x);
} else {
ComputeDataType v = x > s ? s : x; ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v; v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>(round(v)); out[i] = static_cast<T>(round(v));
} }
}
} }
template <typename T> template <typename T>
__global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, __global__ void ClipAndQuantDequantKernel(const T *in,
const int bin_cnt, const int n, const T *scale,
T* out) { const int bin_cnt,
const int round_type,
const int n,
T *out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x; int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x; int tid = threadIdx.x;
...@@ -248,29 +270,42 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, ...@@ -248,29 +270,42 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
ComputeDataType x = static_cast<ComputeDataType>(in[i]); ComputeDataType x = static_cast<ComputeDataType>(in[i]);
if (round_type == 0) {
x = bin_cnt_t * inv_s * x;
x = roundWithTiesToEven(x);
ComputeDataType max_bound = bin_cnt_t;
ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1);
x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x;
out[i] = static_cast<T>((x * s) / bin_cnt_t);
} else {
x = x > s ? s : x; x = x > s ? s : x;
x = x < -s ? -s : x; x = x < -s ? -s : x;
x = bin_cnt_t * inv_s * x; x = bin_cnt_t * inv_s * x;
x = round(x); x = round(x);
out[i] = static_cast<T>((x * s) / bin_cnt_t); out[i] = static_cast<T>((x * s) / bin_cnt_t);
} }
}
} }
template <typename T> template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> { struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, framework::Tensor* out) { const framework::Tensor &scale,
const int bin_cnt,
const int round_type,
framework::Tensor *out) {
int num = in.numel(); int num = in.numel();
int block = 1024; int block = 1024;
int grid = (block - 1 + num) / block; int grid = (block - 1 + num) / block;
const T* in_data = in.data<T>(); const T *in_data = in.data<T>();
const T* scale_data = scale.data<T>(); const T *scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T *out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>( ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, out_data); in_data, scale_data, bin_cnt, round_type, num, out_data);
} }
}; };
...@@ -278,33 +313,39 @@ template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>; ...@@ -278,33 +313,39 @@ template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
template <typename T> template <typename T>
struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> { struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, framework::Tensor* out) { const framework::Tensor &scale,
const int bin_cnt,
const int round_type,
framework::Tensor *out) {
int num = in.numel(); int num = in.numel();
int block = 1024; int block = 1024;
int grid = (block - 1 + num) / block; int grid = (block - 1 + num) / block;
const T* in_data = in.data<T>(); const T *in_data = in.data<T>();
const T* scale_data = scale.data<T>(); const T *scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T *out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantDequantKernel<T><<<grid, block, 0, ctx.stream()>>>( ClipAndQuantDequantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, out_data); in_data, scale_data, bin_cnt, round_type, num, out_data);
} }
}; };
// ChannelClipAndQuantKernel for quant_axis is 0 // ChannelClipAndQuantKernel for quant_axis is 0
template <typename T> template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, __global__ void ChannelClipAndQuantKernelQuantAxis0(const T *in,
const T *scale,
const int bin_cnt, const int bin_cnt,
const int round_type,
const int64_t n, const int64_t n,
const int c, T* out) { const int c,
T *out) {
int tid = threadIdx.x; int tid = threadIdx.x;
int64_t channel_size = n / c; int64_t channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size; const T *in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size; T *out_c = out + blockIdx.x * channel_size;
using ComputeDataType = typename QuantizeDataType<T>::type; using ComputeDataType = typename QuantizeDataType<T>::type;
...@@ -314,18 +355,33 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, ...@@ -314,18 +355,33 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
for (int64_t i = tid; i < channel_size; i += blockDim.x) { for (int64_t i = tid; i < channel_size; i += blockDim.x) {
ComputeDataType x = static_cast<ComputeDataType>(in_c[i]); ComputeDataType x = static_cast<ComputeDataType>(in_c[i]);
if (round_type == 0) {
x = bin_cnt_t * inv_s * x;
x = roundWithTiesToEven(x);
ComputeDataType max_bound = bin_cnt_t;
ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1);
x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x;
out_c[i] = static_cast<T>(x);
} else {
ComputeDataType v = x > s ? s : x; ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v; v = bin_cnt_t * inv_s * v;
out_c[i] = static_cast<T>(round(v)); out_c[i] = static_cast<T>(round(v));
} }
}
} }
// ChannelClipAndQuantKernel for quant_axis is N // ChannelClipAndQuantKernel for quant_axis is N
template <typename T> template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxisN( __global__ void ChannelClipAndQuantKernelQuantAxisN(const T *in,
const T* in, const T* scale, const int bin_cnt, const int64_t n, const T *scale,
const int nScale, const int quant_stride, T* out) { const int bin_cnt,
const int round_type,
const int64_t n,
const int nScale,
const int quant_stride,
T *out) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
using ComputeDataType = typename QuantizeDataType<T>::type; using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt); ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
...@@ -334,36 +390,50 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN( ...@@ -334,36 +390,50 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN(
static_cast<ComputeDataType>(scale[(i / quant_stride) % nScale]); static_cast<ComputeDataType>(scale[(i / quant_stride) % nScale]);
ComputeDataType inv_s = inverse(s); ComputeDataType inv_s = inverse(s);
ComputeDataType x = static_cast<ComputeDataType>(in[i]); ComputeDataType x = static_cast<ComputeDataType>(in[i]);
if (round_type == 0) {
x = bin_cnt_t * inv_s * x;
x = roundWithTiesToEven(x);
ComputeDataType max_bound = bin_cnt_t;
ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1);
x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x;
out[i] = static_cast<T>(x);
} else {
ComputeDataType v = x > s ? s : x; ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v; v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>(round(v)); out[i] = static_cast<T>(round(v));
} }
}
} }
template <typename T> template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> { struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, const int quant_axis, const framework::Tensor &scale,
framework::Tensor* out) { const int bin_cnt,
const int round_type,
const int quant_axis,
framework::Tensor *out) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true, quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
int64_t num = in.numel(); int64_t num = in.numel();
auto in_dims = in.dims(); auto in_dims = in.dims();
const T* in_data = in.data<T>(); const T *in_data = in.data<T>();
const T* scale_data = scale.data<T>(); const T *scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T *out_data = out->mutable_data<T>(ctx.GetPlace());
if (quant_axis == 0) { if (quant_axis == 0) {
int grid = in_dims[0]; int grid = in_dims[0];
int block = 1024; int block = 1024;
ChannelClipAndQuantKernelQuantAxis0<T><<<grid, block, 0, ctx.stream()>>>( ChannelClipAndQuantKernelQuantAxis0<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, in_dims[0], out_data); in_data, scale_data, bin_cnt, round_type, num, in_dims[0], out_data);
} else { } else {
int quant_stride = 1; int quant_stride = 1;
for (int i = quant_axis + 1; i < in_dims.size(); i++) { for (int i = quant_axis + 1; i < in_dims.size(); i++) {
...@@ -379,8 +449,14 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> { ...@@ -379,8 +449,14 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
const int64_t grid_size = const int64_t grid_size =
std::min(max_blocks, (num + block_size - 1) / block_size); std::min(max_blocks, (num + block_size - 1) / block_size);
ChannelClipAndQuantKernelQuantAxisN<T><<<grid_size, block_size>>>( ChannelClipAndQuantKernelQuantAxisN<T>
in_data, scale_data, bin_cnt, num, in_dims[quant_axis], quant_stride, <<<grid_size, block_size>>>(in_data,
scale_data,
bin_cnt,
round_type,
num,
in_dims[quant_axis],
quant_stride,
out_data); out_data);
} }
} }
...@@ -390,12 +466,14 @@ template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, ...@@ -390,12 +466,14 @@ template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext,
float>; float>;
template <typename T> template <typename T>
__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, __global__ void FindRangeAbsMaxAndFillArray(const T *cur_scale,
const T* last_scale, const T *last_scale,
const int64_t* iter, const int64_t *iter,
const int window_size, T* scale_arr, const int window_size,
T* out_scale, int* need_find_max, T *scale_arr,
int* out_size) { T *out_scale,
int *need_find_max,
int *out_size) {
int it = iter[0]; int it = iter[0];
int idx = it % window_size; int idx = it % window_size;
T removed = scale_arr[idx]; T removed = scale_arr[idx];
...@@ -414,45 +492,63 @@ __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, ...@@ -414,45 +492,63 @@ __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
template <typename T> template <typename T>
struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> { struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& cur_scale, const framework::Tensor &cur_scale,
const framework::Tensor& last_scale, const framework::Tensor &last_scale,
const framework::Tensor& iter, const int window_size, const framework::Tensor &iter,
framework::Tensor* scales_arr, framework::Tensor* out_scale) { const int window_size,
framework::Tensor *scales_arr,
framework::Tensor *out_scale) {
const auto gpu_place = ctx.GetPlace(); const auto gpu_place = ctx.GetPlace();
T* scale_arr = scales_arr->mutable_data<T>(gpu_place); T *scale_arr = scales_arr->mutable_data<T>(gpu_place);
T* out_scale_data = out_scale->mutable_data<T>(gpu_place); T *out_scale_data = out_scale->mutable_data<T>(gpu_place);
framework::Tensor need_find_max, out_size; framework::Tensor need_find_max, out_size;
int* find_max = need_find_max.mutable_data<int>({1}, gpu_place); int *find_max = need_find_max.mutable_data<int>({1}, gpu_place);
int* out_size_data = out_size.mutable_data<int>({1}, gpu_place); int *out_size_data = out_size.mutable_data<int>({1}, gpu_place);
FindRangeAbsMaxAndFillArray<T><<<1, 1, 0, ctx.stream()>>>( FindRangeAbsMaxAndFillArray<T>
cur_scale.data<T>(), last_scale.data<T>(), iter.data<int64_t>(), <<<1, 1, 0, ctx.stream()>>>(cur_scale.data<T>(),
window_size, scale_arr, out_scale_data, find_max, out_size_data); last_scale.data<T>(),
iter.data<int64_t>(),
window_size,
scale_arr,
out_scale_data,
find_max,
out_size_data);
int g_find_max; int g_find_max;
memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max, memory::Copy(platform::CPUPlace(),
sizeof(int), ctx.stream()); &g_find_max,
gpu_place,
find_max,
sizeof(int),
ctx.stream());
ctx.Wait(); ctx.Wait();
if (g_find_max) { if (g_find_max) {
int len; int len;
memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data, memory::Copy(platform::CPUPlace(),
sizeof(int), ctx.stream()); &len,
gpu_place,
out_size_data,
sizeof(int),
ctx.stream());
ctx.Wait(); ctx.Wait();
FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len, FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(
out_scale_data); ctx, scale_arr, len, out_scale_data);
} }
} }
}; };
template <typename T> template <typename T>
__global__ void FindMovingAverageAbsMaxKernel(const T* in_state, __global__ void FindMovingAverageAbsMaxKernel(const T *in_state,
const T* in_accum, const T *in_accum,
const T* cur_scale, const T rate, const T *cur_scale,
T* out_state, T* out_accum, const T rate,
T* out_scale) { T *out_state,
T *out_accum,
T *out_scale) {
T state = rate * (*in_state) + T(1.0f); T state = rate * (*in_state) + T(1.0f);
T accum = rate * (*in_accum) + (*cur_scale); T accum = rate * (*in_accum) + (*cur_scale);
*out_state = state; *out_state = state;
...@@ -464,78 +560,119 @@ template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>; ...@@ -464,78 +560,119 @@ template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;
template <typename T> template <typename T>
struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> { struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& in_accum, const framework::Tensor &in_accum,
const framework::Tensor& in_state, const T* cur_scale, const framework::Tensor &in_state,
const float rate, framework::Tensor* out_state, const T *cur_scale,
framework::Tensor* out_accum, framework::Tensor* out_scale) { const float rate,
framework::Tensor *out_state,
framework::Tensor *out_accum,
framework::Tensor *out_scale) {
const auto gpu_place = ctx.GetPlace(); const auto gpu_place = ctx.GetPlace();
T rate_t = static_cast<T>(rate); T rate_t = static_cast<T>(rate);
T* out_state_data = out_state->mutable_data<T>(gpu_place); T *out_state_data = out_state->mutable_data<T>(gpu_place);
T* out_accum_data = out_accum->mutable_data<T>(gpu_place); T *out_accum_data = out_accum->mutable_data<T>(gpu_place);
T* out_scale_data = out_scale->mutable_data<T>(gpu_place); T *out_scale_data = out_scale->mutable_data<T>(gpu_place);
FindMovingAverageAbsMaxKernel<T><<<1, 1, 0, ctx.stream()>>>( FindMovingAverageAbsMaxKernel<T>
in_state.data<T>(), in_accum.data<T>(), cur_scale, rate_t, <<<1, 1, 0, ctx.stream()>>>(in_state.data<T>(),
out_state_data, out_accum_data, out_scale_data); in_accum.data<T>(),
cur_scale,
rate_t,
out_state_data,
out_accum_data,
out_scale_data);
} }
}; };
// ChannelClipAndQuantDequantKernel for quant_axis is 0 // ChannelClipAndQuantDequantKernel for quant_axis is 0
template <typename T> template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis0( __global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in,
const T* in, const T* scale, const int bin_cnt, const int n, const int c, const T *scale,
T* out) { const int bin_cnt,
const int round_type,
const int n,
const int c,
T *out) {
int tid = threadIdx.x; int tid = threadIdx.x;
int channel_size = n / c; int channel_size = n / c;
const T* in_c = in + blockIdx.x * channel_size; const T *in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size; T *out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x]; T s = scale[blockIdx.x];
T inv_s = inverse(s); T inv_s = inverse(s);
for (int i = tid; i < channel_size; i += blockDim.x) { for (int i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i]; T x = in_c[i];
if (round_type == 0) {
x = bin_cnt * inv_s * x;
x = roundWithTiesToEven(x);
T max_bound = bin_cnt;
T min_bound = -bin_cnt - static_cast<T>(1);
x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x;
out_c[i] = (x * s) / bin_cnt;
} else {
T v = x > s ? s : x; T v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt * inv_s * v; v = bin_cnt * inv_s * v;
out_c[i] = round(v) * s / bin_cnt; out_c[i] = round(v) * s / bin_cnt;
} }
}
} }
// ChannelClipAndQuantDequantKernel for quant_axis is 1 // ChannelClipAndQuantDequantKernel for quant_axis is 1
template <typename T> template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis1( __global__ void ChannelClipAndQuantDequantKernelQuantAxis1(const T *in,
const T* in, const T* scale, const int bin_cnt, const int n, const int cin, const T *scale,
const int cout, T* out) { const int bin_cnt,
const int round_type,
const int n,
const int cin,
const int cout,
T *out) {
T s = scale[blockIdx.x % cout]; T s = scale[blockIdx.x % cout];
T inv_s = inverse(s); T inv_s = inverse(s);
int wh_size = n / (cin * cout); int wh_size = n / (cin * cout);
const T* in_c = in + blockIdx.x * wh_size; const T *in_c = in + blockIdx.x * wh_size;
T* out_c = out + blockIdx.x * wh_size; T *out_c = out + blockIdx.x * wh_size;
for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
T x = in_c[i]; T x = in_c[i];
if (round_type == 0) {
x = bin_cnt * inv_s * x;
x = roundWithTiesToEven(x);
T max_bound = bin_cnt;
T min_bound = -bin_cnt - static_cast<T>(1);
x = x > max_bound ? max_bound : x;
x = x < min_bound ? min_bound : x;
out_c[i] = (x * s) / bin_cnt;
} else {
T v = x > s ? s : x; T v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt * inv_s * v; v = bin_cnt * inv_s * v;
out_c[i] = round(v) * s / bin_cnt; out_c[i] = round(v) * s / bin_cnt;
} }
}
} }
template <typename T> template <typename T>
struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> { struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext &ctx,
const framework::Tensor& in, const framework::Tensor& scale, const framework::Tensor &in,
const int bin_cnt, const int quant_axis, const framework::Tensor &scale,
framework::Tensor* out) { const int bin_cnt,
const int round_type,
const int quant_axis,
framework::Tensor *out) {
// At present, channelwise quantization supports conv2d, depthwise_conv2d // At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul // conv2d_transpose and mul
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true, quant_axis == 0 || quant_axis == 1,
true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
quant_axis)); quant_axis));
...@@ -543,23 +680,34 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> { ...@@ -543,23 +680,34 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
int num = in.numel(); int num = in.numel();
auto in_dims = in.dims(); auto in_dims = in.dims();
const T* in_data = in.data<T>(); const T *in_data = in.data<T>();
const T* scale_data = scale.data<T>(); const T *scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace()); T *out_data = out->mutable_data<T>(ctx.GetPlace());
if (quant_axis == 0) { if (quant_axis == 0) {
int grid = in_dims[0]; int grid = in_dims[0];
int block = 1024; int block = 1024;
ChannelClipAndQuantDequantKernelQuantAxis0< ChannelClipAndQuantDequantKernelQuantAxis0<T>
T><<<grid, block, 0, ctx.stream()>>>(in_data, scale_data, bin_cnt, <<<grid, block, 0, ctx.stream()>>>(in_data,
num, in_dims[0], out_data); scale_data,
bin_cnt,
round_type,
num,
in_dims[0],
out_data);
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
int grid = in_dims[0] * in_dims[1]; int grid = in_dims[0] * in_dims[1];
int block = 1024; int block = 1024;
ChannelClipAndQuantDequantKernelQuantAxis1< ChannelClipAndQuantDequantKernelQuantAxis1<T>
T><<<grid, block, 0, ctx.stream()>>>( <<<grid, block, 0, ctx.stream()>>>(in_data,
in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data); scale_data,
bin_cnt,
round_type,
num,
in_dims[0],
in_dims[1],
out_data);
} }
} }
}; };
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
...@@ -33,97 +34,158 @@ inline HOSTDEVICE T inverse(T s) { ...@@ -33,97 +34,158 @@ inline HOSTDEVICE T inverse(T s) {
return s <= static_cast<T>(1e-30) ? one / (s + eps) : one / s; return s <= static_cast<T>(1e-30) ? one / (s + eps) : one / s;
} }
template <typename T>
inline HOSTDEVICE T roundWithTiesToEven(T x) {
T xLower = floor(x);
T xUpper = ceil(x);
// x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to
// even.
T dLower = x - xLower;
T dUpper = xUpper - x;
return static_cast<T>(
(dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper)
? xLower
: xUpper);
}
template <typename T>
class QuantTensorFunctor {
public:
explicit QuantTensorFunctor(const T bin_cnt, const T inv_s)
: bin_cnt_(bin_cnt), inv_s_(inv_s) {}
HOSTDEVICE T operator()(const T x) const {
T out = bin_cnt_ * inv_s_ * x;
out = roundWithTiesToEven(out);
T max_bound = bin_cnt_;
T min_bound = -bin_cnt_ - static_cast<T>(1);
out = out > max_bound ? max_bound : out;
out = out < min_bound ? min_bound : out;
return out;
}
private:
T bin_cnt_;
T inv_s_;
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindAbsMaxFunctor { struct FindAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const T* in, const int num, T* out); void operator()(const DeviceContext &ctx, const T *in, const int num, T *out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ClipAndFakeQuantFunctor { struct ClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in, void operator()(const DeviceContext &ctx,
const framework::Tensor& scale, const int bin_cnt, const framework::Tensor &in,
framework::Tensor* out); const framework::Tensor &scale,
const int bin_cnt,
const int round_type,
framework::Tensor *out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ClipAndFakeQuantDequantFunctor { struct ClipAndFakeQuantDequantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in, void operator()(const DeviceContext &ctx,
const framework::Tensor& scale, const int bin_cnt, const framework::Tensor &in,
framework::Tensor* out); const framework::Tensor &scale,
const int bin_cnt,
int round_type,
framework::Tensor *out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindRangeAbsMaxFunctor { struct FindRangeAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& cur_scale, void operator()(const DeviceContext &ctx,
const framework::Tensor& last_scale, const framework::Tensor &cur_scale,
const framework::Tensor& iter, const int window_size, const framework::Tensor &last_scale,
framework::Tensor* scales_arr, framework::Tensor* out_scale); const framework::Tensor &iter,
const int window_size,
framework::Tensor *scales_arr,
framework::Tensor *out_scale);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindChannelAbsMaxFunctor { struct FindChannelAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in_tensor, void operator()(const DeviceContext &ctx,
const int quant_axis, T* out_abs_max); const framework::Tensor &in_tensor,
const int quant_axis,
T *out_abs_max);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ChannelClipAndFakeQuantFunctor { struct ChannelClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in, void operator()(const DeviceContext &ctx,
const framework::Tensor& scale, const int bin_cnt, const framework::Tensor &in,
const int quant_axis, framework::Tensor* out); const framework::Tensor &scale,
const int bin_cnt,
const int round_type,
const int quant_axis,
framework::Tensor *out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ChannelClipFakeQuantDequantFunctor { struct ChannelClipFakeQuantDequantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in, void operator()(const DeviceContext &ctx,
const framework::Tensor& scale, const int bin_cnt, const framework::Tensor &in,
const int quant_axis, framework::Tensor* out); const framework::Tensor &scale,
const int bin_cnt,
int round_type,
const int quant_axis,
framework::Tensor *out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindMovingAverageAbsMaxFunctor { struct FindMovingAverageAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in_accum, void operator()(const DeviceContext &ctx,
const framework::Tensor& in_state, const framework::Tensor &in_accum,
const framework::Tensor& cur_scale, const framework::Tensor &in_state,
framework::Tensor* out_state, framework::Tensor* out_accum, const framework::Tensor &cur_scale,
framework::Tensor* out_scale); framework::Tensor *out_state,
framework::Tensor *out_accum,
framework::Tensor *out_scale);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeAbsMaxKernelBase : public framework::OpKernel<T> { class FakeAbsMaxKernelBase : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto *in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto *out_scale = context.Output<framework::Tensor>("OutScale");
T* out_s = out_scale->mutable_data<T>(context.GetPlace()); T *out_s = out_scale->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int round_type = context.Attr<int>("round_type");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
const T* in_data = in->data<T>(); const T *in_data = in->data<T>();
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in->numel(), out_s); FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in->numel(), out_s);
RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out); RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
} }
virtual ~FakeAbsMaxKernelBase() = default; virtual ~FakeAbsMaxKernelBase() = default;
protected: protected:
virtual void RunClipFunctor(const DeviceContext& dev_ctx, virtual void RunClipFunctor(const DeviceContext &dev_ctx,
const framework::Tensor& in, const framework::Tensor &in,
const framework::Tensor& scale, int bin_cnt, const framework::Tensor &scale,
framework::Tensor* out) const = 0; int bin_cnt,
int round_type,
framework::Tensor *out) const = 0;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase<DeviceContext, T> { class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase<DeviceContext, T> {
protected: protected:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, void RunClipFunctor(const DeviceContext &dev_ctx,
const framework::Tensor& scale, int bin_cnt, const framework::Tensor &in,
framework::Tensor* out) const override { const framework::Tensor &scale,
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, in, scale, bin_cnt, int bin_cnt,
out); int round_type,
framework::Tensor *out) const override {
ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, in, scale, bin_cnt, round_type, out);
} }
}; };
...@@ -131,37 +193,41 @@ template <typename DeviceContext, typename T> ...@@ -131,37 +193,41 @@ template <typename DeviceContext, typename T>
class FakeQuantizeDequantizeAbsMaxKernel class FakeQuantizeDequantizeAbsMaxKernel
: public FakeAbsMaxKernelBase<DeviceContext, T> { : public FakeAbsMaxKernelBase<DeviceContext, T> {
protected: protected:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, void RunClipFunctor(const DeviceContext &dev_ctx,
const framework::Tensor& scale, int bin_cnt, const framework::Tensor &in,
framework::Tensor* out) const override { const framework::Tensor &scale,
ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(dev_ctx, in, scale, int bin_cnt,
bin_cnt, out); int round_type,
framework::Tensor *out) const override {
ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(
dev_ctx, in, scale, bin_cnt, round_type, out);
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> { class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto *in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto *out_scale = context.Output<framework::Tensor>("OutScale");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int round_type = context.Attr<int>("round_type");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
int quant_axis = context.Attr<int>("quant_axis"); int quant_axis = context.Attr<int>("quant_axis");
bool is_test = context.Attr<bool>("is_test"); bool is_test = context.Attr<bool>("is_test");
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
if (!is_test) { if (!is_test) {
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace()); T *out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
FindChannelAbsMaxFunctor<DeviceContext, T>()(dev_ctx, *in, quant_axis, FindChannelAbsMaxFunctor<DeviceContext, T>()(
out_scale_data); dev_ctx, *in, quant_axis, out_scale_data);
} }
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()( ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
} }
}; };
...@@ -169,130 +235,147 @@ template <typename DeviceContext, typename T> ...@@ -169,130 +235,147 @@ template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeDequantizeAbsMaxKernel class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto *in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto *out_scale = context.Output<framework::Tensor>("OutScale");
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace()); T *out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
out->mutable_data<T>(dev_ctx.GetPlace()); out->mutable_data<T>(dev_ctx.GetPlace());
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int round_type = context.Attr<int>("round_type");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
int quant_axis = context.Attr<int>("quant_axis"); int quant_axis = context.Attr<int>("quant_axis");
FindChannelAbsMaxFunctor<DeviceContext, T>()(dev_ctx, *in, quant_axis, FindChannelAbsMaxFunctor<DeviceContext, T>()(
out_scale_data); dev_ctx, *in, quant_axis, out_scale_data);
ChannelClipFakeQuantDequantFunctor<DeviceContext, T>()( ChannelClipFakeQuantDequantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> { class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto *in = context.Input<framework::Tensor>("X");
auto* in_scale = context.Input<framework::Tensor>("InScale"); auto *in_scale = context.Input<framework::Tensor>("InScale");
auto* out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
bool is_test = context.Attr<bool>("is_test"); bool is_test = context.Attr<bool>("is_test");
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int round_type = context.Attr<int>("round_type");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
// testing // testing
if (is_test) { if (is_test) {
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *in_scale, ClipAndFakeQuantFunctor<DeviceContext, T>()(
bin_cnt, out); dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
return; return;
} }
// training // training
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto *out_scale = context.Output<framework::Tensor>("OutScale");
auto* out_scales = context.Output<framework::Tensor>("OutScales"); auto *out_scales = context.Output<framework::Tensor>("OutScales");
auto* iter = context.Input<framework::Tensor>("Iter"); auto *iter = context.Input<framework::Tensor>("Iter");
int window_size = context.Attr<int>("window_size"); int window_size = context.Attr<int>("window_size");
out_scale->mutable_data<T>(context.GetPlace()); out_scale->mutable_data<T>(context.GetPlace());
framework::Tensor cur_scale; framework::Tensor cur_scale;
T* cur_scale_data = cur_scale.mutable_data<T>({1}, context.GetPlace()); T *cur_scale_data = cur_scale.mutable_data<T>({1}, context.GetPlace());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(), FindAbsMaxFunctor<DeviceContext, T>()(
cur_scale_data); dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
FindRangeAbsMaxFunctor<DeviceContext, T>()(dev_ctx, cur_scale, *in_scale, FindRangeAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
*iter, window_size, out_scales, cur_scale,
*in_scale,
*iter,
window_size,
out_scales,
out_scale); out_scale);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale, ClipAndFakeQuantFunctor<DeviceContext, T>()(
bin_cnt, out); dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> { class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto *in = context.Input<framework::Tensor>("X");
auto* in_scale = context.Input<framework::Tensor>("InScale"); auto *in_scale = context.Input<framework::Tensor>("InScale");
auto* out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
bool is_test = context.Attr<bool>("is_test"); bool is_test = context.Attr<bool>("is_test");
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int round_type = context.Attr<int>("round_type");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
// testing // testing
if (is_test) { if (is_test) {
RunClipFunctor(dev_ctx, *in, *in_scale, bin_cnt, out); RunClipFunctor(dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
return; return;
} }
// training // training
auto* in_accum = context.Input<framework::Tensor>("InAccum"); auto *in_accum = context.Input<framework::Tensor>("InAccum");
auto* in_state = context.Input<framework::Tensor>("InState"); auto *in_state = context.Input<framework::Tensor>("InState");
auto cur_scale = memory::Alloc(dev_ctx, sizeof(T)); auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
T* cur_scale_data = static_cast<T*>(cur_scale->ptr()); T *cur_scale_data = static_cast<T *>(cur_scale->ptr());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(), FindAbsMaxFunctor<DeviceContext, T>()(
cur_scale_data); dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
auto* out_state = context.Output<framework::Tensor>("OutState"); auto *out_state = context.Output<framework::Tensor>("OutState");
auto* out_accum = context.Output<framework::Tensor>("OutAccum"); auto *out_accum = context.Output<framework::Tensor>("OutAccum");
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto *out_scale = context.Output<framework::Tensor>("OutScale");
out_state->mutable_data<T>(context.GetPlace()); out_state->mutable_data<T>(context.GetPlace());
out_accum->mutable_data<T>(context.GetPlace()); out_accum->mutable_data<T>(context.GetPlace());
out_scale->mutable_data<T>(context.GetPlace()); out_scale->mutable_data<T>(context.GetPlace());
float moving_rate = context.Attr<float>("moving_rate"); float moving_rate = context.Attr<float>("moving_rate");
FindMovingAverageAbsMaxFunctor<DeviceContext, T>()( FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state, *in_accum,
out_accum, out_scale); *in_state,
cur_scale_data,
moving_rate,
out_state,
out_accum,
out_scale);
RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out); RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
} }
virtual ~FakeMovingAverageAbsMaxKernelBase() = default; virtual ~FakeMovingAverageAbsMaxKernelBase() = default;
protected: protected:
virtual void RunClipFunctor(const DeviceContext& dev_ctx, virtual void RunClipFunctor(const DeviceContext &dev_ctx,
const framework::Tensor& in, const framework::Tensor &in,
const framework::Tensor& in_scale, int bin_cnt, const framework::Tensor &in_scale,
framework::Tensor* out) const = 0; int bin_cnt,
int round_type,
framework::Tensor *out) const = 0;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeMovingAverageAbsMaxKernel class FakeQuantizeMovingAverageAbsMaxKernel
: public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> { : public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
protected: protected:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, void RunClipFunctor(const DeviceContext &dev_ctx,
const framework::Tensor& in_scale, int bin_cnt, const framework::Tensor &in,
framework::Tensor* out) const override { const framework::Tensor &in_scale,
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, in, in_scale, bin_cnt, int bin_cnt,
out); int round_type,
framework::Tensor *out) const override {
ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, in, in_scale, bin_cnt, round_type, out);
} }
}; };
...@@ -300,23 +383,26 @@ template <typename DeviceContext, typename T> ...@@ -300,23 +383,26 @@ template <typename DeviceContext, typename T>
class FakeQuantizeDequantizeMovingAverageAbsMaxKernel class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
: public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> { : public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
protected: protected:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, void RunClipFunctor(const DeviceContext &dev_ctx,
const framework::Tensor& in_scale, int bin_cnt, const framework::Tensor &in,
framework::Tensor* out) const override { const framework::Tensor &in_scale,
ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(dev_ctx, in, in_scale, int bin_cnt,
bin_cnt, out); int round_type,
framework::Tensor *out) const override {
ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(
dev_ctx, in, in_scale, bin_cnt, round_type, out);
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> { class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto *in = context.Input<framework::Tensor>("X");
auto& dev_ctx = context.template device_context<DeviceContext>(); auto &dev_ctx = context.template device_context<DeviceContext>();
if (context.HasOutput("Out")) { if (context.HasOutput("Out")) {
auto* out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
} }
...@@ -328,37 +414,43 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> { ...@@ -328,37 +414,43 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
} }
// training // training
auto* in_accum = context.Input<framework::Tensor>("InAccum"); auto *in_accum = context.Input<framework::Tensor>("InAccum");
auto* in_state = context.Input<framework::Tensor>("InState"); auto *in_state = context.Input<framework::Tensor>("InState");
auto cur_scale = memory::Alloc(dev_ctx, sizeof(T)); auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
T* cur_scale_data = static_cast<T*>(cur_scale->ptr()); T *cur_scale_data = static_cast<T *>(cur_scale->ptr());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(), FindAbsMaxFunctor<DeviceContext, T>()(
cur_scale_data); dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
auto* out_state = context.Output<framework::Tensor>("OutState"); auto *out_state = context.Output<framework::Tensor>("OutState");
auto* out_accum = context.Output<framework::Tensor>("OutAccum"); auto *out_accum = context.Output<framework::Tensor>("OutAccum");
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto *out_scale = context.Output<framework::Tensor>("OutScale");
out_state->mutable_data<T>(context.GetPlace()); out_state->mutable_data<T>(context.GetPlace());
out_accum->mutable_data<T>(context.GetPlace()); out_accum->mutable_data<T>(context.GetPlace());
out_scale->mutable_data<T>(context.GetPlace()); out_scale->mutable_data<T>(context.GetPlace());
float moving_rate = context.Attr<float>("moving_rate"); float moving_rate = context.Attr<float>("moving_rate");
FindMovingAverageAbsMaxFunctor<DeviceContext, T>()( FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state, *in_accum,
out_accum, out_scale); *in_state,
cur_scale_data,
moving_rate,
out_state,
out_accum,
out_scale);
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class StrightThroughEstimatorGradKernel : public framework::OpKernel<T> { class StrightThroughEstimatorGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* d_out = auto *d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out")); context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto* d_x = context.Output<framework::LoDTensor>(x_grad_name); auto *d_x = context.Output<framework::LoDTensor>(x_grad_name);
PADDLE_ENFORCE_NOT_NULL(d_x, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(d_x,
platform::errors::PreconditionNotMet(
"StrightThroughEstimatorGradKernel " "StrightThroughEstimatorGradKernel "
"doesn't have the output named %s.", "doesn't have the output named %s.",
x_grad_name)); x_grad_name));
......
...@@ -10,9 +10,11 @@ See the License for the specific language governing permissions and ...@@ -10,9 +10,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/quantize_linear_op.h" #include "paddle/fluid/operators/quantize_linear_op.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
...@@ -24,14 +26,17 @@ namespace operators { ...@@ -24,14 +26,17 @@ namespace operators {
template <typename T> template <typename T>
struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> { struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx, void operator()(const platform::CPUDeviceContext &dev_ctx,
const framework::Tensor* in, const framework::Tensor* scale, const framework::Tensor *in,
T max_range, const int quant_axis, framework::Tensor* out) { const framework::Tensor *scale,
T max_range,
const int quant_axis,
framework::Tensor *out) {
// Dequant op is before quantized op // Dequant op is before quantized op
// Dequantize the weight of quantized op // Dequantize the weight of quantized op
auto in_dims = in->dims(); auto in_dims = in->dims();
const int64_t channel = in_dims[quant_axis]; const int64_t channel = in_dims[quant_axis];
const T* scale_factor = scale->data<T>(); const T *scale_factor = scale->data<T>();
if (quant_axis == 0) { if (quant_axis == 0) {
for (int64_t i = 0; i < channel; i++) { for (int64_t i = 0; i < channel; i++) {
T s = scale_factor[i]; T s = scale_factor[i];
...@@ -39,7 +44,7 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> { ...@@ -39,7 +44,7 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
framework::Tensor one_channel_out = out->Slice(i, i + 1); framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in); auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out); auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device(); auto &dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * s / max_range; out_e.device(dev) = in_e * s / max_range;
} }
} else if (quant_axis == 1) { } else if (quant_axis == 1) {
...@@ -49,12 +54,12 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> { ...@@ -49,12 +54,12 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
} }
int64_t step_i = in->numel() / out_iter; int64_t step_i = in->numel() / out_iter;
int64_t step_j = in->numel() / (out_iter * channel); int64_t step_j = in->numel() / (out_iter * channel);
auto* in_data = in->data<T>(); auto *in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(dev_ctx.GetPlace()); auto *out_data = out->mutable_data<T>(dev_ctx.GetPlace());
for (int64_t i = 0; i < out_iter; i++) { for (int64_t i = 0; i < out_iter; i++) {
for (int64_t j = 0; j < channel; j++) { for (int64_t j = 0; j < channel; j++) {
auto* cur_in = in_data + i * step_i + j * step_j; auto *cur_in = in_data + i * step_i + j * step_j;
auto* cur_out = out_data + i * step_i + j * step_j; auto *cur_out = out_data + i * step_i + j * step_j;
T s = scale_factor[j]; T s = scale_factor[j];
for (int64_t k = 0; k < step_j; k++) { for (int64_t k = 0; k < step_j; k++) {
*cur_out = (*cur_in) * s / max_range; *cur_out = (*cur_in) * s / max_range;
...@@ -67,19 +72,17 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> { ...@@ -67,19 +72,17 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
} }
}; };
template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
template struct DequantizeFunctor<platform::CPUDeviceContext, double>;
template struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, float>; template struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, float>;
template struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, double>; template struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, double>;
class QuantizeLinearOp : public framework::OperatorWithKernel { class QuantizeLinearOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "QuantizeLinear"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "QuantizeLinear"); OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint", OP_INOUT_CHECK(
"QuantizeLinear"); ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "QuantizeLinear"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "QuantizeLinear");
ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
int quant_axis = ctx->Attrs().Get<int>("quant_axis"); int quant_axis = ctx->Attrs().Get<int>("quant_axis");
...@@ -95,7 +98,7 @@ class QuantizeLinearOp : public framework::OperatorWithKernel { ...@@ -95,7 +98,7 @@ class QuantizeLinearOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
...@@ -116,9 +119,10 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -116,9 +119,10 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"For conv2d, depthwise_conv2d, conv2d_transpose " "For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.") "and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0) .SetDefault(0)
.AddCustomChecker([](const int& quant_axis) { .AddCustomChecker([](const int &quant_axis) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1 || quant_axis == -1, true, quant_axis == 0 || quant_axis == 1 || quant_axis == -1,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but " "'quant_axis' should be 0 or 1, but "
"the received is %d", "the received is %d",
...@@ -126,13 +130,32 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -126,13 +130,32 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
}); });
AddAttr<int>("bit_length", "(int, default 8)") AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8) .SetDefault(8)
.AddCustomChecker([](const int& bit_length) { .AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true, PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but " "'bit_length' should be between 1 and 16, but "
"the received is %d", "the received is %d",
bit_length)); bit_length));
}); });
AddAttr<int>(
"round_type",
"(int, default 0) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3")
.SetDefault(0)
.AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
platform::errors::InvalidArgument(
"'round_type' should be 0 or 1, 0 rounding to "
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type));
})
.AsExtra();
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
...@@ -156,14 +179,18 @@ namespace ops = paddle::operators; ...@@ -156,14 +179,18 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR( REGISTER_OPERATOR(
quantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker, quantize_linear,
ops::QuantizeLinearOp,
ops::QuantizeLinearOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(quantize_linear, ops::QuantizeLinearKernel<CPU, float>); REGISTER_OP_CPU_KERNEL(quantize_linear, ops::QuantizeLinearKernel<CPU, float>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
dequantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker, dequantize_linear,
ops::QuantizeLinearOp,
ops::QuantizeLinearOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......
...@@ -29,9 +29,13 @@ namespace operators { ...@@ -29,9 +29,13 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctorV2 { struct ChannelDequantizeFunctorV2 {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, void operator()(const DeviceContext& dev_ctx,
const framework::Tensor** scales, const int scale_num, const framework::Tensor* in,
T max_range, const int quant_axis, framework::Tensor* out); const framework::Tensor** scales,
const int scale_num,
T max_range,
const int quant_axis,
framework::Tensor* out);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -44,6 +48,7 @@ class QuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -44,6 +48,7 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
auto* out = context.Output<framework::Tensor>("Y"); auto* out = context.Output<framework::Tensor>("Y");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length"); int bit_length = context.Attr<int>("bit_length");
int round_type = context.Attr<int>("round_type");
int bin_cnt = std::pow(2, bit_length - 1) - 1; int bin_cnt = std::pow(2, bit_length - 1) - 1;
int quant_axis = context.Attr<int>("quant_axis"); int quant_axis = context.Attr<int>("quant_axis");
bool is_test = context.Attr<bool>("is_test"); bool is_test = context.Attr<bool>("is_test");
...@@ -53,25 +58,25 @@ class QuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -53,25 +58,25 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
if (!is_test) { if (!is_test) {
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_s = out_scale->mutable_data<T>(context.GetPlace()); T* out_s = out_scale->mutable_data<T>(context.GetPlace());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), FindAbsMaxFunctor<DeviceContext, T>()(
in->numel(), out_s); dev_ctx, in->data<T>(), in->numel(), out_s);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale, ClipAndFakeQuantFunctor<DeviceContext, T>()(
bin_cnt, out); dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
} else { } else {
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *in_scale, ClipAndFakeQuantFunctor<DeviceContext, T>()(
bin_cnt, out); dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
} }
} else { } else {
if (!is_test) { if (!is_test) {
auto* out_scale = context.Output<framework::Tensor>("OutScale"); auto* out_scale = context.Output<framework::Tensor>("OutScale");
T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace()); T* out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
FindChannelAbsMaxFunctor<DeviceContext, T>()(dev_ctx, *in, quant_axis, FindChannelAbsMaxFunctor<DeviceContext, T>()(
out_scale_data); dev_ctx, *in, quant_axis, out_scale_data);
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()( ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
} else { } else {
ChannelClipAndFakeQuantFunctor<DeviceContext, T>()( ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *in_scale, bin_cnt, quant_axis, out); dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out);
} }
} }
} }
...@@ -87,7 +92,8 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -87,7 +92,8 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> {
auto in_tmp = phi::Cast<T>( auto in_tmp = phi::Cast<T>(
static_cast<const typename paddle::framework::ConvertToPhiContext< static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx), DeviceContext>::TYPE&>(dev_ctx),
*in, experimental::CppTypeToDataType<D>::Type()); *in,
experimental::CppTypeToDataType<D>::Type());
auto* scale = context.Input<framework::Tensor>("Scale"); auto* scale = context.Input<framework::Tensor>("Scale");
auto* out = context.Output<framework::Tensor>("Y"); auto* out = context.Output<framework::Tensor>("Y");
...@@ -97,16 +103,18 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> { ...@@ -97,16 +103,18 @@ class DeQuantizeLinearKernel : public framework::OpKernel<T> {
if (quant_axis < 0) { if (quant_axis < 0) {
float max_range = (std::pow(2, bit_length - 1) - 1); float max_range = (std::pow(2, bit_length - 1) - 1);
DequantizeFunctor<DeviceContext, D>()(dev_ctx, &in_tmp, scale, DequantizeFunctor<DeviceContext, D>()(
static_cast<D>(max_range), out); dev_ctx, &in_tmp, scale, static_cast<D>(max_range), out);
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
scale->numel(), in_tmp.dims()[quant_axis], scale->numel(),
in_tmp.dims()[quant_axis],
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"The number of first scale values must be the same with " "The number of first scale values must be the same with "
"quant_axis dimension value of Input(X) when the `scale` has " "quant_axis dimension value of Input(X) when the `scale` has "
"only one element, but %ld != %ld here.", "only one element, but %ld != %ld here.",
scale->numel(), in_tmp.dims()[quant_axis])); scale->numel(),
in_tmp.dims()[quant_axis]));
int max_range = (std::pow(2, bit_length - 1) - 1); int max_range = (std::pow(2, bit_length - 1) - 1);
ChannelDequantizeFunctorV2<DeviceContext, D>()( ChannelDequantizeFunctorV2<DeviceContext, D>()(
......
...@@ -20,26 +20,31 @@ import logging ...@@ -20,26 +20,31 @@ import logging
import paddle.fluid as fluid import paddle.fluid as fluid
from ....log_helper import get_logger from ....log_helper import get_logger
from .utils import load_variable_data, set_variable_data, stable_sigmoid, quant_tensor, dequant_tensor, _channelwise_quant_axis1_ops, calculate_quant_cos_error from .utils import load_variable_data, set_variable_data, stable_sigmoid, quant_tensor, dequant_tensor, _channelwise_quant_axis1_ops, calculate_quant_cos_error, bias_correction_w
_logger = get_logger( _logger = get_logger(__name__,
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
GAMMA = -0.1 GAMMA = -0.1
ZETA = 1.1 ZETA = 1.1
def compute_soft_rounding(alpha_v): def compute_soft_rounding(alpha_v):
return fluid.layers.clip( return fluid.layers.clip(fluid.layers.sigmoid(alpha_v) * (ZETA - GAMMA) +
fluid.layers.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, min=0, max=1) GAMMA,
min=0,
max=1)
def compute_soft_rounding_np(alpha_v): def compute_soft_rounding_np(alpha_v):
return np.clip( return np.clip(stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA,
stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, a_min=0, a_max=1) a_min=0,
a_max=1)
class AdaRoundLoss(object): class AdaRoundLoss(object):
def __init__(self, reg_param=0.01, default_beta_range=(20, 2)): def __init__(self, reg_param=0.01, default_beta_range=(20, 2)):
self.default_reg_param = reg_param self.default_reg_param = reg_param
self.default_beta_range = default_beta_range self.default_beta_range = default_beta_range
...@@ -48,26 +53,29 @@ class AdaRoundLoss(object): ...@@ -48,26 +53,29 @@ class AdaRoundLoss(object):
square_cost = fluid.layers.square_error_cost(ada_quantized_output, square_cost = fluid.layers.square_error_cost(ada_quantized_output,
orig_output) orig_output)
recon_loss = fluid.layers.reduce_mean( recon_loss = fluid.layers.reduce_mean(
fluid.layers.reduce_sum( fluid.layers.reduce_sum(square_cost, dim=-1))
square_cost, dim=-1))
return recon_loss return recon_loss
def compute_round_loss(self, alpha_v, warm_start, beta): def compute_round_loss(self, alpha_v, warm_start, beta):
def round_loss_fn(): def round_loss_fn():
# compute rectified sigmoid of parameter 'alpha' which maps it between zero and one # compute rectified sigmoid of parameter 'alpha' which maps it between zero and one
h_v = compute_soft_rounding(alpha_v) h_v = compute_soft_rounding(alpha_v)
# calculate regularization term - which ensures parameter to converge to exactly zeros and ones # calculate regularization term - which ensures parameter to converge to exactly zeros and ones
# at the end of optimization # at the end of optimization
reg_term = fluid.layers.reduce_sum(-fluid.layers.pow( reg_term = fluid.layers.reduce_sum(
fluid.layers.abs(2 * h_v - 1), factor=beta) + 1) -fluid.layers.pow(fluid.layers.abs(2 * h_v - 1), factor=beta) +
1)
# calculate the rounding loss # calculate the rounding loss
round_loss = self.default_reg_param * reg_term round_loss = self.default_reg_param * reg_term
return round_loss return round_loss
round_loss = fluid.layers.cond(warm_start, lambda: fluid.layers.fill_constant(shape=[1], dtype='float32', value=0.0), round_loss_fn) round_loss = fluid.layers.cond(
warm_start, lambda: fluid.layers.fill_constant(
shape=[1], dtype='float32', value=0.0), round_loss_fn)
return round_loss return round_loss
...@@ -80,15 +88,16 @@ class AdaRoundLoss(object): ...@@ -80,15 +88,16 @@ class AdaRoundLoss(object):
warm_start_end_iter = warm_start * max_iter warm_start_end_iter = warm_start * max_iter
# compute relative iteration of current iteration # compute relative iteration of current iteration
rel_iter = (cur_iter - warm_start_end_iter) / ( rel_iter = (cur_iter - warm_start_end_iter) / (max_iter -
max_iter - warm_start_end_iter) warm_start_end_iter)
beta = end_beta + 0.5 * (start_beta - end_beta) * (1 + np.cos(rel_iter * beta = end_beta + 0.5 * (start_beta -
np.pi)) end_beta) * (1 + np.cos(rel_iter * np.pi))
return beta return beta
class AdaRound(object): class AdaRound(object):
def __init__(self, def __init__(self,
scale, scale,
weight_tensor, weight_tensor,
...@@ -145,8 +154,7 @@ class AdaRound(object): ...@@ -145,8 +154,7 @@ class AdaRound(object):
h_alpha = compute_soft_rounding_np(np_alpha) h_alpha = compute_soft_rounding_np(np_alpha)
# Scale the tensor # Scale the tensor
tensor_scale = quant_tensor( tensor_scale = quant_tensor(self.ori_weight_tensor.copy(),
self.ori_weight_tensor.copy(),
self.scale, self.scale,
quant_axis=self.quant_axis) quant_axis=self.quant_axis)
...@@ -160,8 +168,8 @@ class AdaRound(object): ...@@ -160,8 +168,8 @@ class AdaRound(object):
weight_tensor_quant = self._calculate_quant_weight() weight_tensor_quant = self._calculate_quant_weight()
# Dequantize the tensor # Dequantize the tensor
weight_tensor_dequant = dequant_tensor( weight_tensor_dequant = dequant_tensor(weight_tensor_quant +
weight_tensor_quant + self.offset, self.offset,
self.scale, self.scale,
quant_axis=self.quant_axis) quant_axis=self.quant_axis)
return weight_tensor_dequant return weight_tensor_dequant
...@@ -171,10 +179,10 @@ class AdaRound(object): ...@@ -171,10 +179,10 @@ class AdaRound(object):
return weight_tensor_quant return weight_tensor_quant
def get_loss(self, beta, warm_start, adaround_out_tensor, orig_out_tensor): def get_loss(self, beta, warm_start, adaround_out_tensor, orig_out_tensor):
round_loss = self.adaround_loss.compute_round_loss(self.alpha_v, round_loss = self.adaround_loss.compute_round_loss(
warm_start, beta) self.alpha_v, warm_start, beta)
recon_loss = self.adaround_loss.compute_recon_loss(adaround_out_tensor, recon_loss = self.adaround_loss.compute_recon_loss(
orig_out_tensor) adaround_out_tensor, orig_out_tensor)
loss = round_loss + recon_loss loss = round_loss + recon_loss
losses = { losses = {
'loss': loss, 'loss': loss,
...@@ -201,6 +209,7 @@ def run_adaround(data_loader, ...@@ -201,6 +209,7 @@ def run_adaround(data_loader,
scale_dict, scale_dict,
num_iterations=1000, num_iterations=1000,
lr=0.001, lr=0.001,
bias_correction=False,
fast_mode=True): fast_mode=True):
fetch_op_name = fetch_list[0].name fetch_op_name = fetch_list[0].name
final_weight_tensor_quant_dict = {} final_weight_tensor_quant_dict = {}
...@@ -226,28 +235,28 @@ def run_adaround(data_loader, ...@@ -226,28 +235,28 @@ def run_adaround(data_loader,
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
# initialize adaround # initialize adaround
adaround = AdaRound( adaround = AdaRound(scale,
scale,
weight_var_tensor, weight_var_tensor,
scope=scope, scope=scope,
weight_var_name=weight_var_name, weight_var_name=weight_var_name,
weight_op_type=weight_op_type, weight_op_type=weight_op_type,
num_iterations=num_iterations) num_iterations=num_iterations)
orig_out_tensor = fluid.data( orig_out_tensor = fluid.data(name='orig_out_tensor',
name='orig_out_tensor',
shape=fp32_fetch_list.shape, shape=fp32_fetch_list.shape,
dtype='float32') dtype='float32')
adaround_out_tensor = fluid.data( adaround_out_tensor = fluid.data(name='adaround_out_tensor',
name='adaround_out_tensor',
shape=fp32_fetch_list.shape, shape=fp32_fetch_list.shape,
dtype='float32') dtype='float32')
beta_tensor = fluid.data( beta_tensor = fluid.data(name='beta',
name='beta', shape=[1], dtype='float32') shape=[1],
warm_start_tensor = fluid.data( dtype='float32')
name='warm_start', shape=[1], dtype='bool') warm_start_tensor = fluid.data(name='warm_start',
shape=[1],
dtype='bool')
train_fetches_loss = adaround.get_loss( train_fetches_loss = adaround.get_loss(beta_tensor,
beta_tensor, warm_start_tensor, adaround_out_tensor, warm_start_tensor,
adaround_out_tensor,
orig_out_tensor) orig_out_tensor)
optimizer = fluid.optimizer.Adam(learning_rate=lr) optimizer = fluid.optimizer.Adam(learning_rate=lr)
loss = train_fetches_loss['loss'] loss = train_fetches_loss['loss']
...@@ -291,16 +300,23 @@ def run_adaround(data_loader, ...@@ -291,16 +300,23 @@ def run_adaround(data_loader,
fetch_list=[v.name for v in train_fetches_loss.values()], fetch_list=[v.name for v in train_fetches_loss.values()],
return_numpy=True) return_numpy=True)
_logger.info( _logger.info(
"Iter {:d}, lr {:.5f}, loss {:.5f}, loss_round {:.5f}, loss_recon {:.5f}, time {:.5f}s". "Iter {:d}, lr {:.5f}, loss {:.5f}, loss_round {:.5f}, loss_recon {:.5f}, time {:.5f}s"
format(i, lr, .format(i, lr, np.mean(out[0]), np.mean(out[1]),
np.mean(out[0]),
np.mean(out[1]),
np.mean(out[2]), start_time - prev_start_time)) np.mean(out[2]), start_time - prev_start_time))
sys.stdout.flush() sys.stdout.flush()
if i == num_iterations: if i == num_iterations:
break break
final_weight_tensor_quant_dict[ final_weight_tensor_quant_dict[
weight_var_name] = adaround.update_final_weights() weight_var_name] = adaround.update_final_weights()
if bias_correction:
final_weight_tensor_quant_dict[weight_var_name] = bias_correction_w(
weight_var_tensor,
final_weight_tensor_quant_dict[weight_var_name],
scale,
adaround.quant_axis,
weight_bits=adaround.weight_bits)
del adaround del adaround
# update adarounded calibrated weights # update adarounded calibrated weights
......
...@@ -36,8 +36,9 @@ from . import utils ...@@ -36,8 +36,9 @@ from . import utils
__all__ = ['PostTrainingQuantization', 'WeightQuantization'] __all__ = ['PostTrainingQuantization', 'WeightQuantization']
_logger = get_logger( _logger = get_logger(__name__,
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
def _all_persistable_var_names(program): def _all_persistable_var_names(program):
...@@ -88,7 +89,8 @@ def _apply_pass(scope, ...@@ -88,7 +89,8 @@ def _apply_pass(scope,
cpp_graph.set_not_owned('__param_scope__', scope) cpp_graph.set_not_owned('__param_scope__', scope)
if attrs: if attrs:
assert attr_values and len(attrs) == len( assert attr_values and len(attrs) == len(
attr_values), "Different number of pass attributes and their values." attr_values
), "Different number of pass attributes and their values."
for attr, value in zip(attrs, attr_values): for attr, value in zip(attrs, attr_values):
ir_pass.set(attr, value) ir_pass.set(attr, value)
ir_pass.apply(cpp_graph) ir_pass.apply(cpp_graph)
...@@ -180,7 +182,8 @@ class PostTrainingQuantization(object): ...@@ -180,7 +182,8 @@ class PostTrainingQuantization(object):
"mul"]. "mul"].
round_type(str, optional): The method of converting the quantized weights round_type(str, optional): The method of converting the quantized weights
value float->int. Currently supports ['round', 'adaround'] methods. value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the nearest whole number. Default is `round`, which is rounding nearest to the integer.
'adaround' is refer to https://arxiv.org/abs/2004.10568.
learning_rate(float, optional): The learning rate of adaround method. learning_rate(float, optional): The learning rate of adaround method.
is_full_quantized(bool, optional): If set is_full_quantized as True, is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set apply quantization to all supported quantizable op type. If set
...@@ -364,7 +367,8 @@ class PostTrainingQuantization(object): ...@@ -364,7 +367,8 @@ class PostTrainingQuantization(object):
batch_id = 0 batch_id = 0
with tqdm( with tqdm(
total=self._batch_nums, total=self._batch_nums,
bar_format='Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', bar_format=
'Preparation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t: ncols=80) as t:
for data in self._data_loader(): for data in self._data_loader():
self._executor.run(program=self._program, self._executor.run(program=self._program,
...@@ -380,9 +384,9 @@ class PostTrainingQuantization(object): ...@@ -380,9 +384,9 @@ class PostTrainingQuantization(object):
self._init_sampling_act_histogram() self._init_sampling_act_histogram()
batch_id = 0 batch_id = 0
with tqdm( with tqdm(total=self._batch_nums,
total=self._batch_nums, bar_format=
bar_format='Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}', 'Sampling stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t: ncols=80) as t:
for data in self._data_loader(): for data in self._data_loader():
self._executor.run(program=self._program, self._executor.run(program=self._program,
...@@ -446,8 +450,7 @@ class PostTrainingQuantization(object): ...@@ -446,8 +450,7 @@ class PostTrainingQuantization(object):
scale_dict = self._quantized_var_threshold scale_dict = self._quantized_var_threshold
else: else:
scale_dict = self._quantized_threshold scale_dict = self._quantized_threshold
run_adaround( run_adaround(self._data_loader,
self._data_loader,
self._program, self._program,
self._fetch_list, self._fetch_list,
self._executor, self._executor,
...@@ -457,6 +460,7 @@ class PostTrainingQuantization(object): ...@@ -457,6 +460,7 @@ class PostTrainingQuantization(object):
self._weight_op_pairs, self._weight_op_pairs,
scale_dict, scale_dict,
num_iterations=self._batch_nums, num_iterations=self._batch_nums,
bias_correction=self._bias_correction,
lr=self._learning_rate) lr=self._learning_rate)
def save_quantized_model(self, def save_quantized_model(self,
...@@ -478,8 +482,7 @@ class PostTrainingQuantization(object): ...@@ -478,8 +482,7 @@ class PostTrainingQuantization(object):
None None
''' '''
clip_extra = True if self._onnx_format else False clip_extra = True if self._onnx_format else False
io.save_inference_model( io.save_inference_model(dirname=save_model_path,
dirname=save_model_path,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename, params_filename=params_filename,
feeded_var_names=self._feed_list, feeded_var_names=self._feed_list,
...@@ -508,17 +511,18 @@ class PostTrainingQuantization(object): ...@@ -508,17 +511,18 @@ class PostTrainingQuantization(object):
if self._data_loader is not None: if self._data_loader is not None:
return return
self._data_loader = io.DataLoader.from_generator( self._data_loader = io.DataLoader.from_generator(feed_list=feed_vars,
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True) capacity=3 *
self._batch_size,
iterable=True)
if self._sample_generator is not None: if self._sample_generator is not None:
self._data_loader.set_sample_generator( self._data_loader.set_sample_generator(self._sample_generator,
self._sample_generator,
batch_size=self._batch_size, batch_size=self._batch_size,
drop_last=True, drop_last=True,
places=self._place) places=self._place)
elif self._batch_generator is not None: elif self._batch_generator is not None:
self._data_loader.set_batch_generator( self._data_loader.set_batch_generator(self._batch_generator,
self._batch_generator, places=self._place) places=self._place)
def _optimize_fp32_model(self): def _optimize_fp32_model(self):
''' '''
...@@ -569,11 +573,9 @@ class PostTrainingQuantization(object): ...@@ -569,11 +573,9 @@ class PostTrainingQuantization(object):
" is not supported for quantization.") " is not supported for quantization.")
# For quantized ops, sample inputs and outputs # For quantized ops, sample inputs and outputs
if op_type in self._quantizable_op_type: if op_type in self._quantizable_op_type:
collect_var_name( collect_var_name(utils._get_op_input_var_names(op),
utils._get_op_input_var_names(op),
persistable_var_names, op_type) persistable_var_names, op_type)
collect_var_name( collect_var_name(utils._get_op_output_var_names(op),
utils._get_op_output_var_names(op),
persistable_var_names, op_type) persistable_var_names, op_type)
# collect quanted op output var name # collect quanted op output var name
for out_var_name in utils._get_op_output_var_names(op): for out_var_name in utils._get_op_output_var_names(op):
...@@ -583,8 +585,7 @@ class PostTrainingQuantization(object): ...@@ -583,8 +585,7 @@ class PostTrainingQuantization(object):
in_var_name] = out_var_name in_var_name] = out_var_name
# For other op, only sample output scale # For other op, only sample output scale
elif op_type in self._out_scale_op_list: elif op_type in self._out_scale_op_list:
collect_var_name( collect_var_name(utils._get_op_output_var_names(op),
utils._get_op_output_var_names(op),
persistable_var_names, op_type) persistable_var_names, op_type)
def _set_activation_persistable(self): def _set_activation_persistable(self):
...@@ -655,6 +656,11 @@ class PostTrainingQuantization(object): ...@@ -655,6 +656,11 @@ class PostTrainingQuantization(object):
scale = s * abs_max_value scale = s * abs_max_value
s += 0.02 s += 0.02
bins = 2**(self._activation_bits - 1) - 1 bins = 2**(self._activation_bits - 1) - 1
if self._onnx_format:
quant_var = np.clip(np.round(var_tensor / scale * bins),
-bins - 1, bins)
quant_dequant_var = quant_var / bins * scale
else:
quant_dequant_var = np.round( quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale * np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale bins) / bins * scale
...@@ -694,6 +700,11 @@ class PostTrainingQuantization(object): ...@@ -694,6 +700,11 @@ class PostTrainingQuantization(object):
scale = s * abs_max_value scale = s * abs_max_value
s += 0.02 s += 0.02
bins = 2**(self._activation_bits - 1) - 1 bins = 2**(self._activation_bits - 1) - 1
if self._onnx_format:
quant_var = np.clip(np.round(var_tensor / scale * bins),
-bins - 1, bins)
quant_dequant_var = quant_var / bins * scale
else:
quant_dequant_var = np.round( quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale * np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale bins) / bins * scale
...@@ -846,8 +857,9 @@ class PostTrainingQuantization(object): ...@@ -846,8 +857,9 @@ class PostTrainingQuantization(object):
if var_name not in self._sampling_act_histogram: if var_name not in self._sampling_act_histogram:
min_val = self._sampling_act_abs_min_max[var_name][0] min_val = self._sampling_act_abs_min_max[var_name][0]
max_val = self._sampling_act_abs_min_max[var_name][1] max_val = self._sampling_act_abs_min_max[var_name][1]
hist, hist_edeges = np.histogram( hist, hist_edeges = np.histogram([],
[], bins=self._histogram_bins, range=(min_val, max_val)) bins=self._histogram_bins,
range=(min_val, max_val))
self._sampling_act_histogram[var_name] = [hist, hist_edeges] self._sampling_act_histogram[var_name] = [hist, hist_edeges]
def _calculate_kl_hist_threshold(self): def _calculate_kl_hist_threshold(self):
...@@ -951,18 +963,11 @@ class PostTrainingQuantization(object): ...@@ -951,18 +963,11 @@ class PostTrainingQuantization(object):
else: else:
scale_dict = self._quantized_threshold scale_dict = self._quantized_threshold
for key, val in scale_dict.items(): for key, val in scale_dict.items():
utils.set_variable_data( utils.set_variable_data(self._scope, self._place, key + ".scale",
self._scope, np.array([val], dtype=np.float32))
self._place, utils.set_variable_data(self._scope, self._place,
key + ".scale",
np.array(
[val], dtype=np.float32))
utils.set_variable_data(
self._scope,
self._place,
key + ".quant_dequant.scale", key + ".quant_dequant.scale",
np.array( np.array([val], dtype=np.float32))
[val], dtype=np.float32))
if not self._onnx_format: if not self._onnx_format:
# apply QuantizationFreezePass, and obtain the final quant model # apply QuantizationFreezePass, and obtain the final quant model
...@@ -1038,8 +1043,8 @@ class PostTrainingQuantization(object): ...@@ -1038,8 +1043,8 @@ class PostTrainingQuantization(object):
for block_id in range(len(self._program.blocks)): for block_id in range(len(self._program.blocks)):
for op in self._program.blocks[block_id].ops: for op in self._program.blocks[block_id].ops:
if op.type in ( if op.type in (self._quantizable_op_type +
self._quantizable_op_type + self._out_scale_op_list): self._out_scale_op_list):
out_var_names = utils._get_op_output_var_names(op) out_var_names = utils._get_op_output_var_names(op)
for var_name in out_var_names: for var_name in out_var_names:
analysis_and_save_info(op, var_name) analysis_and_save_info(op, var_name)
...@@ -1175,9 +1180,10 @@ class WeightQuantization(object): ...@@ -1175,9 +1180,10 @@ class WeightQuantization(object):
if generate_test_model: if generate_test_model:
test_model_dir = os.path.join(save_model_dir, "test_model") test_model_dir = os.path.join(save_model_dir, "test_model")
self._quantize_weight_to_int( self._quantize_weight_to_int(test_model_dir, save_model_filename,
test_model_dir, save_model_filename, save_params_filename, save_params_filename,
quantizable_op_type, weight_bits, weight_quantize_type, True, quantizable_op_type, weight_bits,
weight_quantize_type, True,
threshold_rate) threshold_rate)
def convert_weight_to_fp16(self, save_model_dir): def convert_weight_to_fp16(self, save_model_dir):
...@@ -1216,15 +1222,16 @@ class WeightQuantization(object): ...@@ -1216,15 +1222,16 @@ class WeightQuantization(object):
if self._params_filename is not None: if self._params_filename is not None:
save_var_map[new_var.name] = new_var save_var_map[new_var.name] = new_var
else: else:
save_file_path = os.path.join( save_file_path = os.path.join(os.path.normpath(save_model_dir),
os.path.normpath(save_model_dir), new_var.name) new_var.name)
save_block.append_op( save_block.append_op(type='save',
type='save',
inputs={'X': [new_var]}, inputs={'X': [new_var]},
outputs={}, outputs={},
attrs={ attrs={
'file_path': os.path.normpath(save_file_path), 'file_path':
'save_as_fp16': True os.path.normpath(save_file_path),
'save_as_fp16':
True
}) })
if self._params_filename is not None: if self._params_filename is not None:
...@@ -1237,14 +1244,15 @@ class WeightQuantization(object): ...@@ -1237,14 +1244,15 @@ class WeightQuantization(object):
name=unique_name.generate("saved_params")) name=unique_name.generate("saved_params"))
saved_params_var.desc.set_persistable(True) saved_params_var.desc.set_persistable(True)
save_path = os.path.join( save_path = os.path.join(os.path.normpath(save_model_dir),
os.path.normpath(save_model_dir), self._params_filename) self._params_filename)
save_block.append_op( save_block.append_op(type='save_combine',
type='save_combine',
inputs={'X': save_var_list}, inputs={'X': save_var_list},
outputs={'Y': saved_params_var}, outputs={'Y': saved_params_var},
attrs={'file_path': save_path, attrs={
'save_as_fp16': True}) 'file_path': save_path,
'save_as_fp16': True
})
save_program._sync_with_cpp() save_program._sync_with_cpp()
exe.run(save_program) exe.run(save_program)
...@@ -1293,8 +1301,7 @@ class WeightQuantization(object): ...@@ -1293,8 +1301,7 @@ class WeightQuantization(object):
self._weight_channel_wise_abs_max_quantization( self._weight_channel_wise_abs_max_quantization(
scope, place, weight_bits, op, var_name, for_test) scope, place, weight_bits, op, var_name, for_test)
io.save_inference_model( io.save_inference_model(dirname=save_model_dir,
dirname=save_model_dir,
feeded_var_names=feed_list, feeded_var_names=feed_list,
target_vars=fetch_list, target_vars=fetch_list,
executor=exe, executor=exe,
...@@ -1339,8 +1346,9 @@ class WeightQuantization(object): ...@@ -1339,8 +1346,9 @@ class WeightQuantization(object):
op._set_attr(var_name + "_quant_scale", [scale]) # Save as list op._set_attr(var_name + "_quant_scale", [scale]) # Save as list
op._set_attr("with_quant_attr", True) op._set_attr("with_quant_attr", True)
def _weight_channel_wise_abs_max_quantization( def _weight_channel_wise_abs_max_quantization(self, scope, place,
self, scope, place, weight_bits, op, var_name, for_test): weight_bits, op, var_name,
for_test):
''' '''
Use channel_wise_abs_max method to quantize weight. Use channel_wise_abs_max method to quantize weight.
''' '''
...@@ -1390,8 +1398,8 @@ class WeightQuantization(object): ...@@ -1390,8 +1398,8 @@ class WeightQuantization(object):
and quantize the weights. and quantize the weights.
''' '''
scales = [] scales = []
quantized_weight_data = np.zeros_like( quantized_weight_data = np.zeros_like(weight_data,
weight_data, dtype=save_weight_dtype) dtype=save_weight_dtype)
channel_num = weight_data.shape[0] channel_num = weight_data.shape[0]
for i in range(channel_num): for i in range(channel_num):
scale = np.max(np.abs(weight_data[i])) / quantize_range scale = np.max(np.abs(weight_data[i])) / quantize_range
...@@ -1404,8 +1412,8 @@ class WeightQuantization(object): ...@@ -1404,8 +1412,8 @@ class WeightQuantization(object):
''' '''
For conv2d and depthwise_conv2d, dequantize the weights to fp32. For conv2d and depthwise_conv2d, dequantize the weights to fp32.
''' '''
dequantized_weight_data = np.zeros_like( dequantized_weight_data = np.zeros_like(quantized_weight_data,
quantized_weight_data, dtype=np.float32) dtype=np.float32)
for i in range(len(scales)): for i in range(len(scales)):
dequantized_weight_data[i] = \ dequantized_weight_data[i] = \
(quantized_weight_data[i] * scales[i]).astype(np.float32) (quantized_weight_data[i] * scales[i]).astype(np.float32)
...@@ -1418,8 +1426,8 @@ class WeightQuantization(object): ...@@ -1418,8 +1426,8 @@ class WeightQuantization(object):
and quantize the weights. and quantize the weights.
''' '''
scales = [] scales = []
quantized_weight_data = np.zeros_like( quantized_weight_data = np.zeros_like(weight_data,
weight_data, dtype=save_weight_dtype) dtype=save_weight_dtype)
channel_num = weight_data.shape[-1] channel_num = weight_data.shape[-1]
for i in range(channel_num): for i in range(channel_num):
scale = np.max(np.abs(weight_data[:, i])) / quantize_range scale = np.max(np.abs(weight_data[:, i])) / quantize_range
...@@ -1432,8 +1440,8 @@ class WeightQuantization(object): ...@@ -1432,8 +1440,8 @@ class WeightQuantization(object):
''' '''
For mul, dequantize the weights to fp32. For mul, dequantize the weights to fp32.
''' '''
dequantized_weight_data = np.zeros_like( dequantized_weight_data = np.zeros_like(quantized_weight_data,
quantized_weight_data, dtype=np.float32) dtype=np.float32)
for i in range(len(scales)): for i in range(len(scales)):
dequantized_weight_data[:, i] = \ dequantized_weight_data[:, i] = \
(quantized_weight_data[:, i] * scales[i]).astype(np.float32) (quantized_weight_data[:, i] * scales[i]).astype(np.float32)
...@@ -1441,8 +1449,9 @@ class WeightQuantization(object): ...@@ -1441,8 +1449,9 @@ class WeightQuantization(object):
def _calculate_threshold(self, input, threshold_rate, histogram_bins=5000): def _calculate_threshold(self, input, threshold_rate, histogram_bins=5000):
input_abs = np.abs(input) input_abs = np.abs(input)
hist, hist_edeges = np.histogram( hist, hist_edeges = np.histogram(input_abs,
input_abs, bins=histogram_bins, range=(0, np.max(input_abs))) bins=histogram_bins,
range=(0, np.max(input_abs)))
hist = hist / float(sum(hist)) hist = hist / float(sum(hist))
hist_sum = 0 hist_sum = 0
hist_index = 0 hist_index = 0
......
...@@ -307,8 +307,9 @@ class QuantizationTransformPass(object): ...@@ -307,8 +307,9 @@ class QuantizationTransformPass(object):
var_node = self._insert_func( var_node = self._insert_func(
graph, self._weight_preprocess_func, var_node, op) graph, self._weight_preprocess_func, var_node, op)
elif not is_weight and self._act_preprocess_func is not None: elif not is_weight and self._act_preprocess_func is not None:
var_node = self._insert_func( var_node = self._insert_func(graph,
graph, self._act_preprocess_func, var_node, op) self._act_preprocess_func,
var_node, op)
# if var node is weight and weight_quantize_func is not None, # if var node is weight and weight_quantize_func is not None,
# will insert weight quantize func to quantize and dequantize weight # will insert weight quantize func to quantize and dequantize weight
...@@ -376,9 +377,9 @@ class QuantizationTransformPass(object): ...@@ -376,9 +377,9 @@ class QuantizationTransformPass(object):
graph.out_node_mapping_table = dict() graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops. # The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
with tqdm( with tqdm(total=len(ops),
total=len(ops), bar_format=
bar_format='Adding quant op for weight:|{bar}| {n_fmt}/{total_fmt}', 'Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t: ncols=80) as t:
for op in ops: for op in ops:
if op.name() in self._quantizable_ops: if op.name() in self._quantizable_ops:
...@@ -405,12 +406,8 @@ class QuantizationTransformPass(object): ...@@ -405,12 +406,8 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=core.VarDesc.VarType.INT64) var_dtype=core.VarDesc.VarType.INT64)
_init_var_node( _init_var_node(global_step_in, np.zeros([1], dtype='int64'),
global_step_in, self._scope, self._place)
np.zeros(
[1], dtype='int64'),
self._scope,
self._place)
global_step_out = graph.create_var_node_from_desc( global_step_out = graph.create_var_node_from_desc(
global_step_in.var()) global_step_in.var())
# The attribute of `op_role` is needed by ParallelExecutor. # The attribute of `op_role` is needed by ParallelExecutor.
...@@ -459,12 +456,9 @@ class QuantizationTransformPass(object): ...@@ -459,12 +456,9 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node( _init_var_node(scale_var_node,
scale_var_node, np.zeros(scale_var_node.shape(), dtype=data_type),
np.zeros( self._scope, self._place)
scale_var_node.shape(), dtype=data_type),
self._scope,
self._place)
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max', op_type='fake_quantize_abs_max',
attrs={ attrs={
...@@ -472,8 +466,10 @@ class QuantizationTransformPass(object): ...@@ -472,8 +466,10 @@ class QuantizationTransformPass(object):
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}, },
inputs={'X': var_node}, inputs={'X': var_node},
outputs={'Out': quant_var_node, outputs={
'OutScale': scale_var_node}) 'Out': quant_var_node,
'OutScale': scale_var_node
})
graph.link_to(var_node, quant_op_node) graph.link_to(var_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node) graph.link_to(quant_op_node, quant_var_node)
graph.link_to(quant_op_node, scale_var_node) graph.link_to(quant_op_node, scale_var_node)
...@@ -498,12 +494,9 @@ class QuantizationTransformPass(object): ...@@ -498,12 +494,9 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node( _init_var_node(scale_in_node,
scale_in_node, np.array([_SCALE_DEFAULT_VALUE], dtype=data_type),
np.array( self._scope, self._place)
[_SCALE_DEFAULT_VALUE], dtype=data_type),
self._scope,
self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
inputs = {'X': var_node, 'InScale': scale_in_node} inputs = {'X': var_node, 'InScale': scale_in_node}
...@@ -518,12 +511,9 @@ class QuantizationTransformPass(object): ...@@ -518,12 +511,9 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node( _init_var_node(scales_node,
scales_node, np.zeros([self._window_size], dtype=data_type),
np.zeros( self._scope, self._place)
[self._window_size], dtype=data_type),
self._scope,
self._place)
inputs['Iter'] = self._global_step inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node outputs['OutScales'] = scales_node
...@@ -566,12 +556,9 @@ class QuantizationTransformPass(object): ...@@ -566,12 +556,9 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node( _init_var_node(scale_in_node,
scale_in_node, np.array([_SCALE_DEFAULT_VALUE], dtype=data_type),
np.array( self._scope, self._place)
[_SCALE_DEFAULT_VALUE], dtype=data_type),
self._scope,
self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
ins = {'X': var_node, 'InScale': scale_in_node} ins = {'X': var_node, 'InScale': scale_in_node}
...@@ -584,27 +571,19 @@ class QuantizationTransformPass(object): ...@@ -584,27 +571,19 @@ class QuantizationTransformPass(object):
shape=[1]) shape=[1])
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node( _init_var_node(state_in_node, np.ones([1], dtype=data_type),
state_in_node, self._scope, self._place)
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
accum_in_node = graph.create_persistable_node( accum_in_node = graph.create_persistable_node(
name=unique_name.generate('accum'), name=unique_name.generate('accum'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
shape=[1]) shape=[1])
_init_var_node( _init_var_node(accum_in_node, np.ones([1], dtype=data_type),
accum_in_node, self._scope, self._place)
np.ones( state_out_node = graph.create_var_node_from_desc(
[1], dtype=data_type), state_in_node.var())
self._scope, accum_out_node = graph.create_var_node_from_desc(
self._place) accum_in_node.var())
state_out_node = graph.create_var_node_from_desc(state_in_node.var(
))
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
))
ins['InState'] = state_in_node ins['InState'] = state_in_node
ins['InAccum'] = accum_in_node ins['InAccum'] = accum_in_node
...@@ -656,12 +635,9 @@ class QuantizationTransformPass(object): ...@@ -656,12 +635,9 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node( _init_var_node(scale_var_node,
scale_var_node, np.zeros(scale_var_node.shape(), dtype=data_type),
np.zeros( self._scope, self._place)
scale_var_node.shape(), dtype=data_type),
self._scope,
self._place)
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_channel_wise_quantize_abs_max', op_type='fake_channel_wise_quantize_abs_max',
attrs={ attrs={
...@@ -671,8 +647,10 @@ class QuantizationTransformPass(object): ...@@ -671,8 +647,10 @@ class QuantizationTransformPass(object):
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}, },
inputs={'X': var_node}, inputs={'X': var_node},
outputs={'Out': quant_var_node, outputs={
'OutScale': scale_var_node}) 'Out': quant_var_node,
'OutScale': scale_var_node
})
graph.link_to(var_node, quant_op_node) graph.link_to(var_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node) graph.link_to(quant_op_node, quant_var_node)
graph.link_to(quant_op_node, scale_var_node) graph.link_to(quant_op_node, scale_var_node)
...@@ -696,8 +674,10 @@ class QuantizationTransformPass(object): ...@@ -696,8 +674,10 @@ class QuantizationTransformPass(object):
'max_range': float(max_range), 'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}, },
inputs={'X': var_node, inputs={
'Scale': scale_var_node}, 'X': var_node,
'Scale': scale_var_node
},
outputs={'Out': dequant_var_node}) outputs={'Out': dequant_var_node})
graph.link_to(var_node, dequant_op_node) graph.link_to(var_node, dequant_op_node)
graph.link_to(scale_var_node, dequant_op_node) graph.link_to(scale_var_node, dequant_op_node)
...@@ -723,8 +703,10 @@ class QuantizationTransformPass(object): ...@@ -723,8 +703,10 @@ class QuantizationTransformPass(object):
'quant_axis': quant_axis, 'quant_axis': quant_axis,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}, },
inputs={'X': var_node, inputs={
'Scales': scale_var_nodes}, 'X': var_node,
'Scales': scale_var_nodes
},
outputs={'Out': dequant_var_node}) outputs={'Out': dequant_var_node})
graph.link_to(var_node, dequant_op_node) graph.link_to(var_node, dequant_op_node)
for scale_n in scale_var_nodes: for scale_n in scale_var_nodes:
...@@ -812,8 +794,7 @@ class QuantizationTransformPass(object): ...@@ -812,8 +794,7 @@ class QuantizationTransformPass(object):
startup_program = Program() startup_program = Program()
with program_guard(tmp_program, startup_program): with program_guard(tmp_program, startup_program):
with unique_name.guard(var_node.name() + "_"): with unique_name.guard(var_node.name() + "_"):
in_node = data( in_node = data(var_node.name() + '_tmp_input',
var_node.name() + '_tmp_input',
shape=var_node.shape(), shape=var_node.shape(),
dtype='float32') dtype='float32')
out_node = func(in_node) out_node = func(in_node)
...@@ -828,8 +809,8 @@ class QuantizationTransformPass(object): ...@@ -828,8 +809,8 @@ class QuantizationTransformPass(object):
with scope_guard(self._scope): with scope_guard(self._scope):
self._exe.run(startup_program) self._exe.run(startup_program)
tmp_graph = IrGraph( tmp_graph = IrGraph(core.Graph(tmp_program.desc),
core.Graph(tmp_program.desc), for_test=graph._for_test) for_test=graph._for_test)
in_node = tmp_graph._find_node_by_name(tmp_graph.all_var_nodes(), in_node = tmp_graph._find_node_by_name(tmp_graph.all_var_nodes(),
in_node.name) in_node.name)
out_node = tmp_graph._find_node_by_name(tmp_graph.all_var_nodes(), out_node = tmp_graph._find_node_by_name(tmp_graph.all_var_nodes(),
...@@ -870,9 +851,11 @@ class QuantizationTransformPass(object): ...@@ -870,9 +851,11 @@ class QuantizationTransformPass(object):
# find op's gradient op, such as conv2d_grad # find op's gradient op, such as conv2d_grad
op_grad = op_out_grad.outputs[0] op_grad = op_out_grad.outputs[0]
target_out_grad_node = graph._find_node_by_name( target_out_grad_node = graph._find_node_by_name(
graph.all_var_nodes(), target_out_node.name() + "@GRAD") graph.all_var_nodes(),
target_out_node.name() + "@GRAD")
in_node_grad = graph._find_node_by_name( in_node_grad = graph._find_node_by_name(
graph.all_var_nodes(), target_in_node.name() + "@GRAD") graph.all_var_nodes(),
target_in_node.name() + "@GRAD")
in_node_grad_op = in_node_grad.inputs in_node_grad_op = in_node_grad.inputs
# update op_grad's input # update op_grad's input
graph.update_input_link(var_node, target_out_node, op_grad) graph.update_input_link(var_node, target_out_node, op_grad)
...@@ -945,6 +928,7 @@ class QuantizationTransformPass(object): ...@@ -945,6 +928,7 @@ class QuantizationTransformPass(object):
class QuantizationFreezePass(object): class QuantizationFreezePass(object):
def __init__(self, def __init__(self,
scope, scope,
place, place,
...@@ -970,8 +954,9 @@ class QuantizationFreezePass(object): ...@@ -970,8 +954,9 @@ class QuantizationFreezePass(object):
weight_bits(int): quantization bit number for weights. weight_bits(int): quantization bit number for weights.
activation_bits(int): quantization bit number for activation. activation_bits(int): quantization bit number for activation.
round_type(str, optional): The method of converting the quantized weights round_type(str, optional): The method of converting the quantized weights
value from float to int. Currently supports ['round', 'adaround'] methods. value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the nearest whole number. Default is `round`, which is rounding nearest to the integer.
'adaround' is refer to https://arxiv.org/abs/2004.10568.
weight_quantize_type(str): quantization type for weights, support 'abs_max' and weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained. since weights are fixed once the model is well trained.
...@@ -1017,7 +1002,8 @@ class QuantizationFreezePass(object): ...@@ -1017,7 +1002,8 @@ class QuantizationFreezePass(object):
input_arg_name] input_arg_name]
if input_arg_name not in persistable_vars: if input_arg_name not in persistable_vars:
scale_v = graph._find_node_by_name( scale_v = graph._find_node_by_name(
op_node.outputs, op_node.output('OutScale')[0]) op_node.outputs,
op_node.output('OutScale')[0])
self._quant_var_scale_map[input_arg_name] = scale_v self._quant_var_scale_map[input_arg_name] = scale_v
else: else:
# Obtain scale from OutScale var node # Obtain scale from OutScale var node
...@@ -1033,8 +1019,8 @@ class QuantizationFreezePass(object): ...@@ -1033,8 +1019,8 @@ class QuantizationFreezePass(object):
scale_v = scale_v.tolist() scale_v = scale_v.tolist()
self._quant_var_scale_map[input_arg_name] = scale_v self._quant_var_scale_map[input_arg_name] = scale_v
# Quantize weight and restore # Quantize weight and restore
param_v = self._load_var(input_arg_name)
if self._round_type == 'round': if self._round_type == 'round':
param_v = self._load_var(input_arg_name)
if any( if any(
_check_grandchild_op_node(op_node, op) _check_grandchild_op_node(op_node, op)
for op in utils._channelwise_quant_axis1_ops): for op in utils._channelwise_quant_axis1_ops):
...@@ -1045,6 +1031,7 @@ class QuantizationFreezePass(object): ...@@ -1045,6 +1031,7 @@ class QuantizationFreezePass(object):
param_v.copy(), scale_v, quant_axis, param_v.copy(), scale_v, quant_axis,
self._weight_bits) self._weight_bits)
quantized_param_v = np.round(quantized_param_v) quantized_param_v = np.round(quantized_param_v)
# Weight bias correction
if self._bias_correction == True: if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w( quantized_param_v = utils.bias_correction_w(
param_v, param_v,
...@@ -1072,8 +1059,8 @@ class QuantizationFreezePass(object): ...@@ -1072,8 +1059,8 @@ class QuantizationFreezePass(object):
if self._weight_quantize_type == 'channel_wise_abs_max': if self._weight_quantize_type == 'channel_wise_abs_max':
quant_axis = 1 if op_node.name() in \ quant_axis = 1 if op_node.name() in \
utils._channelwise_quant_axis1_ops else 0 utils._channelwise_quant_axis1_ops else 0
self._insert_post_channel_dequant_op(graph, op_node, self._insert_post_channel_dequant_op(
quant_axis) graph, op_node, quant_axis)
else: else:
self._insert_post_dequant_op(graph, op_node) self._insert_post_dequant_op(graph, op_node)
...@@ -1128,7 +1115,8 @@ class QuantizationFreezePass(object): ...@@ -1128,7 +1115,8 @@ class QuantizationFreezePass(object):
" more than one output." % (op_node.name())) " more than one output." % (op_node.name()))
output_var_node = graph._find_node_by_name( output_var_node = graph._find_node_by_name(
op_node.outputs, op_node.output_arg_names()[0]) op_node.outputs,
op_node.output_arg_names()[0])
weight_scale_node = graph.create_persistable_node( weight_scale_node = graph.create_persistable_node(
name=unique_name.generate('channel_scale'), name=unique_name.generate('channel_scale'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
...@@ -1136,9 +1124,8 @@ class QuantizationFreezePass(object): ...@@ -1136,9 +1124,8 @@ class QuantizationFreezePass(object):
var_dtype=output_var_node.dtype()) var_dtype=output_var_node.dtype())
data_type = 'float64' if output_var_node.dtype( data_type = 'float64' if output_var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node(weight_scale_node, _init_var_node(weight_scale_node, channel_scale.astype(data_type),
channel_scale.astype(data_type), self._scope, self._scope, self._place)
self._place)
dequant_var_node = graph.create_var_node( dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()), name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(), var_type=output_var_node.type(),
...@@ -1201,7 +1188,8 @@ class QuantizationFreezePass(object): ...@@ -1201,7 +1188,8 @@ class QuantizationFreezePass(object):
" more than one output." % (op_node.name())) " more than one output." % (op_node.name()))
output_var_node = graph._find_node_by_name( output_var_node = graph._find_node_by_name(
op_node.outputs, op_node.output_arg_names()[0]) op_node.outputs,
op_node.output_arg_names()[0])
dequant_var_node = graph.create_var_node( dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()), name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(), var_type=output_var_node.type(),
...@@ -1213,8 +1201,10 @@ class QuantizationFreezePass(object): ...@@ -1213,8 +1201,10 @@ class QuantizationFreezePass(object):
'max_range': float(max_range), 'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward 'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}, },
inputs={'X': output_var_node, inputs={
'Scale': scale_var_node}, 'X': output_var_node,
'Scale': scale_var_node
},
outputs={'Out': dequant_var_node}) outputs={'Out': dequant_var_node})
graph.link_to(output_var_node, dequant_op_node) graph.link_to(output_var_node, dequant_op_node)
graph.link_to(scale_var_node, dequant_op_node) graph.link_to(scale_var_node, dequant_op_node)
...@@ -1273,6 +1263,7 @@ class QuantizationFreezePass(object): ...@@ -1273,6 +1263,7 @@ class QuantizationFreezePass(object):
class ConvertToInt8Pass(object): class ConvertToInt8Pass(object):
def __init__(self, scope, place, quantizable_op_type=None): def __init__(self, scope, place, quantizable_op_type=None):
""" """
Convert the weights into int8_t type. Convert the weights into int8_t type.
...@@ -1312,8 +1303,8 @@ class ConvertToInt8Pass(object): ...@@ -1312,8 +1303,8 @@ class ConvertToInt8Pass(object):
name = var_node.name() name = var_node.name()
if name in persistable_vars: if name in persistable_vars:
if name not in input_map: if name not in input_map:
int8_var_node = self._convert_to_int8(graph, int8_var_node = self._convert_to_int8(
var_node) graph, var_node)
input_map[name] = int8_var_node input_map[name] = int8_var_node
graph.update_input_link(var_node, input_map[name], graph.update_input_link(var_node, input_map[name],
op_node) op_node)
...@@ -1361,6 +1352,7 @@ class ConvertToInt8Pass(object): ...@@ -1361,6 +1352,7 @@ class ConvertToInt8Pass(object):
class TransformForMobilePass(object): class TransformForMobilePass(object):
def __init__(self): def __init__(self):
""" """
This pass is used to convert the frozen graph for paddle-mobile execution. This pass is used to convert the frozen graph for paddle-mobile execution.
...@@ -1403,6 +1395,7 @@ class TransformForMobilePass(object): ...@@ -1403,6 +1395,7 @@ class TransformForMobilePass(object):
class OutScaleForTrainingPass(object): class OutScaleForTrainingPass(object):
def __init__(self, scope=None, place=None, moving_rate=0.9): def __init__(self, scope=None, place=None, moving_rate=0.9):
""" """
This pass is used for calculating output scales of some operators. This pass is used for calculating output scales of some operators.
...@@ -1436,8 +1429,7 @@ class OutScaleForTrainingPass(object): ...@@ -1436,8 +1429,7 @@ class OutScaleForTrainingPass(object):
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._teller_set: if op.name() in self._teller_set:
target_ops.append(op) target_ops.append(op)
with tqdm( with tqdm(total=len(target_ops),
total=len(target_ops),
bar_format='Adding OutScale op:|{bar}| {n_fmt}/{total_fmt}', bar_format='Adding OutScale op:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t: ncols=80) as t:
for op in target_ops: for op in target_ops:
...@@ -1455,12 +1447,8 @@ class OutScaleForTrainingPass(object): ...@@ -1455,12 +1447,8 @@ class OutScaleForTrainingPass(object):
var_dtype=in_node.dtype()) var_dtype=in_node.dtype())
data_type = 'float64' if in_node.dtype() \ data_type = 'float64' if in_node.dtype() \
== core.VarDesc.VarType.FP64 else 'float32' == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node( _init_var_node(scale_node, np.ones([1], dtype=data_type),
scale_node, self._scope, self._place)
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
ins = {'X': in_node} ins = {'X': in_node}
outs = {'OutScale': scale_node} outs = {'OutScale': scale_node}
if not self._is_test: if not self._is_test:
...@@ -1469,23 +1457,17 @@ class OutScaleForTrainingPass(object): ...@@ -1469,23 +1457,17 @@ class OutScaleForTrainingPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=in_node.dtype(), var_dtype=in_node.dtype(),
shape=[1]) shape=[1])
_init_var_node( _init_var_node(state_in_node,
state_in_node, np.ones([1], dtype=data_type),
np.ones( self._scope, self._place)
[1], dtype=data_type),
self._scope,
self._place)
accum_in_node = graph.create_persistable_node( accum_in_node = graph.create_persistable_node(
name=unique_name.generate('scale_accum@'), name=unique_name.generate('scale_accum@'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=in_node.dtype(), var_dtype=in_node.dtype(),
shape=[1]) shape=[1])
_init_var_node( _init_var_node(accum_in_node,
accum_in_node, np.ones([1], dtype=data_type),
np.ones( self._scope, self._place)
[1], dtype=data_type),
self._scope,
self._place)
state_out_node = graph.create_var_node_from_desc( state_out_node = graph.create_var_node_from_desc(
state_in_node.var()) state_in_node.var())
accum_out_node = graph.create_var_node_from_desc( accum_out_node = graph.create_var_node_from_desc(
...@@ -1525,6 +1507,7 @@ class OutScaleForTrainingPass(object): ...@@ -1525,6 +1507,7 @@ class OutScaleForTrainingPass(object):
class OutScaleForInferencePass(object): class OutScaleForInferencePass(object):
def __init__(self, scope=None): def __init__(self, scope=None):
""" """
This pass is used for setting output scales of some operators. This pass is used for setting output scales of some operators.
...@@ -1566,8 +1549,8 @@ class OutScaleForInferencePass(object): ...@@ -1566,8 +1549,8 @@ class OutScaleForInferencePass(object):
# For compatibility, we save output threshold by two methods. # For compatibility, we save output threshold by two methods.
op_node.op()._set_attr("out_threshold", float(scale_value)) op_node.op()._set_attr("out_threshold", float(scale_value))
argname_index = utils._get_output_name_index(op_node, argname_index = utils._get_output_name_index(
var_name) op_node, var_name)
assert argname_index is not None, \ assert argname_index is not None, \
var_name + " is not the output of the op" var_name + " is not the output of the op"
op_node.op()._set_attr(argname_index[0] + str(argname_index[1]) \ op_node.op()._set_attr(argname_index[0] + str(argname_index[1]) \
...@@ -1660,9 +1643,9 @@ class AddQuantDequantPass(object): ...@@ -1660,9 +1643,9 @@ class AddQuantDequantPass(object):
# Forward stage, insert quant_dequant op # Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes() all_op_nodes = graph.all_op_nodes()
with tqdm( with tqdm(total=len(all_op_nodes),
total=len(all_op_nodes), bar_format=
bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}', 'Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t: ncols=80) as t:
for op_node in all_op_nodes: for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type: if op_node.name() in self._quantizable_op_type:
...@@ -1685,8 +1668,8 @@ class AddQuantDequantPass(object): ...@@ -1685,8 +1668,8 @@ class AddQuantDequantPass(object):
op_node.op()._set_attr("with_quant_attr", True) op_node.op()._set_attr("with_quant_attr", True)
arg_names = utils._get_op_input_var_names(op_node) arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names: for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, in_node = graph._find_node_by_name(
arg_name) op_node.inputs, arg_name)
if arg_name in dequantized_vars_map: if arg_name in dequantized_vars_map:
quant_var_node = dequantized_vars_map[arg_name] quant_var_node = dequantized_vars_map[arg_name]
else: else:
...@@ -1703,8 +1686,8 @@ class AddQuantDequantPass(object): ...@@ -1703,8 +1686,8 @@ class AddQuantDequantPass(object):
if op_node.name() in self._quantizable_grad_op_type: if op_node.name() in self._quantizable_grad_op_type:
for input_name in op_node.input_arg_names(): for input_name in op_node.input_arg_names():
if input_name in dequantized_vars_map: if input_name in dequantized_vars_map:
in_node = graph._find_node_by_name(op_node.inputs, in_node = graph._find_node_by_name(
input_name) op_node.inputs, input_name)
dequant_var_node = dequantized_vars_map[input_name] dequant_var_node = dequantized_vars_map[input_name]
graph.update_input_link(in_node, dequant_var_node, graph.update_input_link(in_node, dequant_var_node,
op_node) op_node)
...@@ -1716,8 +1699,8 @@ class AddQuantDequantPass(object): ...@@ -1716,8 +1699,8 @@ class AddQuantDequantPass(object):
quant_bits): quant_bits):
"""Insert fake_quantize_dequantize_moving_average_abs_max op. """Insert fake_quantize_dequantize_moving_average_abs_max op.
""" """
quant_var_node = graph.create_var_node( quant_var_node = graph.create_var_node(name="{}.quant_dequant".format(
name="{}.quant_dequant".format(var_node.name()), var_node.name()),
var_type=var_node.type(), var_type=var_node.type(),
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
...@@ -1728,12 +1711,9 @@ class AddQuantDequantPass(object): ...@@ -1728,12 +1711,9 @@ class AddQuantDequantPass(object):
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node( _init_var_node(scale_in_node,
scale_in_node, np.array([_SCALE_DEFAULT_VALUE], dtype=data_type),
np.array( self._scope, self._place)
[_SCALE_DEFAULT_VALUE], dtype=data_type),
self._scope,
self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
ins = {'X': var_node, 'InScale': scale_in_node} ins = {'X': var_node, 'InScale': scale_in_node}
...@@ -1746,27 +1726,19 @@ class AddQuantDequantPass(object): ...@@ -1746,27 +1726,19 @@ class AddQuantDequantPass(object):
shape=[1]) shape=[1])
data_type = 'float64' if var_node.dtype( data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32' ) == core.VarDesc.VarType.FP64 else 'float32'
_init_var_node( _init_var_node(state_in_node, np.ones([1], dtype=data_type),
state_in_node, self._scope, self._place)
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
accum_in_node = graph.create_persistable_node( accum_in_node = graph.create_persistable_node(
name=unique_name.generate('quant_dequant.accum'), name=unique_name.generate('quant_dequant.accum'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
shape=[1]) shape=[1])
_init_var_node( _init_var_node(accum_in_node, np.ones([1], dtype=data_type),
accum_in_node, self._scope, self._place)
np.ones( state_out_node = graph.create_var_node_from_desc(
[1], dtype=data_type), state_in_node.var())
self._scope, accum_out_node = graph.create_var_node_from_desc(
self._place) accum_in_node.var())
state_out_node = graph.create_var_node_from_desc(state_in_node.var(
))
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
))
ins['InState'] = state_in_node ins['InState'] = state_in_node
ins['InAccum'] = accum_in_node ins['InAccum'] = accum_in_node
...@@ -1833,8 +1805,8 @@ class InsertQuantizeLinear(object): ...@@ -1833,8 +1805,8 @@ class InsertQuantizeLinear(object):
def insert_quant_op(self, graph, var_node): def insert_quant_op(self, graph, var_node):
assert var_node.is_var(), '{} is not a var'.format(var_node.name()) assert var_node.is_var(), '{} is not a var'.format(var_node.name())
quant_var_node = graph.create_var_node( quant_var_node = graph.create_var_node(name=self._quantized_var_name(
name=self._quantized_var_name(var_node.name()), var_node.name()),
var_type=var_node.type(), var_type=var_node.type(),
shape=var_node.shape(), shape=var_node.shape(),
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
...@@ -1863,12 +1835,9 @@ class InsertQuantizeLinear(object): ...@@ -1863,12 +1835,9 @@ class InsertQuantizeLinear(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=scale_var_node.shape(), shape=scale_var_node.shape(),
var_dtype=core.VarDesc.VarType.INT32) var_dtype=core.VarDesc.VarType.INT32)
_init_var_node( _init_var_node(zero_point_node,
zero_point_node, np.zeros(scale_var_node.shape(), dtype="int32"),
np.zeros( self._scope, self._place)
scale_var_node.shape(), dtype="int32"),
self._scope,
self._place)
inputs = {"X": var_node, "Scale": scale_var_node} inputs = {"X": var_node, "Scale": scale_var_node}
if zero_point_node is not None: if zero_point_node is not None:
...@@ -1879,12 +1848,11 @@ class InsertQuantizeLinear(object): ...@@ -1879,12 +1848,11 @@ class InsertQuantizeLinear(object):
if not self._is_test: if not self._is_test:
attrs["is_test"] = self._is_test attrs["is_test"] = self._is_test
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
scale_out_node = graph.create_var_node_from_desc(scale_var_node.var( scale_out_node = graph.create_var_node_from_desc(
)) scale_var_node.var())
outputs["OutScale"] = scale_out_node outputs["OutScale"] = scale_out_node
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(op_type="quantize_linear",
op_type="quantize_linear",
attrs=attrs, attrs=attrs,
inputs=inputs, inputs=inputs,
outputs=outputs) outputs=outputs)
...@@ -1914,12 +1882,9 @@ class InsertQuantizeLinear(object): ...@@ -1914,12 +1882,9 @@ class InsertQuantizeLinear(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=scale_var_node.shape(), shape=scale_var_node.shape(),
var_dtype=core.VarDesc.VarType.INT32) var_dtype=core.VarDesc.VarType.INT32)
_init_var_node( _init_var_node(zero_point_node,
zero_point_node, np.zeros(scale_var_node.shape(), dtype="int32"),
np.zeros( self._scope, self._place)
scale_var_node.shape(), dtype="int32"),
self._scope,
self._place)
inputs = {"X": var_node, "Scale": scale_var_node} inputs = {"X": var_node, "Scale": scale_var_node}
if zero_point_node is not None: if zero_point_node is not None:
...@@ -1929,8 +1894,7 @@ class InsertQuantizeLinear(object): ...@@ -1929,8 +1894,7 @@ class InsertQuantizeLinear(object):
if not self._is_test: if not self._is_test:
attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward attrs["op_role"] = core.op_proto_and_checker_maker.OpRole.Forward
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(op_type="dequantize_linear",
op_type="dequantize_linear",
attrs=attrs, attrs=attrs,
inputs=inputs, inputs=inputs,
outputs={"Y": dequant_var_node}) outputs={"Y": dequant_var_node})
...@@ -2151,11 +2115,13 @@ class QuantizationTransformPassV2(object): ...@@ -2151,11 +2115,13 @@ class QuantizationTransformPassV2(object):
# will insert activation preprocess func # will insert activation preprocess func
# to preorocess activation before quantization # to preorocess activation before quantization
if is_weight and self._weight_preprocess_func is not None: if is_weight and self._weight_preprocess_func is not None:
var_node = self._insert_func( var_node = self._insert_func(graph,
graph, self._weight_preprocess_func, var_node, op) self._weight_preprocess_func,
var_node, op)
elif not is_weight and self._act_preprocess_func is not None: elif not is_weight and self._act_preprocess_func is not None:
var_node = self._insert_func( var_node = self._insert_func(graph,
graph, self._act_preprocess_func, var_node, op) self._act_preprocess_func,
var_node, op)
# if var node is weight and weight_quantize_func is not None, # if var node is weight and weight_quantize_func is not None,
# will insert weight quantize func to quantize and dequantize weight # will insert weight quantize func to quantize and dequantize weight
...@@ -2167,8 +2133,9 @@ class QuantizationTransformPassV2(object): ...@@ -2167,8 +2133,9 @@ class QuantizationTransformPassV2(object):
processed_vars.append(name) processed_vars.append(name)
continue continue
elif not is_weight and self._act_quantize_func is not None: elif not is_weight and self._act_quantize_func is not None:
target_out_node = self._insert_func( target_out_node = self._insert_func(graph,
graph, self._act_quantize_func, var_node, op) self._act_quantize_func,
var_node, op)
processed_vars.append(name) processed_vars.append(name)
continue continue
...@@ -2263,9 +2230,9 @@ class QuantizationTransformPassV2(object): ...@@ -2263,9 +2230,9 @@ class QuantizationTransformPassV2(object):
graph.out_node_mapping_table = dict() graph.out_node_mapping_table = dict()
# The process of _transform_forward and _transform_backward is needed in two for loops. # The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
with tqdm( with tqdm(total=len(ops),
total=len(ops), bar_format=
bar_format='Adding quant op for weight:|{bar}| {n_fmt}/{total_fmt}', 'Adding quant op with weight:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t: ncols=80) as t:
for op in ops: for op in ops:
if op.name() in self._quantizable_ops: if op.name() in self._quantizable_ops:
...@@ -2375,9 +2342,9 @@ class AddQuantDequantPassV2(object): ...@@ -2375,9 +2342,9 @@ class AddQuantDequantPassV2(object):
# Forward stage, insert quant_dequant op # Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes() all_op_nodes = graph.all_op_nodes()
with tqdm( with tqdm(total=len(all_op_nodes),
total=len(all_op_nodes), bar_format=
bar_format='Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}', 'Adding quant activation op:|{bar}| {n_fmt}/{total_fmt}',
ncols=80) as t: ncols=80) as t:
for op_node in all_op_nodes: for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type: if op_node.name() in self._quantizable_op_type:
...@@ -2397,8 +2364,8 @@ class AddQuantDequantPassV2(object): ...@@ -2397,8 +2364,8 @@ class AddQuantDequantPassV2(object):
"qat_without_weight") "qat_without_weight")
arg_names = utils._get_op_input_var_names(op_node) arg_names = utils._get_op_input_var_names(op_node)
for arg_name in arg_names: for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, in_node = graph._find_node_by_name(
arg_name) op_node.inputs, arg_name)
if in_node.persistable(): if in_node.persistable():
continue continue
if arg_name in dequantized_vars_map: if arg_name in dequantized_vars_map:
...@@ -2425,8 +2392,8 @@ class AddQuantDequantPassV2(object): ...@@ -2425,8 +2392,8 @@ class AddQuantDequantPassV2(object):
if op_node.name() in self._quantizable_grad_op_type: if op_node.name() in self._quantizable_grad_op_type:
for input_name in op_node.input_arg_names(): for input_name in op_node.input_arg_names():
if input_name in dequantized_vars_map: if input_name in dequantized_vars_map:
in_node = graph._find_node_by_name(op_node.inputs, in_node = graph._find_node_by_name(
input_name) op_node.inputs, input_name)
dequant_var_node = dequantized_vars_map[input_name] dequant_var_node = dequantized_vars_map[input_name]
graph.update_input_link(in_node, dequant_var_node, graph.update_input_link(in_node, dequant_var_node,
op_node) op_node)
...@@ -2502,22 +2469,20 @@ class ReplaceFakeQuantDequantPass(object): ...@@ -2502,22 +2469,20 @@ class ReplaceFakeQuantDequantPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=scale_node.shape(), shape=scale_node.shape(),
var_dtype=core.VarDesc.VarType.INT32) var_dtype=core.VarDesc.VarType.INT32)
_init_var_node( _init_var_node(zero_point_node,
zero_point_node, np.zeros(scale_node.shape(), dtype="int32"),
np.zeros( self._scope, self._place)
scale_node.shape(), dtype="int32"),
self._scope,
self._place)
quant_var_node = graph.create_var_node( quant_var_node = graph.create_var_node(name=self._quantized_var_name(
name=self._quantized_var_name(x_node.name()), x_node.name()),
var_type=x_node.type(), var_type=x_node.type(),
shape=x_node.shape(), shape=x_node.shape(),
var_dtype=x_node.dtype()) var_dtype=x_node.dtype())
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(op_type="quantize_linear",
op_type="quantize_linear", attrs={
attrs={"quant_axis": quant_axis, "quant_axis": quant_axis,
"bit_length": bit_length}, "bit_length": bit_length
},
inputs={ inputs={
"X": x_node, "X": x_node,
"Scale": scale_node, "Scale": scale_node,
...@@ -2529,10 +2494,11 @@ class ReplaceFakeQuantDequantPass(object): ...@@ -2529,10 +2494,11 @@ class ReplaceFakeQuantDequantPass(object):
if zero_point_node is not None: if zero_point_node is not None:
graph.link_to(zero_point_node, quant_op_node) graph.link_to(zero_point_node, quant_op_node)
graph.link_to(quant_op_node, quant_var_node) graph.link_to(quant_op_node, quant_var_node)
dequant_op_node = graph.create_op_node( dequant_op_node = graph.create_op_node(op_type="dequantize_linear",
op_type="dequantize_linear", attrs={
attrs={"quant_axis": quant_axis, "quant_axis": quant_axis,
"bit_length": bit_length}, "bit_length": bit_length
},
inputs={ inputs={
"X": quant_var_node, "X": quant_var_node,
"Scale": scale_node, "Scale": scale_node,
...@@ -2617,7 +2583,8 @@ class QuantWeightPass(object): ...@@ -2617,7 +2583,8 @@ class QuantWeightPass(object):
scale_node = graph._find_node_by_name(_op.inputs, scale_node = graph._find_node_by_name(_op.inputs,
_op.input("Scale")[0]) _op.input("Scale")[0])
zero_point_node = graph._find_node_by_name( zero_point_node = graph._find_node_by_name(
_op.inputs, _op.input("ZeroPoint")[0]) _op.inputs,
_op.input("ZeroPoint")[0])
out_node = graph._find_node_by_name(_op.outputs, out_node = graph._find_node_by_name(_op.outputs,
_op.output("Y")[0]) _op.output("Y")[0])
...@@ -2633,8 +2600,11 @@ class QuantWeightPass(object): ...@@ -2633,8 +2600,11 @@ class QuantWeightPass(object):
param_v = self._load_var(x_node.name()) param_v = self._load_var(x_node.name())
quant_axis = _op.op().attr("quant_axis") quant_axis = _op.op().attr("quant_axis")
bits_length = _op.op().attr("bit_length") bits_length = _op.op().attr("bit_length")
quantized_param_v = utils.quant_tensor(param_v.copy(), scale_v, quantized_param_v = utils.quant_tensor(param_v.copy(),
quant_axis, bits_length) scale_v,
quant_axis,
bits_length,
onnx_format=True)
if self._bias_correction == True: if self._bias_correction == True:
quantized_param_v = utils.bias_correction_w( quantized_param_v = utils.bias_correction_w(
param_v, param_v,
......
...@@ -321,7 +321,7 @@ def set_variable_data(scope, place, var_name, np_value): ...@@ -321,7 +321,7 @@ def set_variable_data(scope, place, var_name, np_value):
tensor.set(np_value, place) tensor.set(np_value, place)
def quant_tensor(x, scale, quant_axis=0, weight_bits=8): def quant_tensor(x, scale, quant_axis=0, weight_bits=8, onnx_format=False):
# symmetry quant # symmetry quant
def _clip(x, scale): def _clip(x, scale):
x[x > scale] = scale x[x > scale] = scale
...@@ -335,13 +335,25 @@ def quant_tensor(x, scale, quant_axis=0, weight_bits=8): ...@@ -335,13 +335,25 @@ def quant_tensor(x, scale, quant_axis=0, weight_bits=8):
if s == 0.0: if s == 0.0:
s = 1e-8 s = 1e-8
if quant_axis == 0: if quant_axis == 0:
if onnx_format:
x[i] = np.round(x[i] / s * bnt)
x[i] = np.clip(x[i], -bnt - 1, bnt)
else:
x[i] = _clip(x[i], s) x[i] = _clip(x[i], s)
x[i] = x[i] / s * bnt x[i] = x[i] / s * bnt
else:
if onnx_format:
x[:, i] = np.round(x[:, i] / s * bnt)
x[:, i] = np.clip(x[:, i], -bnt - 1, bnt)
else: else:
x[:, i] = _clip(x[:, i], s) x[:, i] = _clip(x[:, i], s)
x[:, i] = x[:, i] / s * bnt x[:, i] = x[:, i] / s * bnt
else: else:
scale = 1e-8 if scale == 0.0 else scale scale = 1e-8 if scale == 0.0 else scale
if onnx_format:
x = np.round(x / scale * bnt)
x = np.clip(x, -bnt - 1, bnt)
else:
x = _clip(x, scale) x = _clip(x, scale)
x = x / scale * bnt x = x / scale * bnt
return x return x
...@@ -416,6 +428,7 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor): ...@@ -416,6 +428,7 @@ def calculate_quant_cos_error(orig_tensor, qdq_tensor):
class tqdm(object): class tqdm(object):
def __init__(self, total, bar_format='Loading|{bar}', ncols=80): def __init__(self, total, bar_format='Loading|{bar}', ncols=80):
self.total = total self.total = total
self.bar_format = bar_format self.bar_format = bar_format
......
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") file(
GLOB TEST_OPS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
function(_inference_analysis_python_api_int8_test target model_dir data_path filename use_mkldnn) function(_inference_analysis_python_api_int8_test target model_dir data_path
py_test(${target} SRCS ${filename} filename use_mkldnn)
ENVS CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} py_test(
${target}
SRCS ${filename}
ENVS
CPU_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn} FLAGS_use_mkldnn=${use_mkldnn}
ARGS --infer_model ${model_dir}/model ARGS
--infer_data ${data_path} --infer_model
--int8_model_save_path int8_models/${target} ${model_dir}/model
--warmup_batch_size ${WARMUP_BATCH_SIZE} --infer_data
--batch_size 50) ${data_path}
--int8_model_save_path
int8_models/${target}
--warmup_batch_size
${WARMUP_BATCH_SIZE}
--batch_size
50)
endfunction() endfunction()
function(inference_analysis_python_api_int8_test target model_dir data_path filename) function(inference_analysis_python_api_int8_test target model_dir data_path
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} False) filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path}
${filename} False)
endfunction() endfunction()
function(inference_analysis_python_api_int8_test_custom_warmup_batch_size target model_dir data_dir filename warmup_batch_size) function(inference_analysis_python_api_int8_test_custom_warmup_batch_size
target model_dir data_dir filename warmup_batch_size)
set(WARMUP_BATCH_SIZE ${warmup_batch_size}) set(WARMUP_BATCH_SIZE ${warmup_batch_size})
inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir} ${filename}) inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_dir}
${filename})
endfunction() endfunction()
function(inference_analysis_python_api_int8_test_mkldnn target model_dir data_path filename) function(inference_analysis_python_api_int8_test_mkldnn target model_dir
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path} ${filename} True) data_path filename)
_inference_analysis_python_api_int8_test(${target} ${model_dir} ${data_path}
${filename} True)
endfunction() endfunction()
function(download_data install_dir url data_file check_sum) function(download_data install_dir url data_file check_sum)
if (NOT EXISTS ${install_dir}/${data_file}) if(NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${url} ${data_file} ${check_sum}) inference_download_and_uncompress(${install_dir} ${url} ${data_file}
${check_sum})
endif() endif()
endfunction() endfunction()
function(download_quant_data install_dir data_file check_sum) function(download_quant_data install_dir data_file check_sum)
if (NOT EXISTS ${install_dir}/${data_file}) if(NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file} ${check_sum}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8
${data_file} ${check_sum})
endif() endif()
endfunction() endfunction()
function(download_quant_model install_dir data_file check_sum) function(download_quant_model install_dir data_file check_sum)
if (NOT EXISTS ${install_dir}/${data_file}) if(NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file} ${check_sum}) inference_download_and_uncompress(
${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file} ${check_sum})
endif() endif()
endfunction() endfunction()
function(download_quant_fp32_model install_dir data_file check_sum) function(download_quant_fp32_model install_dir data_file check_sum)
if (NOT EXISTS ${install_dir}/${data_file}) if(NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models/fp32 ${data_file} ${check_sum}) inference_download_and_uncompress(
${install_dir} ${INFERENCE_URL}/int8/QAT_models/fp32 ${data_file}
${check_sum})
endif() endif()
endfunction() endfunction()
function(download_lstm_model install_dir data_file check_sum) function(download_lstm_model install_dir data_file check_sum)
if (NOT EXISTS ${install_dir}/${data_file}) if(NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/lstm ${data_file} ${check_sum}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/lstm
${data_file} ${check_sum})
endif() endif()
endfunction() endfunction()
function(inference_quant_int8_image_classification_test target quant_model_dir dataset_path) function(inference_quant_int8_image_classification_test target quant_model_dir
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant_int8_image_classification_comparison.py" dataset_path)
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} py_test(
${target}
SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant_int8_image_classification_comparison.py"
ENVS
FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=true FLAGS_use_mkldnn=true
ARGS --quant_model ${quant_model_dir} ARGS
--infer_data ${dataset_path} --quant_model
--batch_size 25 ${quant_model_dir}
--batch_num 2 --infer_data
--acc_diff_threshold 0.1) ${dataset_path}
--batch_size
25
--batch_num
2
--acc_diff_threshold
0.1)
endfunction() endfunction()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25 # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 25
function(inference_quant2_int8_image_classification_test target quant_model_dir fp32_model_dir dataset_path) function(inference_quant2_int8_image_classification_test target quant_model_dir
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_image_classification_comparison.py" fp32_model_dir dataset_path)
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} py_test(
${target}
SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_image_classification_comparison.py"
ENVS
FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=true FLAGS_use_mkldnn=true
ARGS --quant_model ${quant_model_dir} ARGS
--fp32_model ${fp32_model_dir} --quant_model
--infer_data ${dataset_path} ${quant_model_dir}
--batch_size 50 --fp32_model
--batch_num 2 ${fp32_model_dir}
--acc_diff_threshold 0.1) --infer_data
${dataset_path}
--batch_size
50
--batch_num
2
--acc_diff_threshold
0.1)
endfunction() endfunction()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20
function(inference_quant2_int8_nlp_test target quant_model_dir fp32_model_dir dataset_path labels_path ops_to_quantize) function(
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_nlp_comparison.py" inference_quant2_int8_nlp_test
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} target
quant_model_dir
fp32_model_dir
dataset_path
labels_path
ops_to_quantize)
py_test(
${target}
SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_nlp_comparison.py"
ENVS
FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=true FLAGS_use_mkldnn=true
ARGS --quant_model ${quant_model_dir} ARGS
--fp32_model ${fp32_model_dir} --quant_model
--infer_data ${dataset_path} ${quant_model_dir}
--labels ${labels_path} --fp32_model
--batch_size 10 ${fp32_model_dir}
--batch_num 2 --infer_data
--acc_diff_threshold 0.1 ${dataset_path}
--ops_to_quantize ${ops_to_quantize}) --labels
${labels_path}
--batch_size
10
--batch_num
2
--acc_diff_threshold
0.1
--ops_to_quantize
${ops_to_quantize})
endfunction() endfunction()
function(inference_quant2_int8_lstm_model_test target fp32_model quant_model dataset_path) function(inference_quant2_int8_lstm_model_test target fp32_model quant_model
py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_lstm_model.py" dataset_path)
ARGS --fp32_model ${fp32_model} py_test(
--quant_model ${quant_model} ${target}
--infer_data ${dataset_path} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_lstm_model.py"
--num_threads 1 ARGS
--mkldnn_cache_capacity 100 --fp32_model
--warmup_iter 100 ${fp32_model}
--acc_diff_threshold 0.11) --quant_model
${quant_model}
--infer_data
${dataset_path}
--num_threads
1
--mkldnn_cache_capacity
100
--warmup_iter
100
--acc_diff_threshold
0.11)
endfunction() endfunction()
function(download_quant_data install_dir data_file check_sum) function(download_quant_data install_dir data_file check_sum)
if (NOT EXISTS ${install_dir}/${data_file}) if(NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file} ${check_sum}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8
${data_file} ${check_sum})
endif() endif()
endfunction() endfunction()
function(download_quant_model install_dir data_file check_sum) function(download_quant_model install_dir data_file check_sum)
if (NOT EXISTS ${install_dir}/${data_file}) if(NOT EXISTS ${install_dir}/${data_file})
inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file} ${check_sum}) inference_download_and_uncompress(
${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file} ${check_sum})
endif() endif()
endfunction() endfunction()
function(save_quant_ic_model_test target quant_model_dir int8_model_save_path) function(save_quant_ic_model_test target quant_model_dir int8_model_save_path)
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py py_test(
ARGS --quant_model_path ${quant_model_dir} ${target}
--int8_model_save_path ${int8_model_save_path} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py
ARGS
--quant_model_path
${quant_model_dir}
--int8_model_save_path
${int8_model_save_path}
--debug) --debug)
endfunction() endfunction()
function(save_quant_nlp_model_test target quant_model_dir int8_model_save_path ops_to_quantize) function(save_quant_nlp_model_test target quant_model_dir int8_model_save_path
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py ops_to_quantize)
ARGS --quant_model_path ${quant_model_dir} py_test(
--int8_model_save_path ${int8_model_save_path} ${target}
--ops_to_quantize ${ops_to_quantize}) SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py
ARGS
--quant_model_path
${quant_model_dir}
--int8_model_save_path
${int8_model_save_path}
--ops_to_quantize
${ops_to_quantize})
endfunction() endfunction()
function(convert_model2dot_test target model_path save_graph_dir save_graph_name) function(convert_model2dot_test target model_path save_graph_dir
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/convert_model2dot.py save_graph_name)
ARGS --model_path ${model_path} py_test(
--save_graph_dir ${save_graph_dir} ${target}
--save_graph_name ${save_graph_name}) SRCS ${CMAKE_CURRENT_SOURCE_DIR}/convert_model2dot.py
ARGS
--model_path
${model_path}
--save_graph_dir
${save_graph_dir}
--save_graph_name
${save_graph_name})
endfunction() endfunction()
if(WIN32) if(WIN32)
...@@ -175,43 +271,63 @@ if(LINUX AND WITH_MKLDNN) ...@@ -175,43 +271,63 @@ if(LINUX AND WITH_MKLDNN)
# Quant ResNet50 # Quant ResNet50
set(QUANT_RESNET50_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant") set(QUANT_RESNET50_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant")
set(QUANT_RESNET50_MODEL_ARCHIVE "ResNet50_qat_model.tar.gz") set(QUANT_RESNET50_MODEL_ARCHIVE "ResNet50_qat_model.tar.gz")
download_quant_model(${QUANT_RESNET50_MODEL_DIR} ${QUANT_RESNET50_MODEL_ARCHIVE} ff89b934ab961c3a4a844193ece2e8a7) download_quant_model(
inference_quant_int8_image_classification_test(test_quant_int8_resnet50_mkldnn ${QUANT_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) ${QUANT_RESNET50_MODEL_DIR} ${QUANT_RESNET50_MODEL_ARCHIVE}
ff89b934ab961c3a4a844193ece2e8a7)
inference_quant_int8_image_classification_test(
test_quant_int8_resnet50_mkldnn ${QUANT_RESNET50_MODEL_DIR}/model
${IMAGENET_DATA_PATH})
# Quant ResNet101 # Quant ResNet101
set(QUANT_RESNET101_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet101_quant") set(QUANT_RESNET101_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet101_quant")
set(QUANT_RESNET101_MODEL_ARCHIVE "ResNet101_qat_model.tar.gz") set(QUANT_RESNET101_MODEL_ARCHIVE "ResNet101_qat_model.tar.gz")
download_quant_model(${QUANT_RESNET101_MODEL_DIR} ${QUANT_RESNET101_MODEL_ARCHIVE} 95c6d01e3aeba31c13efb2ba8057d558) download_quant_model(
${QUANT_RESNET101_MODEL_DIR} ${QUANT_RESNET101_MODEL_ARCHIVE}
95c6d01e3aeba31c13efb2ba8057d558)
# inference_quant_int8_image_classification_test(test_quant_int8_resnet101_mkldnn ${QUANT_RESNET101_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) # inference_quant_int8_image_classification_test(test_quant_int8_resnet101_mkldnn ${QUANT_RESNET101_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# Quant GoogleNet # Quant GoogleNet
set(QUANT_GOOGLENET_MODEL_DIR "${QUANT_INSTALL_DIR}/GoogleNet_quant") set(QUANT_GOOGLENET_MODEL_DIR "${QUANT_INSTALL_DIR}/GoogleNet_quant")
set(QUANT_GOOGLENET_MODEL_ARCHIVE "GoogleNet_qat_model.tar.gz") set(QUANT_GOOGLENET_MODEL_ARCHIVE "GoogleNet_qat_model.tar.gz")
download_quant_model(${QUANT_GOOGLENET_MODEL_DIR} ${QUANT_GOOGLENET_MODEL_ARCHIVE} 1d4a7383baa63e7d1c423e8db2b791d5) download_quant_model(
inference_quant_int8_image_classification_test(test_quant_int8_googlenet_mkldnn ${QUANT_GOOGLENET_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) ${QUANT_GOOGLENET_MODEL_DIR} ${QUANT_GOOGLENET_MODEL_ARCHIVE}
1d4a7383baa63e7d1c423e8db2b791d5)
inference_quant_int8_image_classification_test(
test_quant_int8_googlenet_mkldnn ${QUANT_GOOGLENET_MODEL_DIR}/model
${IMAGENET_DATA_PATH})
# Quant MobileNetV1 # Quant MobileNetV1
set(QUANT_MOBILENETV1_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV1_quant") set(QUANT_MOBILENETV1_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV1_quant")
set(QUANT_MOBILENETV1_MODEL_ARCHIVE "MobileNetV1_qat_model.tar.gz") set(QUANT_MOBILENETV1_MODEL_ARCHIVE "MobileNetV1_qat_model.tar.gz")
download_quant_model(${QUANT_MOBILENETV1_MODEL_DIR} ${QUANT_MOBILENETV1_MODEL_ARCHIVE} 3b774d94a9fcbb604d09bdb731fc1162) download_quant_model(
inference_quant_int8_image_classification_test(test_quant_int8_mobilenetv1_mkldnn ${QUANT_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) ${QUANT_MOBILENETV1_MODEL_DIR} ${QUANT_MOBILENETV1_MODEL_ARCHIVE}
3b774d94a9fcbb604d09bdb731fc1162)
inference_quant_int8_image_classification_test(
test_quant_int8_mobilenetv1_mkldnn ${QUANT_MOBILENETV1_MODEL_DIR}/model
${IMAGENET_DATA_PATH})
# Quant MobileNetV2 # Quant MobileNetV2
set(QUANT_MOBILENETV2_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV2_quant") set(QUANT_MOBILENETV2_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV2_quant")
set(QUANT_MOBILENETV2_MODEL_ARCHIVE "MobileNetV2_qat_model.tar.gz") set(QUANT_MOBILENETV2_MODEL_ARCHIVE "MobileNetV2_qat_model.tar.gz")
download_quant_model(${QUANT_MOBILENETV2_MODEL_DIR} ${QUANT_MOBILENETV2_MODEL_ARCHIVE} 758a99d9225d8b73e1a8765883f96cdd) download_quant_model(
inference_quant_int8_image_classification_test(test_quant_int8_mobilenetv2_mkldnn ${QUANT_MOBILENETV2_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) ${QUANT_MOBILENETV2_MODEL_DIR} ${QUANT_MOBILENETV2_MODEL_ARCHIVE}
758a99d9225d8b73e1a8765883f96cdd)
inference_quant_int8_image_classification_test(
test_quant_int8_mobilenetv2_mkldnn ${QUANT_MOBILENETV2_MODEL_DIR}/model
${IMAGENET_DATA_PATH})
# Quant VGG16 # Quant VGG16
set(QUANT_VGG16_MODEL_DIR "${QUANT_INSTALL_DIR}/VGG16_quant") set(QUANT_VGG16_MODEL_DIR "${QUANT_INSTALL_DIR}/VGG16_quant")
set(QUANT_VGG16_MODEL_ARCHIVE "VGG16_qat_model.tar.gz") set(QUANT_VGG16_MODEL_ARCHIVE "VGG16_qat_model.tar.gz")
download_quant_model(${QUANT_VGG16_MODEL_DIR} ${QUANT_VGG16_MODEL_ARCHIVE} c37e63ca82a102f47be266f8068b0b55) download_quant_model(${QUANT_VGG16_MODEL_DIR} ${QUANT_VGG16_MODEL_ARCHIVE}
c37e63ca82a102f47be266f8068b0b55)
# inference_quant_int8_image_classification_test(test_quant_int8_vgg16_mkldnn ${QUANT_VGG16_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) # inference_quant_int8_image_classification_test(test_quant_int8_vgg16_mkldnn ${QUANT_VGG16_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# Quant VGG19 # Quant VGG19
set(QUANT_VGG19_MODEL_DIR "${QUANT_INSTALL_DIR}/VGG19_quant") set(QUANT_VGG19_MODEL_DIR "${QUANT_INSTALL_DIR}/VGG19_quant")
set(QUANT_VGG19_MODEL_ARCHIVE "VGG19_qat_model.tar.gz") set(QUANT_VGG19_MODEL_ARCHIVE "VGG19_qat_model.tar.gz")
download_quant_model(${QUANT_VGG19_MODEL_DIR} ${QUANT_VGG19_MODEL_ARCHIVE} 62bcd4b6c3ca2af67e8251d1c96ea18f) download_quant_model(${QUANT_VGG19_MODEL_DIR} ${QUANT_VGG19_MODEL_ARCHIVE}
62bcd4b6c3ca2af67e8251d1c96ea18f)
# inference_quant_int8_image_classification_test(test_quant_int8_vgg19_mkldnn ${QUANT_VGG19_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) # inference_quant_int8_image_classification_test(test_quant_int8_vgg19_mkldnn ${QUANT_VGG19_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
### Quant2 for image classification ### Quant2 for image classification
...@@ -220,30 +336,54 @@ if(LINUX AND WITH_MKLDNN) ...@@ -220,30 +336,54 @@ if(LINUX AND WITH_MKLDNN)
# with weight scales in `fake_dequantize_max_abs` operators # with weight scales in `fake_dequantize_max_abs` operators
set(QUANT2_RESNET50_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2") set(QUANT2_RESNET50_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2")
set(QUANT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz") set(QUANT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz")
download_quant_model(${QUANT2_RESNET50_MODEL_DIR} ${QUANT2_RESNET50_MODEL_ARCHIVE} e87309457e8c462a579340607f064d66) download_quant_model(
${QUANT2_RESNET50_MODEL_DIR} ${QUANT2_RESNET50_MODEL_ARCHIVE}
e87309457e8c462a579340607f064d66)
set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50") set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_mkldnn ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) inference_quant2_int8_image_classification_test(
test_quant2_int8_resnet50_mkldnn
${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float
${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, # Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# with weight scales in `fake_dequantize_max_abs` operators # with weight scales in `fake_dequantize_max_abs` operators
set(QUANT2_RESNET50_RANGE_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2_range") set(QUANT2_RESNET50_RANGE_MODEL_DIR
"${QUANT_INSTALL_DIR}/ResNet50_quant2_range")
set(QUANT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz") set(QUANT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz")
download_quant_model(${QUANT2_RESNET50_RANGE_MODEL_DIR} ${QUANT2_RESNET50_RANGE_MODEL_ARCHIVE} 2fdc8a139f041c0d270abec826b2d304) download_quant_model(
inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_range_mkldnn ${QUANT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) ${QUANT2_RESNET50_RANGE_MODEL_DIR} ${QUANT2_RESNET50_RANGE_MODEL_ARCHIVE}
2fdc8a139f041c0d270abec826b2d304)
inference_quant2_int8_image_classification_test(
test_quant2_int8_resnet50_range_mkldnn
${QUANT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range
${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes, # Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# with weight scales in `fake_channel_wise_dequantize_max_abs` operators # with weight scales in `fake_channel_wise_dequantize_max_abs` operators
set(QUANT2_RESNET50_CHANNELWISE_MODEL_DIR "${QUANT_INSTALL_DIR}/ResNet50_quant2_channelwise") set(QUANT2_RESNET50_CHANNELWISE_MODEL_DIR
set(QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz") "${QUANT_INSTALL_DIR}/ResNet50_quant2_channelwise")
download_quant_model(${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE} 887a1b1b0e9a4efd10f263a43764db26) set(QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE
inference_quant2_int8_image_classification_test(test_quant2_int8_resnet50_channelwise_mkldnn ${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) "ResNet50_qat_channelwise.tar.gz")
download_quant_model(
${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR}
${QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE}
887a1b1b0e9a4efd10f263a43764db26)
inference_quant2_int8_image_classification_test(
test_quant2_int8_resnet50_channelwise_mkldnn
${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise
${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
# Quant2 MobileNetV1 # Quant2 MobileNetV1
set(QUANT2_MOBILENETV1_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV1_quant2") set(QUANT2_MOBILENETV1_MODEL_DIR "${QUANT_INSTALL_DIR}/MobileNetV1_quant2")
set(QUANT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz") set(QUANT2_MOBILENETV1_MODEL_ARCHIVE "MobileNet_qat_perf.tar.gz")
download_quant_model(${QUANT2_MOBILENETV1_MODEL_DIR} ${QUANT2_MOBILENETV1_MODEL_ARCHIVE} 7f626e453db2d56fed6c2538621ffacf) download_quant_model(
${QUANT2_MOBILENETV1_MODEL_DIR} ${QUANT2_MOBILENETV1_MODEL_ARCHIVE}
7f626e453db2d56fed6c2538621ffacf)
set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1") set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
inference_quant2_int8_image_classification_test(test_quant2_int8_mobilenetv1_mkldnn ${QUANT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float ${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH}) inference_quant2_int8_image_classification_test(
test_quant2_int8_mobilenetv1_mkldnn
${QUANT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf/float
${FP32_MOBILENETV1_MODEL_DIR}/model ${IMAGENET_DATA_PATH})
### Quant2 for NLP ### Quant2 for NLP
...@@ -251,74 +391,100 @@ if(LINUX AND WITH_MKLDNN) ...@@ -251,74 +391,100 @@ if(LINUX AND WITH_MKLDNN)
set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset") set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset")
set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1") set(NLP_DATA_PATH "${NLP_DATA_DIR}/Ernie_dataset/1.8w.bs1")
set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev") set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev")
download_quant_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE} e650ce0cbc1fadbed5cc2c01d4e734dc) download_quant_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE}
e650ce0cbc1fadbed5cc2c01d4e734dc)
# Quant2 Ernie # Quant2 Ernie
set(QUANT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz") set(QUANT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz")
set(QUANT2_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_quant2") set(QUANT2_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_quant2")
download_quant_model(${QUANT2_ERNIE_MODEL_DIR} ${QUANT2_ERNIE_MODEL_ARCHIVE} f7cdf4720755ecf66efbc8044e9922d9) download_quant_model(${QUANT2_ERNIE_MODEL_DIR} ${QUANT2_ERNIE_MODEL_ARCHIVE}
f7cdf4720755ecf66efbc8044e9922d9)
set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz") set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz")
set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float") set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float")
download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE} 114f38804a3ef8c45e7259e68bbd838b) download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE}
set(QUANT2_ERNIE_OPS_TO_QUANTIZE "fc,reshape2,transpose2,matmul,elementwise_add,slice") 114f38804a3ef8c45e7259e68bbd838b)
inference_quant2_int8_nlp_test(test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE}) set(QUANT2_ERNIE_OPS_TO_QUANTIZE
"fc,reshape2,transpose2,matmul,elementwise_add,slice")
inference_quant2_int8_nlp_test(
test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float
${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH}
${NLP_LABLES_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE})
# Quant2 GRU # Quant2 GRU
set(QUANT2_GRU_MODEL_ARCHIVE "GRU_quant_acc.tar.gz") set(QUANT2_GRU_MODEL_ARCHIVE "GRU_quant_acc.tar.gz")
set(QUANT2_GRU_MODEL_DIR "${QUANT_INSTALL_DIR}/GRU_quant2") set(QUANT2_GRU_MODEL_DIR "${QUANT_INSTALL_DIR}/GRU_quant2")
download_quant_model(${QUANT2_GRU_MODEL_DIR} ${QUANT2_GRU_MODEL_ARCHIVE} cf207f8076dcfb8b74d8b6bdddf9090c) download_quant_model(${QUANT2_GRU_MODEL_DIR} ${QUANT2_GRU_MODEL_ARCHIVE}
cf207f8076dcfb8b74d8b6bdddf9090c)
set(QUANT2_GRU_OPS_TO_QUANTIZE "multi_gru") set(QUANT2_GRU_OPS_TO_QUANTIZE "multi_gru")
# Quant2 LSTM # Quant2 LSTM
set(QUANT2_LSTM_MODEL_ARCHIVE "lstm_quant.tar.gz") set(QUANT2_LSTM_MODEL_ARCHIVE "lstm_quant.tar.gz")
set(QUANT2_LSTM_MODEL_DIR "${QUANT_INSTALL_DIR}/lstm_quant_test") set(QUANT2_LSTM_MODEL_DIR "${QUANT_INSTALL_DIR}/lstm_quant_test")
download_quant_model(${QUANT2_LSTM_MODEL_DIR} ${QUANT2_LSTM_MODEL_ARCHIVE} 40a693803b12ee9e251258f32559abcb) download_quant_model(${QUANT2_LSTM_MODEL_DIR} ${QUANT2_LSTM_MODEL_ARCHIVE}
40a693803b12ee9e251258f32559abcb)
set(QUANT2_LSTM_OPS_TO_QUANTIZE "fusion_lstm") set(QUANT2_LSTM_OPS_TO_QUANTIZE "fusion_lstm")
### Save FP32 model or INT8 model from Quant model ### Save FP32 model or INT8 model from Quant model
set(QUANT2_INT8_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_int8") set(QUANT2_INT8_RESNET50_SAVE_PATH
save_quant_ic_model_test(save_quant2_model_resnet50 ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QUANT2_INT8_RESNET50_SAVE_PATH}) "${QUANT_INSTALL_DIR}/ResNet50_quant2_int8")
save_quant_ic_model_test(
save_quant2_model_resnet50
${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float
${QUANT2_INT8_RESNET50_SAVE_PATH})
set(QUANT2_INT8_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8") set(QUANT2_INT8_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8")
save_quant_nlp_model_test(save_quant2_model_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_INT8_ERNIE_SAVE_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE}) save_quant_nlp_model_test(
save_quant2_model_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float
${QUANT2_INT8_ERNIE_SAVE_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE})
set(QUANT2_INT8_GRU_SAVE_PATH "${QUANT_INSTALL_DIR}/GRU_quant2_int8") set(QUANT2_INT8_GRU_SAVE_PATH "${QUANT_INSTALL_DIR}/GRU_quant2_int8")
save_quant_nlp_model_test(save_quant2_model_gru ${QUANT2_GRU_MODEL_DIR}/GRU_quant_acc ${QUANT2_INT8_GRU_SAVE_PATH} ${QUANT2_GRU_OPS_TO_QUANTIZE}) save_quant_nlp_model_test(
save_quant2_model_gru ${QUANT2_GRU_MODEL_DIR}/GRU_quant_acc
${QUANT2_INT8_GRU_SAVE_PATH} ${QUANT2_GRU_OPS_TO_QUANTIZE})
set(QUANT2_INT8_LSTM_SAVE_PATH "${QUANT_INSTALL_DIR}/lstm_quant2_int8") set(QUANT2_INT8_LSTM_SAVE_PATH "${QUANT_INSTALL_DIR}/lstm_quant2_int8")
save_quant_nlp_model_test(save_quant2_model_lstm ${QUANT2_LSTM_MODEL_DIR}/lstm_quant ${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_LSTM_OPS_TO_QUANTIZE}) save_quant_nlp_model_test(
save_quant2_model_lstm ${QUANT2_LSTM_MODEL_DIR}/lstm_quant
${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_LSTM_OPS_TO_QUANTIZE})
# Convert Quant2 model to dot and pdf files # Convert Quant2 model to dot and pdf files
set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file") set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH
convert_model2dot_test(convert_model2dot_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_quant2_int8") "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file")
convert_model2dot_test(
convert_model2dot_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float
${QUANT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_quant2_int8")
### PTQ INT8 ### PTQ INT8
# PTQ int8 lstm model # PTQ int8 lstm model
set(LSTM_DATA_FILE "quant_lstm_input_data.tar.gz") set(LSTM_DATA_FILE "quant_lstm_input_data.tar.gz")
set(LSTM_URL "${INFERENCE_URL}/int8/unittest_model_data") set(LSTM_URL "${INFERENCE_URL}/int8/unittest_model_data")
download_data(${QUANT2_INT8_LSTM_SAVE_PATH} ${LSTM_URL} ${LSTM_DATA_FILE} add84c754e9b792fea1fbd728d134ab7) download_data(${QUANT2_INT8_LSTM_SAVE_PATH} ${LSTM_URL} ${LSTM_DATA_FILE}
add84c754e9b792fea1fbd728d134ab7)
set(QUANT2_FP32_LSTM_MODEL_ARCHIVE "lstm_fp32_model.tar.gz") set(QUANT2_FP32_LSTM_MODEL_ARCHIVE "lstm_fp32_model.tar.gz")
download_lstm_model(${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_FP32_LSTM_MODEL_ARCHIVE} eecd9f44d69a84acc1cf2235c4b8b743) download_lstm_model(
inference_quant2_int8_lstm_model_test(test_quant2_int8_lstm_mkldnn ${QUANT2_INT8_LSTM_SAVE_PATH}/lstm_fp32_model ${QUANT2_LSTM_MODEL_DIR}/lstm_quant ${QUANT2_INT8_LSTM_SAVE_PATH}/quant_lstm_input_data) ${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_FP32_LSTM_MODEL_ARCHIVE}
eecd9f44d69a84acc1cf2235c4b8b743)
inference_quant2_int8_lstm_model_test(
test_quant2_int8_lstm_mkldnn ${QUANT2_INT8_LSTM_SAVE_PATH}/lstm_fp32_model
${QUANT2_LSTM_MODEL_DIR}/lstm_quant
${QUANT2_INT8_LSTM_SAVE_PATH}/quant_lstm_input_data)
endif() endif()
# Since the tests for Quant & INT8 comparison support only testing on Linux # Since the tests for Quant & INT8 comparison support only testing on Linux
# with MKL-DNN, we remove it here to not test it on other systems. # with MKL-DNN, we remove it here to not test it on other systems.
list(REMOVE_ITEM TEST_OPS list(REMOVE_ITEM TEST_OPS test_mkldnn_int8_quantization_strategy
test_mkldnn_int8_quantization_strategy quant_int8_image_classification_comparison quant_int8_nlp_comparison)
quant_int8_image_classification_comparison
quant_int8_nlp_comparison)
#TODO(wanghaoshuang): Fix this unitest failed on GCC8. #TODO(wanghaoshuang): Fix this unitest failed on GCC8.
LIST(REMOVE_ITEM TEST_OPS test_auto_pruning) list(REMOVE_ITEM TEST_OPS test_auto_pruning)
LIST(REMOVE_ITEM TEST_OPS test_filter_pruning) list(REMOVE_ITEM TEST_OPS test_filter_pruning)
# fix # fix
if(WIN32) if(WIN32)
SET(SINGLE_CARD_TEST_OPS set(SINGLE_CARD_TEST_OPS
test_user_defined_quantization test_user_defined_quantization
test_quantization_scale_pass test_quantization_scale_pass
test_quantization_pass test_quantization_pass
...@@ -327,26 +493,31 @@ if(WIN32) ...@@ -327,26 +493,31 @@ if(WIN32)
test_imperative_qat test_imperative_qat
test_imperative_out_scale test_imperative_out_scale
test_graph) test_graph)
LIST(REMOVE_ITEM TEST_OPS ${SINGLE_CARD_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${SINGLE_CARD_TEST_OPS})
foreach(src ${SINGLE_CARD_TEST_OPS}) foreach(src ${SINGLE_CARD_TEST_OPS})
py_test(${src} SRCS ${src}.py ENVS CUDA_VISIBLE_DEVICES=0) py_test(${src} SRCS ${src}.py ENVS CUDA_VISIBLE_DEVICES=0)
endforeach() endforeach()
endif() endif()
foreach(src ${TEST_OPS}) foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py) py_test(${src} SRCS ${src}.py)
endforeach() endforeach()
# setting timeout value for old unittests # setting timeout value for old unittests
if(NOT WIN32) if(NOT WIN32)
set_tests_properties(test_post_training_quantization_lstm_model PROPERTIES TIMEOUT 120) set_tests_properties(test_post_training_quantization_lstm_model
set_tests_properties(test_post_training_quantization_mobilenetv1 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") PROPERTIES TIMEOUT 120)
set_tests_properties(test_post_training_quantization_resnet50 PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY") set_tests_properties(test_post_training_quantization_mobilenetv1
set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT 120) PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY")
set_tests_properties(test_post_training_quantization_while PROPERTIES TIMEOUT 120) set_tests_properties(test_post_training_quantization_resnet50
PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=NIGHTLY")
set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT
120)
set_tests_properties(test_post_training_quantization_while PROPERTIES TIMEOUT
120)
set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120) set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 120)
set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT 120) set_tests_properties(test_weight_quantization_mobilenetv1 PROPERTIES TIMEOUT
120)
endif() endif()
set_tests_properties(test_graph PROPERTIES TIMEOUT 120) set_tests_properties(test_graph PROPERTIES TIMEOUT 120)
...@@ -359,14 +530,19 @@ set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200) ...@@ -359,14 +530,19 @@ set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200)
set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200)
if(LINUX AND WITH_MKLDNN) if(LINUX AND WITH_MKLDNN)
set_tests_properties(test_quant2_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT 120) set_tests_properties(test_quant2_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT
120)
set_tests_properties(convert_model2dot_ernie PROPERTIES TIMEOUT 120) set_tests_properties(convert_model2dot_ernie PROPERTIES TIMEOUT 120)
set_tests_properties(test_quant2_int8_resnet50_channelwise_mkldnn PROPERTIES TIMEOUT 120) set_tests_properties(test_quant2_int8_resnet50_channelwise_mkldnn
set_tests_properties(test_quant_int8_mobilenetv2_mkldnn PROPERTIES TIMEOUT 120) PROPERTIES TIMEOUT 120)
set_tests_properties(test_quant2_int8_resnet50_range_mkldnn PROPERTIES TIMEOUT 120) set_tests_properties(test_quant_int8_mobilenetv2_mkldnn PROPERTIES TIMEOUT
120)
set_tests_properties(test_quant2_int8_resnet50_range_mkldnn PROPERTIES TIMEOUT
120)
set_tests_properties(save_quant2_model_resnet50 PROPERTIES TIMEOUT 120) set_tests_properties(save_quant2_model_resnet50 PROPERTIES TIMEOUT 120)
set_tests_properties(test_quant_int8_resnet50_mkldnn PROPERTIES TIMEOUT 120) set_tests_properties(test_quant_int8_resnet50_mkldnn PROPERTIES TIMEOUT 120)
set_tests_properties(test_quant_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT 120) set_tests_properties(test_quant_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT
120)
set_tests_properties(test_quant2_int8_ernie_mkldnn PROPERTIES TIMEOUT 120) set_tests_properties(test_quant2_int8_ernie_mkldnn PROPERTIES TIMEOUT 120)
set_tests_properties(test_quant_int8_googlenet_mkldnn PROPERTIES TIMEOUT 120) set_tests_properties(test_quant_int8_googlenet_mkldnn PROPERTIES TIMEOUT 120)
set_tests_properties(test_quant2_int8_resnet50_mkldnn PROPERTIES TIMEOUT 120) set_tests_properties(test_quant2_int8_resnet50_mkldnn PROPERTIES TIMEOUT 120)
...@@ -374,8 +550,10 @@ if(LINUX AND WITH_MKLDNN) ...@@ -374,8 +550,10 @@ if(LINUX AND WITH_MKLDNN)
endif() endif()
if(APPLE) if(APPLE)
set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT 300) set_tests_properties(test_post_training_quantization_mnist PROPERTIES TIMEOUT
set_tests_properties(test_post_training_quantization_while PROPERTIES TIMEOUT 300) 300)
set_tests_properties(test_post_training_quantization_while PROPERTIES TIMEOUT
300)
set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 300) set_tests_properties(test_imperative_ptq PROPERTIES TIMEOUT 300)
set_tests_properties(test_imperative_skip_op PROPERTIES TIMEOUT 300) set_tests_properties(test_imperative_skip_op PROPERTIES TIMEOUT 300)
endif() endif()
...@@ -35,8 +35,9 @@ from paddle.fluid.framework import _test_eager_guard ...@@ -35,8 +35,9 @@ from paddle.fluid.framework import _test_eager_guard
from imperative_test_utils import fix_model_dict, ImperativeLenet, ImperativeLinearBn from imperative_test_utils import fix_model_dict, ImperativeLenet, ImperativeLinearBn
from imperative_test_utils import ImperativeLinearBn_hook from imperative_test_utils import ImperativeLinearBn_hook
_logger = get_logger( _logger = get_logger(__name__,
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') logging.INFO,
fmt='%(asctime)s-%(levelname)s: %(message)s')
class TestFuseLinearBn(unittest.TestCase): class TestFuseLinearBn(unittest.TestCase):
...@@ -55,15 +56,15 @@ class TestFuseLinearBn(unittest.TestCase): ...@@ -55,15 +56,15 @@ class TestFuseLinearBn(unittest.TestCase):
quant_h = ptq.quantize(model_h, fuse=True, fuse_list=f_l) quant_h = ptq.quantize(model_h, fuse=True, fuse_list=f_l)
for name, layer in quant_model.named_sublayers(): for name, layer in quant_model.named_sublayers():
if name in f_l: if name in f_l:
assert not (isinstance(layer, nn.BatchNorm1D) or assert not (isinstance(layer, nn.BatchNorm1D)
isinstance(layer, nn.BatchNorm2D)) or isinstance(layer, nn.BatchNorm2D))
out = model(inputs) out = model(inputs)
out_h = model_h(inputs) out_h = model_h(inputs)
out_quant = quant_model(inputs) out_quant = quant_model(inputs)
out_quant_h = quant_h(inputs) out_quant_h = quant_h(inputs)
cos_sim_func = nn.CosineSimilarity(axis=0) cos_sim_func = nn.CosineSimilarity(axis=0)
print('fuse linear+bn', print('fuse linear+bn', cos_sim_func(out.flatten(),
cos_sim_func(out.flatten(), out_quant.flatten())) out_quant.flatten()))
print(cos_sim_func(out_h.flatten(), out_quant_h.flatten())) print(cos_sim_func(out_h.flatten(), out_quant_h.flatten()))
...@@ -87,8 +88,8 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -87,8 +88,8 @@ class TestImperativePTQ(unittest.TestCase):
def cache_unzipping(self, target_folder, zip_path): def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder): if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(
zip_path) target_folder, zip_path)
os.system(cmd) os.system(cmd)
def download_model(self, data_url, data_md5, folder_name): def download_model(self, data_url, data_md5, folder_name):
...@@ -123,8 +124,8 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -123,8 +124,8 @@ class TestImperativePTQ(unittest.TestCase):
def model_test(self, model, batch_num=-1, batch_size=8): def model_test(self, model, batch_num=-1, batch_size=8):
model.eval() model.eval()
test_reader = paddle.batch( test_reader = paddle.batch(paddle.dataset.mnist.test(),
paddle.dataset.mnist.test(), batch_size=batch_size) batch_size=batch_size)
eval_acc_top1_list = [] eval_acc_top1_list = []
for batch_id, data in enumerate(test_reader()): for batch_id, data in enumerate(test_reader()):
...@@ -157,8 +158,8 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -157,8 +158,8 @@ class TestImperativePTQ(unittest.TestCase):
[inference_program, feed_target_names, fetch_targets [inference_program, feed_target_names, fetch_targets
] = (paddle.static.load_inference_model(program_path, exe)) ] = (paddle.static.load_inference_model(program_path, exe))
test_reader = paddle.batch( test_reader = paddle.batch(paddle.dataset.mnist.test(),
paddle.dataset.mnist.test(), batch_size=batch_size) batch_size=batch_size)
top1_correct_num = 0. top1_correct_num = 0.
total_num = 0. total_num = 0.
...@@ -203,13 +204,13 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -203,13 +204,13 @@ class TestImperativePTQ(unittest.TestCase):
self.batch_size) self.batch_size)
input_spec = [ input_spec = [
paddle.static.InputSpec( paddle.static.InputSpec(shape=[None, 1, 28, 28], dtype='float32')
shape=[None, 1, 28, 28], dtype='float32')
] ]
with tempfile.TemporaryDirectory(prefix="imperative_ptq_") as tmpdir: with tempfile.TemporaryDirectory(prefix="imperative_ptq_") as tmpdir:
save_path = os.path.join(tmpdir, "model") save_path = os.path.join(tmpdir, "model")
self.ptq.save_quantized_model( self.ptq.save_quantized_model(model=quant_model,
model=quant_model, path=save_path, input_spec=input_spec) path=save_path,
input_spec=input_spec)
print('Quantized model saved in {%s}' % save_path) print('Quantized model saved in {%s}' % save_path)
after_acc_top1 = self.model_test(quant_model, self.batch_num, after_acc_top1 = self.model_test(quant_model, self.batch_num,
...@@ -225,12 +226,10 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -225,12 +226,10 @@ class TestImperativePTQ(unittest.TestCase):
print('After converted acc_top1: %s' % after_acc_top1) print('After converted acc_top1: %s' % after_acc_top1)
print('Infer acc_top1: %s' % infer_acc_top1) print('Infer acc_top1: %s' % infer_acc_top1)
self.assertTrue( self.assertTrue(after_acc_top1 >= self.eval_acc_top1,
after_acc_top1 >= self.eval_acc_top1,
msg="The test acc {%f} is less than {%f}." % msg="The test acc {%f} is less than {%f}." %
(after_acc_top1, self.eval_acc_top1)) (after_acc_top1, self.eval_acc_top1))
self.assertTrue( self.assertTrue(infer_acc_top1 >= after_acc_top1,
infer_acc_top1 >= after_acc_top1,
msg='The acc is lower after converting model.') msg='The acc is lower after converting model.')
end_time = time.time() end_time = time.time()
...@@ -243,6 +242,7 @@ class TestImperativePTQ(unittest.TestCase): ...@@ -243,6 +242,7 @@ class TestImperativePTQ(unittest.TestCase):
class TestImperativePTQfuse(TestImperativePTQ): class TestImperativePTQfuse(TestImperativePTQ):
def func_ptq(self): def func_ptq(self):
start_time = time.time() start_time = time.time()
...@@ -261,19 +261,19 @@ class TestImperativePTQfuse(TestImperativePTQ): ...@@ -261,19 +261,19 @@ class TestImperativePTQfuse(TestImperativePTQ):
quant_model = self.ptq.quantize(model, fuse=True, fuse_list=f_l) quant_model = self.ptq.quantize(model, fuse=True, fuse_list=f_l)
for name, layer in quant_model.named_sublayers(): for name, layer in quant_model.named_sublayers():
if name in f_l: if name in f_l:
assert not (isinstance(layer, nn.BatchNorm1D) or assert not (isinstance(layer, nn.BatchNorm1D)
isinstance(layer, nn.BatchNorm2D)) or isinstance(layer, nn.BatchNorm2D))
before_acc_top1 = self.model_test(quant_model, self.batch_num, before_acc_top1 = self.model_test(quant_model, self.batch_num,
self.batch_size) self.batch_size)
input_spec = [ input_spec = [
paddle.static.InputSpec( paddle.static.InputSpec(shape=[None, 1, 28, 28], dtype='float32')
shape=[None, 1, 28, 28], dtype='float32')
] ]
with tempfile.TemporaryDirectory(prefix="imperative_ptq_") as tmpdir: with tempfile.TemporaryDirectory(prefix="imperative_ptq_") as tmpdir:
save_path = os.path.join(tmpdir, "model") save_path = os.path.join(tmpdir, "model")
self.ptq.save_quantized_model( self.ptq.save_quantized_model(model=quant_model,
model=quant_model, path=save_path, input_spec=input_spec) path=save_path,
input_spec=input_spec)
print('Quantized model saved in {%s}' % save_path) print('Quantized model saved in {%s}' % save_path)
after_acc_top1 = self.model_test(quant_model, self.batch_num, after_acc_top1 = self.model_test(quant_model, self.batch_num,
...@@ -291,14 +291,12 @@ class TestImperativePTQfuse(TestImperativePTQ): ...@@ -291,14 +291,12 @@ class TestImperativePTQfuse(TestImperativePTQ):
#Check whether the quant_model is correct after converting. #Check whether the quant_model is correct after converting.
#The acc of quantized model should be higher than 0.95. #The acc of quantized model should be higher than 0.95.
self.assertTrue( self.assertTrue(after_acc_top1 >= self.eval_acc_top1,
after_acc_top1 >= self.eval_acc_top1,
msg="The test acc {%f} is less than {%f}." % msg="The test acc {%f} is less than {%f}." %
(after_acc_top1, self.eval_acc_top1)) (after_acc_top1, self.eval_acc_top1))
#Check the saved infer_model.The acc of infer model #Check the saved infer_model.The acc of infer model
#should not be lower than the one of dygraph model. #should not be lower than the one of dygraph model.
self.assertTrue( self.assertTrue(infer_acc_top1 >= after_acc_top1,
infer_acc_top1 >= after_acc_top1,
msg='The acc is lower after converting model.') msg='The acc is lower after converting model.')
end_time = time.time() end_time = time.time()
...@@ -311,6 +309,7 @@ class TestImperativePTQfuse(TestImperativePTQ): ...@@ -311,6 +309,7 @@ class TestImperativePTQfuse(TestImperativePTQ):
class TestImperativePTQHist(TestImperativePTQ): class TestImperativePTQHist(TestImperativePTQ):
def set_vars(self): def set_vars(self):
config = PTQConfig(HistQuantizer(), AbsmaxQuantizer()) config = PTQConfig(HistQuantizer(), AbsmaxQuantizer())
self.ptq = ImperativePTQ(config) self.ptq = ImperativePTQ(config)
...@@ -332,13 +331,14 @@ class TestImperativePTQHist(TestImperativePTQ): ...@@ -332,13 +331,14 @@ class TestImperativePTQHist(TestImperativePTQ):
class TestImperativePTQKL(TestImperativePTQ): class TestImperativePTQKL(TestImperativePTQ):
def set_vars(self): def set_vars(self):
config = PTQConfig(KLQuantizer(), PerChannelAbsmaxQuantizer()) config = PTQConfig(KLQuantizer(), PerChannelAbsmaxQuantizer())
self.ptq = ImperativePTQ(config) self.ptq = ImperativePTQ(config)
self.batch_num = 10 self.batch_num = 10
self.batch_size = 10 self.batch_size = 10
self.eval_acc_top1 = 1.0 self.eval_acc_top1 = 0.98
conv2d_1_wt_thresholds = [ conv2d_1_wt_thresholds = [
0.18116560578346252, 0.17079241573810577, 0.1702047884464264, 0.18116560578346252, 0.17079241573810577, 0.1702047884464264,
......
...@@ -34,6 +34,7 @@ np.random.seed(0) ...@@ -34,6 +34,7 @@ np.random.seed(0)
class TestPostTrainingQuantization(unittest.TestCase): class TestPostTrainingQuantization(unittest.TestCase):
def setUp(self): def setUp(self):
self.download_path = 'int8/download' self.download_path = 'int8/download'
self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
...@@ -44,8 +45,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -44,8 +45,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
try: try:
os.system("mkdir -p " + self.int8_model_path) os.system("mkdir -p " + self.int8_model_path)
except Exception as e: except Exception as e:
print("Failed to create {} due to {}".format(self.int8_model_path, print("Failed to create {} due to {}".format(
str(e))) self.int8_model_path, str(e)))
sys.exit(-1) sys.exit(-1)
def tearDown(self): def tearDown(self):
...@@ -53,8 +54,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -53,8 +54,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
def cache_unzipping(self, target_folder, zip_path): def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder): if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(
zip_path) target_folder, zip_path)
os.system(cmd) os.system(cmd)
def download_model(self, data_url, data_md5, folder_name): def download_model(self, data_url, data_md5, folder_name):
...@@ -68,6 +69,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -68,6 +69,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
return data_cache_folder return data_cache_folder
def get_batch_reader(self, data_path, place): def get_batch_reader(self, data_path, place):
def reader(): def reader():
with open(data_path, 'rb') as in_file: with open(data_path, 'rb') as in_file:
while True: while True:
...@@ -80,15 +82,14 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -80,15 +82,14 @@ class TestPostTrainingQuantization(unittest.TestCase):
seq_len = (alllen >> 16) & 0xFFFF seq_len = (alllen >> 16) & 0xFFFF
label = in_file.read(4 * label_len) label = in_file.read(4 * label_len)
label = np.frombuffer( label = np.frombuffer(label, dtype=np.int32).reshape(
label, dtype=np.int32).reshape([len(label) // 4]) [len(label) // 4])
if label.shape[0] != 1 or label[0] > 6350: if label.shape[0] != 1 or label[0] > 6350:
continue continue
feat = in_file.read(4 * seq_len * 8) feat = in_file.read(4 * seq_len * 8)
feat = np.frombuffer( feat = np.frombuffer(feat, dtype=np.float32).reshape(
feat, [len(feat) // 4 // 8, 8])
dtype=np.float32).reshape([len(feat) // 4 // 8, 8])
lod_feat = [feat.shape[0]] lod_feat = [feat.shape[0]]
minputs = fluid.create_lod_tensor(feat, [lod_feat], place) minputs = fluid.create_lod_tensor(feat, [lod_feat], place)
...@@ -97,6 +98,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -97,6 +98,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
return reader return reader
def get_simple_reader(self, data_path, place): def get_simple_reader(self, data_path, place):
def reader(): def reader():
with open(data_path, 'rb') as in_file: with open(data_path, 'rb') as in_file:
while True: while True:
...@@ -109,15 +111,14 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -109,15 +111,14 @@ class TestPostTrainingQuantization(unittest.TestCase):
seq_len = (alllen >> 16) & 0xFFFF seq_len = (alllen >> 16) & 0xFFFF
label = in_file.read(4 * label_len) label = in_file.read(4 * label_len)
label = np.frombuffer( label = np.frombuffer(label, dtype=np.int32).reshape(
label, dtype=np.int32).reshape([len(label) // 4]) [len(label) // 4])
if label.shape[0] != 1 or label[0] > 6350: if label.shape[0] != 1 or label[0] > 6350:
continue continue
feat = in_file.read(4 * seq_len * 8) feat = in_file.read(4 * seq_len * 8)
feat = np.frombuffer( feat = np.frombuffer(feat, dtype=np.float32).reshape(
feat, [len(feat) // 4 // 8, 8])
dtype=np.float32).reshape([len(feat) // 4 // 8, 8])
lod_feat = [feat.shape[0]] lod_feat = [feat.shape[0]]
minputs = fluid.create_lod_tensor(feat, [lod_feat], place) minputs = fluid.create_lod_tensor(feat, [lod_feat], place)
...@@ -178,8 +179,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -178,8 +179,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
scope = fluid.global_scope() scope = fluid.global_scope()
batch_generator = self.get_batch_reader(data_path, place) batch_generator = self.get_batch_reader(data_path, place)
ptq = PostTrainingQuantization( ptq = PostTrainingQuantization(executor=exe,
executor=exe,
model_dir=model_path, model_dir=model_path,
batch_generator=batch_generator, batch_generator=batch_generator,
batch_nums=batch_nums, batch_nums=batch_nums,
...@@ -223,10 +223,11 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -223,10 +223,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start post training quantization for {0} on {1} samples ...". print("Start post training quantization for {0} on {1} samples ...".
format(model_name, quant_iterations)) format(model_name, quant_iterations))
self.generate_quantized_model( self.generate_quantized_model(fp32_model_path, data_path, algo,
fp32_model_path, data_path, algo, round_type, quantizable_op_type, round_type, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model, is_full_quantize, is_use_cache_file,
quant_iterations, onnx_format) is_optimize_model, quant_iterations,
onnx_format)
print("Start INT8 inference for {0} on {1} samples ...".format( print("Start INT8 inference for {0} on {1} samples ...".format(
model_name, infer_iterations)) model_name, infer_iterations))
...@@ -245,6 +246,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -245,6 +246,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization): class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
def test_post_training_avg(self): def test_post_training_avg(self):
model_name = "nlp_lstm_fp32_model" model_name = "nlp_lstm_fp32_model"
model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz" model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz"
...@@ -268,6 +270,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization): ...@@ -268,6 +270,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization): class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
def test_post_training_avg_onnx_format(self): def test_post_training_avg_onnx_format(self):
model_name = "nlp_lstm_fp32_model" model_name = "nlp_lstm_fp32_model"
model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz" model_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz"
...@@ -285,8 +288,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization): ...@@ -285,8 +288,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
infer_iterations = 100 infer_iterations = 100
quant_iterations = 10 quant_iterations = 10
onnx_format = True onnx_format = True
self.run_test( self.run_test(model_name,
model_name,
model_url, model_url,
model_md5, model_md5,
data_name, data_name,
......
...@@ -33,6 +33,7 @@ np.random.seed(0) ...@@ -33,6 +33,7 @@ np.random.seed(0)
class TestPostTrainingQuantization(unittest.TestCase): class TestPostTrainingQuantization(unittest.TestCase):
def setUp(self): def setUp(self):
self.root_path = tempfile.TemporaryDirectory() self.root_path = tempfile.TemporaryDirectory()
self.int8_model_path = os.path.join(self.root_path.name, self.int8_model_path = os.path.join(self.root_path.name,
...@@ -43,8 +44,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -43,8 +44,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
try: try:
os.system("mkdir -p " + self.int8_model_path) os.system("mkdir -p " + self.int8_model_path)
except Exception as e: except Exception as e:
print("Failed to create {} due to {}".format(self.int8_model_path, print("Failed to create {} due to {}".format(
str(e))) self.int8_model_path, str(e)))
sys.exit(-1) sys.exit(-1)
def tearDown(self): def tearDown(self):
...@@ -52,8 +53,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -52,8 +53,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
def cache_unzipping(self, target_folder, zip_path): def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder): if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(
zip_path) target_folder, zip_path)
os.system(cmd) os.system(cmd)
def download_model(self, data_url, data_md5, folder_name): def download_model(self, data_url, data_md5, folder_name):
...@@ -115,14 +116,14 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -115,14 +116,14 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_size=10, batch_size=10,
batch_nums=10, batch_nums=10,
onnx_format=False, onnx_format=False,
skip_tensor_list=None): skip_tensor_list=None,
bias_correction=False):
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
val_reader = paddle.dataset.mnist.train() val_reader = paddle.dataset.mnist.train()
ptq = PostTrainingQuantization( ptq = PostTrainingQuantization(executor=exe,
executor=exe,
model_dir=model_path, model_dir=model_path,
sample_generator=val_reader, sample_generator=val_reader,
batch_size=batch_size, batch_size=batch_size,
...@@ -132,6 +133,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -132,6 +133,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
round_type=round_type, round_type=round_type,
is_full_quantize=is_full_quantize, is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model, optimize_model=is_optimize_model,
bias_correction=bias_correction,
onnx_format=onnx_format, onnx_format=onnx_format,
skip_tensor_list=skip_tensor_list, skip_tensor_list=skip_tensor_list,
is_use_cache_file=is_use_cache_file) is_use_cache_file=is_use_cache_file)
...@@ -152,6 +154,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -152,6 +154,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_size=10, batch_size=10,
infer_iterations=10, infer_iterations=10,
quant_iterations=5, quant_iterations=5,
bias_correction=False,
onnx_format=False, onnx_format=False,
skip_tensor_list=None): skip_tensor_list=None):
...@@ -160,20 +163,23 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -160,20 +163,23 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start FP32 inference for {0} on {1} images ...".format( print("Start FP32 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size)) model_name, infer_iterations * batch_size))
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program( (fp32_throughput, fp32_latency,
origin_model_path, batch_size, infer_iterations) fp32_acc1) = self.run_program(origin_model_path, batch_size,
infer_iterations)
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model_name, quant_iterations * batch_size)) format(model_name, quant_iterations * batch_size))
self.generate_quantized_model( self.generate_quantized_model(origin_model_path, algo, round_type,
origin_model_path, algo, round_type, quantizable_op_type, quantizable_op_type, is_full_quantize,
is_full_quantize, is_use_cache_file, is_optimize_model, batch_size, is_use_cache_file, is_optimize_model,
quant_iterations, onnx_format, skip_tensor_list) batch_size, quant_iterations, onnx_format,
skip_tensor_list, bias_correction)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
model_name, infer_iterations * batch_size)) model_name, infer_iterations * batch_size))
(int8_throughput, int8_latency, int8_acc1) = self.run_program( (int8_throughput, int8_latency,
self.int8_model_path, batch_size, infer_iterations) int8_acc1) = self.run_program(self.int8_model_path, batch_size,
infer_iterations)
print("---Post training quantization of {} method---".format(algo)) print("---Post training quantization of {} method---".format(algo))
print( print(
...@@ -191,6 +197,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -191,6 +197,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
class TestPostTrainingKLForMnist(TestPostTrainingQuantization): class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
def test_post_training_kl(self): def test_post_training_kl(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -212,6 +219,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization): ...@@ -212,6 +219,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
class TestPostTraininghistForMnist(TestPostTrainingQuantization): class TestPostTraininghistForMnist(TestPostTrainingQuantization):
def test_post_training_hist(self): def test_post_training_hist(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -233,6 +241,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization): ...@@ -233,6 +241,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
class TestPostTrainingmseForMnist(TestPostTrainingQuantization): class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self): def test_post_training_mse(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -254,6 +263,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization): ...@@ -254,6 +263,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
class TestPostTrainingemdForMnist(TestPostTrainingQuantization): class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self): def test_post_training_mse(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -275,6 +285,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization): ...@@ -275,6 +285,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
class TestPostTrainingavgForMnist(TestPostTrainingQuantization): class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
def test_post_training_avg(self): def test_post_training_avg(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -296,6 +307,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization): ...@@ -296,6 +307,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
def test_post_training_abs_max(self): def test_post_training_abs_max(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -317,6 +329,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): ...@@ -317,6 +329,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization): class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self): def test_post_training_mse(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -331,13 +344,25 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization): ...@@ -331,13 +344,25 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
batch_size = 10 batch_size = 10
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, round_type, bias_correction = True
quantizable_op_type, is_full_quantize, is_use_cache_file, self.run_test(model_name,
is_optimize_model, diff_threshold, batch_size, data_url,
infer_iterations, quant_iterations) data_md5,
algo,
round_type,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
batch_size,
infer_iterations,
quant_iterations,
bias_correction=bias_correction)
class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization): class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
def test_post_training_kl(self): def test_post_training_kl(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -359,6 +384,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization): ...@@ -359,6 +384,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization): class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
def test_post_training_mse_onnx_format(self): def test_post_training_mse_onnx_format(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -374,8 +400,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization): ...@@ -374,8 +400,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
batch_size = 10 batch_size = 10
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
self.run_test( self.run_test(model_name,
model_name,
data_url, data_url,
data_md5, data_md5,
algo, algo,
...@@ -393,6 +418,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization): ...@@ -393,6 +418,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
class TestPostTrainingmseForMnistONNXFormatFullQuant( class TestPostTrainingmseForMnistONNXFormatFullQuant(
TestPostTrainingQuantization): TestPostTrainingQuantization):
def test_post_training_mse_onnx_format_full_quant(self): def test_post_training_mse_onnx_format_full_quant(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -408,8 +434,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant( ...@@ -408,8 +434,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
batch_size = 10 batch_size = 10
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
self.run_test( self.run_test(model_name,
model_name,
data_url, data_url,
data_md5, data_md5,
algo, algo,
...@@ -426,6 +451,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant( ...@@ -426,6 +451,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization): class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
def test_post_training_avg_skip_op(self): def test_post_training_avg_skip_op(self):
model_name = "mnist_model" model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz" data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
...@@ -441,8 +467,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization): ...@@ -441,8 +467,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
infer_iterations = 50 infer_iterations = 50
quant_iterations = 5 quant_iterations = 5
skip_tensor_list = ["fc_0.w_0"] skip_tensor_list = ["fc_0.w_0"]
self.run_test( self.run_test(model_name,
model_name,
data_url, data_url,
data_md5, data_md5,
algo, algo,
......
...@@ -83,6 +83,7 @@ def _reader_creator(file_list, ...@@ -83,6 +83,7 @@ def _reader_creator(file_list,
color_jitter=False, color_jitter=False,
rotate=False, rotate=False,
data_dir=DATA_DIR): data_dir=DATA_DIR):
def reader(): def reader():
with open(file_list) as flist: with open(file_list) as flist:
full_lines = [line.strip() for line in flist] full_lines = [line.strip() for line in flist]
...@@ -97,8 +98,10 @@ def _reader_creator(file_list, ...@@ -97,8 +98,10 @@ def _reader_creator(file_list,
continue continue
yield img_path, int(label) yield img_path, int(label)
mapper = functools.partial( mapper = functools.partial(process_image,
process_image, mode=mode, color_jitter=color_jitter, rotate=rotate) mode=mode,
color_jitter=color_jitter,
rotate=rotate)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
...@@ -109,6 +112,7 @@ def val(data_dir=DATA_DIR): ...@@ -109,6 +112,7 @@ def val(data_dir=DATA_DIR):
class TestPostTrainingQuantization(unittest.TestCase): class TestPostTrainingQuantization(unittest.TestCase):
def setUp(self): def setUp(self):
self.int8_download = 'int8/download' self.int8_download = 'int8/download'
self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
...@@ -156,8 +160,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -156,8 +160,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
def cache_unzipping(self, target_folder, zip_path): def cache_unzipping(self, target_folder, zip_path):
if not os.path.exists(target_folder): if not os.path.exists(target_folder):
cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(
zip_path) target_folder, zip_path)
os.system(cmd) os.system(cmd)
def download_data(self, data_urls, data_md5s, folder_name, is_model=True): def download_data(self, data_urls, data_md5s, folder_name, is_model=True):
...@@ -210,10 +214,11 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -210,10 +214,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
label = label.reshape([-1, 1]) label = label.reshape([-1, 1])
t1 = time.time() t1 = time.time()
_, acc1, _ = exe.run( _, acc1, _ = exe.run(infer_program,
infer_program, feed={
feed={feed_dict[0]: image, feed_dict[0]: image,
feed_dict[1]: label}, feed_dict[1]: label
},
fetch_list=fetch_targets) fetch_list=fetch_targets)
t2 = time.time() t2 = time.time()
period = t2 - t1 period = t2 - t1
...@@ -241,13 +246,12 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -241,13 +246,12 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_full_quantize=False, is_full_quantize=False,
is_use_cache_file=False, is_use_cache_file=False,
is_optimize_model=False, is_optimize_model=False,
onnx_format=False, onnx_format=False):
skip_tensor_list=None):
try: try:
os.system("mkdir " + self.int8_model) os.system("mkdir " + self.int8_model)
except Exception as e: except Exception as e:
print("Failed to create {} due to {}".format(self.int8_model, print("Failed to create {} due to {}".format(
str(e))) self.int8_model, str(e)))
sys.exit(-1) sys.exit(-1)
place = fluid.CPUPlace() place = fluid.CPUPlace()
...@@ -255,8 +259,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -255,8 +259,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
scope = fluid.global_scope() scope = fluid.global_scope()
val_reader = val() val_reader = val()
ptq = PostTrainingQuantization( ptq = PostTrainingQuantization(executor=exe,
executor=exe,
sample_generator=val_reader, sample_generator=val_reader,
model_dir=model_path, model_dir=model_path,
algo=algo, algo=algo,
...@@ -265,7 +268,6 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -265,7 +268,6 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_full_quantize=is_full_quantize, is_full_quantize=is_full_quantize,
optimize_model=is_optimize_model, optimize_model=is_optimize_model,
onnx_format=onnx_format, onnx_format=onnx_format,
skip_tensor_list=skip_tensor_list,
is_use_cache_file=is_use_cache_file) is_use_cache_file=is_use_cache_file)
ptq.quantize() ptq.quantize()
ptq.save_quantized_model(self.int8_model) ptq.save_quantized_model(self.int8_model)
...@@ -281,8 +283,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -281,8 +283,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_use_cache_file, is_use_cache_file,
is_optimize_model, is_optimize_model,
diff_threshold, diff_threshold,
onnx_format=False, onnx_format=False):
skip_tensor_list=None):
infer_iterations = self.infer_iterations infer_iterations = self.infer_iterations
batch_size = self.batch_size batch_size = self.batch_size
sample_iterations = self.sample_iterations sample_iterations = self.sample_iterations
...@@ -291,20 +292,22 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -291,20 +292,22 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start FP32 inference for {0} on {1} images ...".format( print("Start FP32 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size)) model, infer_iterations * batch_size))
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program( (fp32_throughput, fp32_latency,
model_cache_folder + "/model", batch_size, infer_iterations) fp32_acc1) = self.run_program(model_cache_folder + "/model",
batch_size, infer_iterations)
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size)) format(model, sample_iterations * batch_size))
self.generate_quantized_model( self.generate_quantized_model(model_cache_folder + "/model",
model_cache_folder + "/model", quantizable_op_type, algo, quantizable_op_type, algo, round_type,
round_type, is_full_quantize, is_use_cache_file, is_optimize_model, is_full_quantize, is_use_cache_file,
onnx_format, skip_tensor_list) is_optimize_model, onnx_format)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size)) model, infer_iterations * batch_size))
(int8_throughput, int8_latency, int8_acc1) = self.run_program( (int8_throughput, int8_latency,
self.int8_model, batch_size, infer_iterations) int8_acc1) = self.run_program(self.int8_model, batch_size,
infer_iterations)
print("---Post training quantization of {} method---".format(algo)) print("---Post training quantization of {} method---".format(algo))
print( print(
...@@ -322,6 +325,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -322,6 +325,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_kl_mobilenetv1(self): def test_post_training_kl_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
algo = "KL" algo = "KL"
...@@ -346,6 +350,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): ...@@ -346,6 +350,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization): class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_avg_mobilenetv1(self): def test_post_training_avg_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
algo = "avg" algo = "avg"
...@@ -369,6 +374,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization): ...@@ -369,6 +374,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization): class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_hist_mobilenetv1(self): def test_post_training_hist_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
algo = "hist" algo = "hist"
...@@ -392,6 +398,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization): ...@@ -392,6 +398,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_abs_max_mobilenetv1(self): def test_post_training_abs_max_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
algo = "abs_max" algo = "abs_max"
...@@ -415,9 +422,10 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): ...@@ -415,9 +422,10 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_onnx_format_mobilenetv1(self): def test_post_training_onnx_format_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
algo = "avg" algo = "emd"
round_type = "round" round_type = "round"
data_urls = [ data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
...@@ -433,8 +441,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): ...@@ -433,8 +441,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
is_optimize_model = True is_optimize_model = True
onnx_format = True onnx_format = True
diff_threshold = 0.05 diff_threshold = 0.05
self.run_test( self.run_test(model,
model,
algo, algo,
round_type, round_type,
data_urls, data_urls,
...@@ -447,38 +454,5 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): ...@@ -447,38 +454,5 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
onnx_format=onnx_format) onnx_format=onnx_format)
class TestPostTrainingForMobilenetv1SkipOP(TestPostTrainingQuantization):
def test_post_training_mobilenetv1_skip(self):
model = "MobileNet-V1"
algo = "avg"
round_type = "round"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
skip_tensor_list = ["fc_0.w_0"]
self.run_test(
model,
algo,
round_type,
data_urls,
data_md5s,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
skip_tensor_list=skip_tensor_list)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -21,6 +21,7 @@ paddle.enable_static() ...@@ -21,6 +21,7 @@ paddle.enable_static()
class TestPostTrainingForResnet50(TestPostTrainingQuantization): class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def test_post_training_resnet50(self): def test_post_training_resnet50(self):
model = "ResNet-50" model = "ResNet-50"
algo = "min_max" algo = "min_max"
...@@ -40,6 +41,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): ...@@ -40,6 +41,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization): class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
def test_post_training_resnet50(self): def test_post_training_resnet50(self):
model = "ResNet-50" model = "ResNet-50"
algo = "min_max" algo = "min_max"
...@@ -54,8 +56,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization): ...@@ -54,8 +56,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
is_optimize_model = False is_optimize_model = False
diff_threshold = 0.025 diff_threshold = 0.025
onnx_format = True onnx_format = True
self.run_test( self.run_test(model,
model,
algo, algo,
round_type, round_type,
data_urls, data_urls,
......
...@@ -21,8 +21,6 @@ import math ...@@ -21,8 +21,6 @@ import math
from op_test import OpTest from op_test import OpTest
# numpy.round has different behavior in comparision to c++ round function
# so we use round_c instead of numpy.round to align the output data
def round_c_single_element(val): def round_c_single_element(val):
dtype = type(val) dtype = type(val)
if val >= 0: if val >= 0:
...@@ -30,6 +28,7 @@ def round_c_single_element(val): ...@@ -30,6 +28,7 @@ def round_c_single_element(val):
return dtype(np.ceil(val - 0.5)) return dtype(np.ceil(val - 0.5))
# rounding to nearest ties away from zero
round_c = np.vectorize(round_c_single_element) round_c = np.vectorize(round_c_single_element)
...@@ -41,17 +40,30 @@ def get_compute_type(dtype): ...@@ -41,17 +40,30 @@ def get_compute_type(dtype):
class TestFakeQuantizeAbsMaxOp(OpTest): class TestFakeQuantizeAbsMaxOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'fake_quantize_abs_max' self.op_type = 'fake_quantize_abs_max'
self.attrs = {'bit_length': 8} self.attrs = {'bit_length': 8}
def _fake_quantize_abs_max(self, dtype, input_shape, distribution): def _fake_quantize_abs_max(self,
dtype,
input_shape,
distribution,
round_type='TiesAwayFromZero'):
input_data = distribution(input_shape).astype(dtype) input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype) compute_type = get_compute_type(dtype)
scale = np.max(np.abs(input_data)) scale = np.max(np.abs(input_data))
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
output_data = round_c(input_data.astype(compute_type) * inv_scale * bnt) if round_type == 'TiesToEven':
round_out = np.round(
input_data.astype(compute_type) * inv_scale * bnt)
output_data = np.clip(round_out, -bnt - 1, bnt)
self.attrs['round_type'] = 0
else:
output_data = round_c(
input_data.astype(compute_type) * inv_scale * bnt)
self.attrs['round_type'] = 1
self.inputs = {'X': input_data} self.inputs = {'X': input_data}
self.outputs = {'Out': output_data, 'OutScale': scale} self.outputs = {'Out': output_data, 'OutScale': scale}
self.dtype = dtype self.dtype = dtype
...@@ -60,6 +72,11 @@ class TestFakeQuantizeAbsMaxOp(OpTest): ...@@ -60,6 +72,11 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
def test_fake_quantize_abs_max(self): def test_fake_quantize_abs_max(self):
self._fake_quantize_abs_max(np.float32, (124, 240), np.random.random) self._fake_quantize_abs_max(np.float32, (124, 240), np.random.random)
def test_fake_quantize_abs_max_round1(self):
self._fake_quantize_abs_max(np.float32, (124, 240),
np.random.random,
round_type='TiesToEven')
def test_fake_quantize_abs_max_float16(self): def test_fake_quantize_abs_max_float16(self):
self._fake_quantize_abs_max(np.float16, (124, 240), np.random.random) self._fake_quantize_abs_max(np.float16, (124, 240), np.random.random)
...@@ -72,21 +89,33 @@ class TestFakeQuantizeAbsMaxOp(OpTest): ...@@ -72,21 +89,33 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest): class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'fake_channel_wise_quantize_abs_max' self.op_type = 'fake_channel_wise_quantize_abs_max'
self.attrs = {'bit_length': 8} self.attrs = {'bit_length': 8}
def _fake_channel_wise_quantize_abs_max(self, dtype, input_shape, def _fake_channel_wise_quantize_abs_max(self,
quant_axis, distribution): dtype,
input_shape,
quant_axis,
distribution,
round_type='TiesToEven'):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.' assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.'
input_data = distribution(input_shape).astype(dtype) input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype) compute_type = get_compute_type(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
compute_axis = tuple( compute_axis = tuple(i for i in range(len(input_shape))
i for i in range(len(input_shape)) if i != quant_axis) if i != quant_axis)
scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True) scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True)
if round_type == 'TiesToEven':
round_out = np.round(
input_data.astype(compute_type) / scale_broadcast * bnt)
output_data = np.clip(round_out, -bnt - 1, bnt)
self.attrs['round_type'] = 0
else:
output_data = round_c(bnt * input_data.astype(compute_type) / output_data = round_c(bnt * input_data.astype(compute_type) /
scale_broadcast) scale_broadcast)
self.attrs['round_type'] = 1
if quant_axis == 1: if quant_axis == 1:
scale_broadcast = np.transpose(scale_broadcast, scale_broadcast = np.transpose(scale_broadcast,
(1, ) + compute_axis) (1, ) + compute_axis)
...@@ -100,19 +129,24 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest): ...@@ -100,19 +129,24 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest):
def test_fake_channel_wise_quantize_abs_max(self): def test_fake_channel_wise_quantize_abs_max(self):
dtype_options = [np.float32, np.float16] dtype_options = [np.float32, np.float16]
input_shape_quant_axis_options = [[(20, 15, 6, 6), 0], input_shape_quant_axis_options = [[(20, 15, 6, 6), 0],
[(15, 20, 5, 5), 1], [(30, 15), 0], [(20, 15, 6, 6), 1], [(30, 30), 0],
[(30, 15), 1]] [(30, 30), 1]]
for dtype, input_shape_quant_axis in itertools.product( round_type_options = ['TiesToEven', 'TiesAwayFromZero']
dtype_options, input_shape_quant_axis_options): for dtype, input_shape_quant_axis, round_type in itertools.product(
dtype_options, input_shape_quant_axis_options,
round_type_options):
input_shape, quant_axis = input_shape_quant_axis input_shape, quant_axis = input_shape_quant_axis
with self.subTest( with self.subTest(dtype=dtype,
dtype=dtype, input_shape=input_shape, input_shape=input_shape,
quant_axis=quant_axis): quant_axis=quant_axis,
round_type=round_type):
self._fake_channel_wise_quantize_abs_max( self._fake_channel_wise_quantize_abs_max(
dtype, input_shape, quant_axis, np.random.random) dtype, input_shape, quant_axis, np.random.random,
round_type)
class TestFakeQuantizeRangeAbsMaxOp(OpTest): class TestFakeQuantizeRangeAbsMaxOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'fake_quantize_range_abs_max' self.op_type = 'fake_quantize_range_abs_max'
self.attrs = {'bit_length': 5, 'window_size': 1} self.attrs = {'bit_length': 5, 'window_size': 1}
...@@ -121,7 +155,8 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): ...@@ -121,7 +155,8 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
dtype, dtype,
input_shape, input_shape,
distribution, distribution,
is_test=False): is_test=False,
round_type='TiesToEven'):
input_data = distribution(input_shape).astype(dtype) input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype) compute_type = get_compute_type(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
...@@ -130,11 +165,19 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): ...@@ -130,11 +165,19 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
out_scale[0] = np.max(np.abs(input_data)) out_scale[0] = np.max(np.abs(input_data))
if is_test: if is_test:
out_scale[0] = in_scale[0] = out_scale[0] - 1.0 out_scale[0] = in_scale[0] = out_scale[0] - 1.0
if round_type == 'TiesToEven':
round_out = np.round(
input_data.astype(compute_type) / out_scale[0] * bnt)
self.attrs['round_type'] = 0
output_data = np.clip(round_out, -bnt - 1, bnt)
else:
if is_test:
clip_data = np.clip(input_data, -in_scale, in_scale) clip_data = np.clip(input_data, -in_scale, in_scale)
else: else:
clip_data = input_data clip_data = input_data
output_data = round_c( output_data = round_c(
clip_data.astype(compute_type) / out_scale[0] * bnt) clip_data.astype(compute_type) / out_scale[0] * bnt)
self.attrs['round_type'] = 1
self.inputs = { self.inputs = {
'X': input_data, 'X': input_data,
'Iter': np.zeros(1).astype(np.int64), 'Iter': np.zeros(1).astype(np.int64),
...@@ -150,18 +193,24 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): ...@@ -150,18 +193,24 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
self.check_output() self.check_output()
def test_fake_quantize_range_abs_max(self): def test_fake_quantize_range_abs_max(self):
dtype_options = [np.float32, np.float16] dtype_options = [np.float16, np.float32]
is_test_options = [False, True] is_test_options = [False, True]
for dtype, is_test in itertools.product(dtype_options, is_test_options): round_type_options = ['TiesToEven', 'TiesAwayFromZero']
for dtype, is_test, round_type in itertools.product(
dtype_options, is_test_options, round_type_options):
self.attrs['bit_length'] = 8 if is_test else 5 self.attrs['bit_length'] = 8 if is_test else 5
with self.subTest(dtype=dtype, is_test=is_test): with self.subTest(dtype=dtype,
is_test=is_test,
round_type=round_type):
self._fake_quantize_range_abs_max( self._fake_quantize_range_abs_max(
dtype, (8, 16, 7, 7), dtype, (8, 16, 6, 6),
lambda shape: (np.random.random(shape) - 0.5) * 10, lambda shape: (np.random.random(shape) - 0.4) * 10,
is_test=is_test) is_test=is_test,
round_type=round_type)
class TestMovingAverageAbsMaxScaleOp(OpTest): class TestMovingAverageAbsMaxScaleOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'moving_average_abs_max_scale' self.op_type = 'moving_average_abs_max_scale'
self.attrs = {'moving_rate': float(0.9), 'is_test': False} self.attrs = {'moving_rate': float(0.9), 'is_test': False}
...@@ -194,6 +243,7 @@ class TestMovingAverageAbsMaxScaleOp(OpTest): ...@@ -194,6 +243,7 @@ class TestMovingAverageAbsMaxScaleOp(OpTest):
class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'fake_quantize_moving_average_abs_max' self.op_type = 'fake_quantize_moving_average_abs_max'
self.attrs = {'bit_length': 5, 'moving_rate': 0.9, 'is_test': False} self.attrs = {'bit_length': 5, 'moving_rate': 0.9, 'is_test': False}
...@@ -203,7 +253,8 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): ...@@ -203,7 +253,8 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
input_shape, input_shape,
distribution, distribution,
dequantize=False, dequantize=False,
with_gradient=False): with_gradient=False,
round_type='TiesAwayFromZero'):
input_data = distribution(input_shape).astype(dtype) input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype) compute_type = get_compute_type(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
...@@ -217,12 +268,20 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): ...@@ -217,12 +268,20 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
np.abs(input_data)) np.abs(input_data))
out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0 out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0
out_scale = out_accum / out_state out_scale = out_accum / out_state
round_data = round_c(input_data.astype(compute_type) / out_scale * bnt) if round_type == 'TiesToEven':
round_out = np.round(
input_data.astype(compute_type) / out_scale * bnt)
quant_data = np.clip(round_out, -bnt - 1, bnt)
self.attrs['round_type'] = 0
else:
quant_data = round_c(
input_data.astype(compute_type) / out_scale * bnt)
self.attrs['round_type'] = 1
if dequantize: if dequantize:
output_data = (round_data * out_scale / bnt).astype(dtype) output_data = (quant_data * out_scale / bnt).astype(dtype)
self.op_type = 'fake_quantize_dequantize_moving_average_abs_max' self.op_type = 'fake_quantize_dequantize_moving_average_abs_max'
else: else:
output_data = round_data.astype(dtype) output_data = quant_data.astype(dtype)
self.inputs = { self.inputs = {
'X': input_data, 'X': input_data,
'InScale': in_scale, 'InScale': in_scale,
...@@ -251,25 +310,39 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): ...@@ -251,25 +310,39 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
self._fake_quantize_moving_average_abs_max(np.float16, (8, 16, 7, 7), self._fake_quantize_moving_average_abs_max(np.float16, (8, 16, 7, 7),
np.random.random) np.random.random)
def test_fake_quantize_moving_average_abs_max_round1(self):
self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7),
np.random.random,
round_type='TiesToEven')
def test_fake_quantize_dequantize_moving_average_abs_max(self): def test_fake_quantize_dequantize_moving_average_abs_max(self):
self._fake_quantize_moving_average_abs_max( self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7),
np.float32, (8, 16, 7, 7),
np.random.random, np.random.random,
dequantize=True, dequantize=True,
with_gradient=True) with_gradient=True)
class TestFakeQuantizeDequantizeAbsMaxOp(OpTest): class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'fake_quantize_dequantize_abs_max' self.op_type = 'fake_quantize_dequantize_abs_max'
self.attrs = {'bit_length': 8} self.attrs = {'bit_length': 8}
def _fake_quantize_dequantize_abs_max(self, dtype, input_shape, def _fake_quantize_dequantize_abs_max(self,
distribution): dtype,
input_shape,
distribution,
round_type='TiesAwayFromZero'):
input_data = distribution(input_shape).astype(dtype) input_data = distribution(input_shape).astype(dtype)
scale = np.max(np.abs(input_data)).astype(dtype) scale = np.max(np.abs(input_data)).astype(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
if round_type == 'TiesToEven':
round_out = np.round(input_data / scale * bnt)
output_data = np.clip(round_out, -bnt - 1, bnt) * scale / bnt
self.attrs['round_type'] = 0
else:
output_data = round_c(input_data / scale * bnt) * scale / bnt output_data = round_c(input_data / scale * bnt) * scale / bnt
self.attrs['round_type'] = 1
self.inputs = {'X': input_data} self.inputs = {'X': input_data}
self.outputs = { self.outputs = {
'Out': output_data, 'Out': output_data,
...@@ -284,24 +357,41 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest): ...@@ -284,24 +357,41 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
self._fake_quantize_dequantize_abs_max(np.float32, (124, 240), self._fake_quantize_dequantize_abs_max(np.float32, (124, 240),
np.random.random) np.random.random)
def test_fake_quantize_dequantize_abs_max_round1(self):
self._fake_quantize_dequantize_abs_max(np.float32, (124, 240),
np.random.random,
round_type='TiesToEven')
class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest): class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = 'fake_channel_wise_quantize_dequantize_abs_max' self.op_type = 'fake_channel_wise_quantize_dequantize_abs_max'
self.attrs = {'bit_length': 8} self.attrs = {'bit_length': 8}
def _fake_channel_wise_quantize_dequantize_abs_max( def _fake_channel_wise_quantize_dequantize_abs_max(self,
self, dtype, input_shape, quant_axis, distribution): dtype,
input_shape,
quant_axis,
distribution,
round_type='TiesToEven'):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.' assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.'
input_data = distribution(input_shape).astype(dtype) input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype) compute_type = get_compute_type(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
output_data = input_data.copy().astype(compute_type) output_data = input_data.copy().astype(compute_type)
compute_axis = tuple( compute_axis = tuple(i for i in range(len(input_shape))
i for i in range(len(input_shape)) if i != quant_axis) if i != quant_axis)
scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True) scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True)
output_data = round_c(bnt * output_data / if round_type == 'TiesToEven':
scale_broadcast) * scale_broadcast / bnt round_out = np.round(bnt * output_data / scale_broadcast)
output_data = np.clip(round_out, -bnt - 1,
bnt) * scale_broadcast / bnt
self.attrs['round_type'] = 0
else:
output_data = round_c(
bnt * output_data / scale_broadcast) * scale_broadcast / bnt
self.attrs['round_type'] = 1
if quant_axis == 1: if quant_axis == 1:
scale_broadcast = np.transpose(scale_broadcast, scale_broadcast = np.transpose(scale_broadcast,
(1, ) + compute_axis) (1, ) + compute_axis)
...@@ -318,10 +408,19 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest): ...@@ -318,10 +408,19 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
input_shape_quant_axis_options = [[(3, 4, 64, 64), 0], input_shape_quant_axis_options = [[(3, 4, 64, 64), 0],
[(15, 20, 5, 5), 1], [(30, 15), 0], [(15, 20, 5, 5), 1], [(30, 15), 0],
[(30, 15), 1]] [(30, 15), 1]]
for input_shape, quant_axis in input_shape_quant_axis_options: round_type_options = ['TiesToEven', 'TiesAwayFromZero']
with self.subTest(input_shape=input_shape, quant_axis=quant_axis): for input_shape_quant_axis, round_type in itertools.product(
input_shape_quant_axis_options, round_type_options):
input_shape, quant_axis = input_shape_quant_axis
with self.subTest(input_shape=input_shape,
quant_axis=quant_axis,
round_type=round_type):
self._fake_channel_wise_quantize_dequantize_abs_max( self._fake_channel_wise_quantize_dequantize_abs_max(
np.float32, input_shape, quant_axis, np.random.random) np.float32,
input_shape,
quant_axis,
np.random.random,
round_type=round_type)
def quantize_max_abs(x, max_range): def quantize_max_abs(x, max_range):
...@@ -349,6 +448,7 @@ def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0): ...@@ -349,6 +448,7 @@ def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0):
class TestChannelWiseQuantizeOp(OpTest): class TestChannelWiseQuantizeOp(OpTest):
def set_args(self): def set_args(self):
self.bit_length = 8 self.bit_length = 8
self.data_type = "float32" self.data_type = "float32"
...@@ -375,6 +475,7 @@ class TestChannelWiseQuantizeOp(OpTest): ...@@ -375,6 +475,7 @@ class TestChannelWiseQuantizeOp(OpTest):
class TestChannelWiseQuantizeOp1(TestChannelWiseQuantizeOp): class TestChannelWiseQuantizeOp1(TestChannelWiseQuantizeOp):
def set_args(self): def set_args(self):
self.bit_length = 8 self.bit_length = 8
self.data_type = "float32" self.data_type = "float32"
...@@ -382,6 +483,7 @@ class TestChannelWiseQuantizeOp1(TestChannelWiseQuantizeOp): ...@@ -382,6 +483,7 @@ class TestChannelWiseQuantizeOp1(TestChannelWiseQuantizeOp):
class TestChannelWiseQuantizeOpTrain(OpTest): class TestChannelWiseQuantizeOpTrain(OpTest):
def set_args(self): def set_args(self):
self.bit_length = 8 self.bit_length = 8
self.data_type = "float32" self.data_type = "float32"
...@@ -410,6 +512,7 @@ class TestChannelWiseQuantizeOpTrain(OpTest): ...@@ -410,6 +512,7 @@ class TestChannelWiseQuantizeOpTrain(OpTest):
class TestquantizeOp(OpTest): class TestquantizeOp(OpTest):
def set_args(self): def set_args(self):
self.bit_length = 8 self.bit_length = 8
self.quant_axis = -1 self.quant_axis = -1
...@@ -436,6 +539,7 @@ class TestquantizeOp(OpTest): ...@@ -436,6 +539,7 @@ class TestquantizeOp(OpTest):
class TestquantizeOpTrain(TestquantizeOp): class TestquantizeOpTrain(TestquantizeOp):
def set_args(self): def set_args(self):
self.bit_length = 8 self.bit_length = 8
self.quant_axis = -1 self.quant_axis = -1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册