diff --git a/mace/core/common.h b/mace/core/common.h index 8eaf062f2cd5a4fb912444623056349620a240b0..b5f819a3ad5f3d0b4a46ac10be719e95763f9e90 100644 --- a/mace/core/common.h +++ b/mace/core/common.h @@ -32,6 +32,4 @@ typedef int64_t index_t; #define MACE_NOT_IMPLEMENTED MACE_CHECK(false, "not implemented") -#define kCostPerGroup 10240 - #endif // MACE_CORE_COMMON_H_ diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 5f1e0dc16d6f9724245aa203b720d293fb1f7266..c35115645aa2e252cef8fbfc16cdd3cd82dfb7d5 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -325,6 +325,12 @@ class Tensor { } } } + MappingGuard(MappingGuard &&other) { + tensor_ = other.tensor_; + other.tensor_ = nullptr; + } + MappingGuard(const MappingGuard &other) = delete; + MappingGuard & operator = (const MappingGuard &other) = delete; ~MappingGuard() { if (tensor_ != nullptr) tensor_->Unmap(); } diff --git a/mace/kernels/BUILD b/mace/kernels/BUILD index 5c2fe3016cde48b8e451702cc5f37141e4b3dc2a..54ed3fcd3f73d0a6cfab668a45dead37b88f09e8 100644 --- a/mace/kernels/BUILD +++ b/mace/kernels/BUILD @@ -15,7 +15,6 @@ cc_library( "*.cc", "opencl/*.cc", ]) + if_neon_enabled(glob([ - "neon/addn_neon.cc", "neon/batch_norm_neon.cc", ])), hdrs = glob([ diff --git a/mace/kernels/activation.h b/mace/kernels/activation.h index 745f174466fefee3683746213247d1e549306309..0a856fc945b3ce2eebc7eaa8e9d0f6fc4985629a 100644 --- a/mace/kernels/activation.h +++ b/mace/kernels/activation.h @@ -54,17 +54,20 @@ void DoActivation(const T *input_ptr, case NOOP: break; case RELU: +#pragma omp parallel for for (index_t i = 0; i < size; ++i) { output_ptr[i] = std::max(input_ptr[i], static_cast(0)); } break; case RELUX: +#pragma omp parallel for for (index_t i = 0; i < size; ++i) { output_ptr[i] = std::min(std::max(input_ptr[i], static_cast(0)), static_cast(relux_max_limit)); } break; case PRELU: +#pragma omp parallel for for (index_t i = 0; i < size; ++i) { T in = input_ptr[i]; if (in < 0) { @@ -75,12 +78,14 @@ void DoActivation(const T *input_ptr, } break; case TANH: +#pragma omp parallel for for (index_t i = 0; i < size; ++i) { T in_exp = std::exp(-2 * input_ptr[i]); output_ptr[i] = (1 - in_exp) / (1 + in_exp); } break; case SIGMOID: +#pragma omp parallel for for (index_t i = 0; i < size; ++i) { output_ptr[i] = 1 / (1 + std::exp(-input_ptr[i])); } diff --git a/mace/kernels/addn.h b/mace/kernels/addn.h index fd28517795b01190c7866a14357c9b863cf7c872..e772d880b6210737167ff3ea48e6aa767986368d 100644 --- a/mace/kernels/addn.h +++ b/mace/kernels/addn.h @@ -8,29 +8,71 @@ #if defined(MACE_ENABLE_NEON) && defined(__aarch64__) #include #endif +#include #include "mace/core/future.h" -#include "mace/core/tensor.h" #include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/tensor.h" namespace mace { namespace kernels { +namespace { + constexpr int kCostPerGroup = 1024; +} // namespace + template struct AddNFunctor { void operator()(const std::vector &input_tensors, - Tensor *output_tensor, StatsFuture *future) { + Tensor *output_tensor, + StatsFuture *future) { output_tensor->ResizeLike(input_tensors[0]); + index_t size = output_tensor->size(); Tensor::MappingGuard output_map(output_tensor); - index_t size = input_tensors[0]->size(); - T *output_ptr = output_tensor->mutable_data(); - memset(output_ptr, 0, size * sizeof(T)); + float *output_data = output_tensor->mutable_data(); + memset(output_data, 0, size * sizeof(float)); int n = input_tensors.size(); - for (int i = 0; i < n; ++i) { - Tensor::MappingGuard input_map(input_tensors[i]); - const T *input_ptr = input_tensors[i]->data(); - for (index_t j = 0; j < size; ++j) { - output_ptr[j] += input_ptr[j]; + int64_t cost = size * n; + int64_t groups = 1; + if (cost > kCostPerGroup) { + groups = cost / kCostPerGroup; + } + int64_t element_per_group = size / groups; + + std::vector mappers; + for (int64_t i = 0; i < n; ++i) { + mappers.emplace_back(Tensor::MappingGuard(input_tensors[i])); + } + +#pragma omp parallel for + for (int64_t i = 0; i < size; i += element_per_group) { + int64_t count = std::min(element_per_group, size - i); + int nn = count >> 2; + int remain = count - (nn << 2); + for (int64_t j = 0; j < n; ++j) { + const float *input_data = input_tensors[j]->data(); + const float *input_ptr = input_data + i; + float *output_ptr = output_data + i; + for (int k = 0; k < nn; ++k) { +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) + float32x4_t in = vld1q_f32(input_ptr); + float32x4_t out = vld1q_f32(output_ptr); + out = vaddq_f32(out, in); + vst1q_f32(output_ptr, out); +#else + for (int m = 0; m < 4; ++m) { + output_ptr[m] += input_ptr[m]; + } +#endif + + input_ptr += 4; + output_ptr += 4; + } + for (int k = 0; k < remain; ++k) { + *output_ptr += *input_ptr; + ++input_ptr; + ++output_ptr; + } } } } @@ -45,7 +87,8 @@ void AddNFunctor::operator()( template struct AddNFunctor { void operator()(const std::vector &input_tensors, - Tensor *output_tensor, StatsFuture *future); + Tensor *output_tensor, + StatsFuture *future); cl::Kernel kernel_; }; diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index 1bb85d1675bc0f6284c19ab7a017ca14716c4dd6..107b3242bc29a5f34b133c7412c6c20d2e9a1134 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -70,8 +70,8 @@ struct BatchNormFunctor : BatchNormFunctorBase { const T *offset_ptr = offset->data(); T *output_ptr = output->mutable_data(); - vector new_scale; - vector new_offset; + std::vector new_scale; + std::vector new_offset; if (!folded_constant_) { new_scale.resize(channels); new_offset.resize(channels); @@ -86,6 +86,8 @@ struct BatchNormFunctor : BatchNormFunctorBase { } } + const T *scale_data = folded_constant_ ? scale_ptr : new_scale.data(); + const T *offset_data = folded_constant_ ? offset_ptr : new_offset.data(); #pragma omp parallel for collapse(4) for (index_t n = 0; n < batch; ++n) { @@ -93,11 +95,7 @@ struct BatchNormFunctor : BatchNormFunctorBase { for (index_t w = 0; w < width; ++w) { for (index_t c = 0; c < channels; ++c) { index_t pos = (((n * height) + h) * width + w) * channels + c; - if (folded_constant_) { - output_ptr[pos] = scale_ptr[c] * input_ptr[pos] + offset_ptr[c]; - } else { - output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c]; - } + output_ptr[pos] = scale_data[c] * input_ptr[pos] + offset_data[c]; } } } diff --git a/mace/kernels/depthwise_conv2d.h b/mace/kernels/depthwise_conv2d.h index 803cb34bcbd68a0948c2619a9a41f71dbeac885a..d87999f1a1b7defc0caa6280580a7c5704f19587 100644 --- a/mace/kernels/depthwise_conv2d.h +++ b/mace/kernels/depthwise_conv2d.h @@ -5,15 +5,237 @@ #ifndef MACE_KERNELS_DEPTHWISE_CONV2D_H_ #define MACE_KERNELS_DEPTHWISE_CONV2D_H_ +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) +#include +#endif + #include "mace/core/common.h" #include "mace/core/future.h" #include "mace/core/public/mace.h" -#include "mace/kernels/conv_pool_2d_util.h" #include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/kernels/conv_pool_2d_util.h" namespace mace { namespace kernels { +namespace { + +template +void DepthwiseConv2dKernel(const T *input_ptr, + const T *filter_ptr, + const T *bias_ptr, + T *output_ptr, + int batch, + int height, + int width, + int channels, + int input_height, + int input_width, + int input_channels, + int multiplier, + int padded_h_start, + int padded_h_stop, + int padded_w_start, + int padded_w_stop, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int h_start, + int h_stop, + int w_start, + int w_stop) { +#pragma omp parallel for collapse(4) + for (int n = 0; n < batch; ++n) { + for (int h = h_start; h < h_stop; ++h) { + for (int w = w_start; w < w_stop; ++w) { + for (int c = 0; c < channels; ++c) { + const index_t inc = c / multiplier; + const index_t m = c % multiplier; + T bias_channel = bias_ptr ? bias_ptr[c] : 0; + index_t offset = n * height * width * channels + + h * width * channels + w * channels + c; + output_ptr[offset] = bias_channel; + T sum = 0; + const T *filter_base = filter_ptr + inc * multiplier + m; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int inh = padded_h_start + h * stride_h + dilation_h * kh; + int inw = padded_w_start + w * stride_w + dilation_w * kw; + if (inh < 0 || inh >= input_height || inw < 0 || + inw >= input_width) { + MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop && + inw >= padded_w_start && inw < padded_w_stop, + "Out of range read from input: ", padded_h_start, + " <= ", inh, " < ", padded_h_stop, ", ", + padded_w_start, " <= ", inw, " < ", padded_w_stop); + } else { + index_t input_offset = + n * input_height * input_width * input_channels + + inh * input_width * input_channels + inw * input_channels + + inc; + sum += input_ptr[input_offset] * filter_base[0]; // HWIM + } + filter_base += input_channels * multiplier; + } + } + output_ptr[offset] += sum; + } + } + } + } +} +template +void DepthwiseConv2dNoOOBCheckKernel(const T *input_ptr, + const T *filter_ptr, + const T *bias_ptr, + T *output_ptr, + int batch, + int height, + int width, + int channels, + int input_height, + int input_width, + int input_channels, + int multiplier, + int padded_h_start, + int padded_h_stop, + int padded_w_start, + int padded_w_stop, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int h_start, + int h_stop, + int w_start, + int w_stop) { + if (multiplier == 1) { + constexpr int c_tile_size = 4; + +#pragma omp parallel for collapse(3) + for (int n = 0; n < batch; ++n) { + for (int h = h_start; h < h_stop; ++h) { + for (int w = w_start; w < w_stop; ++w) { + int c; + for (c = 0; c + c_tile_size <= channels; c += c_tile_size) { +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) + static_assert(c_tile_size == 4, "channels tile size must be 4"); + float32x4_t sum = vdupq_n_f32(0); + if (bias_ptr != nullptr) { + sum = vld1q_f32(bias_ptr + c); + } +#else + T sum[c_tile_size] = {0}; + if (bias_ptr != nullptr) { + for (int ci = 0; ci < c_tile_size; ++ci) { + sum[ci] = bias_ptr[c + ci]; + } + } +#endif + const T *filter_base = filter_ptr + c; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int inh = padded_h_start + h * stride_h + dilation_h * kh; + int inw = padded_w_start + w * stride_w + dilation_w * kw; + MACE_ASSERT(inh >= 0 && inh < input_height && inw >= 0 && + inw < input_width); + index_t input_offset = + n * input_height * input_width * input_channels + + inh * input_width * input_channels + inw * input_channels + + c; +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) + float32x4_t in = vld1q_f32(input_ptr + input_offset); + float32x4_t weights = vld1q_f32(filter_base); + sum = vfmaq_f32(sum, in, weights); +#else + for (int ci = 0; ci < c_tile_size; ++ci) { + sum[ci] += + input_ptr[input_offset + ci] * filter_base[ci]; // HWIM + } +#endif + filter_base += input_channels; + } + } + + index_t offset = n * height * width * channels + + h * width * channels + w * channels + c; +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) + vst1q_f32(output_ptr + offset, sum); +#else + for (int ci = 0; ci < c_tile_size; ++ci) { + output_ptr[offset + ci] = sum[ci]; + } +#endif + } + for (; c < channels; ++c) { + T bias_channel = bias_ptr ? bias_ptr[c] : 0; + index_t offset = n * height * width * channels + + h * width * channels + w * channels + c; + output_ptr[offset] = bias_channel; + T sum = 0; + const T *filter_base = filter_ptr + c; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int inh = padded_h_start + h * stride_h + dilation_h * kh; + int inw = padded_w_start + w * stride_w + dilation_w * kw; + MACE_ASSERT(inh >= 0 && inh < input_height && inw >= 0 && + inw < input_width); + index_t input_offset = + n * input_height * input_width * input_channels + + inh * input_width * input_channels + inw * input_channels + + c; + sum += input_ptr[input_offset] * filter_base[0]; // HWIM + filter_base += input_channels * multiplier; + } + } + output_ptr[offset] += sum; + } + } + } + } + } else { +#pragma omp parallel for collapse(4) + for (int n = 0; n < batch; ++n) { + for (int h = h_start; h < h_stop; ++h) { + for (int w = w_start; w < w_stop; ++w) { + for (int c = 0; c < channels; ++c) { + const index_t inc = c / multiplier; + const index_t m = c % multiplier; + T bias_channel = bias_ptr ? bias_ptr[c] : 0; + index_t offset = n * height * width * channels + + h * width * channels + w * channels + c; + output_ptr[offset] = bias_channel; + T sum = 0; + const T *filter_base = filter_ptr + inc * multiplier + m; + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int inh = padded_h_start + h * stride_h + dilation_h * kh; + int inw = padded_w_start + w * stride_w + dilation_w * kw; + MACE_ASSERT(inh >= 0 && inh < input_height && inw >= 0 && + inw < input_width); + index_t input_offset = + n * input_height * input_width * input_channels + + inh * input_width * input_channels + inw * input_channels + + inc; + sum += input_ptr[input_offset] * filter_base[0]; // HWIM + filter_base += input_channels * multiplier; + } + } + output_ptr[offset] += sum; + } + } + } + } + } +} + +} // namespace + struct DepthwiseConv2dFunctorBase { DepthwiseConv2dFunctorBase(const int *strides, const Padding padding, @@ -28,7 +250,7 @@ struct DepthwiseConv2dFunctorBase { relux_max_limit_(relux_max_limit), prelu_alpha_(prelu_alpha) {} - const int *strides_; // [stride_h, stride_w] + const int *strides_; // [stride_h, stride_w] const Padding padding_; const int *dilations_; // [dilation_h, dilation_w] const ActivationType activation_; @@ -88,7 +310,8 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase { index_t kernel_h = filter->dim(0); index_t kernel_w = filter->dim(1); index_t multiplier = filter->dim(3); - MACE_CHECK(filter->dim(2) == input_channels, filter->dim(2), "!=", input_channels); + MACE_CHECK(filter->dim(2) == input_channels, filter->dim(2), "!=", + input_channels); MACE_CHECK(channels == input_channels * multiplier); int stride_h = strides_[0]; @@ -100,10 +323,15 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase { MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); // The left-upper most offset of the padded input - int padded_h_start = 0 - paddings[0] / 2; - int padded_w_start = 0 - paddings[1] / 2; - index_t padded_h_stop = input_height + paddings[0] - paddings[0] / 2; - index_t padded_w_stop = input_width + paddings[1] - paddings[1] / 2; + int paddings_top = paddings[0] / 2; + int paddings_bottom = paddings[0] - paddings_top; + int paddings_left = paddings[1] / 2; + int paddings_right = paddings[1] - paddings_left; + + int padded_h_start = 0 - paddings_top; + int padded_w_start = 0 - paddings_left; + index_t padded_h_stop = input_height + paddings_bottom; + index_t padded_w_stop = input_width + paddings_right; Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard filter_mapper(filter); @@ -114,43 +342,59 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase { const T *bias_ptr = bias == nullptr ? nullptr : bias->data(); T *output_ptr = output->mutable_data(); -#pragma omp parallel for collapse(4) - for (int n = 0; n < batch; ++n) { - for (int h = 0; h < height; ++h) { - for (int w = 0; w < width; ++w) { - for (int c = 0; c < channels; ++c) { - const index_t inc = c / multiplier; - const index_t m = c % multiplier; - T bias_channel = bias_ptr ? bias_ptr[c] : 0; - index_t offset = n * height * width * channels + - h * width * channels + w * channels + c; - output_ptr[offset] = bias_channel; - T sum = 0; - const T *filter_base = filter_ptr + inc * multiplier + m; - for (int kh = 0; kh < kernel_h; ++kh) { - for (int kw = 0; kw < kernel_w; ++kw) { - int inh = padded_h_start + h * stride_h + dilation_h * kh; - int inw = padded_w_start + w * stride_w + dilation_w * kw; - if (inh < 0 || inh >= input_height || inw < 0 || - inw >= input_width) { - MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop && - inw >= padded_w_start && inw < padded_w_stop, - "Out of range read from input: ", inh, ", ", inw); - } else { - index_t input_offset = - n * input_height * input_width * input_channels + - inh * input_width * input_channels + - inw * input_channels + inc; - sum += input_ptr[input_offset] * filter_base[0]; // HWIM - } - filter_base += input_channels * multiplier; - } - } - output_ptr[offset] += sum; - } - } - } + int valid_h_start = + paddings_top == 0 ? 0 : (paddings_top - 1) / stride_h + 1; + int valid_h_stop = paddings_bottom == 0 + ? height + : height - ((paddings_bottom - 1) / stride_h + 1); + int valid_w_start = + paddings_left == 0 ? 0 : (paddings_left - 1) / stride_w + 1; + int valid_w_stop = paddings_right == 0 + ? width + : width - ((paddings_right - 1) / stride_w + 1); + + // Calculate border elements with out-of-boundary checking + if (valid_h_start > 0) { + DepthwiseConv2dKernel( + input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width, + channels, input_height, input_width, input_channels, multiplier, + padded_h_start, padded_h_stop, padded_w_start, padded_w_stop, + kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, 0, + valid_h_start, 0, width); + } + if (valid_h_stop < height) { + DepthwiseConv2dKernel( + input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width, + channels, input_height, input_width, input_channels, multiplier, + padded_h_start, padded_h_stop, padded_w_start, padded_w_stop, + kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, + std::max(valid_h_start, valid_h_stop), height, 0, width); + } + if (valid_w_start > 0) { + DepthwiseConv2dKernel( + input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width, + channels, input_height, input_width, input_channels, multiplier, + padded_h_start, padded_h_stop, padded_w_start, padded_w_stop, + kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, + valid_h_start, valid_h_stop, 0, valid_w_start); } + if (valid_w_stop < width) { + DepthwiseConv2dKernel( + input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width, + channels, input_height, input_width, input_channels, multiplier, + padded_h_start, padded_h_stop, padded_w_start, padded_w_stop, + kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, + valid_h_start, valid_h_stop, std::max(valid_w_start, valid_w_stop), + width); + } + + // Calculate border elements without out-of-boundary checking + DepthwiseConv2dNoOOBCheckKernel( + input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width, + channels, input_height, input_width, input_channels, multiplier, + padded_h_start, padded_h_stop, padded_w_start, padded_w_stop, kernel_h, + kernel_w, stride_h, stride_w, dilation_h, dilation_w, valid_h_start, + valid_h_stop, valid_w_start, valid_w_stop); output_ptr = output->mutable_data(); DoActivation(output_ptr, output_ptr, output->NumElements(), activation_, @@ -180,7 +424,7 @@ struct DepthwiseConv2dFunctor dilations, activation, relux_max_limit, - prelu_alpha){} + prelu_alpha) {} void operator()(const Tensor *input, const Tensor *filter, diff --git a/mace/kernels/neon/addn_neon.cc b/mace/kernels/neon/addn_neon.cc deleted file mode 100644 index 18f26af03acab04bfcb979fbccca8945990a5a41..0000000000000000000000000000000000000000 --- a/mace/kernels/neon/addn_neon.cc +++ /dev/null @@ -1,57 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/kernels/addn.h" -#include - -namespace mace { -namespace kernels { - -template <> -void AddNFunctor::operator()( - const std::vector &input_tensors, - Tensor *output_tensor, - StatsFuture *future) { - // TODO: neon mem copy - output_tensor->ResizeLike(input_tensors[0]); - index_t size = output_tensor->size(); - float *output_ptr = output_tensor->mutable_data(); - memset(output_ptr, 0, size * sizeof(float)); - int n = input_tensors.size(); - int64_t cost = size * n; - int64_t groups = 1; - if (cost > kCostPerGroup) { - groups = cost / kCostPerGroup; - } - int64_t element_per_group = size / groups; - -#pragma omp parallel for - for (int64_t i = 0; i < size; i += element_per_group) { - int64_t count = std::min(element_per_group, size - i); - int nn = count >> 2; - int remain = count - (nn << 2); - for (int64_t j = 0; j < n; ++j) { - const float *input_base = input_tensors[j]->data(); - const float *inptr = input_base + i; - float *outptr = output_ptr + i; - for (int k = 0; k < nn; ++k) { - float32x4_t _inptr = vld1q_f32(inptr); - float32x4_t _outptr = vld1q_f32(outptr); - _outptr = vaddq_f32(_outptr, _inptr); - vst1q_f32(outptr, _outptr); - - inptr += 4; - outptr += 4; - } - for (int k = 0; k < remain; ++k) { - *outptr += *inptr; - ++inptr; - ++outptr; - } - } - } -}; - -} // namespace kernels -} // namespace mace diff --git a/mace/kernels/resize_bilinear.h b/mace/kernels/resize_bilinear.h index 43e6a2df6140e5a2700bbe8dea528e72411724e0..1762cb3bc0a49b82abb0b2166b610dab54893bb4 100644 --- a/mace/kernels/resize_bilinear.h +++ b/mace/kernels/resize_bilinear.h @@ -68,12 +68,11 @@ void ResizeImage(const T *images, const index_t out_batch_num_values = channels * out_height * out_width; const CachedInterpolation *xs = xs_vec.data(); -#pragma omp parallel for +#pragma omp parallel for collapse(2) for (index_t b = 0; b < batch_size; ++b) { - const T *batch_input_ptr = images + in_batch_num_values * b;; - T *batch_output_ptr = output + out_batch_num_values * b; - for (index_t y = 0; y < out_height; ++y) { + const T *batch_input_ptr = images + in_batch_num_values * b; + T *batch_output_ptr = output + out_batch_num_values * b; const T *y_lower_input_ptr = batch_input_ptr + ys[y].lower * in_width * channels; const T *y_upper_input_ptr = diff --git a/mace/ops/addn.cc b/mace/ops/addn.cc index d9b514d4e3043c598e01d02aa3612c7ecac73abf..c0fd26715cfcc8b6401a45882f5e356b747594c9 100644 --- a/mace/ops/addn.cc +++ b/mace/ops/addn.cc @@ -13,14 +13,6 @@ void Register_AddN(OperatorRegistry *op_registry) { .Build(), AddNOp); -#if MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("AddN") - .Device(DeviceType::NEON) - .TypeConstraint("T") - .Build(), - AddNOp); -#endif // MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("AddN") .Device(DeviceType::OPENCL) .TypeConstraint("T") diff --git a/mace/ops/addn_benchmark.cc b/mace/ops/addn_benchmark.cc index a559ed07caa09c4ed7659022b0da905f14c8ece9..bd56e676d0764478850067455519a9164385bdcd 100644 --- a/mace/ops/addn_benchmark.cc +++ b/mace/ops/addn_benchmark.cc @@ -67,7 +67,6 @@ static void AddNBenchmark(int iters, int inputs, int n, int h, int w, int c) { #define BM_ADDN(INPUTS, N, H, W, C) \ BM_ADDN_MACRO(INPUTS, N, H, W, C, float, CPU); \ - BM_ADDN_MACRO(INPUTS, N, H, W, C, float, NEON); \ BM_ADDN_MACRO(INPUTS, N, H, W, C, float, OPENCL); \ BM_ADDN_MACRO(INPUTS, N, H, W, C, half, OPENCL); diff --git a/mace/ops/addn_test.cc b/mace/ops/addn_test.cc index cdb970be35af7b564f329d6716a5643698bc37f9..84e6811bef0ebfa5320d3012477b0e269a9a515b 100644 --- a/mace/ops/addn_test.cc +++ b/mace/ops/addn_test.cc @@ -33,8 +33,6 @@ void SimpleAdd2() { TEST_F(AddnOpTest, CPUSimpleAdd2) { SimpleAdd2(); } -TEST_F(AddnOpTest, NEONSimpleAdd2) { SimpleAdd2(); } - template void SimpleAdd3() { // Construct graph @@ -61,8 +59,6 @@ void SimpleAdd3() { TEST_F(AddnOpTest, CPUSimpleAdd3) { SimpleAdd3(); } -TEST_F(AddnOpTest, NEONSimpleAdd3) { SimpleAdd3(); } - template void RandomTest() { testing::internal::LogToStderr(); diff --git a/mace/ops/batch_norm_benchmark.cc b/mace/ops/batch_norm_benchmark.cc index 900ce27372b37569324eecd45705d91c96c7369e..ab2fa610adf05389cab58753d6bd77b40c339846 100644 --- a/mace/ops/batch_norm_benchmark.cc +++ b/mace/ops/batch_norm_benchmark.cc @@ -84,7 +84,7 @@ static void BatchNorm( #define BM_BATCH_NORM(N, C, H, W) \ BM_BATCH_NORM_MACRO(N, C, H, W, float, CPU); \ - BM_BATCH_NORM_MACRO(N, C, H, W, float, NEON); \ + BM_BATCH_NORM_MACRO(N, C, H, W, float, NEON); \ BM_BATCH_NORM_MACRO(N, C, H, W, float, OPENCL); \ BM_BATCH_NORM_MACRO(N, C, H, W, half, OPENCL); diff --git a/mace/ops/batch_norm_test.cc b/mace/ops/batch_norm_test.cc index db88f130ed4ae1bc267651d5c152a55d3d63fc47..5c2f703841e0ce084ff8120d3b8002f2d8b4d407 100644 --- a/mace/ops/batch_norm_test.cc +++ b/mace/ops/batch_norm_test.cc @@ -72,92 +72,8 @@ void Simple() { TEST_F(BatchNormOpTest, SimpleCPU) { Simple(); } -TEST_F(BatchNormOpTest, SimpleNEON) { Simple(); } - TEST_F(BatchNormOpTest, SimpleOPENCL) { Simple(); } -TEST_F(BatchNormOpTest, SimpleRandomNeon) { - srand(time(NULL)); - - // generate random input - index_t batch = 1 + rand() % 10; - index_t height = 64; - index_t width = 64; - index_t channels = 3 + rand() % 50; - // Construct graph - OpsTestNet net; - OpDefBuilder("BatchNorm", "BatchNormTest") - .Input("Input") - .Input("Scale") - .Input("Offset") - .Input("Mean") - .Input("Var") - .AddFloatArg("epsilon", 1e-3) - .Output("Output") - .Finalize(net.NewOperatorDef()); - - // Add input data - net.AddRandomInput("Input", - {batch, height, width, channels}); - net.AddRandomInput("Scale", {channels}); - net.AddRandomInput("Offset", {channels}); - net.AddRandomInput("Mean", {channels}); - net.AddRandomInput("Var", {channels}, true); - - // run cpu - net.RunOp(); - - // Check - Tensor expected; - expected.Copy(*net.GetOutput("Output")); - - // Run NEON - net.RunOp(DeviceType::NEON); - - ExpectTensorNear(expected, *net.GetOutput("Output"), 1e-2); -} - -TEST_F(BatchNormOpTest, ComplexRandomNeon) { - srand(time(NULL)); - - // generate random input - index_t batch = 1 + rand() % 10; - index_t channels = 3 + rand() % 50; - index_t height = 103; - index_t width = 113; - // Construct graph - OpsTestNet net; - OpDefBuilder("BatchNorm", "BatchNormTest") - .Input("Input") - .Input("Scale") - .Input("Offset") - .Input("Mean") - .Input("Var") - .AddFloatArg("epsilon", 1e-3) - .Output("Output") - .Finalize(net.NewOperatorDef()); - - // Add input data - net.AddRandomInput("Input", - {batch, height, width, channels}); - net.AddRandomInput("Scale", {channels}); - net.AddRandomInput("Offset", {channels}); - net.AddRandomInput("Mean", {channels}); - net.AddRandomInput("Var", {channels}, true); - - // run cpu - net.RunOp(); - - // Check - Tensor expected; - expected.Copy(*net.GetOutput("Output")); - - // Run NEON - net.RunOp(DeviceType::NEON); - - ExpectTensorNear(expected, *net.GetOutput("Output"), 1e-2); -} - TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { srand(time(NULL)); diff --git a/mace/ops/depthwise_conv_2d_benchmark.cc b/mace/ops/depthwise_conv2d_benchmark.cc similarity index 71% rename from mace/ops/depthwise_conv_2d_benchmark.cc rename to mace/ops/depthwise_conv2d_benchmark.cc index 561c5af030697b8f4641bfb71fa0f8f4753613e2..2f58343ada3665017e197fa171e323348504a706 100644 --- a/mace/ops/depthwise_conv_2d_benchmark.cc +++ b/mace/ops/depthwise_conv2d_benchmark.cc @@ -75,36 +75,38 @@ static void DepthwiseConv2d(int iters, } } -#define BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, \ - DEVICE) \ - static void \ - BM_DEPTHWISE_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \ - int iters) { \ - const int64_t tot = static_cast(iters) * N * C * H * W; \ - mace::testing::ItemsProcessed(tot); \ - mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - DepthwiseConv2d(iters, N, C, H, W, KH, KW, STRIDE, \ - mace::Padding::P, OC); \ - } \ - BENCHMARK( \ - BM_DEPTHWISE_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE) +#define BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, \ + DEVICE) \ + static void \ + BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::ItemsProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + DepthwiseConv2d(iters, N, C, H, W, KH, KW, STRIDE, \ + mace::Padding::P, OC); \ + } \ + BENCHMARK( \ + BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE) #define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, OC) \ BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, float, CPU); \ BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, float, OPENCL); \ BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, half, OPENCL); +BM_DEPTHWISE_CONV_2D(1, 32, 112, 112, 3, 3, 1, SAME, 1); +BM_DEPTHWISE_CONV_2D(1, 32, 112, 112, 3, 3, 2, SAME, 1); BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 1); -//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 1); +BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 1); BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 1); -//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 1); -//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 1); -//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, SAME, 1); -//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 1); -//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 1); -//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 1); -//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 1); -//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 1); -//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 1); +BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 1); +BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 1); +BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, SAME, 1); +BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 1); +BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 1); +BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 1); +BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 1); +BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 1); +BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 1); -} // namespace mace +} // namespace mace diff --git a/mace/ops/depthwise_conv2d_test.cc b/mace/ops/depthwise_conv2d_test.cc index 1df30b01a3cb6ec7da8e2977e3036ea6a9c5366a..c5ff2713d73795421e159c2ad9c7f20e9869d8dc 100644 --- a/mace/ops/depthwise_conv2d_test.cc +++ b/mace/ops/depthwise_conv2d_test.cc @@ -280,7 +280,7 @@ void TestNxNS12(const index_t height, const index_t width) { ExpectTensorNear(expected, *net.GetOutput("DeviceOutput"), 0.1); }; - for (int kernel_size : {3}) { + for (int kernel_size : {2, 3, 4}) { for (int stride : {1, 2}) { func(kernel_size, kernel_size, stride, stride, VALID); func(kernel_size, kernel_size, stride, stride, SAME); diff --git a/tools/export_lib.sh b/tools/export_lib.sh index 6e84ed7d4dd4212dee1bd2105d938b900cfa11c0..eb620b10d59bb96595aeeb00362b38a4c826fcb5 100755 --- a/tools/export_lib.sh +++ b/tools/export_lib.sh @@ -3,8 +3,8 @@ set -e Usage() { - echo "Usage: ./tools/export_lib.sh android_abi[armeabi-v7a/arm64-v8a] runtime[gpu/dsp] export_include_dir export_lib_dir" - echo "eg: ./tools/export_lib.sh armeabi-v7a ../include ../lib/libmace_v7" + echo "Usage: ./tools/export_lib.sh target_abi[armeabi-v7a | arm64-v8a | host] runtime[gpu | dsp] export_include_dir export_lib_dir" + echo "eg: ./tools/export_lib.sh armeabi-v7a gpu ../include ../lib/libmace-armeabi-v7a" } if [ $# -lt 4 ]; then @@ -12,9 +12,7 @@ if [ $# -lt 4 ]; then exit 1 fi -# ANDROID_ABI=arm64-v8a -# ANDROID_ABI=armeabi-v7a -ANDROID_ABI=$1 +TARGET_ABI=$1 RUNTIME=$2 EXPORT_INCLUDE_DIR=$3 EXPORT_LIB_DIR=$4 @@ -63,15 +61,18 @@ build_target() bazel build --verbose_failures -c opt --strip always $BAZEL_TARGET \ --crosstool_top=//external:android/crosstool \ --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ - --cpu=$ANDROID_ABI \ + --cpu=$TARGET_ABI \ --copt="-std=c++11" \ --copt="-D_GLIBCXX_USE_C99_MATH_TR1" \ --copt="-Werror=return-type" \ --copt="-DMACE_OBFUSCATE_LITERALS" \ + --copt="-O3" \ + --define neon=true \ + --define openmp=true \ $DSP_MODE_BUILD_FLAGS || exit 1 } -build_local_target() +build_host_target() { BAZEL_TARGET=$1 bazel build --verbose_failures -c opt --strip always $BAZEL_TARGET \ @@ -79,7 +80,8 @@ build_local_target() --copt="-D_GLIBCXX_USE_C99_MATH_TR1" \ --copt="-Werror=return-type" \ --copt="-DMACE_OBFUSCATE_LITERALS" \ - --define openmp=true || exit -1 + --copt="-O3" \ + --define openmp=true || exit 1 } merge_libs() @@ -132,10 +134,10 @@ bash mace/tools/git/gen_version_source.sh ${CODEGEN_DIR}/version/version.cc || e echo "Step 3: Build libmace targets" bazel clean -if [ x"${RUNTIME}" = x"local" ]; then +if [ x"${TARGET_ABI}" = x"host" ] || [ x"${TARGET_ABI}" = x"local" ]; then for target in ${all_targets[*]} do - build_local_target ${target} + build_host_target ${target} done else for target in ${all_targets[*]}