提交 73609ed6 编写于 作者: 李寅

Merge branch 'quantize' into 'master'

Fix quantize input

See merge request !954
......@@ -33,6 +33,7 @@ cc_library(
"buffer_transform.cc",
"lstm_cell.cc",
"quantize.cc",
"quantization_util.cc",
],
) + if_opencl_enabled(glob(
[
......@@ -48,6 +49,7 @@ cc_library(
)) + if_quantize_enabled(glob(
[
"quantize.cc",
"quantization_util.cc",
],
)),
hdrs = glob(
......@@ -61,6 +63,7 @@ cc_library(
"fixpoint.h",
"gemmlowp_util.h",
"arm/fixpoint_*.h",
"quantization_util.h",
],
) + if_opencl_enabled(glob([
"opencl/*.h",
......@@ -70,6 +73,7 @@ cc_library(
"fixpoint.h",
"gemmlowp_util.h",
"arm/fixpoint_*.h",
"quantization_util.h",
])),
copts = [
"-Werror",
......
......@@ -35,6 +35,7 @@
#ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/gemmlowp_util.h"
#include "mace/ops/quantization_util.h"
#endif // MACE_ENABLE_QUANTIZE
#ifdef MACE_ENABLE_OPENCL
......@@ -802,33 +803,22 @@ class Conv2dOp<DeviceType::CPU, uint8_t> : public ConvPool2dOpBase {
auto input_data = input->data<uint8_t>();
auto filter_data = filter->data<uint8_t>();
auto output_data = output->mutable_data<uint8_t>();
auto bias_data = GetBiasData(bias,
input->scale(),
filter->scale(),
channels,
&bias_);
index_t total_scratch_size = 0;
index_t zero_bias_size = channels * sizeof(int32_t);
total_scratch_size += (bias == nullptr ? zero_bias_size : 0);
index_t im2col_size = depth * columns * sizeof(uint8_t);
auto gemm_input_data = input_data;
std::unique_ptr<Tensor> im2col;
bool im2col_required =
filter_h != 1 || filter_w != 1 || stride_h != 1 || stride_w != 1;
total_scratch_size += (im2col_required ? im2col_size : 0);
ScratchBuffer *scratch = context->device()->scratch_buffer();
scratch->Rewind();
scratch->GrowSize(total_scratch_size);
std::unique_ptr<Tensor> zero_bias;
const int32_t *bias_data = nullptr;
if (bias == nullptr) {
zero_bias.reset(new Tensor(scratch->Scratch(zero_bias_size), DT_INT32));
zero_bias->Reshape({channels});
zero_bias->Clear();
bias_data = zero_bias->data<int32_t>();
} else {
bias_data = bias->data<int32_t>();
}
std::unique_ptr<Tensor> im2col;
auto gemm_input_data = input_data;
if (im2col_required) {
// prepare im2col
index_t im2col_size = depth * columns * sizeof(uint8_t);
ScratchBuffer *scratch = context->device()->scratch_buffer();
scratch->Rewind();
scratch->GrowSize(im2col_size);
im2col.reset(new Tensor(scratch->Scratch(im2col_size), DT_UINT8));
uint8_t *im2col_data = im2col->mutable_data<uint8_t>();
Im2col(input_data, input->shape(), filter_h, filter_w, stride_h,
......@@ -950,6 +940,7 @@ class Conv2dOp<DeviceType::CPU, uint8_t> : public ConvPool2dOpBase {
const ActivationType activation_;
const float relux_max_limit_;
const float leakyrelu_coefficient_;
std::vector<int32_t> bias_;
private:
MACE_OP_INPUT_TAGS(INPUT, FILTER, BIAS);
......
......@@ -1076,8 +1076,9 @@ void TestQuantSimple3x3() {
"Input", {1, 3, 3, 2},
{1, 75, 117, 161, 127, 119, 94, 151, 203, 151, 84, 61, 55, 142, 113, 139,
3, 255}, false, 0.0204, 93);
net.AddInputFromArray<DeviceType::CPU, int32_t>(
"Bias", {1}, {2}, true, 0.00046104, 0);
net.AddInputFromArray<DeviceType::CPU, int32_t>("Bias", {1}, {2}, true);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
......@@ -1167,12 +1168,13 @@ void TestQuant(const index_t batch,
Tensor *q_input = net.GetTensor("QuantizedInput");
Tensor *bias = net.GetTensor("Bias");
auto bias_data = bias->data<float>();
float bias_scale = q_input->scale() * q_filter->scale();
std::vector<int32_t> q_bias(bias->size());
QuantizeWithScaleAndZeropoint(
bias_data, bias->size(), q_input->scale() * q_filter->scale(), 0,
q_bias.data());
net.AddInputFromArray<DeviceType::CPU, int32_t>("QuantizedBias",
{out_channels}, q_bias, true);
bias_data, bias->size(), bias_scale, 0, q_bias.data());
net.AddInputFromArray<DeviceType::CPU, int32_t>(
"QuantizedBias", {out_channels}, q_bias, true, bias_scale, 0);
OpDefBuilder("Conv2D", "QuantizeConv2dTest")
.Input("QuantizedInput")
.Input("QuantizedFilter")
......
......@@ -21,6 +21,7 @@
#include <vector>
#ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/quantization_util.h"
// We reuse TensorFlow Lite's optimized depthwiseconv_uint8 and parallelized it
// using OpenMP for MACE's quantized depthwise_conv2d.
#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
......@@ -355,21 +356,13 @@ class DepthwiseConv2dOp<DeviceType::CPU, uint8_t>
auto input_data = input->data<uint8_t>();
auto filter_data = filter->data<uint8_t>();
auto output_data = output->mutable_data<uint8_t>();
auto bias_data = GetBiasData(bias,
input->scale(),
filter->scale(),
out_channels,
&bias_);
if (dilation_h == 1 && dilation_w == 1) {
std::vector<index_t> bias_shape{out_channels};
std::unique_ptr<Tensor> zero_bias;
const int32_t *bias_data = nullptr;
if (bias == nullptr) {
zero_bias.reset(
new Tensor(GetCPUAllocator(), DT_INT32));
zero_bias->Resize(bias_shape);
zero_bias->Clear();
bias_data = zero_bias->data<int32_t>();
} else {
bias_data = bias->data<int32_t>();
}
int32_t quantized_multiplier;
int32_t right_shift;
GetOutputMultiplierAndShift(input->scale(), filter->scale(),
......@@ -378,6 +371,7 @@ class DepthwiseConv2dOp<DeviceType::CPU, uint8_t>
// 1HWO
std::vector<index_t> filter_shape{
1, filter->dim(0), filter->dim(1), filter->dim(2) * filter->dim(3)};
std::vector<index_t> bias_shape{out_channels};
tflite::optimized_ops::DepthwiseConv(
input_data, ShapeToTfliteDims(input->shape()), -input->zero_point(),
......@@ -387,7 +381,6 @@ class DepthwiseConv2dOp<DeviceType::CPU, uint8_t>
quantized_multiplier, right_shift, 0, 255, output_data,
ShapeToTfliteDims(output->shape()));
} else {
auto bias_data = bias == nullptr ? nullptr : bias->data<int32_t>();
float output_multiplier =
input->scale() * filter->scale() / output->scale();
const int pad_hw[2] = {pad_top, pad_left};
......@@ -485,6 +478,9 @@ class DepthwiseConv2dOp<DeviceType::CPU, uint8_t>
protected:
MACE_OP_INPUT_TAGS(INPUT, FILTER, BIAS);
MACE_OP_OUTPUT_TAGS(OUTPUT);
private:
std::vector<int32_t> bias_;
};
#endif // MACE_ENABLE_QUANTIZE
......
......@@ -345,7 +345,9 @@ void QuantSimpleValidTest() {
"Filter", {3, 3, 2, 1},
{212, 239, 110, 170, 216, 91, 162, 161, 255, 2, 10, 120, 183, 101, 100,
33, 137, 51}, true, 0.0137587, 120);
net.AddInputFromArray<CPU, int32_t>("Bias", {2}, {2, 2}, true);
net.AddInputFromArray<CPU, int32_t>(
"Bias", {2}, {2, 2}, true, 0.000101168, 0);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input")
.Input("Filter")
......@@ -436,12 +438,13 @@ void TestQuant(const index_t batch,
Tensor *q_input = net.GetTensor("QuantizedInput");
Tensor *bias = net.GetTensor("Bias");
auto bias_data = bias->data<float>();
float bias_scale = q_input->scale() * q_filter->scale();
std::vector<int32_t> q_bias(bias->size());
QuantizeWithScaleAndZeropoint(
bias_data, bias->size(), q_input->scale() * q_filter->scale(), 0,
q_bias.data());
bias_data, bias->size(), bias_scale, 0, q_bias.data());
net.AddInputFromArray<DeviceType::CPU, int32_t>(
"QuantizedBias", {out_channels}, q_bias, true);
"QuantizedBias", {out_channels}, q_bias, true, bias_scale, 0);
OpDefBuilder("DepthwiseConv2d", "QuantizedDepthwiseConv2DTest")
.Input("QuantizedInput")
.Input("QuantizedFilter")
......
......@@ -24,6 +24,7 @@
#ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/gemmlowp_util.h"
#include "mace/ops/quantization_util.h"
#endif // MACE_ENABLE_QUANTIZE
#ifdef MACE_ENABLE_OPENCL
......@@ -155,19 +156,11 @@ class FullyConnectedOp<DeviceType::CPU, uint8_t>
auto input_ptr = input->data<uint8_t>();
auto weight_ptr = weight->data<uint8_t>();
auto output_ptr = output->mutable_data<uint8_t>();
std::vector<index_t> bias_shape{output_size};
std::unique_ptr<Tensor> zero_bias;
const int32_t *bias_ptr = nullptr;
if (bias == nullptr) {
zero_bias.reset(
new Tensor(GetCPUAllocator(), DT_INT32));
zero_bias->Resize(bias_shape);
zero_bias->Clear();
bias_ptr = zero_bias->data<int32_t>();
} else {
bias_ptr = bias->data<int32_t>();
}
auto bias_ptr = GetBiasData(bias,
input->scale(),
weight->scale(),
output_size,
&bias_);
gemmlowp::MatrixMap<const uint8_t, gemmlowp::MapOrder::RowMajor>
weight_matrix(weight_ptr, output_size, input_size);
......@@ -187,6 +180,9 @@ class FullyConnectedOp<DeviceType::CPU, uint8_t>
return MaceStatus::MACE_SUCCESS;
}
private:
std::vector<int32_t> bias_;
};
#endif // MACE_ENABLE_QUANTIZE
......
......@@ -259,12 +259,12 @@ void QuantRandom(const index_t batch,
Tensor *q_input = net.GetTensor("QuantizedInput");
Tensor *bias = net.GetTensor("Bias");
auto bias_data = bias->data<float>();
float bias_scale = q_input->scale() * q_weight->scale();
std::vector<int32_t> q_bias(bias->size());
QuantizeWithScaleAndZeropoint(
bias_data, bias->size(), q_input->scale() * q_weight->scale(), 0,
q_bias.data());
net.AddInputFromArray<DeviceType::CPU, int32_t>("QuantizedBias",
{out_channel}, q_bias);
bias_data, bias->size(), bias_scale, 0, q_bias.data());
net.AddInputFromArray<DeviceType::CPU, int32_t>(
"QuantizedBias", {out_channel}, q_bias, true, bias_scale, 0);
OpDefBuilder("FullyConnected", "QuantizeFullyConnectedTest")
.Input("QuantizedInput")
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/quantization_util.h"
namespace mace {
namespace ops {
const int32_t *GetBiasData(const Tensor *bias,
const float input_scale,
const float filter_scale,
const index_t channels,
std::vector<int32_t> *bias_vec) {
const int32_t *bias_data = nullptr;
if (bias == nullptr) {
bias_vec->resize(channels, 0);
bias_data = bias_vec->data();
} else {
auto original_bias_data = bias->data<int32_t>();
bool adjust_bias_required =
fabs(input_scale * filter_scale - bias->scale()) > 1e-6;
if (!adjust_bias_required) {
bias_data = original_bias_data;
} else {
bias_vec->resize(channels);
float adjust_scale = bias->scale() / (input_scale * filter_scale);
for (index_t i = 0; i < channels; ++i) {
(*bias_vec)[i] = static_cast<int32_t>(
roundf(original_bias_data[i] * adjust_scale));
}
bias_data = bias_vec->data();
}
}
return bias_data;
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_QUANTIZATION_UTIL_H_
#define MACE_OPS_QUANTIZATION_UTIL_H_
#include <vector>
#include "mace/core/tensor.h"
namespace mace {
namespace ops {
const int32_t *GetBiasData(const Tensor *bias,
const float input_scale,
const float filter_scale,
const index_t channels,
std::vector<int32_t> *bias_vec);
} // namespace ops
} // namespace mace
#endif // MACE_OPS_QUANTIZATION_UTIL_H_
......@@ -218,6 +218,7 @@ class MaceKeyword(object):
mace_variance_str = 'variance'
mace_step_h_str = 'step_h'
mace_step_w_str = 'step_w'
mace_find_range_every_time = 'find_range_every_time'
class TransformerRule(Enum):
......
......@@ -117,6 +117,10 @@ class Transformer(base_converter.ConverterInterface):
self._quantize_activation_info = {}
self._quantized_tensor = set()
self.input_name_map = {}
self.output_name_map = {}
self.initialize_name_map()
def run(self):
for key in self._option.transformer_option:
transformer = self._registered_transformers[key]
......@@ -128,6 +132,18 @@ class Transformer(base_converter.ConverterInterface):
self.delete_after_check_nodes()
return self._model, self._quantize_activation_info
def initialize_name_map(self):
for input_node in self._option.input_nodes.values():
new_input_name = MaceKeyword.mace_input_node_name \
+ '_' + input_node.name
self.input_name_map[input_node.name] = new_input_name
output_nodes = self._option.check_nodes.values()
for output_node in output_nodes:
new_output_name = MaceKeyword.mace_output_node_name \
+ '_' + output_node.name
self.output_name_map[output_node.name] = new_output_name
def filter_format(self):
filter_format_value = ConverterUtil.get_arg(self._model,
MaceKeyword.mace_filter_format_str).i # noqa
......@@ -1382,29 +1398,16 @@ class Transformer(base_converter.ConverterInterface):
return False
print("Add mace quantize and dequantize nodes")
input_name_map = {}
output_name_map = {}
for input_node in self._option.input_nodes.values():
new_input_name = MaceKeyword.mace_input_node_name \
+ '_' + input_node.name
input_name_map[input_node.name] = new_input_name
output_nodes = self._option.check_nodes.values()
for output_node in output_nodes:
new_output_name = MaceKeyword.mace_output_node_name \
+ '_' + output_node.name
output_name_map[output_node.name] = new_output_name
for op in self._model.op:
for i in range(len(op.input)):
if op.input[i] in input_name_map:
op.input[i] = input_name_map[op.input[i]]
if op.input[i] in self.input_name_map:
op.input[i] = self.input_name_map[op.input[i]]
for i in range(len(op.output)):
if op.output[i] in output_name_map:
if op.output[i] in self.output_name_map:
op.name = MaceKeyword.mace_output_node_name \
+ '_' + op.name
new_output_name = output_name_map[op.output[i]]
new_output_name = self.output_name_map[op.output[i]]
self._quantize_activation_info[new_output_name] = \
self._quantize_activation_info[op.output[i]]
op.output[i] = new_output_name
......@@ -1427,23 +1430,31 @@ class Transformer(base_converter.ConverterInterface):
% (op.name, op.type))
for input_node in self._option.input_nodes.values():
new_input_name = self.input_name_map[input_node.name]
op_def = self._model.op.add()
op_def.name = \
self.normalize_op_name(input_name_map[input_node.name])
op_def.name = self.normalize_op_name(new_input_name)
op_def.type = MaceOp.Quantize.name
op_def.input.extend([input_node.name])
op_def.output.extend([input_name_map[input_node.name]])
op_def.output.extend([new_input_name])
output_shape = op_def.output_shape.add()
output_shape.dims.extend(input_node.shape)
self.copy_quantize_info(
op_def, self._quantize_activation_info[new_input_name])
ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8)
ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC)
# use actual ranges for model input quantize
find_range_every_time_arg = op_def.arg.add()
find_range_every_time_arg.name = \
MaceKeyword.mace_find_range_every_time
find_range_every_time_arg.i = 1
output_nodes = self._option.check_nodes.values()
for output_node in output_nodes:
op_def = self._model.op.add()
op_def.name = self.normalize_op_name(output_node.name)
op_def.type = MaceOp.Dequantize.name
op_def.input.extend([output_name_map[output_node.name]])
op_def.input.extend([self.output_name_map[output_node.name]])
op_def.output.extend([output_node.name])
output_shape = op_def.output_shape.add()
output_shape.dims.extend(
......@@ -1651,6 +1662,24 @@ class Transformer(base_converter.ConverterInterface):
if not self._option.quantize:
return False
print("Add default quantize info for input")
for input_node in self._option.input_nodes.values():
if input_node.name not in self._quantize_activation_info:
print("Input range %s: %s" % (input_node.name,
str(input_node.range)))
new_input_name = self.input_name_map[input_node.name]
scale, zero, minval, maxval = \
quantize_util.adjust_range(input_node.range[0],
input_node.range[1],
non_zero=False)
quantize_info = mace_pb2.QuantizeActivationInfo()
quantize_info.minval = minval
quantize_info.maxval = maxval
quantize_info.scale = scale
quantize_info.zero_point = zero
self._quantize_activation_info[new_input_name] = quantize_info
print("Add default quantize info for ops like Pooling, Softmax")
for op in self._model.op:
if op.type in [MaceOp.Pooling.name,
......@@ -1661,7 +1690,12 @@ class Transformer(base_converter.ConverterInterface):
MaceOp.SpaceToBatchND.name]:
del op.quantize_info[:]
producer_op = self._producer[op.input[0]]
self.copy_quantize_info(op, producer_op.quantize_info[0])
if producer_op.output[0] in self._option.input_nodes:
new_input_name = self.input_name_map[producer_op.output[0]]
self.copy_quantize_info(
op, self._quantize_activation_info[new_input_name])
else:
self.copy_quantize_info(op, producer_op.quantize_info[0])
self._quantize_activation_info[op.output[0]] = \
op.quantize_info[0]
elif (op.type == MaceOp.Concat.name
......@@ -1709,24 +1743,6 @@ class Transformer(base_converter.ConverterInterface):
self.add_quantize_info(op, minval, maxval)
self._quantize_activation_info[op.output[0]] = quantize_info
print("Add default quantize info for input")
for input_node in self._option.input_nodes.values():
if input_node.name not in self._quantize_activation_info:
print("Input range %s: %s" % (input_node.name,
str(input_node.range)))
new_input_name = MaceKeyword.mace_input_node_name \
+ '_' + input_node.name
scale, zero, minval, maxval = \
quantize_util.adjust_range(input_node.range[0],
input_node.range[1],
non_zero=False)
quantize_info = mace_pb2.QuantizeActivationInfo()
quantize_info.minval = minval
quantize_info.maxval = maxval
quantize_info.scale = scale
quantize_info.zero_point = zero
self._quantize_activation_info[new_input_name] = quantize_info
return False
def check_quantize_info(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册