diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 83961e1058b86b5b8832cd9c443a79d9f79c2ef8..5f1e0dc16d6f9724245aa203b720d293fb1f7266 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -288,7 +288,8 @@ class Tensor { } CASES(dtype_, (os << (this->data()[i]) << ", ")); } - LOG(INFO) << os.str(); + LOG(INFO) << "Tensor size: [" << dim(0) << ", " << dim(1) << ", " + << dim(2) << ", " << dim(3) << "], content:\n" << os.str(); } inline size_t SizeOfType() const { diff --git a/mace/kernels/BUILD b/mace/kernels/BUILD index ba1b601f267bcdfd3fa342ad4fa8bcc2d4f02d69..5c2fe3016cde48b8e451702cc5f37141e4b3dc2a 100644 --- a/mace/kernels/BUILD +++ b/mace/kernels/BUILD @@ -15,15 +15,14 @@ cc_library( "*.cc", "opencl/*.cc", ]) + if_neon_enabled(glob([ - "neon/*.cc", + "neon/addn_neon.cc", + "neon/batch_norm_neon.cc", ])), hdrs = glob([ "*.h", "opencl/*.h", - ]) + if_neon_enabled(glob([ - "neon/*.h", - ])), - copts = if_openmp_enabled(["-fopenmp"]), + ]), + copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]), linkopts = if_android(["-lm"]), deps = [ "//mace/core", diff --git a/mace/kernels/addn.h b/mace/kernels/addn.h index 0b4828a4491e18be8a1c35d4d82d70d1790d0abf..fd28517795b01190c7866a14357c9b863cf7c872 100644 --- a/mace/kernels/addn.h +++ b/mace/kernels/addn.h @@ -5,6 +5,10 @@ #ifndef MACE_KERNELS_ADDN_H_ #define MACE_KERNELS_ADDN_H_ +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) +#include +#endif + #include "mace/core/future.h" #include "mace/core/tensor.h" #include "mace/core/runtime/opencl/cl2_header.h" @@ -12,7 +16,6 @@ namespace mace { namespace kernels { - template struct AddNFunctor { void operator()(const std::vector &input_tensors, @@ -47,7 +50,7 @@ struct AddNFunctor { cl::Kernel kernel_; }; -} // namespace kernels -} // namespace mace +} // namespace kernels +} // namespace mace #endif // MACE_KERNELS_ADDN_H_ diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index 6f16bf6fc7aaa0c1c4b7400fe62458b5bd643574..1bb85d1675bc0f6284c19ab7a017ca14716c4dd6 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -136,7 +136,7 @@ struct BatchNormFunctor : BatchNormFunctorBase { cl::Kernel kernel_; }; -} // namepsace kernels -} // namespace mace +} // namepsace kernels +} // namespace mace -#endif // MACE_KERNELS_BATCH_NORM_H_ +#endif // MACE_KERNELS_BATCH_NORM_H_ diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index cc331a17c223eb252f11cc3555b79990ee8894ee..ae605016719200e98e54622bfdf584c642d6c97a 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -5,14 +5,176 @@ #ifndef MACE_KERNELS_CONV_2D_H_ #define MACE_KERNELS_CONV_2D_H_ +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) +#include +#endif + #include "mace/core/future.h" +#include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/tensor.h" #include "mace/kernels/activation.h" #include "mace/kernels/conv_pool_2d_util.h" -#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/utils/utils.h" namespace mace { namespace kernels { +namespace { + +template +void Conv2dKernelFunc(const T *input_ptr, // batch start + const T *filter_ptr, + const T *bias_ptr, + T *output_ptr, // batch start + const int h_offset, + const int w_offset, + const int c_offset, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channels, + const int input_channels, + const int width, + const int padded_width) { + T sum[h_count * w_count * c_count] = {0.0f}; + if (bias_ptr != nullptr) { + for (int hi = 0; hi < h_count; ++hi) { + for (int wi = 0; wi < w_count; ++wi) { + for (int ci = 0; ci < c_count; ++ci) { + const int sum_idx = (hi * w_count + wi) * c_count + ci; + sum[sum_idx] = bias_ptr[c_offset + ci]; + } + } + } + } + + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int inc = 0; + for (; inc + inc_tile_size <= input_channels; inc += inc_tile_size) { +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) + // AArch64 NEON has 32 128-bit general purpose registers + static_assert(inc_tile_size == 4, "input channels tile size must be 4"); + float32x4_t in[h_count * w_count]; +#else + T in[h_count * w_count * inc_tile_size]; +#endif + for (int hi = 0; hi < h_count; ++hi) { + for (int wi = 0; wi < w_count; ++wi) { + const int in_idx = hi * w_count + wi; + const int inh = (h_offset + hi) * stride_h + kh * dilation_h; + const int inw = (w_offset + wi) * stride_w + kw * dilation_w; + const int in_offset = + (inh * padded_width + inw) * input_channels + inc; +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) + static_assert(inc_tile_size == 4, + "input channels tile size must be 4"); + in[in_idx] = vld1q_f32(input_ptr + in_offset); +#else + for (int inci = 0; inci < inc_tile_size; ++inci) { + in[in_idx * inc_tile_size + inci] = input_ptr[in_offset + inci]; + } +#endif + } + } + +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) + static_assert(inc_tile_size == 4, "input channels tile size must be 4"); + float32x4_t weights[c_count]; +#else + T weights[c_count * inc_tile_size]; +#endif + for (int ci = 0; ci < c_count; ++ci) { + const int weights_idx = ci; + const int filter_offset = + ((kh * kernel_w + kw) * channels + c_offset + ci) * + input_channels + + inc; +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) + weights[weights_idx] = vld1q_f32(filter_ptr + filter_offset); +#else + for (int inci = 0; inci < inc_tile_size; ++inci) { + weights[weights_idx * inc_tile_size + inci] = + filter_ptr[filter_offset + inci]; + } +#endif + } + for (int hi = 0; hi < h_count; ++hi) { + for (int wi = 0; wi < w_count; ++wi) { + for (int ci = 0; ci < c_count; ++ci) { + const int weights_idx = ci; + const int in_idx = hi * w_count + wi; + const int sum_idx = (hi * w_count + wi) * c_count + ci; +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) + float32x4_t tmp = vmulq_f32(in[in_idx], weights[weights_idx]); + sum[sum_idx] += vaddvq_f32(tmp); +#else + for (int inci = 0; inci < inc_tile_size; ++inci) { + sum[sum_idx] += + in[in_idx * inc_tile_size + inci] * + weights[weights_idx * inc_tile_size + inci]; + } +#endif + } + } + } + } + // handling the remaining input channels + for (; inc < input_channels; ++inc) { + T in[h_count * w_count]; + for (int hi = 0; hi < h_count; ++hi) { + for (int wi = 0; wi < w_count; ++wi) { + const int in_idx = hi * w_count + wi; + const int inh = (h_offset + hi) * stride_h + kh * dilation_h; + const int inw = (w_offset + wi) * stride_w + kw * dilation_w; + const int in_offset = + (inh * padded_width + inw) * input_channels + inc; + in[in_idx] = input_ptr[in_offset]; + } + } + + T weights[c_count]; + for (int ci = 0; ci < c_count; ++ci) { + const int weights_idx = ci; + const int filter_offset = + ((kh * kernel_w + kw) * channels + c_offset + ci) * + input_channels + + inc; + weights[weights_idx] = filter_ptr[filter_offset]; + } + for (int hi = 0; hi < h_count; ++hi) { + for (int wi = 0; wi < w_count; ++wi) { + for (int ci = 0; ci < c_count; ++ci) { + const int weights_idx = ci; + const int in_idx = hi * w_count + wi; + const int sum_idx = (hi * w_count + wi) * c_count + ci; + sum[sum_idx] += in[in_idx] * weights[weights_idx]; + } + } + } + } + } + } + // save output + for (int hi = 0; hi < h_count; ++hi) { + for (int wi = 0; wi < w_count; ++wi) { + for (int ci = 0; ci < c_count; ++ci) { + const int out_offset = + ((h_offset + hi) * width + w_offset + wi) * channels + c_offset + + ci; + const int sum_idx = (hi * w_count + wi) * c_count + ci; + output_ptr[out_offset] = sum[sum_idx]; + } + } + } +} +}; // namespace struct Conv2dFunctorBase { Conv2dFunctorBase(const int *strides, @@ -28,7 +190,7 @@ struct Conv2dFunctorBase { relux_max_limit_(relux_max_limit), prelu_alpha_(prelu_alpha) {} - const int *strides_; // [stride_h, stride_w] + const int *strides_; // [stride_h, stride_w] const Padding paddings_; const int *dilations_; // [dilation_h, dilation_w] const ActivationType activation_; @@ -51,8 +213,8 @@ struct Conv2dFunctor : Conv2dFunctorBase { relux_max_limit, prelu_alpha) {} - void operator()(const Tensor *input, // NHWC - const Tensor *filter, // HWIO + void operator()(const Tensor *input, // NHWC + const Tensor *filter, // HWOI const Tensor *bias, Tensor *output, StatsFuture *future) { @@ -67,18 +229,21 @@ struct Conv2dFunctor : Conv2dFunctorBase { paddings_, output_shape.data(), paddings.data()); output->Resize(output_shape); - index_t batch = output->dim(0); - index_t height = output->dim(1); - index_t width = output->dim(2); - index_t channels = output->dim(3); + int batch = output->dim(0); + int height = output->dim(1); + int width = output->dim(2); + int channels = output->dim(3); - index_t input_batch = input->dim(0); - index_t input_height = input->dim(1); - index_t input_width = input->dim(2); - index_t input_channels = input->dim(3); + int input_batch = input->dim(0); + int input_height = input->dim(1); + int input_width = input->dim(2); + int input_channels = input->dim(3); - index_t kernel_h = filter->dim(0); - index_t kernel_w = filter->dim(1); + int kernel_h = filter->dim(0); + int kernel_w = filter->dim(1); + MACE_CHECK(filter->dim(2) == channels, filter->dim(2), " != ", channels); + MACE_CHECK(filter->dim(3) == input_channels, filter->dim(3), " != ", + input_channels); int stride_h = strides_[0]; int stride_w = strides_[1]; @@ -88,11 +253,17 @@ struct Conv2dFunctor : Conv2dFunctorBase { MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch"); - // The left-upper most offset of the padded input - int padded_h_start = 0 - paddings[0] / 2; - int padded_w_start = 0 - paddings[1] / 2; - index_t padded_h_stop = input_height + paddings[0] - paddings[0] / 2; - index_t padded_w_stop = input_width + paddings[1] - paddings[1] / 2; + int padded_height = input_height + paddings[0]; + int padded_width = input_width + paddings[1]; + + Tensor padded_input; + // Keep this alive during kernel execution + if (paddings[0] > 0 || paddings[1] > 0) { + ConstructNHWCInputWithPadding(input, paddings.data(), &padded_input); + input = &padded_input; + } + + // padded_input.DebugPrint(); Tensor::MappingGuard input_mapper(input); Tensor::MappingGuard filter_mapper(filter); @@ -103,40 +274,338 @@ struct Conv2dFunctor : Conv2dFunctorBase { auto bias_data = bias == nullptr ? nullptr : bias->data(); auto output_data = output->mutable_data(); + constexpr int inc_tile_size = 4; +// TODO Auto tuning these parameters +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) + const int c_tile_size = 4; + const int h_tile_size = 2; + const int w_tile_size = 2; +#else + const int c_tile_size = 4; + const int h_tile_size = 1; + const int w_tile_size = 2; +#endif + + const int c_tiles = RoundUpDiv(channels, c_tile_size); + const int h_tiles = RoundUpDiv(height, h_tile_size); + const int w_tiles = RoundUpDiv(width, w_tile_size); + #pragma omp parallel for collapse(4) for (int n = 0; n < batch; ++n) { - for (int h = 0; h < height; ++h) { - for (int w = 0; w < width; ++w) { - for (int c = 0; c < channels; ++c) { - const int out_idx = ((n * height + h) * width + w) * channels + c; - T bias_channel = 0.0f; - if (bias) bias_channel = bias_data[c]; - output_data[out_idx] = bias_channel; - T sum = 0.0f; - const T *filter_ptr = filter_data + c; - for (int kh = 0; kh < kernel_h; ++kh) { - for (int kw = 0; kw < kernel_w; ++kw) { - for (int inc = 0; inc < input_channels; ++inc) { - int inh = padded_h_start + h * stride_h + dilation_h * kh; - int inw = padded_w_start + w * stride_w + dilation_w * kw; - if (inh < 0 || inh >= input_height || inw < 0 || - inw >= input_width) { - MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop && - inw >= padded_w_start && inw < padded_w_stop, - "Out of range read from input: ", inh, ", ", - inw); - } else { - index_t input_offset = - n * input_height * input_width * input_channels + - inh * input_width * input_channels + - inw * input_channels + inc; - sum += input_data[input_offset] * *filter_ptr; - } - filter_ptr += channels; + for (int cb = 0; cb < c_tiles; ++cb) { + for (int hb = 0; hb < h_tiles; ++hb) { + for (int wb = 0; wb < w_tiles; ++wb) { + const T *input_ptr = + input_data + n * padded_height * padded_width * input_channels; + T *output_ptr = output_data + n * height * width * channels; + const int h_offset = hb * h_tile_size; + const int w_offset = wb * w_tile_size; + const int c_offset = cb * c_tile_size; + + const int h_count = std::min(h_tile_size, height - h_offset); + const int w_count = std::min(w_tile_size, width - w_offset); + const int c_count = std::min(c_tile_size, channels - c_offset); + + switch (c_count) { + case 1: + switch (h_count) { + case 1: + switch (w_count) { + case 1: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 2: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 3: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 4: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + default: + LOG(FATAL) << "Unsupported width tile: " << w_count; + } + break; + case 2: + switch (w_count) { + case 1: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 2: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 3: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 4: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + default: + LOG(FATAL) << "Unsupported width tile: " << w_count; + } + break; + default: + LOG(FATAL) << "Unsupported height tile: " << h_count; } - } + break; + case 2: + switch (h_count) { + case 1: + switch (w_count) { + case 1: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 2: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 3: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 4: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + default: + LOG(FATAL) << "Unsupported width tile: " << w_count; + } + break; + case 2: + switch (w_count) { + case 1: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 2: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 3: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 4: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + default: + LOG(FATAL) << "Unsupported width tile: " << w_count; + } + break; + default: + LOG(FATAL) << "Unsupported height tile: " << h_count; + } + break; + case 3: + switch (h_count) { + case 1: + switch (w_count) { + case 1: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 2: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 3: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 4: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + default: + LOG(FATAL) << "Unsupported width tile: " << w_count; + } + break; + case 2: + switch (w_count) { + case 1: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 2: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 3: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 4: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + default: + LOG(FATAL) << "Unsupported width tile: " << w_count; + } + break; + default: + LOG(FATAL) << "Unsupported height tile: " << h_count; + } + break; + case 4: + switch (h_count) { + case 1: + switch (w_count) { + case 1: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 2: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 3: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 4: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + default: + LOG(FATAL) << "Unsupported width tile: " << w_count; + } + break; + case 2: + switch (w_count) { + case 1: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 2: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 3: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + case 4: + Conv2dKernelFunc( + input_ptr, filter_data, bias_data, output_ptr, + h_offset, w_offset, c_offset, kernel_h, kernel_w, + stride_h, stride_w, dilation_h, dilation_w, + channels, input_channels, width, padded_width); + break; + default: + LOG(FATAL) << "Unsupported width tile: " << w_count; + } + break; + default: + LOG(FATAL) << "Unsupported height tile: " << h_count; + } + break; + default: + LOG(FATAL) << "Unsupported channel tile: " << c_count; } - output_data[out_idx] += sum; } } } diff --git a/mace/kernels/conv_pool_2d_util.cc b/mace/kernels/conv_pool_2d_util.cc index 95b0cf79751bc71eb0f2f0ab4a955b2d54cd8809..60d65cdd2cfbf8fd14f9e0a27d8cdb8e3b1cd7e2 100644 --- a/mace/kernels/conv_pool_2d_util.cc +++ b/mace/kernels/conv_pool_2d_util.cc @@ -17,7 +17,7 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, "Invalid dilations, must >= 1"); MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && - (dilations[1] == 1 || strides[1] == 1), + (dilations[1] == 1 || strides[1] == 1), "If dilations > 1, strides should be 1"); MACE_CHECK_NOTNULL(output_shape); MACE_CHECK_NOTNULL(padding_size); @@ -51,7 +51,8 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1; output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1; break; - default:MACE_CHECK(false, "Unsupported padding type: ", padding); + default: + MACE_CHECK(false, "Unsupported padding type: ", padding); } // Note: TensorFlow may padded one more on the right/bottom side @@ -59,12 +60,10 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW // utilize the more centered features. We need to benchmark // based on the model accuracy. - padding_size[0] = - std::max(0, (output_height - 1) * strides[0] - + k_extent_height - input_shape[2]); - padding_size[1] = - std::max(0, (output_width - 1) * strides[1] - + k_extent_width - input_shape[3]); + padding_size[0] = std::max( + 0, (output_height - 1) * strides[0] + k_extent_height - input_shape[2]); + padding_size[1] = std::max( + 0, (output_width - 1) * strides[1] + k_extent_width - input_shape[3]); output_shape[0] = input_shape[0]; output_shape[1] = output_channels; @@ -73,7 +72,7 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW } void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC - const index_t *filter_shape, // HWIO + const index_t *filter_shape, // HWOI const int *dilations, const int *strides, Padding padding, @@ -82,7 +81,7 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, "Invalid dilations, must >= 1"); MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && - (dilations[1] == 1 || strides[1] == 1), + (dilations[1] == 1 || strides[1] == 1), "If dilations > 1, strides should be 1"); MACE_CHECK_NOTNULL(output_shape); MACE_CHECK_NOTNULL(padding_size); @@ -98,7 +97,7 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC index_t output_height = 0, output_width = 0; index_t kernel_height = filter_shape[0]; index_t kernel_width = filter_shape[1]; - index_t output_channels = filter_shape[3]; + index_t output_channels = filter_shape[2]; index_t k_extent_height = (kernel_height - 1) * dilations[0] + 1; index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1; @@ -116,7 +115,8 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC output_height = (input_shape[1] + k_extent_height - 2) / strides[0] + 1; output_width = (input_shape[2] + k_extent_width - 2) / strides[1] + 1; break; - default:MACE_CHECK(false, "Unsupported padding type: ", padding); + default: + MACE_CHECK(false, "Unsupported padding type: ", padding); } // Note: TensorFlow may padded one more on the right/bottom side @@ -124,12 +124,10 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC // utilize the more centered features. We need to benchmark // based on the model accuracy. - padding_size[0] = - std::max(0, (output_height - 1) * strides[0] - + k_extent_height - input_shape[1]); - padding_size[1] = - std::max(0, (output_width - 1) * strides[1] - + k_extent_width - input_shape[2]); + padding_size[0] = std::max( + 0, (output_height - 1) * strides[0] + k_extent_height - input_shape[1]); + padding_size[1] = std::max( + 0, (output_width - 1) * strides[1] + k_extent_width - input_shape[2]); output_shape[0] = input_shape[0]; output_shape[1] = output_height; @@ -146,7 +144,7 @@ void CalPaddingSize(const index_t *input_shape, // NCHW MACE_CHECK(dilations[0] > 0 && dilations[1] > 0, "Invalid dilations, must >= 1"); MACE_CHECK((dilations[0] == 1 || strides[0] == 1) && - (dilations[1] == 1 || strides[1] == 1), + (dilations[1] == 1 || strides[1] == 1), "If dilations > 1, strides should be 1"); MACE_CHECK_NOTNULL(padding_size); @@ -167,19 +165,18 @@ void CalPaddingSize(const index_t *input_shape, // NCHW output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1; output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1; break; - default:MACE_CHECK(false, "Unsupported padding type: ", padding); + default: + MACE_CHECK(false, "Unsupported padding type: ", padding); } // Note: TensorFlow may padded one more on the right/bottom side // TODO may be it's better to also truncate the left/top to // utilize the more centered features. We need to benchmark // based on the model accuracy. - padding_size[0] = - std::max(0, (output_height - 1) * strides[0] - + k_extent_height - input_shape[2]); - padding_size[1] = - std::max(0, (output_width - 1) * strides[1] - + k_extent_width - input_shape[3]); + padding_size[0] = std::max( + 0, (output_height - 1) * strides[0] + k_extent_height - input_shape[2]); + padding_size[1] = std::max( + 0, (output_width - 1) * strides[1] + k_extent_width - input_shape[3]); } void ConstructInputWithPadding(const Tensor *input_tensor, @@ -206,18 +203,18 @@ void ConstructInputWithPadding(const Tensor *input_tensor, output_tensor->Resize(output_shape); Tensor::MappingGuard padded_output_mapper(output_tensor); - float *output_ptr = output_tensor->mutable_data(); - memset(output_ptr, 0, output_tensor->size() * sizeof(float)); + float *output_data = output_tensor->mutable_data(); + memset(output_data, 0, output_tensor->size() * sizeof(float)); // Skip the padded top rows if (padding_same_value) { -#define COPY_INPUT \ - std::fill(output_ptr, output_ptr+padded_left, input[0]); \ - output_ptr += padded_left; \ - memcpy(output_ptr, input, width * sizeof(float)); \ - output_ptr += width; \ - std::fill(output_ptr , output_ptr + padded_right, input[width-1]); \ - output_ptr += padded_right; +#define COPY_INPUT \ + std::fill(output_data, output_data + padded_left, input[0]); \ + output_data += padded_left; \ + memcpy(output_data, input, width * sizeof(float)); \ + output_data += width; \ + std::fill(output_data, output_data + padded_right, input[width - 1]); \ + output_data += padded_right; const int padded_bottom = paddings[0] - padded_top; const int padded_right = paddings[1] - padded_left; @@ -239,19 +236,69 @@ void ConstructInputWithPadding(const Tensor *input_tensor, } #undef COPY_INPUT } else { - output_ptr += padded_top * output_width; + output_data += padded_top * output_width; for (int i = 0; i < batch; ++i) { for (int j = 0; j < channels; ++j) { for (int k = 0; k < height; ++k) { - memcpy(output_ptr + padded_left, input, width * sizeof(float)); + memcpy(output_data + padded_left, input, width * sizeof(float)); input += width; - output_ptr += output_width; + output_data += output_width; } // Skip the padded bottom in this channel and top in the next channel - output_ptr += paddings[0] * output_width; + output_data += paddings[0] * output_width; } } } } -} // namespace kernels -} // namespace mace + +void ConstructNHWCInputWithPadding(const Tensor *input_tensor, + const int *paddings, + Tensor *output_tensor, + bool padding_same_value) { + VLOG(1) << "input: " << input_tensor->NumElements(); + Tensor::MappingGuard input_mapper(input_tensor); + const float *input = input_tensor->data(); + const index_t *input_shape = input_tensor->shape().data(); + + index_t batch = input_shape[0]; + index_t height = input_shape[1]; + index_t width = input_shape[2]; + index_t channels = input_shape[3]; + + std::vector output_shape( + {batch, paddings[0] + height, paddings[1] + width, channels}); + + const int output_height = output_shape[1]; + const int output_width = output_shape[2]; + const int padded_top = paddings[0] / 2; + const int padded_left = paddings[1] / 2; + + output_tensor->Resize(output_shape); + + Tensor::MappingGuard padded_output_mapper(output_tensor); + float *output_data = output_tensor->mutable_data(); + memset(output_data, 0, output_tensor->size() * sizeof(float)); + + // Skip the padded top rows + if (padding_same_value) { + LOG(FATAL) << "Not implemented"; + } else { +#pragma omp parallel for collapse(3) + for (int n = 0; n < batch; ++n) { + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + const float *input_ptr = + input + ((n * height + h) * width + w) * channels; + float *output_ptr = + output_data + + ((n * output_height + h + padded_top) * output_width + w + + padded_left) * + channels; + memcpy(output_ptr, input_ptr, channels * sizeof(float)); + } + } + } + } +} +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/conv_pool_2d_util.h b/mace/kernels/conv_pool_2d_util.h index 8e305477d2d2d55cfbef28651dfd53b8ea811d64..87f9546f829c7ef81ba1275de8ee8d1c21ccb77d 100644 --- a/mace/kernels/conv_pool_2d_util.h +++ b/mace/kernels/conv_pool_2d_util.h @@ -44,6 +44,12 @@ void ConstructInputWithPadding(const Tensor *input, const int *paddings, Tensor *output_tensor, bool padding_same_value = false); + +void ConstructNHWCInputWithPadding(const Tensor *input, + const int *paddings, + Tensor *output_tensor, + bool padding_same_value = false); + } // namespace kernels } // namespace mace diff --git a/mace/kernels/depthwise_conv2d.h b/mace/kernels/depthwise_conv2d.h index 395797240c5dd589438f9fe5dc7c6de047d2290d..803cb34bcbd68a0948c2619a9a41f71dbeac885a 100644 --- a/mace/kernels/depthwise_conv2d.h +++ b/mace/kernels/depthwise_conv2d.h @@ -64,8 +64,8 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase { std::vector fake_filter_shape(4); fake_filter_shape[0] = filter->shape()[0]; fake_filter_shape[1] = filter->shape()[1]; - fake_filter_shape[3] = filter->shape()[2] * filter->shape()[3]; - fake_filter_shape[2] = 1; + fake_filter_shape[2] = filter->shape()[2] * filter->shape()[3]; + fake_filter_shape[3] = 1; std::vector output_shape(4); std::vector paddings(2); diff --git a/mace/kernels/neon/addn_neon.cc b/mace/kernels/neon/addn_neon.cc index f3f2a3ac426820d302f445c03c8b8bffc3d3b685..18f26af03acab04bfcb979fbccca8945990a5a41 100644 --- a/mace/kernels/neon/addn_neon.cc +++ b/mace/kernels/neon/addn_neon.cc @@ -10,9 +10,11 @@ namespace kernels { template <> void AddNFunctor::operator()( - const std::vector &input_tensors, Tensor *output_tensor, + const std::vector &input_tensors, + Tensor *output_tensor, StatsFuture *future) { // TODO: neon mem copy + output_tensor->ResizeLike(input_tensors[0]); index_t size = output_tensor->size(); float *output_ptr = output_tensor->mutable_data(); memset(output_ptr, 0, size * sizeof(float)); diff --git a/mace/kernels/neon/avg_pooling_neon_3x3.cc b/mace/kernels/neon/avg_pooling_neon_3x3.cc deleted file mode 100644 index e50f454c742f9307ee3b21300be06f52a36a3c0d..0000000000000000000000000000000000000000 --- a/mace/kernels/neon/avg_pooling_neon_3x3.cc +++ /dev/null @@ -1,224 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include -#include -#include - -#include "mace/core/common.h" - -namespace mace { -namespace kernels { - -void PoolingAvgNeonK3x3S2x2(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape, - const int *paddings) { - index_t batch = in_shape[0]; - index_t channels = in_shape[1]; - index_t in_height = in_shape[2]; - index_t in_width = in_shape[3]; - - index_t out_height = out_shape[2]; - index_t out_width = out_shape[3]; - - int padding_top = paddings[0] / 2; - int padding_bottom = paddings[0] - padding_top; - int padding_left = paddings[1] / 2; - int padding_right = paddings[1] - padding_left; - - int in_image_size = in_height * in_width; - int out_image_size = out_height * out_width; - index_t input_offset = 0; - index_t output_offset = 0; - float avg_factors[4] = {1.0 / 9.0, 1.0 / 9.0, 1.0 / 9.0, 1.0 / 9.0}; - -#pragma omp parallel for collapse(2) - for (int b = 0; b < batch; ++b) { - for (int c = 0; c < channels; ++c) { - float *outptr = output + output_offset; - - for (int h = 0; h < out_height; ++h) { - int w = 0; - int num_vectors = 0; - const float *r0, *r1, *r2; - if (!((h == 0 && padding_top > 0) || - (h == out_height - 1 && padding_bottom > 0))) { - r0 = input + input_offset + (h * 2 - padding_top) * in_width; - r1 = r0 + in_width; - r2 = r1 + in_width; - - if (padding_left > 0) { - if (padding_left == 1) { - float sum0 = std::max(r0[0], r0[1]); - float sum1 = std::max(r1[0], r1[1]); - float max2 = std::max(r2[0], r2[1]); - *outptr = (r0[0] + r0[1] + r1[0] + r1[1] + r2[0] + r2[1]) / 9.0; - ++r0; - ++r1; - } else { // padding_left == 2 - *outptr = (r0[0] + r1[0] + r2[0]) / 9.0; - } - ++outptr; - ++w; - } - if (padding_right > 0) { - num_vectors = (out_width - w - 1) >> 2; - } else { - num_vectors = (out_width - w) >> 2; - } - } - - w += num_vectors << 2; - float32x4_t factors = vld1q_f32(avg_factors); - float32x4x2_t row0 = vld2q_f32(r0); - float32x4x2_t row1 = vld2q_f32(r1); - float32x4x2_t row2 = vld2q_f32(r2); - for (; num_vectors > 0; --num_vectors) { - float32x4x2_t row0_next = vld2q_f32(r0 + 8); - float32x4x2_t row1_next = vld2q_f32(r1 + 8); - float32x4x2_t row2_next = vld2q_f32(r2 + 8); - - float32x4_t sum0 = vaddq_f32(row0.val[0], row0.val[1]); - float32x4_t sum1 = vaddq_f32(row1.val[0], row1.val[1]); - float32x4_t sum2 = vaddq_f32(row2.val[0], row2.val[1]); - - float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1); - float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1); - float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1); - - sum0 = vaddq_f32(sum0, row02); - sum1 = vaddq_f32(sum1, row12); - sum2 = vaddq_f32(sum2, row22); - - float32x4_t sum_result = vaddq_f32(vaddq_f32(sum0, sum1), sum2); - float32x4_t avg_result = vmulq_f32(sum_result, factors); - - vst1q_f32(outptr, avg_result); - - row0 = row0_next; - row1 = row1_next; - row2 = row2_next; - - r0 += 8; - r1 += 8; - r2 += 8; - outptr += 4; - } - - for (; w < out_width; ++w) { - float sum = 0.0; - for (int kh = 0; kh < 3; ++kh) { - for (int kw = 0; kw < 3; ++kw) { - int inh = h * 2 - padding_top + kh; - int inw = w * 2 - padding_left + kw; - if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) { - sum += input[input_offset + inh * in_width + inw]; - } - } - } - - *outptr = sum / 9.0; - ++outptr; - } - } - input_offset += in_image_size; - output_offset += out_image_size; - } - } -} - -// assume the input has already been padded -void PoolingAvgNeonK3x3S2x2Padded(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape) { - index_t batch = in_shape[0]; - index_t channels = in_shape[1]; - index_t in_height = in_shape[2]; - index_t in_width = in_shape[3]; - - index_t out_height = out_shape[2]; - index_t out_width = out_shape[3]; - - int in_image_size = in_height * in_width; - int out_image_size = out_height * out_width; - index_t input_offset = 0; - index_t output_offset = 0; - float avg_factors[4] = {1.0 / 9.0, 1.0 / 9.0, 1.0 / 9.0, 1.0 / 9.0}; - -#pragma omp parallel for collapse(2) - for (int b = 0; b < batch; ++b) { - for (int c = 0; c < channels; ++c) { - const float *img0 = input + input_offset; - float *outptr = output + output_offset; - - const float *r0 = img0; - const float *r1 = r0 + in_width; - const float *r2 = r1 + in_width; - - for (int h = 0; h < out_height; h++) { - int num_vectors = out_width >> 2; - int remain = out_width - (num_vectors << 2); - - float32x4_t factors = vld1q_f32(avg_factors); - float32x4x2_t row0 = vld2q_f32(r0); - float32x4x2_t row1 = vld2q_f32(r1); - float32x4x2_t row2 = vld2q_f32(r2); - for (; num_vectors > 0; --num_vectors) { - float32x4x2_t row0_next = vld2q_f32(r0 + 8); - float32x4x2_t row1_next = vld2q_f32(r1 + 8); - float32x4x2_t row2_next = vld2q_f32(r2 + 8); - - float32x4_t sum0 = vaddq_f32(row0.val[0], row0.val[1]); - float32x4_t sum1 = vaddq_f32(row1.val[0], row1.val[1]); - float32x4_t sum2 = vaddq_f32(row2.val[0], row2.val[1]); - - float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1); - float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1); - float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1); - - sum0 = vaddq_f32(sum0, row02); - sum1 = vaddq_f32(sum1, row12); - sum2 = vaddq_f32(sum2, row22); - - float32x4_t sum_result = vaddq_f32(vaddq_f32(sum0, sum1), sum2); - float32x4_t avg_result = vmulq_f32(sum_result, factors); - - vst1q_f32(outptr, avg_result); - - row0 = row0_next; - row1 = row1_next; - row2 = row2_next; - - r0 += 8; - r1 += 8; - r2 += 8; - outptr += 4; - } - - for (; remain > 0; remain--) { - *outptr = (r0[0] + r0[1] + r0[2] + r1[0] + r1[1] + r1[2] + r2[0] + - r2[1] + r2[2]) / - 9.0; - - r0 += 2; - r1 += 2; - r2 += 2; - outptr++; - } - - r0 += 1 + in_width; - r1 += 1 + in_width; - r2 += 1 + in_width; - } - input_offset += in_image_size; - output_offset += out_image_size; - } - } -} - -} // namespace kernels -} // namespace mace diff --git a/mace/kernels/neon/batch_norm_neon.cc b/mace/kernels/neon/batch_norm_neon.cc index b681f6e8679d174288fdaca68eb182c416a3036b..84dc44086223e290a0fa1b072b71c930bffef4fc 100644 --- a/mace/kernels/neon/batch_norm_neon.cc +++ b/mace/kernels/neon/batch_norm_neon.cc @@ -15,6 +15,7 @@ void BatchNormFunctor::operator()( const Tensor *offset, const Tensor *mean, const Tensor *var, + const float epsilon, Tensor *output, StatsFuture *future) { // Batch normalization in the paper https://arxiv.org/abs/1502.03167 . @@ -26,8 +27,8 @@ void BatchNormFunctor::operator()( // new_offset = \offset - mean * common_val; // Y = new_scale * X + new_offset; const index_t n = input->dim(0); - const index_t channel = input->dim(1); - const index_t sample_size = input->dim(2) * input->dim(3); + const index_t sample_size = input->dim(1) * input->dim(2); + const index_t channel = input->dim(3); const float *input_ptr = input->data(); const float *scale_ptr = scale->data(); @@ -36,36 +37,47 @@ void BatchNormFunctor::operator()( const float *var_ptr = var->data(); float *output_ptr = output->mutable_data(); - index_t count = sample_size >> 2; - index_t remain_count = sample_size - (count << 2); + const index_t ch_blks = channel >> 2; + const index_t remain_chs = channel - (ch_blks << 2); + + std::vector new_scale(channel); + std::vector new_offset(channel); + #pragma omp parallel for for (index_t c = 0; c < channel; ++c) { - float new_scale = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon_); - float new_offset = offset_ptr[c] - mean_ptr[c] * new_scale; - index_t pos = c * sample_size; - - float32x4_t new_scale_f = vdupq_n_f32(new_scale); - float32x4_t new_offset_f = vdupq_n_f32(new_offset); - for (index_t i = 0; i < n; ++i) { - const float *input_sample_ptr = input_ptr + pos; - float *output_sample_ptr = output_ptr + pos; + new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon); + new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c]; + } - for (index_t j = 0; j < count; ++j) { +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < n; ++i) { + for (index_t j = 0; j < sample_size; ++j) { + const float *input_sample_ptr = input_ptr + (i * sample_size + j) * channel; + float *output_sample_ptr = output_ptr + (i * sample_size + j) * channel; + const float *new_scale_ptr = new_scale.data(); + const float *new_offset_ptr = new_offset.data(); + for (index_t cb = 0; cb < ch_blks; ++cb) { + float32x4_t new_scale_f = vld1q_f32(new_scale_ptr); + float32x4_t new_offset_f = vld1q_f32(new_offset_ptr); float32x4_t input_f = vld1q_f32(input_sample_ptr); float32x4_t output_f = vfmaq_f32(new_offset_f, input_f, new_scale_f); vst1q_f32(output_sample_ptr, output_f); + input_sample_ptr += 4; output_sample_ptr += 4; + new_scale_ptr += 4; + new_offset_ptr += 4; } - for (index_t j = 0; j < remain_count; ++j) { - *output_sample_ptr = new_scale * *input_sample_ptr + new_offset; + for (index_t c = (ch_blks << 2); c < channel; ++c) { + *output_sample_ptr = new_scale[c] * *input_sample_ptr + new_offset[c]; ++output_sample_ptr; ++input_sample_ptr; + ++new_scale_ptr; + ++new_offset_ptr; } - pos += channel * sample_size; } } }; } // namespace kernels -} // namespace mace +} // namespace mace diff --git a/mace/kernels/neon/global_avg_pooling_neon.cc b/mace/kernels/neon/global_avg_pooling_neon.cc deleted file mode 100644 index cf639559bdfd5cd5b243aec636c98590cd05855a..0000000000000000000000000000000000000000 --- a/mace/kernels/neon/global_avg_pooling_neon.cc +++ /dev/null @@ -1,56 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/kernels/global_avg_pooling.h" -#include - -namespace mace { -namespace kernels { - -template <> -void GlobalAvgPoolingFunctor::operator()( - const float *input, const index_t *input_shape, - float *output, StatsFuture *future) { - index_t batch = input_shape[0]; - index_t channels = input_shape[1]; - index_t height = input_shape[2]; - index_t width = input_shape[3]; - - index_t image_size = height * width; - index_t input_offset = 0; - index_t total_channels = batch * channels; - -#pragma omp parallel for - for (int c = 0; c < total_channels; ++c) { - const float *inptr = input + c * image_size; - float sum = 0.0; - - int num_vectors = image_size >> 2; - int remain = image_size - (num_vectors << 2); - - if (num_vectors > 0) { - float sum_out[4] = {0.0, 0.0, 0.0, 0.0}; - - float32x4_t sum_vector = vld1q_f32(inptr); - inptr += 4; - for (int n = 1; n < num_vectors; ++n) { - float32x4_t vector = vld1q_f32(inptr); - sum_vector = vaddq_f32(sum_vector, vector); - inptr += 4; - } - vst1q_f32(sum_out, sum_vector); - - sum = sum_out[0] + sum_out[1] + sum_out[2] + sum_out[3]; - } - - for (int i = 0; i < remain; ++i) { - sum += *inptr; - ++inptr; - } - output[c] = sum / image_size; - } -}; - -} // namespace kernels -} // namespace mace diff --git a/mace/kernels/neon/max_pooling_neon_2x2.cc b/mace/kernels/neon/max_pooling_neon_2x2.cc deleted file mode 100644 index 69743b33cb4886d88c645229e84a567184bab0a2..0000000000000000000000000000000000000000 --- a/mace/kernels/neon/max_pooling_neon_2x2.cc +++ /dev/null @@ -1,173 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include -#include -#include - -#include "mace/core/common.h" - -namespace mace { -namespace kernels { - -void PoolingMaxNeonK2x2S2x2(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape, - const int *paddings) { - index_t batch = in_shape[0]; - index_t channels = in_shape[1]; - index_t in_height = in_shape[2]; - index_t in_width = in_shape[3]; - - index_t out_height = out_shape[2]; - index_t out_width = out_shape[3]; - - int padding_top = paddings[0] / 2; - int padding_bottom = paddings[0] - padding_top; - int padding_left = paddings[1] / 2; - int padding_right = paddings[1] - padding_left; - - int in_image_size = in_height * in_width; - int out_image_size = out_height * out_width; - index_t input_offset = 0; - index_t output_offset = 0; - -#pragma omp parallel for collapse(2) - for (int b = 0; b < batch; ++b) { - for (int c = 0; c < channels; ++c) { - float *outptr = output + output_offset; - const float *r0, *r1; - - for (int h = 0; h < out_height; ++h) { - int w = 0; - int num_vectors = 0; - if (!((h == 0 && padding_top > 0) || - (h == out_height - 1 && padding_bottom > 0))) { - r0 = input + input_offset + (h * 2 - padding_top) * in_width; - r1 = r0 + in_width; - if (padding_left > 0) { - *outptr = std::max(r0[0], r1[0]); - ++r0; - ++r1; - ++outptr; - ++w; - } - if (padding_right > 0) { - num_vectors = (out_width - w - 1) >> 2; - } else { - num_vectors = (out_width - w) >> 2; - } - } - - w += num_vectors << 2; - - for (; num_vectors > 0; --num_vectors) { - float32x4_t r00 = vld1q_f32(r0); - float32x4_t r10 = vld1q_f32(r1); - float32x4_t r01 = vld1q_f32(r0 + 4); - float32x4_t r11 = vld1q_f32(r1 + 4); - - float32x4_t max0 = vmaxq_f32(r00, r10); - float32x4_t max1 = vmaxq_f32(r01, r11); - - float32x4_t max_result = vpmaxq_f32(max0, max1); - - vst1q_f32(outptr, max_result); - - r0 += 8; - r1 += 8; - outptr += 4; - } - - for (; w < out_width; ++w) { - float max = std::numeric_limits::lowest(); - for (int kh = 0; kh < 2; ++kh) { - for (int kw = 0; kw < 2; ++kw) { - int inh = h * 2 - padding_top + kh; - int inw = w * 2 - padding_left + kw; - if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) { - max = std::max(max, input[input_offset + inh * in_width + inw]); - } - } - } - - *outptr = max; - ++outptr; - } - } - input_offset += in_image_size; - output_offset += out_image_size; - } - } -} - -// assume the input has already been padded -void PoolingMaxNeonK2x2S2x2Padded(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape) { - index_t batch = in_shape[0]; - index_t channels = in_shape[1]; - index_t in_height = in_shape[2]; - index_t in_width = in_shape[3]; - - index_t out_height = out_shape[2]; - index_t out_width = out_shape[3]; - - int in_image_size = in_height * in_width; - int out_image_size = out_height * out_width; - index_t input_offset = 0; - index_t output_offset = 0; - -#pragma omp parallel for collapse(2) - for (int b = 0; b < batch; ++b) { - for (int c = 0; c < channels; ++c) { - const float *img0 = input + input_offset; - float *outptr = output + output_offset; - - const float *r0 = img0; - const float *r1 = img0 + in_width; - - for (int h = 0; h < out_height; ++h) { - int num_vectors = out_width >> 2; - int remain = out_width - (num_vectors << 2); - - for (; num_vectors > 0; --num_vectors) { - float32x4_t r00 = vld1q_f32(r0); - float32x4_t r10 = vld1q_f32(r1); - float32x4_t r01 = vld1q_f32(r0 + 4); - float32x4_t r11 = vld1q_f32(r1 + 4); - - float32x4_t max0 = vmaxq_f32(r00, r10); - float32x4_t max1 = vmaxq_f32(r01, r11); - - float32x4_t max_result = vpmaxq_f32(max0, max1); - - vst1q_f32(outptr, max_result); - r0 += 8; - r1 += 8; - outptr += 4; - } - - for (; remain > 0; --remain) { - float max0 = std::max(r0[0], r0[1]); - float max1 = std::max(r1[0], r1[1]); - *outptr = std::max(max0, max1); - - r0 += 2; - r1 += 2; - outptr++; - } - r0 += in_width; - r1 += in_width; - } - input_offset += in_image_size; - output_offset += out_image_size; - } - } -} - -} // namespace kernels -} // namespace mace diff --git a/mace/kernels/neon/max_pooling_neon_3x3.cc b/mace/kernels/neon/max_pooling_neon_3x3.cc deleted file mode 100644 index 0c7a74d0b0d1133d9367ceac158240e84aa49d83..0000000000000000000000000000000000000000 --- a/mace/kernels/neon/max_pooling_neon_3x3.cc +++ /dev/null @@ -1,220 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include -#include - -#include "mace/core/common.h" - -namespace mace { -namespace kernels { - -void PoolingMaxNeonK3x3S2x2(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape, - const int *paddings) { - index_t batch = in_shape[0]; - index_t channels = in_shape[1]; - index_t in_height = in_shape[2]; - index_t in_width = in_shape[3]; - - index_t out_height = out_shape[2]; - index_t out_width = out_shape[3]; - - int padding_top = paddings[0] / 2; - int padding_bottom = paddings[0] - padding_top; - int padding_left = paddings[1] / 2; - int padding_right = paddings[1] - padding_left; - - int in_image_size = in_height * in_width; - int out_image_size = out_height * out_width; - index_t input_offset = 0; - index_t output_offset = 0; - -#pragma omp parallel for collapse(2) - for (int b = 0; b < batch; ++b) { - for (int c = 0; c < channels; ++c) { - float *outptr = output + output_offset; - - for (int h = 0; h < out_height; ++h) { - int w = 0; - int num_vectors = 0; - const float *r0, *r1, *r2; - if (!((h == 0 && padding_top > 0) || - (h == out_height - 1 && padding_bottom > 0))) { - r0 = input + input_offset + (h * 2 - padding_top) * in_width; - r1 = r0 + in_width; - r2 = r1 + in_width; - - if (padding_left > 0) { - if (padding_left == 1) { - float max0 = std::max(r0[0], r0[1]); - float max1 = std::max(r1[0], r1[1]); - float max2 = std::max(r2[0], r2[1]); - *outptr = std::max(std::max(max0, max1), max2); - ++r0; - ++r1; - } else { // padding_left == 2 - float max_tmp = std::max(r0[0], r1[0]); - *outptr = std::max(max_tmp, r2[0]); - } - ++outptr; - ++w; - } - if (padding_right > 0) { - num_vectors = (out_width - w - 1) >> 2; - } else { - num_vectors = (out_width - w) >> 2; - } - } - - w += num_vectors << 2; - float32x4x2_t row0 = vld2q_f32(r0); - float32x4x2_t row1 = vld2q_f32(r1); - float32x4x2_t row2 = vld2q_f32(r2); - for (; num_vectors > 0; --num_vectors) { - float32x4x2_t row0_next = vld2q_f32(r0 + 8); - float32x4x2_t row1_next = vld2q_f32(r1 + 8); - float32x4x2_t row2_next = vld2q_f32(r2 + 8); - - float32x4_t max0 = vmaxq_f32(row0.val[0], row0.val[1]); - float32x4_t max1 = vmaxq_f32(row1.val[0], row1.val[1]); - float32x4_t max2 = vmaxq_f32(row2.val[0], row2.val[1]); - - float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1); - float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1); - float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1); - - max0 = vmaxq_f32(max0, row02); - max1 = vmaxq_f32(max1, row12); - max2 = vmaxq_f32(max2, row22); - - float32x4_t max_result = vmaxq_f32(vmaxq_f32(max0, max1), max2); - - vst1q_f32(outptr, max_result); - - row0 = row0_next; - row1 = row1_next; - row2 = row2_next; - - r0 += 8; - r1 += 8; - r2 += 8; - outptr += 4; - } - - for (; w < out_width; ++w) { - float max = std::numeric_limits::lowest(); - for (int kh = 0; kh < 3; ++kh) { - for (int kw = 0; kw < 3; ++kw) { - int inh = h * 2 - padding_top + kh; - int inw = w * 2 - padding_left + kw; - if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) { - max = std::max(max, input[input_offset + inh * in_width + inw]); - } - } - } - - *outptr = max; - ++outptr; - } - } - input_offset += in_image_size; - output_offset += out_image_size; - } - } -} - -// assume the input has already been padded -void PoolingMaxNeonK3x3S2x2Padded(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape) { - index_t batch = in_shape[0]; - index_t channels = in_shape[1]; - index_t in_height = in_shape[2]; - index_t in_width = in_shape[3]; - - index_t out_height = out_shape[2]; - index_t out_width = out_shape[3]; - - int in_image_size = in_height * in_width; - int out_image_size = out_height * out_width; - index_t input_offset = 0; - index_t output_offset = 0; - -#pragma omp parallel for collapse(2) - for (int b = 0; b < batch; ++b) { - for (int c = 0; c < channels; ++c) { - const float *img0 = input + input_offset; - float *outptr = output + output_offset; - - const float *r0 = img0; - const float *r1 = r0 + in_width; - const float *r2 = r1 + in_width; - - for (int h = 0; h < out_height; h++) { - int num_vectors = out_width >> 2; - int remain = out_width - (num_vectors << 2); - - float32x4x2_t row0 = vld2q_f32(r0); - float32x4x2_t row1 = vld2q_f32(r1); - float32x4x2_t row2 = vld2q_f32(r2); - for (; num_vectors > 0; num_vectors--) { - float32x4x2_t row0_next = vld2q_f32(r0 + 8); - float32x4x2_t row1_next = vld2q_f32(r1 + 8); - float32x4x2_t row2_next = vld2q_f32(r2 + 8); - - float32x4_t max0 = vmaxq_f32(row0.val[0], row0.val[1]); - float32x4_t max1 = vmaxq_f32(row1.val[0], row1.val[1]); - float32x4_t max2 = vmaxq_f32(row2.val[0], row2.val[1]); - - float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1); - float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1); - float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1); - - max0 = vmaxq_f32(max0, row02); - max1 = vmaxq_f32(max1, row12); - max2 = vmaxq_f32(max2, row22); - - float32x4_t max_result = vmaxq_f32(vmaxq_f32(max0, max1), max2); - - vst1q_f32(outptr, max_result); - - row0 = row0_next; - row1 = row1_next; - row2 = row2_next; - - r0 += 8; - r1 += 8; - r2 += 8; - outptr += 4; - } - - for (; remain > 0; remain--) { - float max0 = std::max(std::max(r0[0], r0[1]), r0[2]); - float max1 = std::max(std::max(r1[0], r1[1]), r1[2]); - float max2 = std::max(std::max(r2[0], r2[1]), r2[2]); - - *outptr = std::max(std::max(max0, max1), max2); - - r0 += 2; - r1 += 2; - r2 += 2; - outptr++; - } - - r0 += 1 + in_width; - r1 += 1 + in_width; - r2 += 1 + in_width; - } - input_offset += in_image_size; - output_offset += out_image_size; - } - } -} - -} // namespace kernels -} // namespace mace diff --git a/mace/kernels/neon/pooling_neon.cc b/mace/kernels/neon/pooling_neon.cc deleted file mode 100644 index cf280c38e7efb7b891768158850906b5d6695061..0000000000000000000000000000000000000000 --- a/mace/kernels/neon/pooling_neon.cc +++ /dev/null @@ -1,131 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/kernels/pooling.h" - -namespace mace { -namespace kernels { - -extern void PoolingMaxNeonK2x2S2x2(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape, - const int *paddings); - -extern void PoolingAvgNeonK2x2S2x2(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape, - const int *paddings); - -extern void PoolingMaxNeonK3x3S2x2(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape, - const int *paddings); - -extern void PoolingAvgNeonK3x3S2x2(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape, - const int *paddings); - -#ifdef __COPY_MAKE_PADDING -extern void PoolingMaxNeonK2x2S2x2Padded(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape); - -extern void PoolingAvgNeonK2x2S2x2Padded(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape); - -extern void PoolingMaxNeonK3x3S2x2Padded(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape); - -extern void PoolingAvgNeonK3x3S2x2Padded(const float *input, - const index_t *in_shape, - float *output, - const index_t *out_shape); -#endif - -template <> -void PoolingFunctor::operator()( - const Tensor *input_tensor, - Tensor *output_tensor, - StatsFuture *future) { - - std::vector output_shape(4); - std::vector paddings(2); - std::vector filter_shape(4); - filter_shape[0] = input_tensor->shape()[1]; - filter_shape[1] = input_tensor->shape()[1]; - filter_shape[2] = kernels_[0]; - filter_shape[3] = kernels_[1]; - - kernels::CalcPaddingAndOutputSize( - input_tensor->shape().data(), filter_shape.data(), this->dilations_, - strides_, this->padding_, output_shape.data(), - paddings.data()); - output_tensor->Resize(output_shape); - - const float *input = input_tensor->data(); - float *output = output_tensor->mutable_data(); - const index_t *input_shape = input_tensor->shape().data(); - -#ifdef __COPY_MAKE_PADDING - Tensor padded_input; - ConstructInputWithPadding(input_tensor, paddings.data(), &padded_input); - input = padded_input.data(); - input_shape = padded_input.shape().data(); -#endif - - if (kernels_[0] == 2 && kernels_[1] == 2 && strides_[0] == 2 && - strides_[1] == 2) { - // kernel_size: 2x2, strides: 2x2 - if (pooling_type_ == MAX) { // MAX_POOL_2x2s2x2 -#ifdef __COPY_MAKE_PADDING - PoolingMaxNeonK2x2S2x2Padded(input, input_shape, output, output_shape.data()); -#else - PoolingMaxNeonK2x2S2x2(input, input_shape, output, output_shape.data(), - paddings.data()); -#endif - } else { // AVG_POOL_2x2s2x2 -#ifdef __COPY_MAKE_PADDING - PoolingAvgNeonK2x2S2x2Padded(input, input_shape, output, output_shape.data()); -#else - PoolingAvgNeonK2x2S2x2(input, input_shape, output, output_shape.data(), - paddings.data()); -#endif - } - } else if (kernels_[0] == 3 && kernels_[1] == 3 && strides_[0] == 2 && - strides_[1] == 2) { - // kernel_size: 3x3, strides: 2x2 - if (pooling_type_ == MAX) { // MAX_POOL_3x3s2x2 -#ifdef __COPY_MAKE_PADDING - PoolingMaxNeonK3x3S2x2Padded(input, input_shape, output, output_shape.data()); -#else - PoolingMaxNeonK3x3S2x2(input, input_shape, output, output_shape.data(), - paddings.data()); -#endif - } else { // AVG_POOL_3x3s2x2 -#ifdef __COPY_MAKE_PADDING - PoolingAvgNeonK3x3S2x2Padded(input, input_shape, output, output_shape.data()); -#else - PoolingAvgNeonK3x3S2x2(input, input_shape, output, output_shape.data(), - paddings.data()); -#endif - } - } else { // not implement yet - PoolingFunctor(pooling_type_, kernels_, strides_, - padding_, dilations_)( - input_tensor, output_tensor, future); - } -} - -} // namespace kernels -} // namespace mace diff --git a/mace/kernels/neon/relu_neon.cc b/mace/kernels/neon/relu_neon.cc deleted file mode 100644 index ad74b819224397b9b532a19729df32cb190e93a7..0000000000000000000000000000000000000000 --- a/mace/kernels/neon/relu_neon.cc +++ /dev/null @@ -1,70 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/kernels/relu.h" -#include - -namespace mace { -namespace kernels { - -template <> -void ActivationFunctor::operator()(const Tensor *input_tensor, - Tensor *output_tensor, - StatsFuture *future) { - const float *input = input_tensor->data(); - float *output = output_tensor->mutable_data(); - index_t size = input_tensor->size(); - if (max_limit_ < 0) { -#pragma omp parallel for - for (int64_t i = 0; i < size; i += kCostPerGroup) { - int64_t count = std::min(static_cast(kCostPerGroup), size - i); - int block = count >> 2; - int remain = count - (block << 2); - const float *inptr = input + i; - float *outptr = output + i; - float32x4_t zero = vdupq_n_f32(0.f); - for (; block > 0; --block) { - float32x4_t in = vld1q_f32(inptr); - float32x4_t out = vmaxq_f32(in, zero); - vst1q_f32(outptr, out); - - inptr += 4; - outptr += 4; - } - for (; remain > 0; --remain) { - *outptr = std::max(*inptr, 0.f); - ++inptr; - ++outptr; - } - } - } else { -#pragma omp parallel for - for (int64_t i = 0; i < size; i += kCostPerGroup) { - int64_t count = std::min(static_cast(kCostPerGroup), size - i); - int block = count >> 2; - int remain = count - (block << 2); - const float *inptr = input + i; - float *outptr = output + i; - float32x4_t zero = vdupq_n_f32(0.f); - float32x4_t vmax = vdupq_n_f32(max_limit_); - for (; block > 0; --block) { - float32x4_t in = vld1q_f32(inptr); - float32x4_t out = vmaxq_f32(in, zero); - out = vminq_f32(out, vmax); - vst1q_f32(outptr, out); - - inptr += 4; - outptr += 4; - } - for (; remain > 0; --remain) { - *outptr = std::min(std::max(*inptr, 0.f), max_limit_); - ++inptr; - ++outptr; - } - } - } -}; - -} // namespace kernels -} // namespace mace diff --git a/mace/kernels/opencl/cl/buffer_to_image.cl b/mace/kernels/opencl/cl/buffer_to_image.cl index f95029c0300a1e44a93036136e33e8f77c393bad..d4a28e6e6180c8058b21fa9436334aaaa8fe501e 100644 --- a/mace/kernels/opencl/cl/buffer_to_image.cl +++ b/mace/kernels/opencl/cl/buffer_to_image.cl @@ -1,9 +1,9 @@ #include -__kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, ic, oc */ +__kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, oc, ic */ __private const int filter_w, - __private const int in_channel, __private const int out_channel, + __private const int in_channel, __write_only image2d_t output) { int w = get_global_id(0); int h = get_global_id(1); @@ -13,23 +13,26 @@ __kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, i const int in_channel_idx = w % rounded_in_channel; const int h_idx = hw_idx / filter_w; const int w_idx = hw_idx % filter_w; - const int offset = ((h_idx * filter_w + w_idx) * in_channel + in_channel_idx) * out_channel - + out_channel_idx; + const int offset = ((h_idx * filter_w + w_idx) * out_channel + out_channel_idx) * in_channel + + in_channel_idx; - const int size = out_channel - out_channel_idx; VEC_DATA_TYPE(DATA_TYPE, 4) values = 0; - if (in_channel_idx < in_channel) { + if (out_channel_idx < out_channel) { + const int size = out_channel - out_channel_idx; if (size < 4) { - switch(size) { + switch (size) { case 3: - values.z = *(input + offset + 2); + values.z = *(input + offset + 2 * in_channel); case 2: - values.y = *(input + offset + 1); + values.y = *(input + offset + 1 * in_channel); case 1: values.x = *(input + offset); } } else { - values = vload4(0, input + offset); + values.w = *(input + offset + 3 * in_channel); + values.z = *(input + offset + 2 * in_channel); + values.y = *(input + offset + 1 * in_channel); + values.x = *(input + offset); } } @@ -37,10 +40,10 @@ __kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, i CMD_TYPE(write_image, CMD_DATA_TYPE)(output, coord, values); } -__kernel void filter_image_to_buffer(__global DATA_TYPE *output, /* h, w, ic, oc */ +__kernel void filter_image_to_buffer(__global DATA_TYPE *output, /* h, w, oc, ic */ __private const int filter_w, - __private const int in_channel, __private const int out_channel, + __private const int in_channel, __read_only image2d_t input) { int w = get_global_id(0); int h = get_global_id(1); @@ -50,29 +53,31 @@ __kernel void filter_image_to_buffer(__global DATA_TYPE *output, /* h, w, ic, oc const int in_channel_idx = w % rounded_in_channel; const int h_idx = hw_idx / filter_w; const int w_idx = hw_idx % filter_w; - const int offset = ((h_idx * filter_w + w_idx) * in_channel + in_channel_idx) * out_channel - + out_channel_idx; + const int offset = ((h_idx * filter_w + w_idx) * out_channel + out_channel_idx) * in_channel + + in_channel_idx; - if (in_channel_idx < in_channel) { + if (out_channel_idx < out_channel) { int2 coord = (int2)(w, h); VEC_DATA_TYPE(DATA_TYPE, 4) values = CMD_TYPE(read_image, CMD_DATA_TYPE)(input, SAMPLER, coord); const int size = (out_channel - out_channel_idx); if (size < 4) { switch (size) { case 3: - output[offset+2] = values.s2; + output[offset + 2 * in_channel] = values.z; case 2: - output[offset+1] = values.s1; + output[offset + 1 * in_channel] = values.y; case 1: - output[offset] = values.s0; + output[offset] = values.x; } } else { - vstore4(values, 0, output + offset); + output[offset + 3 * in_channel] = values.w; + output[offset + 2 * in_channel] = values.z; + output[offset + 1 * in_channel] = values.y; + output[offset] = values.x; } } } - __kernel void dw_filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, ic, m */ __private const int filter_w, __private const int in_channel, diff --git a/mace/kernels/opencl/depthwise_conv_opencl.cc b/mace/kernels/opencl/depthwise_conv_opencl.cc index 79c9196d94021ebf32ef0d96e2e46c4e0bfdd475..81cfb3dd9dc1ffe522ecc5c8c718f2c31357cae4 100644 --- a/mace/kernels/opencl/depthwise_conv_opencl.cc +++ b/mace/kernels/opencl/depthwise_conv_opencl.cc @@ -149,8 +149,8 @@ void DepthwiseConv2dFunctor::operator()( std::vector fake_filter_shape(4); fake_filter_shape[0] = filter->shape()[0]; fake_filter_shape[1] = filter->shape()[1]; - fake_filter_shape[3] = filter->shape()[2] * filter->shape()[3]; - fake_filter_shape[2] = 1; + fake_filter_shape[2] = filter->shape()[2] * filter->shape()[3]; + fake_filter_shape[3] = 1; std::vector output_shape(4); std::vector paddings(2); diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index cc9cfee9f86e8d35eedc39c66aed20cb9987fed5..a9923a6204e13fa7c7660fcc228d6ec34cb95162 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -19,12 +19,12 @@ void CalInOutputImageShape(const std::vector &shape, /* NHWC */ } // [RoundUp<4>(Ic) * H * W, (Oc + 3) / 4] -void CalConv2dFilterImageShape(const std::vector &shape, /* HWIO */ +void CalConv2dFilterImageShape(const std::vector &shape, /* HWOI */ std::vector &image_shape) { MACE_CHECK(shape.size() == 4); image_shape.resize(2); - image_shape[0] = shape[0] * shape[1] * RoundUp(shape[2], 4); - image_shape[1] = RoundUpDiv4(shape[3]); + image_shape[0] = shape[0] * shape[1] * RoundUp(shape[3], 4); + image_shape[1] = RoundUpDiv4(shape[2]); } // [H * W * M, (Ic + 3) / 4] @@ -179,6 +179,7 @@ void TuningOrRun3DKernel(cl::Kernel &kernel, local_ws[2] = std::min(gws[2], kwg_size / (local_ws[0] * local_ws[1])); return { + // TODO tuning these magic numbers {local_ws[0], local_ws[1], local_ws[2], 1}, {kwg_size / 16, 4, 4, 1}, {kwg_size / 32, 4, 8, 1}, @@ -200,7 +201,7 @@ void TuningOrRun3DKernel(cl::Kernel &kernel, {9, 7, 15, 1}, {15, 7, 9, 1}, {1, kwg_size, 1, 1}, - {4, 15, 8, 1}, // SNPE size + {4, 15, 8, 1}, }; }; cl::Event event; diff --git a/mace/ops/activation.cc b/mace/ops/activation.cc index 5cdffef16dd86852f112356413c829e8c9e5ff4a..96c04ac150e0b92b76032e162a09d96e07cc85ca 100644 --- a/mace/ops/activation.cc +++ b/mace/ops/activation.cc @@ -13,14 +13,6 @@ void Register_Activation(OperatorRegistry *op_registry) { .Build(), ActivationOp); -#if MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation") - .Device(DeviceType::NEON) - .TypeConstraint("T") - .Build(), - ActivationOp); -#endif // MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation") .Device(DeviceType::OPENCL) .TypeConstraint("T") diff --git a/mace/ops/activation_benchmark.cc b/mace/ops/activation_benchmark.cc index 8010bc24dea8effe2750826e9d1c2bc8bb99fe9e..8a26e2436fd07c194e1b302d2d198f1ca3dc8d84 100644 --- a/mace/ops/activation_benchmark.cc +++ b/mace/ops/activation_benchmark.cc @@ -298,14 +298,15 @@ static void SigmoidBenchmark( } \ BENCHMARK(BM_SIGMOID_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) -#define BM_SIGMOID(N, C, H, W, TYPE) \ - BM_SIGMOID_MACRO(N, C, H, W, TYPE, CPU); \ - BM_SIGMOID_MACRO(N, C, H, W, TYPE, OPENCL); - -BM_SIGMOID(1, 1, 512, 512, float); -BM_SIGMOID(1, 3, 128, 128, float); -BM_SIGMOID(1, 3, 512, 512, float); -BM_SIGMOID(1, 32, 112, 112, float); -BM_SIGMOID(1, 64, 256, 256, float); +#define BM_SIGMOID(N, C, H, W) \ + BM_SIGMOID_MACRO(N, C, H, W, float, CPU); \ + BM_SIGMOID_MACRO(N, C, H, W, float, OPENCL); \ + BM_SIGMOID_MACRO(N, C, H, W, half, OPENCL); + +BM_SIGMOID(1, 1, 512, 512); +BM_SIGMOID(1, 3, 128, 128); +BM_SIGMOID(1, 3, 512, 512); +BM_SIGMOID(1, 32, 112, 112); +BM_SIGMOID(1, 64, 256, 256); } // namespace mace diff --git a/mace/ops/activation_test.cc b/mace/ops/activation_test.cc index 2fd1078c88ef3151268c2ff548a281bfb1bf3b3e..02e16108eaacc8f91608457717bfeb7e55260dac 100644 --- a/mace/ops/activation_test.cc +++ b/mace/ops/activation_test.cc @@ -53,10 +53,6 @@ void TestSimpleRelu() { TEST_F(ActivationOpTest, CPUSimpleRelu) { TestSimpleRelu(); } -#if __ARM_NEON -TEST_F(ActivationOpTest, NEONSimpleRelu) { TestSimpleRelu(); } -#endif - TEST_F(ActivationOpTest, OPENCLSimpleRelu) { TestSimpleRelu(); } @@ -104,12 +100,6 @@ TEST_F(ActivationOpTest, CPUUnalignedSimpleRelu) { TestUnalignedSimpleRelu(); } -#if __ARM_NEON -TEST_F(ActivationOpTest, NEONUnalignedSimpleRelu) { - TestUnalignedSimpleRelu(); -} -#endif - TEST_F(ActivationOpTest, OPENCLUnalignedSimpleRelu) { TestUnalignedSimpleRelu(); } @@ -160,10 +150,6 @@ void TestSimpleRelux() { TEST_F(ActivationOpTest, CPUSimple) { TestSimpleRelux(); } -#if __ARM_NEON -TEST_F(ActivationOpTest, NEONSimple) { TestSimpleRelux(); } -#endif - TEST_F(ActivationOpTest, OPENCLSimple) { TestSimpleRelux(); } @@ -216,12 +202,6 @@ TEST_F(ActivationOpTest, CPUSimpleRelux) { TestSimpleReluRelux(); } -#if __ARM_NEON -TEST_F(ActivationOpTest, NEONSimpleRelux) { - TestSimpleReluRelux(); -} -#endif - TEST_F(ActivationOpTest, OPENCLSimpleRelux) { TestSimpleReluRelux(); } @@ -272,12 +252,6 @@ void TestSimplePrelu() { TEST_F(ActivationOpTest, CPUSimplePrelu) { TestSimplePrelu(); } -#if __ARM_NEON -TEST_F(ActivationOpTest, NEONSimplePrelu) { - TestSimplePrelu(); -} -#endif - TEST_F(ActivationOpTest, OPENCLSimplePrelu) { TestSimplePrelu(); } @@ -329,10 +303,6 @@ void TestSimpleTanh() { TEST_F(ActivationOpTest, CPUSimpleTanh) { TestSimpleTanh(); } -#if __ARM_NEON -TEST_F(ActivationOpTest, NEONSimpleTanh) { TestSimpleTanh(); } -#endif - TEST_F(ActivationOpTest, OPENCLSimpleTanh) { TestSimpleTanh(); } @@ -387,12 +357,6 @@ TEST_F(ActivationOpTest, CPUSimpleSigmoid) { TestSimpleSigmoid(); } -#if __ARM_NEON -TEST_F(ActivationOpTest, NEONSimpleSigmoid) { - TestSimpleSigmoid(); -} -#endif - TEST_F(ActivationOpTest, OPENCLSimpleSigmoid) { TestSimpleSigmoid(); } diff --git a/mace/ops/addn_benchmark.cc b/mace/ops/addn_benchmark.cc index 7e9d9856be72fa95fbc968b8c056d7a4caf52d5d..a559ed07caa09c4ed7659022b0da905f14c8ece9 100644 --- a/mace/ops/addn_benchmark.cc +++ b/mace/ops/addn_benchmark.cc @@ -65,16 +65,16 @@ static void AddNBenchmark(int iters, int inputs, int n, int h, int w, int c) { } \ BENCHMARK(BM_ADDN_##INPUTS##_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE) -#define BM_ADDN(INPUTS, N, H, W, C, TYPE) \ - BM_ADDN_MACRO(INPUTS, N, H, W, C, TYPE, CPU); \ - BM_ADDN_MACRO(INPUTS, N, H, W, C, TYPE, OPENCL); +#define BM_ADDN(INPUTS, N, H, W, C) \ + BM_ADDN_MACRO(INPUTS, N, H, W, C, float, CPU); \ + BM_ADDN_MACRO(INPUTS, N, H, W, C, float, NEON); \ + BM_ADDN_MACRO(INPUTS, N, H, W, C, float, OPENCL); \ + BM_ADDN_MACRO(INPUTS, N, H, W, C, half, OPENCL); -BM_ADDN(2, 1, 256, 256, 32, float); -BM_ADDN(2, 1, 128, 128, 32, float); -// BM_ADDN(2, 1, 240, 240, 256, half); -BM_ADDN(4, 1, 128, 128, 3, float); -BM_ADDN(2, 1, 256, 256, 3, float); -BM_ADDN(2, 1, 512, 512, 3, float); -// BM_ADDN(4, 1, 240, 240, 256, half); +BM_ADDN(2, 1, 256, 256, 32); +BM_ADDN(2, 1, 128, 128, 32); +BM_ADDN(4, 1, 128, 128, 3); +BM_ADDN(2, 1, 256, 256, 3); +BM_ADDN(2, 1, 512, 512, 3); -} // namespace mace +} // namespace mace diff --git a/mace/ops/addn_test.cc b/mace/ops/addn_test.cc index 691b15712b4f72f074486c86a75fd95ee5d08d7e..cdb970be35af7b564f329d6716a5643698bc37f9 100644 --- a/mace/ops/addn_test.cc +++ b/mace/ops/addn_test.cc @@ -33,12 +33,8 @@ void SimpleAdd2() { TEST_F(AddnOpTest, CPUSimpleAdd2) { SimpleAdd2(); } -/* TEST_F(AddnOpTest, NEONSimpleAdd2) { SimpleAdd2(); } -TEST_F(AddnOpTest, OPENCLSimpleAdd2) { SimpleAdd2(); } -*/ - template void SimpleAdd3() { // Construct graph @@ -65,9 +61,7 @@ void SimpleAdd3() { TEST_F(AddnOpTest, CPUSimpleAdd3) { SimpleAdd3(); } -/* TEST_F(AddnOpTest, NEONSimpleAdd3) { SimpleAdd3(); } -*/ template void RandomTest() { diff --git a/mace/ops/batch_norm_benchmark.cc b/mace/ops/batch_norm_benchmark.cc index abfe85a6cb8cd65a553a1610164b61aae570e2a1..900ce27372b37569324eecd45705d91c96c7369e 100644 --- a/mace/ops/batch_norm_benchmark.cc +++ b/mace/ops/batch_norm_benchmark.cc @@ -82,21 +82,24 @@ static void BatchNorm( } \ BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) -#define BM_BATCH_NORM(N, C, H, W, TYPE) \ - BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU); \ - BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, OPENCL); +#define BM_BATCH_NORM(N, C, H, W) \ + BM_BATCH_NORM_MACRO(N, C, H, W, float, CPU); \ + BM_BATCH_NORM_MACRO(N, C, H, W, float, NEON); \ + BM_BATCH_NORM_MACRO(N, C, H, W, float, OPENCL); \ + BM_BATCH_NORM_MACRO(N, C, H, W, half, OPENCL); -BM_BATCH_NORM(1, 1, 512, 512, float); -BM_BATCH_NORM(1, 3, 128, 128, float); -BM_BATCH_NORM(1, 3, 512, 512, float); -BM_BATCH_NORM(1, 32, 112, 112, float); -BM_BATCH_NORM(1, 64, 256, 256, float); -BM_BATCH_NORM(1, 64, 512, 512, float); -BM_BATCH_NORM(1, 128, 56, 56, float); -BM_BATCH_NORM(1, 128, 256, 256, float); -BM_BATCH_NORM(1, 256, 14, 14, float); -BM_BATCH_NORM(1, 512, 14, 14, float); -BM_BATCH_NORM(1, 1024, 7, 7, float); -BM_BATCH_NORM(32, 1, 256, 256, float); -BM_BATCH_NORM(32, 3, 256, 256, float); -} // namespace mace +BM_BATCH_NORM(1, 1, 512, 512); +BM_BATCH_NORM(1, 3, 128, 128); +BM_BATCH_NORM(1, 3, 512, 512); +BM_BATCH_NORM(1, 32, 112, 112); +BM_BATCH_NORM(1, 64, 256, 256); +BM_BATCH_NORM(1, 64, 512, 512); +BM_BATCH_NORM(1, 128, 56, 56); +BM_BATCH_NORM(1, 128, 256, 256); +BM_BATCH_NORM(1, 256, 14, 14); +BM_BATCH_NORM(1, 512, 14, 14); +BM_BATCH_NORM(1, 1024, 7, 7); +BM_BATCH_NORM(32, 1, 256, 256); +BM_BATCH_NORM(32, 3, 256, 256); + +} // namespace mace diff --git a/mace/ops/batch_norm_test.cc b/mace/ops/batch_norm_test.cc index a312df78a10feed11087bb00e2ac3e67e9ee564c..db88f130ed4ae1bc267651d5c152a55d3d63fc47 100644 --- a/mace/ops/batch_norm_test.cc +++ b/mace/ops/batch_norm_test.cc @@ -72,23 +72,18 @@ void Simple() { TEST_F(BatchNormOpTest, SimpleCPU) { Simple(); } -/* -TEST_F(BatchNormOpTest, SimpleNEON) { - Simple(); -} -*/ +TEST_F(BatchNormOpTest, SimpleNEON) { Simple(); } TEST_F(BatchNormOpTest, SimpleOPENCL) { Simple(); } -/* TEST_F(BatchNormOpTest, SimpleRandomNeon) { srand(time(NULL)); // generate random input index_t batch = 1 + rand() % 10; - index_t channels = 3 + rand() % 50; index_t height = 64; index_t width = 64; + index_t channels = 3 + rand() % 50; // Construct graph OpsTestNet net; OpDefBuilder("BatchNorm", "BatchNormTest") @@ -97,18 +92,17 @@ TEST_F(BatchNormOpTest, SimpleRandomNeon) { .Input("Offset") .Input("Mean") .Input("Var") - .Input("Epsilon") + .AddFloatArg("epsilon", 1e-3) .Output("Output") .Finalize(net.NewOperatorDef()); // Add input data - net.AddRandomInput("Input", {batch, channels, height, -width}); + net.AddRandomInput("Input", + {batch, height, width, channels}); net.AddRandomInput("Scale", {channels}); net.AddRandomInput("Offset", {channels}); net.AddRandomInput("Mean", {channels}); net.AddRandomInput("Var", {channels}, true); - net.AddInputFromArray("Epsilon", {}, {1e-3}); // run cpu net.RunOp(); @@ -139,18 +133,17 @@ TEST_F(BatchNormOpTest, ComplexRandomNeon) { .Input("Offset") .Input("Mean") .Input("Var") - .Input("Epsilon") + .AddFloatArg("epsilon", 1e-3) .Output("Output") .Finalize(net.NewOperatorDef()); // Add input data - net.AddRandomInput("Input", {batch, channels, height, -width}); + net.AddRandomInput("Input", + {batch, height, width, channels}); net.AddRandomInput("Scale", {channels}); net.AddRandomInput("Offset", {channels}); net.AddRandomInput("Mean", {channels}); net.AddRandomInput("Var", {channels}, true); - net.AddInputFromArray("Epsilon", {}, {1e-3}); // run cpu net.RunOp(); @@ -164,7 +157,6 @@ width}); ExpectTensorNear(expected, *net.GetOutput("Output"), 1e-2); } -*/ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { srand(time(NULL)); diff --git a/mace/ops/batch_to_space_benchmark.cc b/mace/ops/batch_to_space_benchmark.cc index 02da45ca126f2cb02962a67af7c887d5642e6036..46363f867c01b19659aff97bfb86d4a4bbb31a8a 100644 --- a/mace/ops/batch_to_space_benchmark.cc +++ b/mace/ops/batch_to_space_benchmark.cc @@ -47,10 +47,10 @@ static void BMBatchToSpace( } \ BENCHMARK(BM_BATCH_TO_SPACE_##N##_##H##_##W##_##C##_##ARG##_##TYPE##_##DEVICE) -#define BM_BATCH_TO_SPACE(N, H, W, C, ARG, TYPE) \ - BM_BATCH_TO_SPACE_MACRO(N, H, W, C, ARG, TYPE, OPENCL); +#define BM_BATCH_TO_SPACE(N, H, W, C, ARG) \ + BM_BATCH_TO_SPACE_MACRO(N, H, W, C, ARG, float, OPENCL); -BM_BATCH_TO_SPACE(128, 8, 8, 128, 2, float); -BM_BATCH_TO_SPACE(4, 128, 128, 32, 2, float); -BM_BATCH_TO_SPACE(16, 64, 64, 32, 4, float); -} // namespace mace \ No newline at end of file +BM_BATCH_TO_SPACE(128, 8, 8, 128, 2); +BM_BATCH_TO_SPACE(4, 128, 128, 32, 2); +BM_BATCH_TO_SPACE(16, 64, 64, 32, 4); +} // namespace mace diff --git a/mace/ops/bias_add.cc b/mace/ops/bias_add.cc index 01a7582dccba285633784dbaa3b2ae43b7ed366b..63f8661df565fb1a92480680e873d022a0affbbe 100644 --- a/mace/ops/bias_add.cc +++ b/mace/ops/bias_add.cc @@ -13,16 +13,6 @@ void Register_BiasAdd(OperatorRegistry *op_registry) { .Build(), BiasAddOp); - /* - #if __ARM_NEON - REGISTER_OPERATOR(op_registry,OpKeyBuilder("BiasAdd") - .Device(DeviceType::NEON) - .TypeConstraint("T") - .Build(), - BiasAddOp); - #endif // __ARM_NEON - */ - REGISTER_OPERATOR(op_registry, OpKeyBuilder("BiasAdd") .Device(DeviceType::OPENCL) .TypeConstraint("T") diff --git a/mace/ops/bias_add_benchmark.cc b/mace/ops/bias_add_benchmark.cc index 09f96267940903cd6438a82882247d3e32e92961..7d091fd90cf210136d504c490a79e68a1d1bfde5 100644 --- a/mace/ops/bias_add_benchmark.cc +++ b/mace/ops/bias_add_benchmark.cc @@ -59,21 +59,22 @@ static void BiasAdd(int iters, int batch, int channels, int height, int width) { } \ BENCHMARK(BM_BIAS_ADD_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) -#define BM_BIAS_ADD(N, C, H, W, TYPE) \ - BM_BIAS_ADD_MACRO(N, C, H, W, TYPE, CPU); \ - BM_BIAS_ADD_MACRO(N, C, H, W, TYPE, OPENCL); +#define BM_BIAS_ADD(N, C, H, W) \ + BM_BIAS_ADD_MACRO(N, C, H, W, float, CPU); \ + BM_BIAS_ADD_MACRO(N, C, H, W, float, OPENCL); \ + BM_BIAS_ADD_MACRO(N, C, H, W, half, OPENCL); -BM_BIAS_ADD(1, 1, 512, 512, float); -BM_BIAS_ADD(1, 3, 128, 128, float); -BM_BIAS_ADD(1, 3, 512, 512, float); -BM_BIAS_ADD(1, 32, 112, 112, float); -BM_BIAS_ADD(1, 64, 256, 256, float); -BM_BIAS_ADD(1, 64, 512, 512, float); -BM_BIAS_ADD(1, 128, 56, 56, float); -BM_BIAS_ADD(1, 128, 256, 256, float); -BM_BIAS_ADD(1, 256, 14, 14, float); -BM_BIAS_ADD(1, 512, 14, 14, float); -BM_BIAS_ADD(1, 1024, 7, 7, float); -BM_BIAS_ADD(32, 1, 256, 256, float); -BM_BIAS_ADD(32, 3, 256, 256, float); -} // namespace mace +BM_BIAS_ADD(1, 1, 512, 512); +BM_BIAS_ADD(1, 3, 128, 128); +BM_BIAS_ADD(1, 3, 512, 512); +BM_BIAS_ADD(1, 32, 112, 112); +BM_BIAS_ADD(1, 64, 256, 256); +BM_BIAS_ADD(1, 64, 512, 512); +BM_BIAS_ADD(1, 128, 56, 56); +BM_BIAS_ADD(1, 128, 256, 256); +BM_BIAS_ADD(1, 256, 14, 14); +BM_BIAS_ADD(1, 512, 14, 14); +BM_BIAS_ADD(1, 1024, 7, 7); +BM_BIAS_ADD(32, 1, 256, 256); +BM_BIAS_ADD(32, 3, 256, 256); +} // namespace mace diff --git a/mace/ops/buffer_to_image.cc b/mace/ops/buffer_to_image.cc index c9118a19392b9f90fe3eb80bba2a1b9b8a17f4b3..718374de349ef20d476faa063ee63fb2557bb3b7 100644 --- a/mace/ops/buffer_to_image.cc +++ b/mace/ops/buffer_to_image.cc @@ -20,4 +20,4 @@ void Register_BufferToImage(OperatorRegistry *op_registry) { BufferToImageOp); } -} // namespace mace +} // namespace mace diff --git a/mace/ops/buffer_to_image.h b/mace/ops/buffer_to_image.h index 4b3cfb6e8bb1f4dcb4a365cafb4e4334c5363e4b..723063387a0f010b38975dddac2b0f8688ca3edc 100644 --- a/mace/ops/buffer_to_image.h +++ b/mace/ops/buffer_to_image.h @@ -35,5 +35,5 @@ class BufferToImageOp: public Operator { OP_OUTPUT_TAGS(OUTPUT); }; -} // namespace mace -#endif // MACE_OPS_BUFFER_TO_IMAGE_H_ +} // namespace mace +#endif // MACE_OPS_BUFFER_TO_IMAGE_H_ diff --git a/mace/ops/conv_2d.cc b/mace/ops/conv_2d.cc index f7191c370b4bd713f94d825775c745560e234422..0185c1d1e62cdb0bf836b162e29101b12aa6f348 100644 --- a/mace/ops/conv_2d.cc +++ b/mace/ops/conv_2d.cc @@ -13,20 +13,6 @@ void Register_Conv2D(OperatorRegistry *op_registry) { .Build(), Conv2dOp); - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D") - .Device(DeviceType::CPU) - .TypeConstraint("T") - .Build(), - Conv2dOp); - -#if MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D") - .Device(DeviceType::NEON) - .TypeConstraint("T") - .Build(), - Conv2dOp); -#endif // MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D") .Device(DeviceType::OPENCL) .TypeConstraint("T") diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index b02eb17e63d5e1e3cf126b6be0284bdb9f73d954..63c9df80c9d40bc17b6ec491755ffbb80e5d7986 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -29,7 +29,7 @@ static void Conv2d(int iters, // Add input data net.AddRandomInput("Input", {batch, height, width, channels}); net.AddRandomInput("Filter", - {kernel_h, kernel_w, channels, output_channels}); + {kernel_h, kernel_w, output_channels, channels}); net.AddRandomInput("Bias", {output_channels}); if (D == DeviceType::OPENCL) { @@ -92,50 +92,46 @@ static void Conv2d(int iters, BENCHMARK( \ BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE) -#define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \ - BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, OPENCL); - -// ICNet -BM_CONV_2D(1, 512, 15, 15, 1, 1, 1, VALID, 1024, half); -//// SNPE GPU ExecutionDuration = 448us, % ALU Utilization = 105 -BM_CONV_2D(1, 64, 60, 60, 1, 1, 1, VALID, 128, half); -//// SNPE GPU ExecutionDuration = 258us, % ALU Utilization = 108 -BM_CONV_2D(1, 32, 60, 60, 1, 1, 1, VALID, 128, half); - -BM_CONV_2D(1, 128, 60, 60, 3, 3, 1, VALID, 128, half); -//// SNPE GPU ExecutionDuration = 506us, % ALU Utilization = 106.8 -BM_CONV_2D(1, 32, 60, 60, 3, 3, 1, SAME, 32, half); -BM_CONV_2D(1, 3, 512, 512, 7, 7, 2, SAME, 64, half); -BM_CONV_2D(1, 512, 64, 64, 1, 1, 1, SAME, 256, half); - -BM_CONV_2D(1, 128, 16, 16, 3, 3, 1, VALID, 32, half); -BM_CONV_2D(1, 128, 64, 64, 3, 3, 1, VALID, 32, half); -BM_CONV_2D(1, 128, 128, 128, 3, 3, 1, VALID, 32, half); - -// Test RGB <-> YUV -// BM_CONV_2D(1, 3, 2160, 1080, 1, 1, 1, VALID, 3, float); -// BM_CONV_2D(1, 3, 480, 480, 1, 1, 1, VALID, 3, float); -// -// BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float); -// BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128, float); // Test bad -// alignments -// BM_CONV_2D(1, 3, 512, 512, 1, 1, 1, VALID, 3, float); -// BM_CONV_2D(1, 32, 112, 112, 1, 1, 1, VALID, 64, float); -// BM_CONV_2D(1, 64, 56, 56, 1, 1, 1, VALID, 128, float); -// BM_CONV_2D(1, 256, 28, 28, 1, 1, 1, VALID, 256, float); -// BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, VALID, 1024, float); -// BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float); -// BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float); -// BM_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 3, float); -// BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float); -// BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float); -// BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 128, float); -// BM_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 3, float); -// BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 128, float); -// BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 128, float); -// BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 128, float); -// BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, VALID, 128, float); -// BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128, float); -// BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, SAME, 128, float); -// BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, SAME, 128, float); -} // namespace mace +#define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC) \ + BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, float, CPU); \ + BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, float, OPENCL); \ + BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, half, OPENCL); + +BM_CONV_2D(1, 512, 15, 15, 1, 1, 1, VALID, 1024); +BM_CONV_2D(1, 64, 60, 60, 1, 1, 1, VALID, 128); +BM_CONV_2D(1, 32, 60, 60, 1, 1, 1, VALID, 128); + +BM_CONV_2D(1, 128, 60, 60, 3, 3, 1, VALID, 128); +BM_CONV_2D(1, 32, 60, 60, 3, 3, 1, SAME, 32); +BM_CONV_2D(1, 3, 512, 512, 7, 7, 2, SAME, 64); +BM_CONV_2D(1, 512, 64, 64, 1, 1, 1, SAME, 256); + +BM_CONV_2D(1, 128, 16, 16, 3, 3, 1, VALID, 32); +BM_CONV_2D(1, 128, 64, 64, 3, 3, 1, VALID, 32); +BM_CONV_2D(1, 128, 128, 128, 3, 3, 1, VALID, 32); + +BM_CONV_2D(1, 3, 480, 480, 1, 1, 1, VALID, 3); + +BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128); +BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128); // Test bad alignments +BM_CONV_2D(1, 3, 512, 512, 1, 1, 1, VALID, 3); +BM_CONV_2D(1, 32, 112, 112, 1, 1, 1, VALID, 64); +BM_CONV_2D(1, 64, 56, 56, 1, 1, 1, VALID, 128); +BM_CONV_2D(1, 256, 28, 28, 1, 1, 1, VALID, 256); +BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, VALID, 1024); +BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128); +BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128); +BM_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 3); +BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128); +BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128); +BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 128); +BM_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 3); +BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 128); +BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 128); +BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 128); +BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, VALID, 128); +BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128); +BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, SAME, 128); +BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, SAME, 128); + +} // namespace mace diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 877da76da3ae1d45c9fcca0abdc6e54426091401..fb93504e5d946d832edc9bbed37a829a10985d8f 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -10,81 +10,6 @@ using namespace mace; class Conv2dOpTest : public OpsTestBase {}; -template -void TestSimple3x3VALID() { - OpsTestNet net; - OpDefBuilder("Conv2D", "Conv2dTest") - .Input("Input") - .Input("Filter") - .Input("Bias") - .Output("Output") - .AddIntsArg("strides", {1, 1}) - .AddIntArg("padding", Padding::VALID) - .AddIntsArg("dilations", {1, 1}) - .Finalize(net.NewOperatorDef()); - - // Add args - - // Add input data - net.AddInputFromArray( - "Input", {1, 2, 3, 3}, - {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - net.AddInputFromArray( - "Filter", {1, 2, 3, 3}, - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, - 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); - net.AddInputFromArray("Bias", {1}, {0.1f}); - - // Run - net.RunOp(D); - - // Check - auto expected = CreateTensor({1, 1, 1, 1}, {18.1f}); - - ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); -} - -template -void TestSimple3x3SAME() { - OpsTestNet net; - OpDefBuilder("Conv2D", "Conv2dTest") - .Input("Input") - .Input("Filter") - .Input("Bias") - .Output("Output") - .AddIntsArg("strides", {1, 1}) - .AddIntArg("padding", Padding::SAME) - .AddIntsArg("dilations", {1, 1}) - .Finalize(net.NewOperatorDef()); - - // Add input data - net.AddInputFromArray( - "Input", {1, 2, 3, 3}, - {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - net.AddInputFromArray( - "Filter", {1, 2, 3, 3}, - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, - 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); - net.AddInputFromArray("Bias", {1}, {0.1f}); - - // Run - net.RunOp(D); - - // Check - auto expected = CreateTensor( - {1, 1, 3, 3}, - {8.1f, 12.1f, 8.1f, 12.1f, 18.1f, 12.1f, 8.1f, 12.1f, 8.1f}); - - ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); -} - -#if __ARM_NEON -TEST_F(Conv2dOpTest, NEONSimple) { - TestSimple3x3VALID(); - TestSimple3x3SAME(); -} -#endif - template void TestNHWCSimple3x3VALID() { OpsTestNet net; @@ -93,7 +18,7 @@ void TestNHWCSimple3x3VALID() { "Input", {1, 3, 3, 2}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); net.AddInputFromArray( - "Filter", {3, 3, 2, 1}, + "Filter", {3, 3, 1, 2}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); net.AddInputFromArray("Bias", {1}, {0.1f}); @@ -150,7 +75,7 @@ void TestNHWCSimple3x3SAME() { "Input", {1, 3, 3, 2}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); net.AddInputFromArray( - "Filter", {3, 3, 2, 1}, + "Filter", {3, 3, 1, 2}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); net.AddInputFromArray("Bias", {1}, {0.1f}); @@ -211,42 +136,6 @@ TEST_F(Conv2dOpTest, OPENCLSimple) { TestNHWCSimple3x3SAME(); } -template -void TestSimple3x3WithoutBias() { - OpsTestNet net; - OpDefBuilder("Conv2D", "Conv2dTest") - .Input("Input") - .Input("Filter") - .Output("Output") - .AddIntsArg("strides", {1, 1}) - .AddIntArg("padding", Padding::VALID) - .AddIntsArg("dilations", {1, 1}) - .Finalize(net.NewOperatorDef()); - - // Add input data - net.AddInputFromArray( - "Input", {1, 2, 3, 3}, - {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - net.AddInputFromArray( - "Filter", {1, 2, 3, 3}, - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, - 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); - - // Run - net.RunOp(D); - - // Check - auto expected = CreateTensor({1, 1, 1, 1}, {18.0f}); - - ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); -} - -#ifdef __ARM_NEON -TEST_F(Conv2dOpTest, NEONWithouBias) { - TestSimple3x3WithoutBias(); -} -#endif - template void TestNHWCSimple3x3WithoutBias() { OpsTestNet net; @@ -256,7 +145,7 @@ void TestNHWCSimple3x3WithoutBias() { "Input", {1, 3, 3, 2}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); net.AddInputFromArray( - "Filter", {3, 3, 2, 1}, + "Filter", {3, 3, 1, 2}, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); @@ -309,47 +198,6 @@ TEST_F(Conv2dOpTest, OPENCLWithoutBias) { TestNHWCSimple3x3WithoutBias(); } -template -static void TestCombined3x3() { - // Construct graph - OpsTestNet net; - OpDefBuilder("Conv2D", "Conv2DTest") - .Input("Input") - .Input("Filter") - .Input("Bias") - .Output("Output") - .AddIntsArg("strides", {2, 2}) - .AddIntArg("padding", Padding::SAME) - .AddIntsArg("dilations", {1, 1}) - .Finalize(net.NewOperatorDef()); - - // Add input data - net.AddInputFromArray( - "Input", {1, 2, 5, 5}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - net.AddInputFromArray( - "Filter", {2, 2, 3, 3}, - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, - 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, - 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); - net.AddInputFromArray("Bias", {2}, {0.1f, 0.2f}); - - // Run - net.RunOp(D); - - // Check - auto expected = CreateTensor( - {1, 2, 3, 3}, {8.1f, 12.1f, 8.1f, 12.1f, 18.1f, 12.1f, 8.1f, 12.1f, 8.1f, - 4.2f, 6.2f, 4.2f, 6.2f, 9.2f, 6.2f, 4.2f, 6.2f, 4.2f}); - - ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); -} - -#ifdef __ARM_NEON -TEST_F(Conv2dOpTest, NEONCombined) { TestCombined3x3(); } -#endif - template static void TestNHWCCombined3x3() { // Construct graph @@ -362,9 +210,9 @@ static void TestNHWCCombined3x3() { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); net.AddInputFromArray( "Filter", {3, 3, 2, 2}, - {1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, - 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, - 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f}); + {1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, + 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, + 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f}); net.AddInputFromArray("Bias", {2}, {0.1f, 0.2f}); if (D == DeviceType::OPENCL) { @@ -436,8 +284,8 @@ void TestConv1x1() { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); net.AddInputFromArray( - "Filter", {1, 1, 5, 2}, - {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}); + "Filter", {1, 1, 2, 5}, + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f}); net.AddInputFromArray("Bias", {2}, {0.1f, 0.2f}); if (D == DeviceType::OPENCL) { @@ -522,7 +370,7 @@ static void TestComplexConvNxNS12(const std::vector &shape) { // Add input data net.AddRandomInput("Input", {batch, height, width, input_channels}); net.AddRandomInput( - "Filter", {kernel_h, kernel_w, input_channels, output_channels}); + "Filter", {kernel_h, kernel_w, output_channels, input_channels}); net.AddRandomInput("Bias", {output_channels}); // run on cpu @@ -606,7 +454,7 @@ static void TestHalfComplexConvNxNS12(const std::vector &input_shape, float_input_data); std::vector float_filter_data; GenerateRandomRealTypeData( - {kernel_h, kernel_w, input_channels, output_channels}, + {kernel_h, kernel_w, output_channels, input_channels}, float_filter_data); std::vector float_bias_data; GenerateRandomRealTypeData({output_channels}, float_bias_data); @@ -614,7 +462,7 @@ static void TestHalfComplexConvNxNS12(const std::vector &input_shape, net.AddInputFromArray( "Input", {batch, height, width, input_channels}, float_input_data); net.AddInputFromArray( - "Filter", {kernel_h, kernel_w, input_channels, output_channels}, + "Filter", {kernel_h, kernel_w, output_channels, input_channels}, float_filter_data); net.AddInputFromArray("Bias", {output_channels}, float_bias_data); @@ -748,7 +596,7 @@ static void TestDilationConvNxN(const std::vector &shape, const int dil // Add input data net.AddRandomInput("Input", {batch, height, width, input_channels}); net.AddRandomInput( - "Filter", {kernel_h, kernel_w, input_channels, output_channels}); + "Filter", {kernel_h, kernel_w, output_channels, input_channels}); net.AddRandomInput("Bias", {output_channels}); // run on cpu diff --git a/mace/ops/depthwise_conv2d.cc b/mace/ops/depthwise_conv2d.cc index da8a51d564c14e7f1eb511f0c373fee3e466ce6b..a6fb9eb4d86d60f2762ea0052e820f6f2e2b79af 100644 --- a/mace/ops/depthwise_conv2d.cc +++ b/mace/ops/depthwise_conv2d.cc @@ -13,14 +13,6 @@ void Register_DepthwiseConv2d(OperatorRegistry *op_registry) { .Build(), DepthwiseConv2dOp); -#if MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthwiseConv2d") - .Device(DeviceType::NEON) - .TypeConstraint("T") - .Build(), - DepthwiseConv2dOp); -#endif // MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthwiseConv2d") .Device(DeviceType::OPENCL) .TypeConstraint("T") diff --git a/mace/ops/depthwise_conv2d_test.cc b/mace/ops/depthwise_conv2d_test.cc index f2cabdee9254773f3791a81b212610a7f8a878c4..1df30b01a3cb6ec7da8e2977e3036ea6a9c5366a 100644 --- a/mace/ops/depthwise_conv2d_test.cc +++ b/mace/ops/depthwise_conv2d_test.cc @@ -288,12 +288,6 @@ void TestNxNS12(const index_t height, const index_t width) { } } -#if __ARM_NEON -TEST_F(DepthwiseConv2dOpTest, NeonSimpleNxNS12) { - TestNxNS12(4, 4); -} -#endif - TEST_F(DepthwiseConv2dOpTest, OpenCLSimpleNxNS12) { TestNxNS12(4, 4); } @@ -302,13 +296,6 @@ TEST_F(DepthwiseConv2dOpTest, OpenCLSimpleNxNS12Half) { TestNxNS12(4, 4); } -#if __ARM_NEON -TEST_F(DepthwiseConv2dOpTest, NeonAlignedNxNS12) { - TestNxNS12(64, 64); - TestNxNS12(128, 128); -} -#endif - TEST_F(DepthwiseConv2dOpTest, OpenCLAlignedNxNS12) { TestNxNS12(64, 64); TestNxNS12(128, 128); @@ -319,12 +306,6 @@ TEST_F(DepthwiseConv2dOpTest, OpenCLAlignedNxNS12Half) { TestNxNS12(128, 128); } -#if __ARM_NEON -TEST_F(DepthwiseConv2dOpTest, NeonUnalignedNxNS12) { - TestNxNS12(107, 113); -} -#endif - TEST_F(DepthwiseConv2dOpTest, OpenCLUnalignedNxNS12) { TestNxNS12(107, 113); } diff --git a/mace/ops/depthwise_conv_2d_benchmark.cc b/mace/ops/depthwise_conv_2d_benchmark.cc index 1f7dfa3ca6bb3ecbf2e1a2528af165bd2ab7c8b6..561c5af030697b8f4641bfb71fa0f8f4753613e2 100644 --- a/mace/ops/depthwise_conv_2d_benchmark.cc +++ b/mace/ops/depthwise_conv_2d_benchmark.cc @@ -89,21 +89,22 @@ static void DepthwiseConv2d(int iters, BENCHMARK( \ BM_DEPTHWISE_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE) -#define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \ - BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \ - BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, OPENCL); +#define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, OC) \ + BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, float, CPU); \ + BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, float, OPENCL); \ + BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, half, OPENCL); -BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 1, float); -//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 1, float); -BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 1, float); -//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 1, float); -//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 1, float); -//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, SAME, 1, float); -//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 1, float); -//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 1, float); -//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 1, float); -//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 1, float); -//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 1, float); -//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 1, float); +BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 1); +//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 1); +BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 1); +//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 1); +//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 1); +//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, SAME, 1); +//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 1); +//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 1); +//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 1); +//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 1); +//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 1); +//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 1); } // namespace mace diff --git a/mace/ops/folded_batch_norm.cc b/mace/ops/folded_batch_norm.cc index 9915bee4128f1e3766a91070d1cae48e044f459f..b8fadbb72419800f61c73aed5cb0112447d41aff 100644 --- a/mace/ops/folded_batch_norm.cc +++ b/mace/ops/folded_batch_norm.cc @@ -14,14 +14,6 @@ void Register_FoldedBatchNorm(OperatorRegistry *op_registry) { .Build(), FoldedBatchNormOp); -#if MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm") - .Device(DeviceType::NEON) - .TypeConstraint("T") - .Build(), - FoldedBatchNormOp); -#endif // MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm") .Device(DeviceType::OPENCL) diff --git a/mace/ops/fused_conv_2d.cc b/mace/ops/fused_conv_2d.cc index fd17a12a1a3e72df988fe68b967d71b4a8c50640..4a0245f5e03e252775b1f17cea5a6823ecdaa56e 100644 --- a/mace/ops/fused_conv_2d.cc +++ b/mace/ops/fused_conv_2d.cc @@ -13,12 +13,6 @@ void Register_FusedConv2D(OperatorRegistry *op_registry) { .Build(), FusedConv2dOp); - REGISTER_OPERATOR(op_registry, OpKeyBuilder("FusedConv2D") - .Device(DeviceType::CPU) - .TypeConstraint("T") - .Build(), - FusedConv2dOp); - REGISTER_OPERATOR(op_registry, OpKeyBuilder("FusedConv2D") .Device(DeviceType::OPENCL) .TypeConstraint("T") diff --git a/mace/ops/fused_conv_2d_test.cc b/mace/ops/fused_conv_2d_test.cc index 87d99b9e1dc401b5948a6734fb4925aaa5870d9e..37a056f18a916fc228aaebc35eba905418856c2e 100644 --- a/mace/ops/fused_conv_2d_test.cc +++ b/mace/ops/fused_conv_2d_test.cc @@ -298,7 +298,7 @@ static void TestComplexConvNxNS12(const std::vector &shape) { // Add input data net.AddRandomInput("Input", {batch, height, width, input_channels}); net.AddRandomInput( - "Filter", {kernel_h, kernel_w, input_channels, output_channels}); + "Filter", {kernel_h, kernel_w, output_channels, input_channels}); net.AddRandomInput("Bias", {output_channels}); // run on cpu @@ -375,7 +375,7 @@ static void TestHalfComplexConvNxNS12(const std::vector &shape) { float_input_data); std::vector float_filter_data; GenerateRandomRealTypeData( - {kernel_h, kernel_w, input_channels, output_channels}, + {kernel_h, kernel_w, output_channels, input_channels}, float_filter_data); std::vector float_bias_data; GenerateRandomRealTypeData({output_channels}, float_bias_data); @@ -383,7 +383,7 @@ static void TestHalfComplexConvNxNS12(const std::vector &shape) { net.AddInputFromArray( "Input", {batch, height, width, input_channels}, float_input_data); net.AddInputFromArray( - "Filter", {kernel_h, kernel_w, input_channels, output_channels}, + "Filter", {kernel_h, kernel_w, output_channels, input_channels}, float_filter_data); net.AddInputFromArray("Bias", {output_channels}, float_bias_data); @@ -462,7 +462,7 @@ static void TestGeneralConvNxNS12(const std::vector &image_shape, // Add input data net.AddRandomInput("Input", {batch, height, width, input_channels}); net.AddRandomInput( - "Filter", {kernel_h, kernel_w, input_channels, output_channels}); + "Filter", {kernel_h, kernel_w, output_channels, input_channels}); net.AddRandomInput("Bias", {output_channels}); // run on cpu @@ -540,7 +540,7 @@ static void TestAtrousConvNxN(const std::vector &shape, const int dilat // Add input data net.AddRandomInput("Input", {batch, height, width, input_channels}); net.AddRandomInput( - "Filter", {kernel_h, kernel_w, input_channels, output_channels}); + "Filter", {kernel_h, kernel_w, output_channels, input_channels}); net.AddRandomInput("Bias", {output_channels}); // run on cpu @@ -622,7 +622,7 @@ static void TestGeneralHalfAtrousConv(const std::vector &image_shape, // Add input data net.AddRandomInput("Input", {batch, height, width, input_channels}); net.AddRandomInput( - "Filter", {kernel_h, kernel_w, input_channels, output_channels}); + "Filter", {kernel_h, kernel_w, output_channels, input_channels}); net.AddRandomInput("Bias", {output_channels}); // run on cpu diff --git a/mace/ops/global_avg_pooling.cc b/mace/ops/global_avg_pooling.cc index 65fd7f43b8051971fac29818aeace86db6a3a98f..d2ab13cfd8b8b7ca867af256a8d9e9a43e6c3ca3 100644 --- a/mace/ops/global_avg_pooling.cc +++ b/mace/ops/global_avg_pooling.cc @@ -12,14 +12,6 @@ void Register_GlobalAvgPooling(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), GlobalAvgPoolingOp); - -#if MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("GlobalAvgPooling") - .Device(DeviceType::NEON) - .TypeConstraint("T") - .Build(), - GlobalAvgPoolingOp); -#endif // MACE_ENABLE_NEON } } // namespace mace diff --git a/mace/ops/global_avg_pooling_test.cc b/mace/ops/global_avg_pooling_test.cc index cb12c0d489d50ee7315a3aa7e6411b6a8c792aa0..a00ffc36d542562bf1bf4365f67977fda4fdbc38 100644 --- a/mace/ops/global_avg_pooling_test.cc +++ b/mace/ops/global_avg_pooling_test.cc @@ -31,29 +31,3 @@ TEST_F(GlobalAvgPoolingOpTest, 3x7x7_CPU) { ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } - -#if __ARM_NEON -TEST_F(GlobalAvgPoolingOpTest, 3x7x7_NEON) { - // Construct graph - OpsTestNet net; - OpDefBuilder("GlobalAvgPooling", "GlobalAvgPoolingTest") - .Input("Input") - .Output("Output") - .Finalize(net.NewOperatorDef()); - - // Add input data - std::vector input(147); - for (int i = 0; i < 147; ++i) { - input[i] = i / 49 + 1; - } - net.AddInputFromArray("Input", {1, 3, 7, 7}, input); - - // Run - net.RunOp(DeviceType::NEON); - - // Check - auto expected = CreateTensor({1, 3, 1, 1}, {1, 2, 3}); - - ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); -} -#endif diff --git a/mace/ops/matmul_benchmark.cc b/mace/ops/matmul_benchmark.cc index 864767999c80e8dfd1ce5d047317f9d85432c1dc..1850086dfcfd5fad716146efeb23e19c8934768c 100644 --- a/mace/ops/matmul_benchmark.cc +++ b/mace/ops/matmul_benchmark.cc @@ -61,10 +61,10 @@ static void MatMulBenchmark( } \ BENCHMARK(BM_MATMUL_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE) -#define BM_MATMUL(N, H, C, W, TYPE) \ - BM_MATMUL_MACRO(N, H, C, W, TYPE, OPENCL); +#define BM_MATMUL(N, H, C, W) \ + BM_MATMUL_MACRO(N, H, C, W, half, OPENCL); -BM_MATMUL(16, 32, 128, 49, half); -BM_MATMUL(16, 32, 128, 961, half); -BM_MATMUL(16, 32, 128, 3969, half); -} // namespace mace +BM_MATMUL(16, 32, 128, 49); +BM_MATMUL(16, 32, 128, 961); +BM_MATMUL(16, 32, 128, 3969); +} // namespace mace diff --git a/mace/ops/pooling.cc b/mace/ops/pooling.cc index d372f242e74deda8df82335667c096d73d6f7228..2ac3b9ace3d04ce1af987b916c7184eae89150b7 100644 --- a/mace/ops/pooling.cc +++ b/mace/ops/pooling.cc @@ -18,14 +18,6 @@ void Register_Pooling(OperatorRegistry *op_registry) { .Build(), PoolingOp); -#if MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling") - .Device(DeviceType::NEON) - .TypeConstraint("T") - .Build(), - PoolingOp); -#endif // MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling") .Device(DeviceType::OPENCL) .TypeConstraint("T") diff --git a/mace/ops/resize_bilinear.cc b/mace/ops/resize_bilinear.cc index b44f462b1aaae066b84280de038a0a33a95c9970..c933dc876164227c658aa72cf8fc2a2e5e174248 100644 --- a/mace/ops/resize_bilinear.cc +++ b/mace/ops/resize_bilinear.cc @@ -13,14 +13,6 @@ void Register_ResizeBilinear(OperatorRegistry *op_registry) { .Build(), ResizeBilinearOp); -#if MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBilinear") - .Device(DeviceType::NEON) - .TypeConstraint("T") - .Build(), - ResizeBilinearOp); -#endif // MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBilinear") .Device(DeviceType::OPENCL) .TypeConstraint("T") diff --git a/mace/ops/resize_bilinear_benchmark.cc b/mace/ops/resize_bilinear_benchmark.cc index 01ffda0e686527b5a6f30a24b02c853ebe56d5ce..f582ede7edef794eeeef568abeb2f9eb0d999615 100644 --- a/mace/ops/resize_bilinear_benchmark.cc +++ b/mace/ops/resize_bilinear_benchmark.cc @@ -69,18 +69,18 @@ static void ResizeBilinearBenchmark(int iters, BENCHMARK( \ BM_RESIZE_BILINEAR_##N##_##C##_##H0##_##W0##_##H1##_##W1##_##TYPE##_##DEVICE) -#define BM_RESIZE_BILINEAR(N, C, H0, W0, H1, W1, TYPE) \ - BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, CPU); \ - BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, OPENCL); +#define BM_RESIZE_BILINEAR(N, C, H0, W0, H1, W1) \ + BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, float, CPU); \ + BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, float, OPENCL); \ + BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, half, OPENCL); -// SNPE 835 GPU: 6870us -BM_RESIZE_BILINEAR(1, 128, 120, 120, 480, 480, float); +BM_RESIZE_BILINEAR(1, 128, 120, 120, 480, 480); -BM_RESIZE_BILINEAR(1, 256, 7, 7, 15, 15, float); -BM_RESIZE_BILINEAR(1, 256, 15, 15, 30, 30, float); -BM_RESIZE_BILINEAR(1, 128, 30, 30, 60, 60, float); -BM_RESIZE_BILINEAR(1, 128, 240, 240, 480, 480, float); -BM_RESIZE_BILINEAR(1, 3, 4032, 3016, 480, 480, float); -BM_RESIZE_BILINEAR(1, 3, 480, 480, 4032, 3016, float); +BM_RESIZE_BILINEAR(1, 256, 7, 7, 15, 15); +BM_RESIZE_BILINEAR(1, 256, 15, 15, 30, 30); +BM_RESIZE_BILINEAR(1, 128, 30, 30, 60, 60); +BM_RESIZE_BILINEAR(1, 128, 240, 240, 480, 480); +BM_RESIZE_BILINEAR(1, 3, 4032, 3016, 480, 480); +BM_RESIZE_BILINEAR(1, 3, 480, 480, 4032, 3016); -} // namespace mace +} // namespace mace diff --git a/mace/ops/softmax_benchmark.cc b/mace/ops/softmax_benchmark.cc index 267074a7c1130ce6063d0b2a937395005c082b04..5e8a283d889b8012853c77340832c377a5f43169 100644 --- a/mace/ops/softmax_benchmark.cc +++ b/mace/ops/softmax_benchmark.cc @@ -55,13 +55,14 @@ static void SoftmaxBenchmark( } \ BENCHMARK(BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) -#define BM_SOFTMAX(N, C, H, W, TYPE) \ - BM_SOFTMAX_MACRO(N, C, H, W, TYPE, CPU); \ - BM_SOFTMAX_MACRO(N, C, H, W, TYPE, OPENCL); +#define BM_SOFTMAX(N, C, H, W) \ + BM_SOFTMAX_MACRO(N, C, H, W, float, CPU); \ + BM_SOFTMAX_MACRO(N, C, H, W, float, OPENCL); \ + BM_SOFTMAX_MACRO(N, C, H, W, half, OPENCL); -BM_SOFTMAX(1, 1, 512, 512, float); -BM_SOFTMAX(1, 3, 128, 128, float); -BM_SOFTMAX(1, 3, 512, 512, float); -BM_SOFTMAX(1, 32, 112, 112, float); -BM_SOFTMAX(1, 64, 256, 256, float); +BM_SOFTMAX(1, 1, 512, 512); +BM_SOFTMAX(1, 3, 128, 128); +BM_SOFTMAX(1, 3, 512, 512); +BM_SOFTMAX(1, 32, 112, 112); +BM_SOFTMAX(1, 64, 256, 256); } // namespace mace diff --git a/mace/ops/space_to_batch_benchmark.cc b/mace/ops/space_to_batch_benchmark.cc index 9b3e4d1cb68178406c418c76505e1d2a90a8eb69..86ba58085c9a1e4aa103da5a30555a15bf40c838 100644 --- a/mace/ops/space_to_batch_benchmark.cc +++ b/mace/ops/space_to_batch_benchmark.cc @@ -49,10 +49,10 @@ static void BMSpaceToBatch( BENCHMARK( \ BM_SPACE_TO_BATCH_##N##_##H##_##W##_##C##_##SHAPE##_##TYPE##_##DEVICE) -#define BM_SPACE_TO_BATCH(N, H, W, C, SHAPE, TYPE) \ - BM_SPACE_TO_BATCH_MACRO(N, H, W, C, SHAPE, TYPE, OPENCL); +#define BM_SPACE_TO_BATCH(N, H, W, C, SHAPE) \ + BM_SPACE_TO_BATCH_MACRO(N, H, W, C, SHAPE, float, OPENCL); -BM_SPACE_TO_BATCH(128, 16, 16, 128, 2, float); -BM_SPACE_TO_BATCH(1, 256, 256, 32, 2, float); -BM_SPACE_TO_BATCH(1, 256, 256, 32, 4, float); -} // namespace mace \ No newline at end of file +BM_SPACE_TO_BATCH(128, 16, 16, 128, 2); +BM_SPACE_TO_BATCH(1, 256, 256, 32, 2); +BM_SPACE_TO_BATCH(1, 256, 256, 32, 4); +} // namespace mace diff --git a/mace/ops/winograd_convolution_test.cc b/mace/ops/winograd_convolution_test.cc index 364aec6b482169f5e99574160cffd2833a88bce7..c76757f948ba049120dd792ee1884ff99dc856c3 100644 --- a/mace/ops/winograd_convolution_test.cc +++ b/mace/ops/winograd_convolution_test.cc @@ -19,9 +19,9 @@ void TransposeFilter(const std::vector &input, const float *input_ptr = input.data(); for (index_t h = 0; h < input_shape[0]; ++h) { for (index_t w = 0; w < input_shape[1]; ++w) { - for (index_t ic = 0; ic < input_shape[2]; ++ic) { - for (index_t oc = 0; oc < input_shape[3]; ++oc) { - int offset = ((oc * input_shape[2] + ic) * input_shape[0] + h) * input_shape[1] + w; + for (index_t oc = 0; oc < input_shape[2]; ++oc) { + for (index_t ic = 0; ic < input_shape[3]; ++ic) { + int offset = ((oc * input_shape[3] + ic) * input_shape[0] + h) * input_shape[1] + w; output[offset] = *input_ptr; ++input_ptr; } @@ -43,7 +43,7 @@ void WinogradConvolution(const index_t batch, OpsTestNet net; // Add input data std::vector filter_data; - std::vector filter_shape = {3, 3, in_channels, out_channels}; + std::vector filter_shape = {3, 3, out_channels, in_channels}; GenerateRandomRealTypeData(filter_shape, filter_data); net.AddRandomInput("Input", {batch, height, width, in_channels}); net.AddInputFromArray("Filter", filter_shape, filter_data); diff --git a/mace/ops/winograd_transform_benchmark.cc b/mace/ops/winograd_transform_benchmark.cc index 28d73b2a5cc96f3ca41dea871251a3400b48fb5b..a7b99257e03193d52cb5c6d6e83e7106171318de 100644 --- a/mace/ops/winograd_transform_benchmark.cc +++ b/mace/ops/winograd_transform_benchmark.cc @@ -48,12 +48,12 @@ static void BMWinogradTransform( BENCHMARK( \ BM_WINOGRAD_TRANSFORM_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE) -#define BM_WINOGRAD_TRANSFORM(N, H, W, C, TYPE) \ - BM_WINOGRAD_TRANSFORM_MACRO(N, H, W, C, TYPE, OPENCL); +#define BM_WINOGRAD_TRANSFORM(N, H, W, C) \ + BM_WINOGRAD_TRANSFORM_MACRO(N, H, W, C, half, OPENCL); -BM_WINOGRAD_TRANSFORM(1, 16, 16, 128, half); -BM_WINOGRAD_TRANSFORM(1, 64, 64, 128, half); -BM_WINOGRAD_TRANSFORM(1, 128, 128, 128, half); +BM_WINOGRAD_TRANSFORM(1, 16, 16, 128); +BM_WINOGRAD_TRANSFORM(1, 64, 64, 128); +BM_WINOGRAD_TRANSFORM(1, 128, 128, 128); template static void BMWinogradInverseTransform( @@ -100,11 +100,11 @@ static void BMWinogradInverseTransform( BENCHMARK( \ BM_WINOGRAD_INVERSE_TRANSFORM_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE) -#define BM_WINOGRAD_INVERSE_TRANSFORM(N, H, W, C, TYPE) \ - BM_WINOGRAD_INVERSE_TRANSFORM_MACRO(N, H, W, C, TYPE, OPENCL); +#define BM_WINOGRAD_INVERSE_TRANSFORM(N, H, W, C) \ + BM_WINOGRAD_INVERSE_TRANSFORM_MACRO(N, H, W, C, half, OPENCL); -BM_WINOGRAD_INVERSE_TRANSFORM(1, 14, 14, 32, half); -BM_WINOGRAD_INVERSE_TRANSFORM(1, 62, 62, 32, half); -BM_WINOGRAD_INVERSE_TRANSFORM(1, 126, 126, 32, half); +BM_WINOGRAD_INVERSE_TRANSFORM(1, 14, 14, 32); +BM_WINOGRAD_INVERSE_TRANSFORM(1, 62, 62, 32); +BM_WINOGRAD_INVERSE_TRANSFORM(1, 126, 126, 32); -} // namespace mace \ No newline at end of file +} // namespace mace diff --git a/mace/utils/utils.h b/mace/utils/utils.h index ce682f06080813b52c6c883b25cc899aa40ca174..968d4f6e61c6d9557500dc27b795475972f4a93b 100644 --- a/mace/utils/utils.h +++ b/mace/utils/utils.h @@ -30,6 +30,11 @@ Integer RoundUpDiv8(Integer i) { return (i + 7) >> 3; } +template +Integer RoundUpDiv(Integer i, Integer factor) { + return (i + factor - 1) / factor; +} + template Integer CeilQuotient(Integer a, Integer b) { return (a + b - 1) / b; diff --git a/tools/bazel-adb-run.sh b/tools/bazel-adb-run.sh index 6e1e68f2754254ca0e2b1015670231c546cc1875..9b964c36380391a07d522f019460e465e0178899 100755 --- a/tools/bazel-adb-run.sh +++ b/tools/bazel-adb-run.sh @@ -18,8 +18,8 @@ BAZEL_BIN_PATH=${BAZEL_BIN_PATH#//} BAZEL_BIN_PATH=bazel-bin/$BAZEL_BIN_PATH BIN_NAME=`echo $BAZEL_TARGET | cut -d: -f2` -ANDROID_ABI=arm64-v8a ANDROID_ABI=armeabi-v7a +ANDROID_ABI=arm64-v8a STRIP="--strip always" VLOG_LEVEL=0 PROFILING="1" @@ -43,7 +43,8 @@ bazel build -c opt $STRIP --verbose_failures $BAZEL_TARGET \ --copt="-D_GLIBCXX_USE_C99_MATH_TR1" \ --copt="-DMACE_DISABLE_NO_TUNING_WARNING" \ --copt="-Werror=return-type" \ - --define neon=false \ + --copt="-O3" \ + --define neon=true \ --define openmp=true if [ $? -ne 0 ]; then