From e33a05278595b50320ff6731b72fc00ae33a0d50 Mon Sep 17 00:00:00 2001 From: Bin Li Date: Thu, 2 Aug 2018 19:39:59 +0800 Subject: [PATCH] Add quantized convolution --- WORKSPACE | 6 +- mace/core/tensor.h | 2 +- mace/core/testing/test_benchmark_main.cc | 8 +- mace/kernels/conv_2d.h | 285 +++++++++++++++++++++++ mace/kernels/conv_pool_2d_util.cc | 260 ++++++++++----------- mace/kernels/conv_pool_2d_util.h | 22 +- mace/kernels/gemmlowp_util.cc | 58 +++++ mace/kernels/gemmlowp_util.h | 27 +++ mace/ops/BUILD | 2 +- mace/ops/conv_2d.cc | 6 + mace/ops/conv_2d_benchmark.cc | 64 ++++- mace/ops/conv_2d_test.cc | 158 +++++++++++++ mace/ops/ops_test_util.h | 45 +++- mace/public/mace_runtime.h | 10 + mace/utils/logging.h | 2 +- tools/bazel.rc | 1 + 16 files changed, 811 insertions(+), 145 deletions(-) create mode 100644 mace/kernels/gemmlowp_util.cc create mode 100644 mace/kernels/gemmlowp_util.h diff --git a/WORKSPACE b/WORKSPACE index 61bd9c08..d7e2449f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -78,10 +78,10 @@ new_http_archive( http_archive( name = "gemmlowp", - sha256 = "4160b941d374d1a941776625405c22c32d8cb3d64c772ce8c1683efcd56cbc98", - strip_prefix = "gemmlowp-master-cae29f7fd3ca6672012ade2894ca028461003fb4", + sha256 = "4ed0bfeb81a41d8a6b953cecb5921a6455b42661493661cf51ef42bf5bc81db3", + strip_prefix = "gemmlowp-master-cee239e8386372e857c781081e0971b241eff722", urls = [ - "https://cnbj1.fds.api.xiaomi.com/mace/third-party/gemmlowp/gemmlowp-master-cae29f7fd3ca6672012ade2894ca028461003fb4.zip", + "https://cnbj1.fds.api.xiaomi.com/mace/third-party/gemmlowp/gemmlowp-master-cee239e8386372e857c781081e0971b241eff722.zip", ], ) diff --git a/mace/core/tensor.h b/mace/core/tensor.h index aef57969..e48edc87 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -96,7 +96,7 @@ inline std::ostream &operator<<(std::ostream &os, unsigned char c) { } } // namespace numerical_chars -enum DataFormat { NHWC = 0, NCHW = 1, HWOI = 2, OIHW = 3, HWIO = 4 }; +enum DataFormat { NHWC = 0, NCHW = 1, HWOI = 2, OIHW = 3, HWIO = 4, OHWI = 5 }; class Tensor { public: diff --git a/mace/core/testing/test_benchmark_main.cc b/mace/core/testing/test_benchmark_main.cc index 7af19d1c..e730c10e 100644 --- a/mace/core/testing/test_benchmark_main.cc +++ b/mace/core/testing/test_benchmark_main.cc @@ -38,10 +38,16 @@ int main(int argc, char **argv) { // config runtime mace::MaceStatus status = mace::SetOpenMPThreadsAndAffinityPolicy( FLAGS_omp_num_threads, - static_cast(FLAGS_cpu_affinity_policy)); + static_cast(FLAGS_cpu_affinity_policy)); if (status != mace::MACE_SUCCESS) { LOG(WARNING) << "Set openmp or cpu affinity failed."; } + status = SetGemmlowpThreadPolicy( + FLAGS_omp_num_threads, + static_cast(FLAGS_cpu_affinity_policy)); + if (status != mace::MACE_SUCCESS) { + LOG(WARNING) << "Set gemmlowp threads or cpu affinity failed."; + } mace::OpenCLRuntime::Configure( static_cast(FLAGS_gpu_perf_hint), static_cast(FLAGS_gpu_priority_hint)); diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 9654d967..f51eb7b1 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -20,7 +20,9 @@ #endif #include #include +#include #include +#include #include #include "mace/core/future.h" @@ -29,6 +31,7 @@ #include "mace/kernels/conv_pool_2d_util.h" #include "mace/kernels/arm/conv_2d_neon.h" #include "mace/kernels/arm/conv_winograd.h" +#include "mace/kernels/gemmlowp_util.h" #include "mace/utils/utils.h" #ifdef MACE_ENABLE_OPENCL @@ -715,6 +718,288 @@ struct Conv2dFunctor : Conv2dFunctorBase { ScratchBuffer *scratch_; }; +template<> +struct Conv2dFunctor : Conv2dFunctorBase { + Conv2dFunctor(const int *strides, + const Padding &padding_type, + const std::vector &paddings, + const int *dilations, + const ActivationType activation, + const float relux_max_limit, + const bool is_filter_transformed, + ScratchBuffer *scratch) + : Conv2dFunctorBase(strides, + padding_type, + paddings, + dilations, + activation, + relux_max_limit), + scratch_(scratch) { + MACE_UNUSED(is_filter_transformed); + } + + template + inline void Im2col( + const T *in_data, const std::vector &in_shape, + const index_t filter_h, const index_t filter_w, const index_t stride_h, + const index_t stride_w, const T zero_point, const int pad_height, + const int pad_width, const std::vector &out_shape, + const index_t depth, T* im2col_data) { + const index_t batches = out_shape[0]; + const index_t out_height = out_shape[1]; + const index_t out_width = out_shape[2]; + const index_t column_len = depth; + const index_t in_height = in_shape[1]; + const index_t in_width = in_shape[2]; + const index_t in_channels = in_shape[3]; + const index_t input_row_size = in_width * in_channels; + const index_t patch_row_size = filter_w * in_channels; + +#pragma omp parallel for collapse(3) + for (index_t b = 0; b < batches; ++b) { + for (index_t h = 0; h < out_height; ++h) { + for (index_t w = 0; w < out_width; ++w) { + // Reshape a patch of input to column, which is corresponding to + // a column of output(:, column). + const index_t ih_begin = h * stride_h - (pad_height >> 1); + const index_t ih_end = ih_begin + filter_h; + const index_t iw_begin = w * stride_w - (pad_width >> 1); + const index_t iw_end = iw_begin + filter_w; + // gate height and width to separate padding + const index_t ih_begin_gated = std::max(0, ih_begin); + const index_t ih_end_gated = std::min(ih_end, in_height); + const index_t iw_begin_gated = std::max(0, iw_begin); + const index_t iw_end_gated = std::min(iw_end, in_width); + const index_t pad_top = std::max(0, -ih_begin); + const index_t pad_bottom = ih_end - ih_end_gated; + const index_t pad_left = std::max(0, -iw_begin); + const index_t pad_right = iw_end - iw_end_gated; + index_t im2col_column_offset = + ((b * out_height + h) * out_width + w) * column_len; + + // fill in padding top + if (pad_top > 0) { + std::fill_n(im2col_data + im2col_column_offset, + pad_top * patch_row_size, zero_point); + } + + const index_t patch_row_size_gated = + std::min(filter_w - pad_left, + in_width - iw_begin_gated) * in_channels; + MACE_CHECK(patch_row_size_gated == + ((filter_w - (pad_left + pad_right)) * in_channels)); + const index_t pad_left_size = pad_left * in_channels; + const index_t pad_right_size = pad_right * in_channels; + index_t im2col_offset = im2col_column_offset + + (pad_top * filter_w + pad_left) * in_channels; + index_t in_offset = + ((b * in_height + ih_begin_gated) * in_width + iw_begin_gated) * + in_channels; + + // fill in effective rows + for (index_t ih = ih_begin_gated; ih < ih_end_gated; ++ih) { + // fill in padding left + if (pad_left > 0) { + const index_t left_offset = im2col_offset - pad_left_size; + std::fill_n(im2col_data + left_offset, pad_left_size, zero_point); + } + // copy effective data + std::copy_n(in_data + in_offset, patch_row_size_gated, + im2col_data + im2col_offset); + // fill in padding right + if (pad_right > 0) { + const index_t right_offset = im2col_offset + patch_row_size_gated; + std::fill_n(im2col_data + right_offset, pad_right_size, + zero_point); + } + in_offset += input_row_size; + im2col_offset += patch_row_size; + } + + // fill in padding bottom + if (pad_bottom > 0) { + const index_t pad_bottom_size = pad_bottom * patch_row_size; + const index_t bottom_offset = + im2col_column_offset + column_len - pad_bottom_size; + std::fill_n(im2col_data + bottom_offset, pad_bottom_size, + zero_point); + } + } + } + } + } + + inline void GetOutputMultiplierAndShift( + const float lhs_scale, const float rhs_scale, const float output_scale, + int32_t *quantized_multiplier, int *right_shift) { + float real_multiplier = lhs_scale * rhs_scale / output_scale; + MACE_CHECK(real_multiplier > 0.f && real_multiplier < 1.f, real_multiplier); + int exponent; + const double significand = std::frexp(real_multiplier, &exponent); + *right_shift = -exponent; + int64_t q = static_cast(std::round(significand * (1ll << 31))); + MACE_CHECK(q <= (1ll << 31)); + if (q == (1ll << 31)) { + q /= 2; + (*right_shift)--; + } + MACE_CHECK(*right_shift >= 0); + MACE_CHECK(q <= std::numeric_limits::max()); + *quantized_multiplier = static_cast(q); + } + + typedef gemmlowp::VectorMap + ColVectorMap; + typedef std::tuple< + gemmlowp::OutputStageBiasAddition, + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, + gemmlowp::OutputStageSaturatingCastToUint8> Pipeline; + inline Pipeline MakeOutputPipeline( + const int32_t* bias_data, const index_t channels, const float lhs_scale, + const float rhs_scale, const float output_scale, + const int32_t output_zero_point) { + ColVectorMap bias_vector(bias_data, channels); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + int32_t quantized_multiplier; + int32_t right_shift; + GetOutputMultiplierAndShift(lhs_scale, rhs_scale, output_scale, + &quantized_multiplier, &right_shift); + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint + quantize_down_stage; + quantize_down_stage.result_offset_after_shift = output_zero_point; + quantize_down_stage.result_fixedpoint_multiplier = quantized_multiplier; + quantize_down_stage.result_shift = right_shift; + + gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage; + return std::make_tuple(bias_addition_stage, quantize_down_stage, + saturating_cast_stage); + } + + MaceStatus operator()(const Tensor *input, // NHWC + const Tensor *filter, // OHWI + const Tensor *bias, + Tensor *output, // NHWC + StatsFuture *future) { + MACE_UNUSED(future); + MACE_CHECK(dilations_[0] == 1 && dilations_[1] == 1, + "Quantization convolution does not support dilation > 1 yet."); + + gemmlowp::GemmContext& gemm_context = GetGemmlowpContext(); + + std::vector output_shape(4); + std::vector paddings(2); + if (paddings_.empty()) { + CalcPaddingAndOutputSize(input->shape().data(), + NHWC, + filter->shape().data(), + OHWI, + dilations_, + strides_, + padding_type_, + output_shape.data(), + paddings.data()); + } else { + paddings = paddings_; + CalcOutputSize(input->shape().data(), + NHWC, + filter->shape().data(), + OHWI, + paddings_.data(), + dilations_, + strides_, + RoundType::FLOOR, + output_shape.data()); + } + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + + index_t batch = output->dim(0); + index_t height = output->dim(1); + index_t width = output->dim(2); + index_t channels = output->dim(3); + index_t input_batch = input->dim(0); + index_t input_channels = input->dim(3); + index_t filter_h = filter->dim(1); + index_t filter_w = filter->dim(2); + index_t stride_h = strides_[0]; + index_t stride_w = strides_[1]; + const index_t depth = input_channels * filter_h * filter_w; + const index_t columns = batch * height * width; + + MACE_CHECK(filter->dim(0) == channels, filter->dim(0), " != ", channels); + MACE_CHECK(filter->dim(3) == input_channels, filter->dim(3), " != ", + input_channels); + MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard filter_guard(filter); + Tensor::MappingGuard output_guard(output); + + auto input_data = input->data(); + auto filter_data = filter->data(); + auto output_data = output->mutable_data(); + + index_t total_scratch_size = 0; + index_t zero_bias_size = channels * sizeof(int32_t); + total_scratch_size += (bias == nullptr ? zero_bias_size : 0); + index_t im2col_size = depth * columns * sizeof(uint8_t); + bool im2col_required = + filter_h != 1 || filter_w != 1 || stride_h != 1 || stride_w != 1; + total_scratch_size += (im2col_required ? im2col_size : 0); + scratch_->Rewind(); + scratch_->GrowSize(total_scratch_size); + + std::unique_ptr zero_bias; + const int32_t *bias_data = nullptr; + if (bias == nullptr) { + zero_bias.reset(new Tensor(scratch_->Scratch(zero_bias_size), DT_INT32)); + zero_bias->Reshape({channels}); + zero_bias->Clear(); + bias_data = zero_bias->data(); + } else { + bias_data = bias->data(); + } + + std::unique_ptr im2col; + auto gemm_input_data = input_data; + if (im2col_required) { + // prepare im2col + im2col.reset(new Tensor(scratch_->Scratch(im2col_size), DT_UINT8)); + uint8_t *im2col_data = im2col->mutable_data(); + Im2col(input_data, input->shape(), filter_h, filter_w, stride_h, + stride_w, static_cast(input->zero_point()), + paddings[0], paddings[1], output->shape(), depth, im2col_data); + gemm_input_data = im2col_data; + } + + const int gemm_filter_rows = static_cast(channels); + const int gemm_filter_cols = static_cast(depth); + const int gemm_input_rows = static_cast(depth); + const int gemm_input_cols = static_cast(columns); + const int gemm_output_rows = static_cast(channels); + const int gemm_output_cols = static_cast(columns); + gemmlowp::MatrixMap + filter_matrix(filter_data, gemm_filter_rows, gemm_filter_cols); + gemmlowp::MatrixMap + input_matrix(gemm_input_data, gemm_input_rows, gemm_input_cols); + gemmlowp::MatrixMap + output_matrix(output_data, gemm_output_rows, gemm_output_cols); + + const auto &output_pipeline = MakeOutputPipeline( + bias_data, channels, filter->scale(), input->scale(), output->scale(), + output->zero_point()); + + using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; + gemmlowp::GemmWithOutputPipeline( + &gemm_context, filter_matrix, input_matrix, &output_matrix, + -filter->zero_point(), -input->zero_point(), output_pipeline); + + return MACE_SUCCESS; + } + + ScratchBuffer *scratch_; +}; + #ifdef MACE_ENABLE_OPENCL template struct Conv2dFunctor : Conv2dFunctorBase { diff --git a/mace/kernels/conv_pool_2d_util.cc b/mace/kernels/conv_pool_2d_util.cc index 2af765e3..72a03072 100644 --- a/mace/kernels/conv_pool_2d_util.cc +++ b/mace/kernels/conv_pool_2d_util.cc @@ -20,20 +20,43 @@ namespace mace { namespace kernels { -void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW - const index_t *filter_shape, // OIHW - const int *dilations, - const int *strides, - Padding padding, - index_t *output_shape, - int *padding_size) { +void CalcPaddingAndOutputSize(const index_t *input_shape, + const DataFormat input_format, + const index_t *filter_shape, + const DataFormat filter_format, + const int *dilations, + const int *strides, + Padding padding, + index_t *output_shape, + int *padding_size) { MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, "Invalid dilations, must >= 1"); MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && - (dilations[1] == 1 || strides[1] == 1), + (dilations[1] == 1 || strides[1] == 1), "If dilations > 1, strides should be 1"); MACE_CHECK_NOTNULL(output_shape); MACE_CHECK_NOTNULL(padding_size); + + index_t input_height = 0, input_width = 0; + index_t kernel_height = 0, kernel_width = 0; + if (input_format == NCHW) { + input_height = input_shape[2]; + input_width = input_shape[3]; + } else if (input_format == NHWC) { + input_height = input_shape[1]; + input_width = input_shape[2]; + } else { + MACE_NOT_IMPLEMENTED; + } + if (filter_format == OIHW) { + kernel_height = filter_shape[2]; + kernel_width = filter_shape[3]; + } else if (filter_format == OHWI) { + kernel_height = filter_shape[1]; + kernel_width = filter_shape[2]; + } else { + MACE_NOT_IMPLEMENTED; + } /* * Convlution/pooling arithmetic: * o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1 @@ -42,27 +65,23 @@ void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW */ padding_size[0] = 0; padding_size[1] = 0; - index_t output_height = 0, output_width = 0; - index_t kernel_height = filter_shape[2]; - index_t kernel_width = filter_shape[3]; index_t output_channels = filter_shape[0]; - index_t k_extent_height = (kernel_height - 1) * dilations[0] + 1; index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1; switch (padding) { case VALID: - output_height = (input_shape[2] - k_extent_height) / strides[0] + 1; - output_width = (input_shape[3] - k_extent_width) / strides[1] + 1; + output_height = (input_height - k_extent_height) / strides[0] + 1; + output_width = (input_width - k_extent_width) / strides[1] + 1; break; case SAME: - output_height = (input_shape[2] - 1) / strides[0] + 1; - output_width = (input_shape[3] - 1) / strides[1] + 1; + output_height = (input_height - 1) / strides[0] + 1; + output_width = (input_width - 1) / strides[1] + 1; break; case FULL: - output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1; - output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1; + output_height = (input_height + k_extent_height - 2) / strides[0] + 1; + output_width = (input_width + k_extent_width - 2) / strides[1] + 1; break; default: MACE_CHECK(false, "Unsupported padding type: ", padding); @@ -74,14 +93,33 @@ void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW // based on the model accuracy. padding_size[0] = std::max( - 0, (output_height - 1) * strides[0] + k_extent_height - input_shape[2]); + 0, (output_height - 1) * strides[0] + k_extent_height - input_height); padding_size[1] = std::max( - 0, (output_width - 1) * strides[1] + k_extent_width - input_shape[3]); + 0, (output_width - 1) * strides[1] + k_extent_width - input_width); output_shape[0] = input_shape[0]; - output_shape[1] = output_channels; - output_shape[2] = output_height; - output_shape[3] = output_width; + if (input_format == NCHW) { + output_shape[1] = output_channels; + output_shape[2] = output_height; + output_shape[3] = output_width; + } else if (input_format == NHWC) { + output_shape[1] = output_height; + output_shape[2] = output_width; + output_shape[3] = output_channels; + } else { + MACE_NOT_IMPLEMENTED; + } +} + +void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW + const index_t *filter_shape, // OIHW + const int *dilations, + const int *strides, + Padding padding, + index_t *output_shape, + int *padding_size) { + CalcPaddingAndOutputSize(input_shape, NCHW, filter_shape, OIHW, dilations, + strides, padding, output_shape, padding_size); } void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC @@ -91,65 +129,86 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC Padding padding, index_t *output_shape, int *padding_size) { + CalcPaddingAndOutputSize(input_shape, NHWC, filter_shape, OIHW, dilations, + strides, padding, output_shape, padding_size); +} + +void CalcOutputSize(const index_t *input_shape, + const DataFormat input_format, + const index_t *filter_shape, + const DataFormat filter_format, + const int *padding_size, + const int *dilations, + const int *strides, + const RoundType round_type, + index_t *output_shape) { MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, "Invalid dilations, must >= 1"); MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && - (dilations[1] == 1 || strides[1] == 1), + (dilations[1] == 1 || strides[1] == 1), "If dilations > 1, strides should be 1"); MACE_CHECK_NOTNULL(output_shape); MACE_CHECK_NOTNULL(padding_size); + + index_t input_height = 0, input_width = 0; + index_t kernel_height = 0, kernel_width = 0; + if (input_format == NCHW) { + input_height = input_shape[2]; + input_width = input_shape[3]; + } else if (input_format == NHWC) { + input_height = input_shape[1]; + input_width = input_shape[2]; + } else { + MACE_NOT_IMPLEMENTED; + } + if (filter_format == OIHW) { + kernel_height = filter_shape[2]; + kernel_width = filter_shape[3]; + } else if (filter_format == OHWI) { + kernel_height = filter_shape[1]; + kernel_width = filter_shape[2]; + } else { + MACE_NOT_IMPLEMENTED; + } /* * Convlution/pooling arithmetic: * o = (i + 2 * p - k - (k - 1) * (d - 1)) / s + 1 * For details, see https://arxiv.org/pdf/1603.07285.pdf or * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html */ - padding_size[0] = 0; - padding_size[1] = 0; - index_t output_height = 0, output_width = 0; index_t output_channels = filter_shape[0]; - index_t kernel_height = filter_shape[2]; - index_t kernel_width = filter_shape[3]; - - index_t k_extent_height = (kernel_height - 1) * dilations[0] + 1; - index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1; - switch (padding) { - case VALID: - output_height = (input_shape[1] - k_extent_height) / strides[0] + 1; - output_width = (input_shape[2] - k_extent_width) / strides[1] + 1; - break; - case SAME: - output_height = (input_shape[1] - 1) / strides[0] + 1; - output_width = (input_shape[2] - 1) / strides[1] + 1; - break; - case FULL: - output_height = (input_shape[1] + k_extent_height - 2) / strides[0] + 1; - output_width = (input_shape[2] + k_extent_width - 2) / strides[1] + 1; - break; - default: - MACE_CHECK(false, "Unsupported padding type: ", padding); + if (round_type == FLOOR) { + output_height = static_cast( + std::floor(1.0 * (input_height + padding_size[0] - kernel_height - + (kernel_height - 1) * (dilations[0] - 1)) / strides[0]) + 1); + output_width = static_cast( + std::floor(1.0 * (input_width + padding_size[1] - kernel_width - + (kernel_width - 1) * (dilations[1] - 1)) / strides[1]) + 1); + } else { + output_height = static_cast( + std::ceil(1.0 * (input_height + padding_size[0] - kernel_height - + (kernel_height - 1) * (dilations[0] - 1)) / strides[0]) + 1); + output_width = static_cast( + std::ceil(1.0 * (input_width + padding_size[1] - kernel_width - + (kernel_width - 1) * (dilations[1] - 1)) / strides[1]) + 1); } - // Note: TensorFlow may padded one more on the right/bottom side - // TODO(liuqi): may be it's better to also truncate the left/top to - // utilize the more centered features. We need to benchmark - // based on the model accuracy. - - padding_size[0] = std::max( - 0, (output_height - 1) * strides[0] + k_extent_height - input_shape[1]); - padding_size[1] = std::max( - 0, (output_width - 1) * strides[1] + k_extent_width - input_shape[2]); - output_shape[0] = input_shape[0]; - output_shape[1] = output_height; - output_shape[2] = output_width; - output_shape[3] = output_channels; + if (input_format == NCHW) { + output_shape[1] = output_channels; + output_shape[2] = output_height; + output_shape[3] = output_width; + } else if (input_format == NHWC) { + output_shape[1] = output_height; + output_shape[2] = output_width; + output_shape[3] = output_channels; + } else { + MACE_NOT_IMPLEMENTED; + } } - - void CalcOutputSize(const index_t *input_shape, // NHWC const index_t *filter_shape, // OIHW const int *padding_size, @@ -157,39 +216,8 @@ void CalcOutputSize(const index_t *input_shape, // NHWC const int *strides, const RoundType round_type, index_t *output_shape) { - MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, - "Invalid dilations, must >= 1"); - MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && - (dilations[1] == 1 || strides[1] == 1), - "If dilations > 1, strides should be 1"); - MACE_CHECK_NOTNULL(output_shape); - MACE_CHECK_NOTNULL(padding_size); - - output_shape[0] = input_shape[0]; - if (round_type == FLOOR) { - output_shape[1] = static_cast( - std::floor(1.0 * (input_shape[1] + padding_size[0] - filter_shape[2] - - (filter_shape[2] - 1) * (dilations[0] - 1)) / - strides[0]) + - 1); - output_shape[2] = static_cast( - std::floor(1.0 * (input_shape[2] + padding_size[1] - filter_shape[3] - - (filter_shape[3] - 1) * (dilations[1] - 1)) / - strides[1]) + - 1); - } else { - output_shape[1] = static_cast( - std::ceil(1.0 * (input_shape[1] + padding_size[0] - filter_shape[2] - - (filter_shape[2] - 1) * (dilations[0] - 1)) / - strides[0]) + - 1); - output_shape[2] = static_cast( - std::ceil(1.0 * (input_shape[2] + padding_size[1] - filter_shape[3] - - (filter_shape[3] - 1) * (dilations[1] - 1)) / - strides[1]) + - 1); - } - output_shape[3] = filter_shape[0]; + CalcOutputSize(input_shape, NHWC, filter_shape, OIHW, padding_size, dilations, + strides, round_type, output_shape); } void CalcNCHWOutputSize(const index_t *input_shape, // NCHW @@ -199,46 +227,8 @@ void CalcNCHWOutputSize(const index_t *input_shape, // NCHW const int *strides, const RoundType round_type, index_t *output_shape) { - MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, - "Invalid dilations, must >= 1"); - MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && - (dilations[1] == 1 || strides[1] == 1), - "If dilations > 1, strides should be 1"); - MACE_CHECK_NOTNULL(output_shape); - MACE_CHECK_NOTNULL(padding_size); - /* - * Convolution arithmetic: - * o = floor((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1 - * Pooling arithmetic: - * o = ceil((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1 - * For details, see https://arxiv.org/pdf/1603.07285.pdf or - * http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html - */ - output_shape[0] = input_shape[0]; - if (round_type == FLOOR) { - output_shape[2] = static_cast( - std::floor(1.0 * (input_shape[2] + padding_size[0] - filter_shape[2] - - (filter_shape[2] - 1) * (dilations[0] - 1)) / - strides[0]) + - 1); - output_shape[3] = static_cast( - std::floor(1.0 * (input_shape[3] + padding_size[1] - filter_shape[3] - - (filter_shape[3] - 1) * (dilations[1] - 1)) / - strides[1]) + - 1); - } else { - output_shape[2] = static_cast( - std::ceil(1.0 * (input_shape[2] + padding_size[0] - filter_shape[2] - - (filter_shape[2] - 1) * (dilations[0] - 1)) / - strides[0]) + - 1); - output_shape[3] = static_cast( - std::ceil(1.0 * (input_shape[3] + padding_size[1] - filter_shape[3] - - (filter_shape[3] - 1) * (dilations[1] - 1)) / - strides[1]) + - 1); - } - output_shape[1] = filter_shape[0]; + CalcOutputSize(input_shape, NCHW, filter_shape, OIHW, padding_size, dilations, + strides, round_type, output_shape); } void CalPaddingSize(const index_t *input_shape, // NCHW diff --git a/mace/kernels/conv_pool_2d_util.h b/mace/kernels/conv_pool_2d_util.h index 0f0909a3..dba90bc5 100644 --- a/mace/kernels/conv_pool_2d_util.h +++ b/mace/kernels/conv_pool_2d_util.h @@ -32,14 +32,24 @@ enum RoundType { namespace kernels { -void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, +void CalcPaddingAndOutputSize(const index_t *input_shape, + const DataFormat input_format, const index_t *filter_shape, + const DataFormat filter_format, const int *dilations, const int *strides, Padding padding, index_t *output_shape, int *padding_size); +void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, + const index_t *filter_shape, + const int *dilations, + const int *strides, + Padding padding, + index_t *output_shape, + int *padding_size); + void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, const index_t *filter_shape, const int *dilations, @@ -48,6 +58,16 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, index_t *output_shape, int *padding_size); +void CalcOutputSize(const index_t *input_shape, + const DataFormat input_format, + const index_t *filter_shape, + const DataFormat filter_format, + const int *padding_size, + const int *dilations, + const int *strides, + const RoundType round_type, + index_t *output_shape); + void CalcOutputSize(const index_t *input_shape, // NHWC const index_t *filter_shape, // OIHW const int *padding_size, diff --git a/mace/kernels/gemmlowp_util.cc b/mace/kernels/gemmlowp_util.cc new file mode 100644 index 00000000..50716d52 --- /dev/null +++ b/mace/kernels/gemmlowp_util.cc @@ -0,0 +1,58 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/kernels/gemmlowp_util.h" + +#include +#include + +#include "mace/core/runtime/cpu/cpu_runtime.h" + +namespace mace { + +gemmlowp::GemmContext& GetGemmlowpContext() { + static auto *gemm_context = new gemmlowp::GemmContext; + return *gemm_context; +} + +MaceStatus SetGemmlowpThreadPolicy(int num_threads_hint, + CPUAffinityPolicy policy) { + gemmlowp::GemmContext& gemm_context = GetGemmlowpContext(); + + if (policy != AFFINITY_NONE) { + std::vector big_core_ids; + std::vector little_core_ids; + MaceStatus res = GetCPUBigLittleCoreIDs(&big_core_ids, &little_core_ids); + if (res != MACE_SUCCESS) { + return res; + } + + int use_cpu_size; + if (policy == CPUAffinityPolicy::AFFINITY_BIG_ONLY) { + use_cpu_size = static_cast(big_core_ids.size()); + } else { + use_cpu_size = static_cast(little_core_ids.size()); + } + + if (num_threads_hint <= 0 || num_threads_hint > use_cpu_size) { + num_threads_hint = use_cpu_size; + } + } + + gemm_context.set_max_num_threads(std::max(0, num_threads_hint)); + + return MACE_SUCCESS; +} + +} // namespace mace diff --git a/mace/kernels/gemmlowp_util.h b/mace/kernels/gemmlowp_util.h new file mode 100644 index 00000000..ef6efe91 --- /dev/null +++ b/mace/kernels/gemmlowp_util.h @@ -0,0 +1,27 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_GEMMLOWP_UTIL_H_ +#define MACE_KERNELS_GEMMLOWP_UTIL_H_ + +#include +#include "public/gemmlowp.h" + +namespace mace { + +gemmlowp::GemmContext& GetGemmlowpContext(); + +} // namespace mace + +#endif // MACE_KERNELS_GEMMLOWP_UTIL_H_ diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 452b1e1e..690f4a96 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -25,7 +25,7 @@ cc_library( ], deps = [ "//mace/core", - "@gtest//:gtest", + "@gtest", ], ) diff --git a/mace/ops/conv_2d.cc b/mace/ops/conv_2d.cc index 4377afb0..516520f9 100644 --- a/mace/ops/conv_2d.cc +++ b/mace/ops/conv_2d.cc @@ -24,6 +24,12 @@ void Register_Conv2D(OperatorRegistryBase *op_registry) { .Build(), Conv2dOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + Conv2dOp); + #ifdef MACE_ENABLE_OPENCL MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D") .Device(DeviceType::GPU) diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 56feeddc..6616b8ba 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -99,6 +99,67 @@ void Conv2d(int iters, net.Sync(); } } + +template <> +void Conv2d(int iters, + int batch, + int channels, + int height, + int width, + int kernel_h, + int kernel_w, + int stride, + int dilation, + Padding padding, + int output_channels) { + mace::testing::StopTiming(); + if (dilation > 1) { + LOG(WARNING) << "uint8_t benchmarking dilation = 1 instead."; + } + + OpsTestNet net; + + // Add input data + net.AddRandomInput( + "Input", {batch, height, width, channels}); + Tensor *input = net.GetTensor("Input"); + input->SetScale(0.00705); + input->SetZeroPoint(114); + net.AddRandomInput( + "Filter", {output_channels, kernel_h, kernel_w, channels}); + Tensor *filter = net.GetTensor("Filter"); + filter->SetScale(0.0066); + filter->SetZeroPoint(113); + net.AddRandomInput("Bias", {output_channels}); + OpDefBuilder("Conv2D", "Conv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .AddIntsArg("strides", {stride, stride}) + .AddIntArg("padding", padding) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + + net.Setup(DeviceType::CPU); + + Tensor *output = net.GetTensor("Output"); + output->SetScale(0.0107); + output->SetZeroPoint(118); + + // Warm-up + for (int i = 0; i < 2; ++i) { + net.Run(); + net.Sync(); + } + mace::testing::StartTiming(); + while (iters--) { + net.Run(); + net.Sync(); + } +} + } // namespace // In common network, there are usually more than 1 layers, this is used to @@ -135,7 +196,8 @@ void Conv2d(int iters, #define MACE_BM_CONV_2D(N, C, H, W, KH, KW, S, D, P, OC) \ MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, CPU); \ MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, GPU); \ - MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, half, GPU); + MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, half, GPU); \ + MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, uint8_t, CPU); diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 022703bb..ecfdafa2 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -15,6 +15,7 @@ #include #include +#include "mace/kernels/quantize.h" #include "mace/ops/conv_2d.h" #include "mace/ops/ops_test_util.h" @@ -1069,6 +1070,163 @@ TEST_F(Conv2dOpTest, OPENCLUnalignedPad4) { TestArbitraryPadConvNxN({107, 113, 5, 7}, {4, 4}); } +namespace { + +void TestQuantSimple3x3() { + OpsTestNet net; + + // Add input data + net.AddInputFromArray( + "Filter", {1, 3, 3, 2}, + {102, 150, 123, 135, 1, 216, 137, 47, 53, 75, 145, 130, 171, 62, 255, + 122, 72, 211}, 0.0226, 127); + net.AddInputFromArray( + "Input", {1, 3, 3, 2}, + {1, 75, 117, 161, 127, 119, 94, 151, 203, 151, 84, 61, 55, 142, 113, 139, + 3, 255}, 0.0204, 93); + + net.AddInputFromArray("Bias", {1}, {2}); + OpDefBuilder("Conv2D", "Conv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .AddIntsArg("strides", {1, 1}) + .AddIntArg("padding", Padding::VALID) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DT_UINT8)) + .Finalize(net.NewOperatorDef()); + + net.Setup(DeviceType::CPU); + Tensor *output = net.GetTensor("Output"); + output->SetScale(0.000711); + output->SetZeroPoint(1); + // Run + net.Run(); + // Check + auto expected = CreateTensor({1, 1, 1, 1}, {230}); + ExpectTensorNear(*expected, *output); +} + +void TestQuant(const index_t batch, + const index_t out_channels, + const index_t in_channels, + const index_t in_height, + const index_t in_width, + const index_t k_height, + const index_t k_width, + enum Padding padding_type, + const std::vector &strides) { + OpsTestNet net; + net.AddRandomInput("Input", {batch, in_height, in_width, + in_channels}); + net.AddRandomInput("Filter", {out_channels, k_height, k_width, + in_channels}); + net.AddRandomInput("Bias", {out_channels}); + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); + net.TransformDataFormat("Filter", OHWI, "FilterOIHW", + OIHW); + + OpDefBuilder("Conv2D", "Conv2dTest") + .Input("InputNCHW") + .Input("FilterOIHW") + .Input("Bias") + .Output("OutputNCHW") + .AddIntsArg("strides", strides) + .AddIntArg("padding", padding_type) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DT_FLOAT)) + .Finalize(net.NewOperatorDef()); + net.RunOp(CPU); + net.TransformDataFormat("OutputNCHW", NCHW, + "Output", NHWC); + + OpDefBuilder("Quantize", "QuantizeFilter") + .Input("Filter") + .Output("QuantizedFilter") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("Quantize", "QuantizeInput") + .Input("Input") + .Output("QuantizedInput") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("Quantize", "QuantizeOutput") + .Input("Output") + .Output("ExpectedQuantizedOutput") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + Tensor *q_filter = net.GetTensor("QuantizedFilter"); + Tensor *q_input = net.GetTensor("QuantizedInput"); + Tensor *bias = net.GetTensor("Bias"); + auto bias_data = bias->data(); + std::vector q_bias(bias->size()); + kernels::QuantizeWithScaleAndZeropoint( + bias_data, bias->size(), q_input->scale() * q_filter->scale(), 0, + q_bias.data()); + net.AddInputFromArray("QuantizedBias", + {out_channels}, q_bias); + OpDefBuilder("Conv2D", "QuantizeConv2dTest") + .Input("QuantizedInput") + .Input("QuantizedFilter") + .Input("QuantizedBias") + .Output("QuantizedOutput") + .AddIntsArg("strides", strides) + .AddIntArg("padding", padding_type) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DT_UINT8)) + .Finalize(net.NewOperatorDef()); + net.Setup(DeviceType::CPU); + Tensor *eq_output = net.GetTensor("ExpectedQuantizedOutput"); + Tensor *q_output = net.GetTensor("QuantizedOutput"); + q_output->SetScale(eq_output->scale()); + q_output->SetZeroPoint(eq_output->zero_point()); + net.Run(); + + OpDefBuilder("Dequantize", "DeQuantizeTest") + .Input("QuantizedOutput") + .Output("DequantizedOutput") + .OutputType({DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + // Check + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("DequantizedOutput"), 0.01); +} +} // namespace + +TEST_F(Conv2dOpTest, Quant) { + TestQuantSimple3x3(); + TestQuant(1, 128, 64, 32, 32, 1, 1, VALID, {1, 1}); + TestQuant(1, 128, 64, 32, 32, 3, 3, VALID, {1, 1}); + TestQuant(1, 128, 64, 32, 32, 3, 3, SAME, {1, 1}); + TestQuant(1, 128, 64, 32, 32, 3, 3, FULL, {1, 1}); + TestQuant(1, 128, 64, 32, 32, 3, 3, SAME, {2, 2}); + TestQuant(1, 129, 63, 33, 31, 3, 3, SAME, {1, 1}); + TestQuant(9, 128, 64, 32, 32, 3, 3, SAME, {1, 1}); + TestQuant(1, 128, 64, 32, 32, 1, 5, SAME, {1, 1}); + TestQuant(1, 128, 64, 32, 32, 5, 5, SAME, {1, 1}); + TestQuant(1, 128, 64, 32, 32, 5, 1, SAME, {1, 1}); + TestQuant(1, 128, 64, 32, 32, 7, 7, SAME, {1, 1}); + TestQuant(1, 128, 64, 32, 32, 7, 7, SAME, {2, 2}); + TestQuant(1, 128, 64, 32, 32, 7, 7, SAME, {3, 3}); +} + } // namespace test } // namespace ops } // namespace mace diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index e8da144d..cdba4ee9 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -116,7 +116,9 @@ class OpsTestNet { template void AddInputFromArray(const std::string &name, const std::vector &shape, - const std::vector &data) { + const std::vector &data, + const float scale = 0.0, + const int32_t zero_point = 0) { Tensor *input = ws_.CreateTensor(name, GetDeviceAllocator(D), DataTypeToEnum::v()); input->Resize(shape); @@ -124,6 +126,8 @@ class OpsTestNet { T *input_data = input->mutable_data(); MACE_CHECK(static_cast(input->size()) == data.size()); memcpy(input_data, data.data(), data.size() * sizeof(T)); + input->SetScale(scale); + input->SetZeroPoint(zero_point); } template @@ -303,6 +307,26 @@ class OpsTestNet { } } } + } else if (src_format == OHWI && dst_format == OIHW) { + index_t out_channels = input_shape[0]; + index_t height = input_shape[1]; + index_t width = input_shape[2]; + index_t in_channels = input_shape[3]; + output->Resize({out_channels, in_channels, height, width}); + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + const T *input_data = input->data(); + T *output_data = output->mutable_data(); + for (index_t b = 0; b < out_channels; ++b) { + for (index_t c = 0; c < in_channels; ++c) { + for (index_t h = 0; h < height; ++h) { + for (index_t w = 0; w < width; ++w) { + output_data[((b * in_channels + c) * height + h) * width + w] = + input_data[((b * height + h) * width + w) * in_channels + c]; + } + } + } + } } else { MACE_NOT_IMPLEMENTED; } @@ -632,6 +656,25 @@ void ExpectTensorNear(const Tensor &x, Expector::Near(x, y, rel_err, abs_err); } +template +void ExpectTensorSimilar(const Tensor &x, + const Tensor &y, + const double abs_err = 1e-5) { + AssertSameDims(x, y); + Tensor::MappingGuard x_mapper(&x); + Tensor::MappingGuard y_mapper(&y); + auto x_data = x.data(); + auto y_data = y.data(); + double dot_product = 0.0, x_norm = 0.0, y_norm = 0.0; + for (index_t i = 0; i < x.size(); i++) { + dot_product += x_data[i] * y_data[i]; + x_norm += x_data[i] * x_data[i]; + y_norm += y_data[i] * y_data[i]; + } + double similarity = dot_product / (sqrt(x_norm) * sqrt(y_norm)); + EXPECT_NEAR(1.0, similarity, abs_err); +} + template void BufferToImage(OpsTestNet *net, const std::string &input_name, diff --git a/mace/public/mace_runtime.h b/mace/public/mace_runtime.h index 12d3d974..c1e158ab 100644 --- a/mace/public/mace_runtime.h +++ b/mace/public/mace_runtime.h @@ -180,6 +180,16 @@ __attribute__((visibility("default"))) MaceStatus GetBigLittleCoreIDs(std::vector *big_core_ids, std::vector *little_core_ids); +// Set gemmlowp threads number and processor affinity. +// gemmlowp is used by mace for quantization. +// Caution: this function may hurt performance if improper parameters provided. +// +// This function may not work well on some chips (e.g. MTK). Setting thread +// affinity to offline cores may run very slow or unexpectedly. In such cases, +// please use SetGemmlowpThreadPolicy with default policy instead. +__attribute__((visibility("default"))) +MaceStatus SetGemmlowpThreadPolicy(int num_threads_hint, + CPUAffinityPolicy policy); } // namespace mace #endif // MACE_PUBLIC_MACE_RUNTIME_H_ diff --git a/mace/utils/logging.h b/mace/utils/logging.h index 8f8fe87c..c6fc6b51 100644 --- a/mace/utils/logging.h +++ b/mace/utils/logging.h @@ -91,7 +91,7 @@ T &&CheckNotNull(const char *file, int line, const char *exprtext, T &&t) { #define MACE_CHECK_NOTNULL(val) \ ::mace::logging::CheckNotNull(__FILE__, __LINE__, \ - "'" #val "' Must be non NULL", (val)) + "'" #val "' Must not be NULL", (val)) #define MACE_NOT_IMPLEMENTED MACE_CHECK(false, "not implemented") diff --git a/tools/bazel.rc b/tools/bazel.rc index d5621063..22357b79 100644 --- a/tools/bazel.rc +++ b/tools/bazel.rc @@ -9,6 +9,7 @@ build --verbose_failures build --copt=-std=c++11 build --copt=-D_GLIBCXX_USE_C99_MATH_TR1 build --copt=-DMACE_OBFUSCATE_LITERALS +build --copt=-DGEMMLOWP_USE_OPENMP # Usage example: bazel build --config android build:android --crosstool_top=//external:android/crosstool -- GitLab