提交 eddba255 编写于 作者: 李寅

Merge branch 'gemmlowp' into 'master'

Add quantized convolution

See merge request !706
......@@ -78,10 +78,10 @@ new_http_archive(
http_archive(
name = "gemmlowp",
sha256 = "4160b941d374d1a941776625405c22c32d8cb3d64c772ce8c1683efcd56cbc98",
strip_prefix = "gemmlowp-master-cae29f7fd3ca6672012ade2894ca028461003fb4",
sha256 = "4ed0bfeb81a41d8a6b953cecb5921a6455b42661493661cf51ef42bf5bc81db3",
strip_prefix = "gemmlowp-master-cee239e8386372e857c781081e0971b241eff722",
urls = [
"https://cnbj1.fds.api.xiaomi.com/mace/third-party/gemmlowp/gemmlowp-master-cae29f7fd3ca6672012ade2894ca028461003fb4.zip",
"https://cnbj1.fds.api.xiaomi.com/mace/third-party/gemmlowp/gemmlowp-master-cee239e8386372e857c781081e0971b241eff722.zip",
],
)
......
......@@ -96,7 +96,7 @@ inline std::ostream &operator<<(std::ostream &os, unsigned char c) {
}
} // namespace numerical_chars
enum DataFormat { NHWC = 0, NCHW = 1, HWOI = 2, OIHW = 3, HWIO = 4 };
enum DataFormat { NHWC = 0, NCHW = 1, HWOI = 2, OIHW = 3, HWIO = 4, OHWI = 5 };
class Tensor {
public:
......
......@@ -38,10 +38,16 @@ int main(int argc, char **argv) {
// config runtime
mace::MaceStatus status = mace::SetOpenMPThreadsAndAffinityPolicy(
FLAGS_omp_num_threads,
static_cast<mace::CPUAffinityPolicy >(FLAGS_cpu_affinity_policy));
static_cast<mace::CPUAffinityPolicy>(FLAGS_cpu_affinity_policy));
if (status != mace::MACE_SUCCESS) {
LOG(WARNING) << "Set openmp or cpu affinity failed.";
}
status = SetGemmlowpThreadPolicy(
FLAGS_omp_num_threads,
static_cast<mace::CPUAffinityPolicy>(FLAGS_cpu_affinity_policy));
if (status != mace::MACE_SUCCESS) {
LOG(WARNING) << "Set gemmlowp threads or cpu affinity failed.";
}
mace::OpenCLRuntime::Configure(
static_cast<mace::GPUPerfHint>(FLAGS_gpu_perf_hint),
static_cast<mace::GPUPriorityHint>(FLAGS_gpu_priority_hint));
......
......@@ -20,7 +20,9 @@
#endif
#include <algorithm>
#include <functional>
#include <limits>
#include <memory>
#include <tuple>
#include <vector>
#include "mace/core/future.h"
......@@ -29,6 +31,7 @@
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/kernels/arm/conv_2d_neon.h"
#include "mace/kernels/arm/conv_winograd.h"
#include "mace/kernels/gemmlowp_util.h"
#include "mace/utils/utils.h"
#ifdef MACE_ENABLE_OPENCL
......@@ -715,6 +718,288 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
ScratchBuffer *scratch_;
};
template<>
struct Conv2dFunctor<DeviceType::CPU, uint8_t> : Conv2dFunctorBase {
Conv2dFunctor(const int *strides,
const Padding &padding_type,
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const bool is_filter_transformed,
ScratchBuffer *scratch)
: Conv2dFunctorBase(strides,
padding_type,
paddings,
dilations,
activation,
relux_max_limit),
scratch_(scratch) {
MACE_UNUSED(is_filter_transformed);
}
template <typename T>
inline void Im2col(
const T *in_data, const std::vector<index_t> &in_shape,
const index_t filter_h, const index_t filter_w, const index_t stride_h,
const index_t stride_w, const T zero_point, const int pad_height,
const int pad_width, const std::vector<index_t> &out_shape,
const index_t depth, T* im2col_data) {
const index_t batches = out_shape[0];
const index_t out_height = out_shape[1];
const index_t out_width = out_shape[2];
const index_t column_len = depth;
const index_t in_height = in_shape[1];
const index_t in_width = in_shape[2];
const index_t in_channels = in_shape[3];
const index_t input_row_size = in_width * in_channels;
const index_t patch_row_size = filter_w * in_channels;
#pragma omp parallel for collapse(3)
for (index_t b = 0; b < batches; ++b) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
// Reshape a patch of input to column, which is corresponding to
// a column of output(:, column).
const index_t ih_begin = h * stride_h - (pad_height >> 1);
const index_t ih_end = ih_begin + filter_h;
const index_t iw_begin = w * stride_w - (pad_width >> 1);
const index_t iw_end = iw_begin + filter_w;
// gate height and width to separate padding
const index_t ih_begin_gated = std::max<index_t>(0, ih_begin);
const index_t ih_end_gated = std::min<index_t>(ih_end, in_height);
const index_t iw_begin_gated = std::max<index_t>(0, iw_begin);
const index_t iw_end_gated = std::min<index_t>(iw_end, in_width);
const index_t pad_top = std::max<index_t>(0, -ih_begin);
const index_t pad_bottom = ih_end - ih_end_gated;
const index_t pad_left = std::max<index_t>(0, -iw_begin);
const index_t pad_right = iw_end - iw_end_gated;
index_t im2col_column_offset =
((b * out_height + h) * out_width + w) * column_len;
// fill in padding top
if (pad_top > 0) {
std::fill_n(im2col_data + im2col_column_offset,
pad_top * patch_row_size, zero_point);
}
const index_t patch_row_size_gated =
std::min(filter_w - pad_left,
in_width - iw_begin_gated) * in_channels;
MACE_CHECK(patch_row_size_gated ==
((filter_w - (pad_left + pad_right)) * in_channels));
const index_t pad_left_size = pad_left * in_channels;
const index_t pad_right_size = pad_right * in_channels;
index_t im2col_offset = im2col_column_offset +
(pad_top * filter_w + pad_left) * in_channels;
index_t in_offset =
((b * in_height + ih_begin_gated) * in_width + iw_begin_gated) *
in_channels;
// fill in effective rows
for (index_t ih = ih_begin_gated; ih < ih_end_gated; ++ih) {
// fill in padding left
if (pad_left > 0) {
const index_t left_offset = im2col_offset - pad_left_size;
std::fill_n(im2col_data + left_offset, pad_left_size, zero_point);
}
// copy effective data
std::copy_n(in_data + in_offset, patch_row_size_gated,
im2col_data + im2col_offset);
// fill in padding right
if (pad_right > 0) {
const index_t right_offset = im2col_offset + patch_row_size_gated;
std::fill_n(im2col_data + right_offset, pad_right_size,
zero_point);
}
in_offset += input_row_size;
im2col_offset += patch_row_size;
}
// fill in padding bottom
if (pad_bottom > 0) {
const index_t pad_bottom_size = pad_bottom * patch_row_size;
const index_t bottom_offset =
im2col_column_offset + column_len - pad_bottom_size;
std::fill_n(im2col_data + bottom_offset, pad_bottom_size,
zero_point);
}
}
}
}
}
inline void GetOutputMultiplierAndShift(
const float lhs_scale, const float rhs_scale, const float output_scale,
int32_t *quantized_multiplier, int *right_shift) {
float real_multiplier = lhs_scale * rhs_scale / output_scale;
MACE_CHECK(real_multiplier > 0.f && real_multiplier < 1.f, real_multiplier);
int exponent;
const double significand = std::frexp(real_multiplier, &exponent);
*right_shift = -exponent;
int64_t q = static_cast<int64_t>(std::round(significand * (1ll << 31)));
MACE_CHECK(q <= (1ll << 31));
if (q == (1ll << 31)) {
q /= 2;
(*right_shift)--;
}
MACE_CHECK(*right_shift >= 0);
MACE_CHECK(q <= std::numeric_limits<int32_t>::max());
*quantized_multiplier = static_cast<int32_t>(q);
}
typedef gemmlowp::VectorMap<const int32_t, gemmlowp::VectorShape::Col>
ColVectorMap;
typedef std::tuple<
gemmlowp::OutputStageBiasAddition<ColVectorMap>,
gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint,
gemmlowp::OutputStageSaturatingCastToUint8> Pipeline;
inline Pipeline MakeOutputPipeline(
const int32_t* bias_data, const index_t channels, const float lhs_scale,
const float rhs_scale, const float output_scale,
const int32_t output_zero_point) {
ColVectorMap bias_vector(bias_data, channels);
gemmlowp::OutputStageBiasAddition<ColVectorMap> bias_addition_stage;
bias_addition_stage.bias_vector = bias_vector;
int32_t quantized_multiplier;
int32_t right_shift;
GetOutputMultiplierAndShift(lhs_scale, rhs_scale, output_scale,
&quantized_multiplier, &right_shift);
gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint
quantize_down_stage;
quantize_down_stage.result_offset_after_shift = output_zero_point;
quantize_down_stage.result_fixedpoint_multiplier = quantized_multiplier;
quantize_down_stage.result_shift = right_shift;
gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage;
return std::make_tuple(bias_addition_stage, quantize_down_stage,
saturating_cast_stage);
}
MaceStatus operator()(const Tensor *input, // NHWC
const Tensor *filter, // OHWI
const Tensor *bias,
Tensor *output, // NHWC
StatsFuture *future) {
MACE_UNUSED(future);
MACE_CHECK(dilations_[0] == 1 && dilations_[1] == 1,
"Quantization convolution does not support dilation > 1 yet.");
gemmlowp::GemmContext& gemm_context = GetGemmlowpContext();
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
if (paddings_.empty()) {
CalcPaddingAndOutputSize(input->shape().data(),
NHWC,
filter->shape().data(),
OHWI,
dilations_,
strides_,
padding_type_,
output_shape.data(),
paddings.data());
} else {
paddings = paddings_;
CalcOutputSize(input->shape().data(),
NHWC,
filter->shape().data(),
OHWI,
paddings_.data(),
dilations_,
strides_,
RoundType::FLOOR,
output_shape.data());
}
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
index_t batch = output->dim(0);
index_t height = output->dim(1);
index_t width = output->dim(2);
index_t channels = output->dim(3);
index_t input_batch = input->dim(0);
index_t input_channels = input->dim(3);
index_t filter_h = filter->dim(1);
index_t filter_w = filter->dim(2);
index_t stride_h = strides_[0];
index_t stride_w = strides_[1];
const index_t depth = input_channels * filter_h * filter_w;
const index_t columns = batch * height * width;
MACE_CHECK(filter->dim(0) == channels, filter->dim(0), " != ", channels);
MACE_CHECK(filter->dim(3) == input_channels, filter->dim(3), " != ",
input_channels);
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard filter_guard(filter);
Tensor::MappingGuard output_guard(output);
auto input_data = input->data<uint8_t>();
auto filter_data = filter->data<uint8_t>();
auto output_data = output->mutable_data<uint8_t>();
index_t total_scratch_size = 0;
index_t zero_bias_size = channels * sizeof(int32_t);
total_scratch_size += (bias == nullptr ? zero_bias_size : 0);
index_t im2col_size = depth * columns * sizeof(uint8_t);
bool im2col_required =
filter_h != 1 || filter_w != 1 || stride_h != 1 || stride_w != 1;
total_scratch_size += (im2col_required ? im2col_size : 0);
scratch_->Rewind();
scratch_->GrowSize(total_scratch_size);
std::unique_ptr<Tensor> zero_bias;
const int32_t *bias_data = nullptr;
if (bias == nullptr) {
zero_bias.reset(new Tensor(scratch_->Scratch(zero_bias_size), DT_INT32));
zero_bias->Reshape({channels});
zero_bias->Clear();
bias_data = zero_bias->data<int32_t>();
} else {
bias_data = bias->data<int32_t>();
}
std::unique_ptr<Tensor> im2col;
auto gemm_input_data = input_data;
if (im2col_required) {
// prepare im2col
im2col.reset(new Tensor(scratch_->Scratch(im2col_size), DT_UINT8));
uint8_t *im2col_data = im2col->mutable_data<uint8_t>();
Im2col(input_data, input->shape(), filter_h, filter_w, stride_h,
stride_w, static_cast<uint8_t>(input->zero_point()),
paddings[0], paddings[1], output->shape(), depth, im2col_data);
gemm_input_data = im2col_data;
}
const int gemm_filter_rows = static_cast<int>(channels);
const int gemm_filter_cols = static_cast<int>(depth);
const int gemm_input_rows = static_cast<int>(depth);
const int gemm_input_cols = static_cast<int>(columns);
const int gemm_output_rows = static_cast<int>(channels);
const int gemm_output_cols = static_cast<int>(columns);
gemmlowp::MatrixMap<const uint8_t, gemmlowp::MapOrder::RowMajor>
filter_matrix(filter_data, gemm_filter_rows, gemm_filter_cols);
gemmlowp::MatrixMap<const uint8_t, gemmlowp::MapOrder::ColMajor>
input_matrix(gemm_input_data, gemm_input_rows, gemm_input_cols);
gemmlowp::MatrixMap<uint8_t, gemmlowp::MapOrder::ColMajor>
output_matrix(output_data, gemm_output_rows, gemm_output_cols);
const auto &output_pipeline = MakeOutputPipeline(
bias_data, channels, filter->scale(), input->scale(), output->scale(),
output->zero_point());
using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams;
gemmlowp::GemmWithOutputPipeline<uint8_t, uint8_t, BitDepthParams>(
&gemm_context, filter_matrix, input_matrix, &output_matrix,
-filter->zero_point(), -input->zero_point(), output_pipeline);
return MACE_SUCCESS;
}
ScratchBuffer *scratch_;
};
#ifdef MACE_ENABLE_OPENCL
template<typename T>
struct Conv2dFunctor<DeviceType::GPU, T> : Conv2dFunctorBase {
......
......@@ -20,20 +20,43 @@
namespace mace {
namespace kernels {
void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
const int *dilations,
const int *strides,
Padding padding,
index_t *output_shape,
int *padding_size) {
void CalcPaddingAndOutputSize(const index_t *input_shape,
const DataFormat input_format,
const index_t *filter_shape,
const DataFormat filter_format,
const int *dilations,
const int *strides,
Padding padding,
index_t *output_shape,
int *padding_size) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
index_t input_height = 0, input_width = 0;
index_t kernel_height = 0, kernel_width = 0;
if (input_format == NCHW) {
input_height = input_shape[2];
input_width = input_shape[3];
} else if (input_format == NHWC) {
input_height = input_shape[1];
input_width = input_shape[2];
} else {
MACE_NOT_IMPLEMENTED;
}
if (filter_format == OIHW) {
kernel_height = filter_shape[2];
kernel_width = filter_shape[3];
} else if (filter_format == OHWI) {
kernel_height = filter_shape[1];
kernel_width = filter_shape[2];
} else {
MACE_NOT_IMPLEMENTED;
}
/*
* Convlution/pooling arithmetic:
* o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1
......@@ -42,27 +65,23 @@ void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW
*/
padding_size[0] = 0;
padding_size[1] = 0;
index_t output_height = 0, output_width = 0;
index_t kernel_height = filter_shape[2];
index_t kernel_width = filter_shape[3];
index_t output_channels = filter_shape[0];
index_t k_extent_height = (kernel_height - 1) * dilations[0] + 1;
index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1;
switch (padding) {
case VALID:
output_height = (input_shape[2] - k_extent_height) / strides[0] + 1;
output_width = (input_shape[3] - k_extent_width) / strides[1] + 1;
output_height = (input_height - k_extent_height) / strides[0] + 1;
output_width = (input_width - k_extent_width) / strides[1] + 1;
break;
case SAME:
output_height = (input_shape[2] - 1) / strides[0] + 1;
output_width = (input_shape[3] - 1) / strides[1] + 1;
output_height = (input_height - 1) / strides[0] + 1;
output_width = (input_width - 1) / strides[1] + 1;
break;
case FULL:
output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1;
output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1;
output_height = (input_height + k_extent_height - 2) / strides[0] + 1;
output_width = (input_width + k_extent_width - 2) / strides[1] + 1;
break;
default:
MACE_CHECK(false, "Unsupported padding type: ", padding);
......@@ -74,14 +93,33 @@ void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW
// based on the model accuracy.
padding_size[0] = std::max<int>(
0, (output_height - 1) * strides[0] + k_extent_height - input_shape[2]);
0, (output_height - 1) * strides[0] + k_extent_height - input_height);
padding_size[1] = std::max<int>(
0, (output_width - 1) * strides[1] + k_extent_width - input_shape[3]);
0, (output_width - 1) * strides[1] + k_extent_width - input_width);
output_shape[0] = input_shape[0];
output_shape[1] = output_channels;
output_shape[2] = output_height;
output_shape[3] = output_width;
if (input_format == NCHW) {
output_shape[1] = output_channels;
output_shape[2] = output_height;
output_shape[3] = output_width;
} else if (input_format == NHWC) {
output_shape[1] = output_height;
output_shape[2] = output_width;
output_shape[3] = output_channels;
} else {
MACE_NOT_IMPLEMENTED;
}
}
void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
const int *dilations,
const int *strides,
Padding padding,
index_t *output_shape,
int *padding_size) {
CalcPaddingAndOutputSize(input_shape, NCHW, filter_shape, OIHW, dilations,
strides, padding, output_shape, padding_size);
}
void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
......@@ -91,65 +129,86 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
Padding padding,
index_t *output_shape,
int *padding_size) {
CalcPaddingAndOutputSize(input_shape, NHWC, filter_shape, OIHW, dilations,
strides, padding, output_shape, padding_size);
}
void CalcOutputSize(const index_t *input_shape,
const DataFormat input_format,
const index_t *filter_shape,
const DataFormat filter_format,
const int *padding_size,
const int *dilations,
const int *strides,
const RoundType round_type,
index_t *output_shape) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
index_t input_height = 0, input_width = 0;
index_t kernel_height = 0, kernel_width = 0;
if (input_format == NCHW) {
input_height = input_shape[2];
input_width = input_shape[3];
} else if (input_format == NHWC) {
input_height = input_shape[1];
input_width = input_shape[2];
} else {
MACE_NOT_IMPLEMENTED;
}
if (filter_format == OIHW) {
kernel_height = filter_shape[2];
kernel_width = filter_shape[3];
} else if (filter_format == OHWI) {
kernel_height = filter_shape[1];
kernel_width = filter_shape[2];
} else {
MACE_NOT_IMPLEMENTED;
}
/*
* Convlution/pooling arithmetic:
* o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
padding_size[0] = 0;
padding_size[1] = 0;
index_t output_height = 0, output_width = 0;
index_t output_channels = filter_shape[0];
index_t kernel_height = filter_shape[2];
index_t kernel_width = filter_shape[3];
index_t k_extent_height = (kernel_height - 1) * dilations[0] + 1;
index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1;
switch (padding) {
case VALID:
output_height = (input_shape[1] - k_extent_height) / strides[0] + 1;
output_width = (input_shape[2] - k_extent_width) / strides[1] + 1;
break;
case SAME:
output_height = (input_shape[1] - 1) / strides[0] + 1;
output_width = (input_shape[2] - 1) / strides[1] + 1;
break;
case FULL:
output_height = (input_shape[1] + k_extent_height - 2) / strides[0] + 1;
output_width = (input_shape[2] + k_extent_width - 2) / strides[1] + 1;
break;
default:
MACE_CHECK(false, "Unsupported padding type: ", padding);
if (round_type == FLOOR) {
output_height = static_cast<index_t>(
std::floor(1.0 * (input_height + padding_size[0] - kernel_height -
(kernel_height - 1) * (dilations[0] - 1)) / strides[0]) + 1);
output_width = static_cast<index_t>(
std::floor(1.0 * (input_width + padding_size[1] - kernel_width -
(kernel_width - 1) * (dilations[1] - 1)) / strides[1]) + 1);
} else {
output_height = static_cast<index_t>(
std::ceil(1.0 * (input_height + padding_size[0] - kernel_height -
(kernel_height - 1) * (dilations[0] - 1)) / strides[0]) + 1);
output_width = static_cast<index_t>(
std::ceil(1.0 * (input_width + padding_size[1] - kernel_width -
(kernel_width - 1) * (dilations[1] - 1)) / strides[1]) + 1);
}
// Note: TensorFlow may padded one more on the right/bottom side
// TODO(liuqi): may be it's better to also truncate the left/top to
// utilize the more centered features. We need to benchmark
// based on the model accuracy.
padding_size[0] = std::max<int>(
0, (output_height - 1) * strides[0] + k_extent_height - input_shape[1]);
padding_size[1] = std::max<int>(
0, (output_width - 1) * strides[1] + k_extent_width - input_shape[2]);
output_shape[0] = input_shape[0];
output_shape[1] = output_height;
output_shape[2] = output_width;
output_shape[3] = output_channels;
if (input_format == NCHW) {
output_shape[1] = output_channels;
output_shape[2] = output_height;
output_shape[3] = output_width;
} else if (input_format == NHWC) {
output_shape[1] = output_height;
output_shape[2] = output_width;
output_shape[3] = output_channels;
} else {
MACE_NOT_IMPLEMENTED;
}
}
void CalcOutputSize(const index_t *input_shape, // NHWC
const index_t *filter_shape, // OIHW
const int *padding_size,
......@@ -157,39 +216,8 @@ void CalcOutputSize(const index_t *input_shape, // NHWC
const int *strides,
const RoundType round_type,
index_t *output_shape) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
output_shape[0] = input_shape[0];
if (round_type == FLOOR) {
output_shape[1] = static_cast<index_t>(
std::floor(1.0 * (input_shape[1] + padding_size[0] - filter_shape[2] -
(filter_shape[2] - 1) * (dilations[0] - 1)) /
strides[0]) +
1);
output_shape[2] = static_cast<index_t>(
std::floor(1.0 * (input_shape[2] + padding_size[1] - filter_shape[3] -
(filter_shape[3] - 1) * (dilations[1] - 1)) /
strides[1]) +
1);
} else {
output_shape[1] = static_cast<index_t>(
std::ceil(1.0 * (input_shape[1] + padding_size[0] - filter_shape[2] -
(filter_shape[2] - 1) * (dilations[0] - 1)) /
strides[0]) +
1);
output_shape[2] = static_cast<index_t>(
std::ceil(1.0 * (input_shape[2] + padding_size[1] - filter_shape[3] -
(filter_shape[3] - 1) * (dilations[1] - 1)) /
strides[1]) +
1);
}
output_shape[3] = filter_shape[0];
CalcOutputSize(input_shape, NHWC, filter_shape, OIHW, padding_size, dilations,
strides, round_type, output_shape);
}
void CalcNCHWOutputSize(const index_t *input_shape, // NCHW
......@@ -199,46 +227,8 @@ void CalcNCHWOutputSize(const index_t *input_shape, // NCHW
const int *strides,
const RoundType round_type,
index_t *output_shape) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
/*
* Convolution arithmetic:
* o = floor((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1
* Pooling arithmetic:
* o = ceil((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
output_shape[0] = input_shape[0];
if (round_type == FLOOR) {
output_shape[2] = static_cast<index_t>(
std::floor(1.0 * (input_shape[2] + padding_size[0] - filter_shape[2] -
(filter_shape[2] - 1) * (dilations[0] - 1)) /
strides[0]) +
1);
output_shape[3] = static_cast<index_t>(
std::floor(1.0 * (input_shape[3] + padding_size[1] - filter_shape[3] -
(filter_shape[3] - 1) * (dilations[1] - 1)) /
strides[1]) +
1);
} else {
output_shape[2] = static_cast<index_t>(
std::ceil(1.0 * (input_shape[2] + padding_size[0] - filter_shape[2] -
(filter_shape[2] - 1) * (dilations[0] - 1)) /
strides[0]) +
1);
output_shape[3] = static_cast<index_t>(
std::ceil(1.0 * (input_shape[3] + padding_size[1] - filter_shape[3] -
(filter_shape[3] - 1) * (dilations[1] - 1)) /
strides[1]) +
1);
}
output_shape[1] = filter_shape[0];
CalcOutputSize(input_shape, NCHW, filter_shape, OIHW, padding_size, dilations,
strides, round_type, output_shape);
}
void CalPaddingSize(const index_t *input_shape, // NCHW
......
......@@ -32,14 +32,24 @@ enum RoundType {
namespace kernels {
void CalcNCHWPaddingAndOutputSize(const index_t *input_shape,
void CalcPaddingAndOutputSize(const index_t *input_shape,
const DataFormat input_format,
const index_t *filter_shape,
const DataFormat filter_format,
const int *dilations,
const int *strides,
Padding padding,
index_t *output_shape,
int *padding_size);
void CalcNCHWPaddingAndOutputSize(const index_t *input_shape,
const index_t *filter_shape,
const int *dilations,
const int *strides,
Padding padding,
index_t *output_shape,
int *padding_size);
void CalcNHWCPaddingAndOutputSize(const index_t *input_shape,
const index_t *filter_shape,
const int *dilations,
......@@ -48,6 +58,16 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape,
index_t *output_shape,
int *padding_size);
void CalcOutputSize(const index_t *input_shape,
const DataFormat input_format,
const index_t *filter_shape,
const DataFormat filter_format,
const int *padding_size,
const int *dilations,
const int *strides,
const RoundType round_type,
index_t *output_shape);
void CalcOutputSize(const index_t *input_shape, // NHWC
const index_t *filter_shape, // OIHW
const int *padding_size,
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/gemmlowp_util.h"
#include <algorithm>
#include <vector>
#include "mace/core/runtime/cpu/cpu_runtime.h"
namespace mace {
gemmlowp::GemmContext& GetGemmlowpContext() {
static auto *gemm_context = new gemmlowp::GemmContext;
return *gemm_context;
}
MaceStatus SetGemmlowpThreadPolicy(int num_threads_hint,
CPUAffinityPolicy policy) {
gemmlowp::GemmContext& gemm_context = GetGemmlowpContext();
if (policy != AFFINITY_NONE) {
std::vector<int> big_core_ids;
std::vector<int> little_core_ids;
MaceStatus res = GetCPUBigLittleCoreIDs(&big_core_ids, &little_core_ids);
if (res != MACE_SUCCESS) {
return res;
}
int use_cpu_size;
if (policy == CPUAffinityPolicy::AFFINITY_BIG_ONLY) {
use_cpu_size = static_cast<int>(big_core_ids.size());
} else {
use_cpu_size = static_cast<int>(little_core_ids.size());
}
if (num_threads_hint <= 0 || num_threads_hint > use_cpu_size) {
num_threads_hint = use_cpu_size;
}
}
gemm_context.set_max_num_threads(std::max(0, num_threads_hint));
return MACE_SUCCESS;
}
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_GEMMLOWP_UTIL_H_
#define MACE_KERNELS_GEMMLOWP_UTIL_H_
#include <iostream>
#include "public/gemmlowp.h"
namespace mace {
gemmlowp::GemmContext& GetGemmlowpContext();
} // namespace mace
#endif // MACE_KERNELS_GEMMLOWP_UTIL_H_
......@@ -25,7 +25,7 @@ cc_library(
],
deps = [
"//mace/core",
"@gtest//:gtest",
"@gtest",
],
)
......
......@@ -24,6 +24,12 @@ void Register_Conv2D(OperatorRegistryBase *op_registry) {
.Build(),
Conv2dOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
.Build(),
Conv2dOp<DeviceType::CPU, uint8_t>);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D")
.Device(DeviceType::GPU)
......
......@@ -99,6 +99,67 @@ void Conv2d(int iters,
net.Sync();
}
}
template <>
void Conv2d<CPU, uint8_t>(int iters,
int batch,
int channels,
int height,
int width,
int kernel_h,
int kernel_w,
int stride,
int dilation,
Padding padding,
int output_channels) {
mace::testing::StopTiming();
if (dilation > 1) {
LOG(WARNING) << "uint8_t benchmarking dilation = 1 instead.";
}
OpsTestNet net;
// Add input data
net.AddRandomInput<DeviceType::CPU, uint8_t>(
"Input", {batch, height, width, channels});
Tensor *input = net.GetTensor("Input");
input->SetScale(0.00705);
input->SetZeroPoint(114);
net.AddRandomInput<DeviceType::CPU, uint8_t>(
"Filter", {output_channels, kernel_h, kernel_w, channels});
Tensor *filter = net.GetTensor("Filter");
filter->SetScale(0.0066);
filter->SetZeroPoint(113);
net.AddRandomInput<DeviceType::CPU, int32_t>("Bias", {output_channels});
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
net.Setup(DeviceType::CPU);
Tensor *output = net.GetTensor("Output");
output->SetScale(0.0107);
output->SetZeroPoint(118);
// Warm-up
for (int i = 0; i < 2; ++i) {
net.Run();
net.Sync();
}
mace::testing::StartTiming();
while (iters--) {
net.Run();
net.Sync();
}
}
} // namespace
// In common network, there are usually more than 1 layers, this is used to
......@@ -135,7 +196,8 @@ void Conv2d(int iters,
#define MACE_BM_CONV_2D(N, C, H, W, KH, KW, S, D, P, OC) \
MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, CPU); \
MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, GPU); \
MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, half, GPU);
MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, half, GPU); \
MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, uint8_t, CPU);
......
......@@ -15,6 +15,7 @@
#include <fstream>
#include <vector>
#include "mace/kernels/quantize.h"
#include "mace/ops/conv_2d.h"
#include "mace/ops/ops_test_util.h"
......@@ -1069,6 +1070,163 @@ TEST_F(Conv2dOpTest, OPENCLUnalignedPad4) {
TestArbitraryPadConvNxN<DeviceType::GPU, float>({107, 113, 5, 7}, {4, 4});
}
namespace {
void TestQuantSimple3x3() {
OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, uint8_t>(
"Filter", {1, 3, 3, 2},
{102, 150, 123, 135, 1, 216, 137, 47, 53, 75, 145, 130, 171, 62, 255,
122, 72, 211}, 0.0226, 127);
net.AddInputFromArray<DeviceType::CPU, uint8_t>(
"Input", {1, 3, 3, 2},
{1, 75, 117, 161, 127, 119, 94, 151, 203, 151, 84, 61, 55, 142, 113, 139,
3, 255}, 0.0204, 93);
net.AddInputFromArray<DeviceType::CPU, int32_t>("Bias", {1}, {2});
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DT_UINT8))
.Finalize(net.NewOperatorDef());
net.Setup(DeviceType::CPU);
Tensor *output = net.GetTensor("Output");
output->SetScale(0.000711);
output->SetZeroPoint(1);
// Run
net.Run();
// Check
auto expected = CreateTensor<uint8_t>({1, 1, 1, 1}, {230});
ExpectTensorNear<uint8_t>(*expected, *output);
}
void TestQuant(const index_t batch,
const index_t out_channels,
const index_t in_channels,
const index_t in_height,
const index_t in_width,
const index_t k_height,
const index_t k_width,
enum Padding padding_type,
const std::vector<int> &strides) {
OpsTestNet net;
net.AddRandomInput<CPU, float>("Input", {batch, in_height, in_width,
in_channels});
net.AddRandomInput<CPU, float>("Filter", {out_channels, k_height, k_width,
in_channels});
net.AddRandomInput<CPU, float>("Bias", {out_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter", OHWI, "FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding_type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DT_FLOAT))
.Finalize(net.NewOperatorDef());
net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW,
"Output", NHWC);
OpDefBuilder("Quantize", "QuantizeFilter")
.Input("Filter")
.Output("QuantizedFilter")
.OutputType({DT_UINT8})
.AddIntArg("T", DT_UINT8)
.AddIntArg("non_zero", true)
.Finalize(net.NewOperatorDef());
net.RunOp();
OpDefBuilder("Quantize", "QuantizeInput")
.Input("Input")
.Output("QuantizedInput")
.OutputType({DT_UINT8})
.AddIntArg("T", DT_UINT8)
.AddIntArg("non_zero", true)
.Finalize(net.NewOperatorDef());
net.RunOp();
OpDefBuilder("Quantize", "QuantizeOutput")
.Input("Output")
.Output("ExpectedQuantizedOutput")
.OutputType({DT_UINT8})
.AddIntArg("T", DT_UINT8)
.AddIntArg("non_zero", true)
.Finalize(net.NewOperatorDef());
net.RunOp();
Tensor *q_filter = net.GetTensor("QuantizedFilter");
Tensor *q_input = net.GetTensor("QuantizedInput");
Tensor *bias = net.GetTensor("Bias");
auto bias_data = bias->data<float>();
std::vector<int32_t> q_bias(bias->size());
kernels::QuantizeWithScaleAndZeropoint(
bias_data, bias->size(), q_input->scale() * q_filter->scale(), 0,
q_bias.data());
net.AddInputFromArray<DeviceType::CPU, int32_t>("QuantizedBias",
{out_channels}, q_bias);
OpDefBuilder("Conv2D", "QuantizeConv2dTest")
.Input("QuantizedInput")
.Input("QuantizedFilter")
.Input("QuantizedBias")
.Output("QuantizedOutput")
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding_type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DT_UINT8))
.Finalize(net.NewOperatorDef());
net.Setup(DeviceType::CPU);
Tensor *eq_output = net.GetTensor("ExpectedQuantizedOutput");
Tensor *q_output = net.GetTensor("QuantizedOutput");
q_output->SetScale(eq_output->scale());
q_output->SetZeroPoint(eq_output->zero_point());
net.Run();
OpDefBuilder("Dequantize", "DeQuantizeTest")
.Input("QuantizedOutput")
.Output("DequantizedOutput")
.OutputType({DT_FLOAT})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
net.RunOp();
// Check
ExpectTensorSimilar<float>(*net.GetOutput("Output"),
*net.GetTensor("DequantizedOutput"), 0.01);
}
} // namespace
TEST_F(Conv2dOpTest, Quant) {
TestQuantSimple3x3();
TestQuant(1, 128, 64, 32, 32, 1, 1, VALID, {1, 1});
TestQuant(1, 128, 64, 32, 32, 3, 3, VALID, {1, 1});
TestQuant(1, 128, 64, 32, 32, 3, 3, SAME, {1, 1});
TestQuant(1, 128, 64, 32, 32, 3, 3, FULL, {1, 1});
TestQuant(1, 128, 64, 32, 32, 3, 3, SAME, {2, 2});
TestQuant(1, 129, 63, 33, 31, 3, 3, SAME, {1, 1});
TestQuant(9, 128, 64, 32, 32, 3, 3, SAME, {1, 1});
TestQuant(1, 128, 64, 32, 32, 1, 5, SAME, {1, 1});
TestQuant(1, 128, 64, 32, 32, 5, 5, SAME, {1, 1});
TestQuant(1, 128, 64, 32, 32, 5, 1, SAME, {1, 1});
TestQuant(1, 128, 64, 32, 32, 7, 7, SAME, {1, 1});
TestQuant(1, 128, 64, 32, 32, 7, 7, SAME, {2, 2});
TestQuant(1, 128, 64, 32, 32, 7, 7, SAME, {3, 3});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -116,7 +116,9 @@ class OpsTestNet {
template <DeviceType D, typename T>
void AddInputFromArray(const std::string &name,
const std::vector<index_t> &shape,
const std::vector<T> &data) {
const std::vector<T> &data,
const float scale = 0.0,
const int32_t zero_point = 0) {
Tensor *input =
ws_.CreateTensor(name, GetDeviceAllocator(D), DataTypeToEnum<T>::v());
input->Resize(shape);
......@@ -124,6 +126,8 @@ class OpsTestNet {
T *input_data = input->mutable_data<T>();
MACE_CHECK(static_cast<size_t>(input->size()) == data.size());
memcpy(input_data, data.data(), data.size() * sizeof(T));
input->SetScale(scale);
input->SetZeroPoint(zero_point);
}
template <DeviceType D, typename T>
......@@ -303,6 +307,26 @@ class OpsTestNet {
}
}
}
} else if (src_format == OHWI && dst_format == OIHW) {
index_t out_channels = input_shape[0];
index_t height = input_shape[1];
index_t width = input_shape[2];
index_t in_channels = input_shape[3];
output->Resize({out_channels, in_channels, height, width});
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
for (index_t b = 0; b < out_channels; ++b) {
for (index_t c = 0; c < in_channels; ++c) {
for (index_t h = 0; h < height; ++h) {
for (index_t w = 0; w < width; ++w) {
output_data[((b * in_channels + c) * height + h) * width + w] =
input_data[((b * height + h) * width + w) * in_channels + c];
}
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -632,6 +656,25 @@ void ExpectTensorNear(const Tensor &x,
Expector<EXP_TYPE, RES_TYPE>::Near(x, y, rel_err, abs_err);
}
template <typename T>
void ExpectTensorSimilar(const Tensor &x,
const Tensor &y,
const double abs_err = 1e-5) {
AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y);
auto x_data = x.data<T>();
auto y_data = y.data<T>();
double dot_product = 0.0, x_norm = 0.0, y_norm = 0.0;
for (index_t i = 0; i < x.size(); i++) {
dot_product += x_data[i] * y_data[i];
x_norm += x_data[i] * x_data[i];
y_norm += y_data[i] * y_data[i];
}
double similarity = dot_product / (sqrt(x_norm) * sqrt(y_norm));
EXPECT_NEAR(1.0, similarity, abs_err);
}
template <DeviceType D, typename T>
void BufferToImage(OpsTestNet *net,
const std::string &input_name,
......
......@@ -180,6 +180,16 @@ __attribute__((visibility("default")))
MaceStatus GetBigLittleCoreIDs(std::vector<int> *big_core_ids,
std::vector<int> *little_core_ids);
// Set gemmlowp threads number and processor affinity.
// gemmlowp is used by mace for quantization.
// Caution: this function may hurt performance if improper parameters provided.
//
// This function may not work well on some chips (e.g. MTK). Setting thread
// affinity to offline cores may run very slow or unexpectedly. In such cases,
// please use SetGemmlowpThreadPolicy with default policy instead.
__attribute__((visibility("default")))
MaceStatus SetGemmlowpThreadPolicy(int num_threads_hint,
CPUAffinityPolicy policy);
} // namespace mace
#endif // MACE_PUBLIC_MACE_RUNTIME_H_
......@@ -91,7 +91,7 @@ T &&CheckNotNull(const char *file, int line, const char *exprtext, T &&t) {
#define MACE_CHECK_NOTNULL(val) \
::mace::logging::CheckNotNull(__FILE__, __LINE__, \
"'" #val "' Must be non NULL", (val))
"'" #val "' Must not be NULL", (val))
#define MACE_NOT_IMPLEMENTED MACE_CHECK(false, "not implemented")
......
......@@ -9,6 +9,7 @@ build --verbose_failures
build --copt=-std=c++11
build --copt=-D_GLIBCXX_USE_C99_MATH_TR1
build --copt=-DMACE_OBFUSCATE_LITERALS
build --copt=-DGEMMLOWP_USE_OPENMP
# Usage example: bazel build --config android
build:android --crosstool_top=//external:android/crosstool
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册