提交 4f407e70 编写于 作者: 刘琦

Merge branch 'neon' into 'master'

Improve CPU implementations

See merge request !236
......@@ -32,6 +32,4 @@ typedef int64_t index_t;
#define MACE_NOT_IMPLEMENTED MACE_CHECK(false, "not implemented")
#define kCostPerGroup 10240
#endif // MACE_CORE_COMMON_H_
......@@ -325,6 +325,12 @@ class Tensor {
}
}
}
MappingGuard(MappingGuard &&other) {
tensor_ = other.tensor_;
other.tensor_ = nullptr;
}
MappingGuard(const MappingGuard &other) = delete;
MappingGuard & operator = (const MappingGuard &other) = delete;
~MappingGuard() {
if (tensor_ != nullptr) tensor_->Unmap();
}
......
......@@ -15,7 +15,6 @@ cc_library(
"*.cc",
"opencl/*.cc",
]) + if_neon_enabled(glob([
"neon/addn_neon.cc",
"neon/batch_norm_neon.cc",
])),
hdrs = glob([
......
......@@ -54,17 +54,20 @@ void DoActivation(const T *input_ptr,
case NOOP:
break;
case RELU:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max(input_ptr[i], static_cast<T>(0));
}
break;
case RELUX:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::min(std::max(input_ptr[i], static_cast<T>(0)),
static_cast<T>(relux_max_limit));
}
break;
case PRELU:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
T in = input_ptr[i];
if (in < 0) {
......@@ -75,12 +78,14 @@ void DoActivation(const T *input_ptr,
}
break;
case TANH:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
T in_exp = std::exp(-2 * input_ptr[i]);
output_ptr[i] = (1 - in_exp) / (1 + in_exp);
}
break;
case SIGMOID:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = 1 / (1 + std::exp(-input_ptr[i]));
}
......
......@@ -8,29 +8,71 @@
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
namespace {
constexpr int kCostPerGroup = 1024;
} // namespace
template <DeviceType D, typename T>
struct AddNFunctor {
void operator()(const std::vector<const Tensor *> &input_tensors,
Tensor *output_tensor, StatsFuture *future) {
Tensor *output_tensor,
StatsFuture *future) {
output_tensor->ResizeLike(input_tensors[0]);
index_t size = output_tensor->size();
Tensor::MappingGuard output_map(output_tensor);
index_t size = input_tensors[0]->size();
T *output_ptr = output_tensor->mutable_data<T>();
memset(output_ptr, 0, size * sizeof(T));
float *output_data = output_tensor->mutable_data<float>();
memset(output_data, 0, size * sizeof(float));
int n = input_tensors.size();
for (int i = 0; i < n; ++i) {
Tensor::MappingGuard input_map(input_tensors[i]);
const T *input_ptr = input_tensors[i]->data<T>();
for (index_t j = 0; j < size; ++j) {
output_ptr[j] += input_ptr[j];
int64_t cost = size * n;
int64_t groups = 1;
if (cost > kCostPerGroup) {
groups = cost / kCostPerGroup;
}
int64_t element_per_group = size / groups;
std::vector<Tensor::MappingGuard> mappers;
for (int64_t i = 0; i < n; ++i) {
mappers.emplace_back(Tensor::MappingGuard(input_tensors[i]));
}
#pragma omp parallel for
for (int64_t i = 0; i < size; i += element_per_group) {
int64_t count = std::min(element_per_group, size - i);
int nn = count >> 2;
int remain = count - (nn << 2);
for (int64_t j = 0; j < n; ++j) {
const float *input_data = input_tensors[j]->data<float>();
const float *input_ptr = input_data + i;
float *output_ptr = output_data + i;
for (int k = 0; k < nn; ++k) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
float32x4_t in = vld1q_f32(input_ptr);
float32x4_t out = vld1q_f32(output_ptr);
out = vaddq_f32(out, in);
vst1q_f32(output_ptr, out);
#else
for (int m = 0; m < 4; ++m) {
output_ptr[m] += input_ptr[m];
}
#endif
input_ptr += 4;
output_ptr += 4;
}
for (int k = 0; k < remain; ++k) {
*output_ptr += *input_ptr;
++input_ptr;
++output_ptr;
}
}
}
}
......@@ -45,7 +87,8 @@ void AddNFunctor<DeviceType::NEON, float>::operator()(
template <typename T>
struct AddNFunctor<DeviceType::OPENCL, T> {
void operator()(const std::vector<const Tensor *> &input_tensors,
Tensor *output_tensor, StatsFuture *future);
Tensor *output_tensor,
StatsFuture *future);
cl::Kernel kernel_;
};
......
......@@ -70,8 +70,8 @@ struct BatchNormFunctor : BatchNormFunctorBase {
const T *offset_ptr = offset->data<T>();
T *output_ptr = output->mutable_data<T>();
vector<T> new_scale;
vector<T> new_offset;
std::vector<T> new_scale;
std::vector<T> new_offset;
if (!folded_constant_) {
new_scale.resize(channels);
new_offset.resize(channels);
......@@ -86,6 +86,8 @@ struct BatchNormFunctor : BatchNormFunctorBase {
}
}
const T *scale_data = folded_constant_ ? scale_ptr : new_scale.data();
const T *offset_data = folded_constant_ ? offset_ptr : new_offset.data();
#pragma omp parallel for collapse(4)
for (index_t n = 0; n < batch; ++n) {
......@@ -93,11 +95,7 @@ struct BatchNormFunctor : BatchNormFunctorBase {
for (index_t w = 0; w < width; ++w) {
for (index_t c = 0; c < channels; ++c) {
index_t pos = (((n * height) + h) * width + w) * channels + c;
if (folded_constant_) {
output_ptr[pos] = scale_ptr[c] * input_ptr[pos] + offset_ptr[c];
} else {
output_ptr[pos] = new_scale[c] * input_ptr[pos] + new_offset[c];
}
output_ptr[pos] = scale_data[c] * input_ptr[pos] + offset_data[c];
}
}
}
......
......@@ -5,15 +5,237 @@
#ifndef MACE_KERNELS_DEPTHWISE_CONV2D_H_
#define MACE_KERNELS_DEPTHWISE_CONV2D_H_
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include "mace/core/common.h"
#include "mace/core/future.h"
#include "mace/core/public/mace.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace {
namespace kernels {
namespace {
template <typename T>
void DepthwiseConv2dKernel(const T *input_ptr,
const T *filter_ptr,
const T *bias_ptr,
T *output_ptr,
int batch,
int height,
int width,
int channels,
int input_height,
int input_width,
int input_channels,
int multiplier,
int padded_h_start,
int padded_h_stop,
int padded_w_start,
int padded_w_stop,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w,
int h_start,
int h_stop,
int w_start,
int w_stop) {
#pragma omp parallel for collapse(4)
for (int n = 0; n < batch; ++n) {
for (int h = h_start; h < h_stop; ++h) {
for (int w = w_start; w < w_stop; ++w) {
for (int c = 0; c < channels; ++c) {
const index_t inc = c / multiplier;
const index_t m = c % multiplier;
T bias_channel = bias_ptr ? bias_ptr[c] : 0;
index_t offset = n * height * width * channels +
h * width * channels + w * channels + c;
output_ptr[offset] = bias_channel;
T sum = 0;
const T *filter_base = filter_ptr + inc * multiplier + m;
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
if (inh < 0 || inh >= input_height || inw < 0 ||
inw >= input_width) {
MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop &&
inw >= padded_w_start && inw < padded_w_stop,
"Out of range read from input: ", padded_h_start,
" <= ", inh, " < ", padded_h_stop, ", ",
padded_w_start, " <= ", inw, " < ", padded_w_stop);
} else {
index_t input_offset =
n * input_height * input_width * input_channels +
inh * input_width * input_channels + inw * input_channels +
inc;
sum += input_ptr[input_offset] * filter_base[0]; // HWIM
}
filter_base += input_channels * multiplier;
}
}
output_ptr[offset] += sum;
}
}
}
}
}
template <typename T>
void DepthwiseConv2dNoOOBCheckKernel(const T *input_ptr,
const T *filter_ptr,
const T *bias_ptr,
T *output_ptr,
int batch,
int height,
int width,
int channels,
int input_height,
int input_width,
int input_channels,
int multiplier,
int padded_h_start,
int padded_h_stop,
int padded_w_start,
int padded_w_stop,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w,
int h_start,
int h_stop,
int w_start,
int w_stop) {
if (multiplier == 1) {
constexpr int c_tile_size = 4;
#pragma omp parallel for collapse(3)
for (int n = 0; n < batch; ++n) {
for (int h = h_start; h < h_stop; ++h) {
for (int w = w_start; w < w_stop; ++w) {
int c;
for (c = 0; c + c_tile_size <= channels; c += c_tile_size) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
static_assert(c_tile_size == 4, "channels tile size must be 4");
float32x4_t sum = vdupq_n_f32(0);
if (bias_ptr != nullptr) {
sum = vld1q_f32(bias_ptr + c);
}
#else
T sum[c_tile_size] = {0};
if (bias_ptr != nullptr) {
for (int ci = 0; ci < c_tile_size; ++ci) {
sum[ci] = bias_ptr[c + ci];
}
}
#endif
const T *filter_base = filter_ptr + c;
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
MACE_ASSERT(inh >= 0 && inh < input_height && inw >= 0 &&
inw < input_width);
index_t input_offset =
n * input_height * input_width * input_channels +
inh * input_width * input_channels + inw * input_channels +
c;
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
float32x4_t in = vld1q_f32(input_ptr + input_offset);
float32x4_t weights = vld1q_f32(filter_base);
sum = vfmaq_f32(sum, in, weights);
#else
for (int ci = 0; ci < c_tile_size; ++ci) {
sum[ci] +=
input_ptr[input_offset + ci] * filter_base[ci]; // HWIM
}
#endif
filter_base += input_channels;
}
}
index_t offset = n * height * width * channels +
h * width * channels + w * channels + c;
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
vst1q_f32(output_ptr + offset, sum);
#else
for (int ci = 0; ci < c_tile_size; ++ci) {
output_ptr[offset + ci] = sum[ci];
}
#endif
}
for (; c < channels; ++c) {
T bias_channel = bias_ptr ? bias_ptr[c] : 0;
index_t offset = n * height * width * channels +
h * width * channels + w * channels + c;
output_ptr[offset] = bias_channel;
T sum = 0;
const T *filter_base = filter_ptr + c;
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
MACE_ASSERT(inh >= 0 && inh < input_height && inw >= 0 &&
inw < input_width);
index_t input_offset =
n * input_height * input_width * input_channels +
inh * input_width * input_channels + inw * input_channels +
c;
sum += input_ptr[input_offset] * filter_base[0]; // HWIM
filter_base += input_channels * multiplier;
}
}
output_ptr[offset] += sum;
}
}
}
}
} else {
#pragma omp parallel for collapse(4)
for (int n = 0; n < batch; ++n) {
for (int h = h_start; h < h_stop; ++h) {
for (int w = w_start; w < w_stop; ++w) {
for (int c = 0; c < channels; ++c) {
const index_t inc = c / multiplier;
const index_t m = c % multiplier;
T bias_channel = bias_ptr ? bias_ptr[c] : 0;
index_t offset = n * height * width * channels +
h * width * channels + w * channels + c;
output_ptr[offset] = bias_channel;
T sum = 0;
const T *filter_base = filter_ptr + inc * multiplier + m;
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
MACE_ASSERT(inh >= 0 && inh < input_height && inw >= 0 &&
inw < input_width);
index_t input_offset =
n * input_height * input_width * input_channels +
inh * input_width * input_channels + inw * input_channels +
inc;
sum += input_ptr[input_offset] * filter_base[0]; // HWIM
filter_base += input_channels * multiplier;
}
}
output_ptr[offset] += sum;
}
}
}
}
}
}
} // namespace
struct DepthwiseConv2dFunctorBase {
DepthwiseConv2dFunctorBase(const int *strides,
const Padding padding,
......@@ -28,7 +250,7 @@ struct DepthwiseConv2dFunctorBase {
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
const int *strides_; // [stride_h, stride_w]
const int *strides_; // [stride_h, stride_w]
const Padding padding_;
const int *dilations_; // [dilation_h, dilation_w]
const ActivationType activation_;
......@@ -88,7 +310,8 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
index_t kernel_h = filter->dim(0);
index_t kernel_w = filter->dim(1);
index_t multiplier = filter->dim(3);
MACE_CHECK(filter->dim(2) == input_channels, filter->dim(2), "!=", input_channels);
MACE_CHECK(filter->dim(2) == input_channels, filter->dim(2), "!=",
input_channels);
MACE_CHECK(channels == input_channels * multiplier);
int stride_h = strides_[0];
......@@ -100,10 +323,15 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
// The left-upper most offset of the padded input
int padded_h_start = 0 - paddings[0] / 2;
int padded_w_start = 0 - paddings[1] / 2;
index_t padded_h_stop = input_height + paddings[0] - paddings[0] / 2;
index_t padded_w_stop = input_width + paddings[1] - paddings[1] / 2;
int paddings_top = paddings[0] / 2;
int paddings_bottom = paddings[0] - paddings_top;
int paddings_left = paddings[1] / 2;
int paddings_right = paddings[1] - paddings_left;
int padded_h_start = 0 - paddings_top;
int padded_w_start = 0 - paddings_left;
index_t padded_h_stop = input_height + paddings_bottom;
index_t padded_w_stop = input_width + paddings_right;
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard filter_mapper(filter);
......@@ -114,43 +342,59 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
const T *bias_ptr = bias == nullptr ? nullptr : bias->data<T>();
T *output_ptr = output->mutable_data<T>();
#pragma omp parallel for collapse(4)
for (int n = 0; n < batch; ++n) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int c = 0; c < channels; ++c) {
const index_t inc = c / multiplier;
const index_t m = c % multiplier;
T bias_channel = bias_ptr ? bias_ptr[c] : 0;
index_t offset = n * height * width * channels +
h * width * channels + w * channels + c;
output_ptr[offset] = bias_channel;
T sum = 0;
const T *filter_base = filter_ptr + inc * multiplier + m;
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
if (inh < 0 || inh >= input_height || inw < 0 ||
inw >= input_width) {
MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop &&
inw >= padded_w_start && inw < padded_w_stop,
"Out of range read from input: ", inh, ", ", inw);
} else {
index_t input_offset =
n * input_height * input_width * input_channels +
inh * input_width * input_channels +
inw * input_channels + inc;
sum += input_ptr[input_offset] * filter_base[0]; // HWIM
}
filter_base += input_channels * multiplier;
}
}
output_ptr[offset] += sum;
}
}
}
int valid_h_start =
paddings_top == 0 ? 0 : (paddings_top - 1) / stride_h + 1;
int valid_h_stop = paddings_bottom == 0
? height
: height - ((paddings_bottom - 1) / stride_h + 1);
int valid_w_start =
paddings_left == 0 ? 0 : (paddings_left - 1) / stride_w + 1;
int valid_w_stop = paddings_right == 0
? width
: width - ((paddings_right - 1) / stride_w + 1);
// Calculate border elements with out-of-boundary checking
if (valid_h_start > 0) {
DepthwiseConv2dKernel<T>(
input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width,
channels, input_height, input_width, input_channels, multiplier,
padded_h_start, padded_h_stop, padded_w_start, padded_w_stop,
kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, 0,
valid_h_start, 0, width);
}
if (valid_h_stop < height) {
DepthwiseConv2dKernel<T>(
input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width,
channels, input_height, input_width, input_channels, multiplier,
padded_h_start, padded_h_stop, padded_w_start, padded_w_stop,
kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w,
std::max(valid_h_start, valid_h_stop), height, 0, width);
}
if (valid_w_start > 0) {
DepthwiseConv2dKernel<T>(
input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width,
channels, input_height, input_width, input_channels, multiplier,
padded_h_start, padded_h_stop, padded_w_start, padded_w_stop,
kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w,
valid_h_start, valid_h_stop, 0, valid_w_start);
}
if (valid_w_stop < width) {
DepthwiseConv2dKernel<T>(
input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width,
channels, input_height, input_width, input_channels, multiplier,
padded_h_start, padded_h_stop, padded_w_start, padded_w_stop,
kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w,
valid_h_start, valid_h_stop, std::max(valid_w_start, valid_w_stop),
width);
}
// Calculate border elements without out-of-boundary checking
DepthwiseConv2dNoOOBCheckKernel<T>(
input_ptr, filter_ptr, bias_ptr, output_ptr, batch, height, width,
channels, input_height, input_width, input_channels, multiplier,
padded_h_start, padded_h_stop, padded_w_start, padded_w_stop, kernel_h,
kernel_w, stride_h, stride_w, dilation_h, dilation_w, valid_h_start,
valid_h_stop, valid_w_start, valid_w_stop);
output_ptr = output->mutable_data<T>();
DoActivation(output_ptr, output_ptr, output->NumElements(), activation_,
......@@ -180,7 +424,7 @@ struct DepthwiseConv2dFunctor<DeviceType::OPENCL, T>
dilations,
activation,
relux_max_limit,
prelu_alpha){}
prelu_alpha) {}
void operator()(const Tensor *input,
const Tensor *filter,
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/addn.h"
#include <arm_neon.h>
namespace mace {
namespace kernels {
template <>
void AddNFunctor<DeviceType::NEON, float>::operator()(
const std::vector<const Tensor *> &input_tensors,
Tensor *output_tensor,
StatsFuture *future) {
// TODO: neon mem copy
output_tensor->ResizeLike(input_tensors[0]);
index_t size = output_tensor->size();
float *output_ptr = output_tensor->mutable_data<float>();
memset(output_ptr, 0, size * sizeof(float));
int n = input_tensors.size();
int64_t cost = size * n;
int64_t groups = 1;
if (cost > kCostPerGroup) {
groups = cost / kCostPerGroup;
}
int64_t element_per_group = size / groups;
#pragma omp parallel for
for (int64_t i = 0; i < size; i += element_per_group) {
int64_t count = std::min(element_per_group, size - i);
int nn = count >> 2;
int remain = count - (nn << 2);
for (int64_t j = 0; j < n; ++j) {
const float *input_base = input_tensors[j]->data<float>();
const float *inptr = input_base + i;
float *outptr = output_ptr + i;
for (int k = 0; k < nn; ++k) {
float32x4_t _inptr = vld1q_f32(inptr);
float32x4_t _outptr = vld1q_f32(outptr);
_outptr = vaddq_f32(_outptr, _inptr);
vst1q_f32(outptr, _outptr);
inptr += 4;
outptr += 4;
}
for (int k = 0; k < remain; ++k) {
*outptr += *inptr;
++inptr;
++outptr;
}
}
}
};
} // namespace kernels
} // namespace mace
......@@ -68,12 +68,11 @@ void ResizeImage(const T *images,
const index_t out_batch_num_values = channels * out_height * out_width;
const CachedInterpolation *xs = xs_vec.data();
#pragma omp parallel for
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch_size; ++b) {
const T *batch_input_ptr = images + in_batch_num_values * b;;
T *batch_output_ptr = output + out_batch_num_values * b;
for (index_t y = 0; y < out_height; ++y) {
const T *batch_input_ptr = images + in_batch_num_values * b;
T *batch_output_ptr = output + out_batch_num_values * b;
const T *y_lower_input_ptr =
batch_input_ptr + ys[y].lower * in_width * channels;
const T *y_upper_input_ptr =
......
......@@ -13,14 +13,6 @@ void Register_AddN(OperatorRegistry *op_registry) {
.Build(),
AddNOp<DeviceType::CPU, float>);
#if MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("AddN")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
AddNOp<DeviceType::NEON, float>);
#endif // MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("AddN")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
......
......@@ -67,7 +67,6 @@ static void AddNBenchmark(int iters, int inputs, int n, int h, int w, int c) {
#define BM_ADDN(INPUTS, N, H, W, C) \
BM_ADDN_MACRO(INPUTS, N, H, W, C, float, CPU); \
BM_ADDN_MACRO(INPUTS, N, H, W, C, float, NEON); \
BM_ADDN_MACRO(INPUTS, N, H, W, C, float, OPENCL); \
BM_ADDN_MACRO(INPUTS, N, H, W, C, half, OPENCL);
......
......@@ -33,8 +33,6 @@ void SimpleAdd2() {
TEST_F(AddnOpTest, CPUSimpleAdd2) { SimpleAdd2<DeviceType::CPU>(); }
TEST_F(AddnOpTest, NEONSimpleAdd2) { SimpleAdd2<DeviceType::NEON>(); }
template <DeviceType D>
void SimpleAdd3() {
// Construct graph
......@@ -61,8 +59,6 @@ void SimpleAdd3() {
TEST_F(AddnOpTest, CPUSimpleAdd3) { SimpleAdd3<DeviceType::CPU>(); }
TEST_F(AddnOpTest, NEONSimpleAdd3) { SimpleAdd3<DeviceType::NEON>(); }
template <DeviceType D>
void RandomTest() {
testing::internal::LogToStderr();
......
......@@ -84,7 +84,7 @@ static void BatchNorm(
#define BM_BATCH_NORM(N, C, H, W) \
BM_BATCH_NORM_MACRO(N, C, H, W, float, CPU); \
BM_BATCH_NORM_MACRO(N, C, H, W, float, NEON); \
BM_BATCH_NORM_MACRO(N, C, H, W, float, NEON); \
BM_BATCH_NORM_MACRO(N, C, H, W, float, OPENCL); \
BM_BATCH_NORM_MACRO(N, C, H, W, half, OPENCL);
......
......@@ -72,92 +72,8 @@ void Simple() {
TEST_F(BatchNormOpTest, SimpleCPU) { Simple<DeviceType::CPU>(); }
TEST_F(BatchNormOpTest, SimpleNEON) { Simple<DeviceType::NEON>(); }
TEST_F(BatchNormOpTest, SimpleOPENCL) { Simple<DeviceType::OPENCL>(); }
TEST_F(BatchNormOpTest, SimpleRandomNeon) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t height = 64;
index_t width = 64;
index_t channels = 3 + rand() % 50;
// Construct graph
OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input",
{batch, height, width, channels});
net.AddRandomInput<DeviceType::CPU, float>("Scale", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Offset", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Mean", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Var", {channels}, true);
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run NEON
net.RunOp(DeviceType::NEON);
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-2);
}
TEST_F(BatchNormOpTest, ComplexRandomNeon) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 103;
index_t width = 113;
// Construct graph
OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input",
{batch, height, width, channels});
net.AddRandomInput<DeviceType::CPU, float>("Scale", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Offset", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Mean", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Var", {channels}, true);
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run NEON
net.RunOp(DeviceType::NEON);
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-2);
}
TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
srand(time(NULL));
......
......@@ -75,36 +75,38 @@ static void DepthwiseConv2d(int iters,
}
}
#define BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, \
DEVICE) \
static void \
BM_DEPTHWISE_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
DepthwiseConv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, \
mace::Padding::P, OC); \
} \
BENCHMARK( \
BM_DEPTHWISE_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE)
#define BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, \
DEVICE) \
static void \
BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
DepthwiseConv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, \
mace::Padding::P, OC); \
} \
BENCHMARK( \
BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE)
#define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, OC) \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, float, CPU); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, float, OPENCL); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, half, OPENCL);
BM_DEPTHWISE_CONV_2D(1, 32, 112, 112, 3, 3, 1, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 32, 112, 112, 3, 3, 2, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 1);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 1);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 1);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 1);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, SAME, 1);
//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 1);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 1);
//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 1);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 1);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 1);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 1);
} // namespace mace
} // namespace mace
......@@ -280,7 +280,7 @@ void TestNxNS12(const index_t height, const index_t width) {
ExpectTensorNear<float>(expected, *net.GetOutput("DeviceOutput"), 0.1);
};
for (int kernel_size : {3}) {
for (int kernel_size : {2, 3, 4}) {
for (int stride : {1, 2}) {
func(kernel_size, kernel_size, stride, stride, VALID);
func(kernel_size, kernel_size, stride, stride, SAME);
......
......@@ -3,8 +3,8 @@
set -e
Usage() {
echo "Usage: ./tools/export_lib.sh android_abi[armeabi-v7a/arm64-v8a] runtime[gpu/dsp] export_include_dir export_lib_dir"
echo "eg: ./tools/export_lib.sh armeabi-v7a ../include ../lib/libmace_v7"
echo "Usage: ./tools/export_lib.sh target_abi[armeabi-v7a | arm64-v8a | host] runtime[gpu | dsp] export_include_dir export_lib_dir"
echo "eg: ./tools/export_lib.sh armeabi-v7a gpu ../include ../lib/libmace-armeabi-v7a"
}
if [ $# -lt 4 ]; then
......@@ -12,9 +12,7 @@ if [ $# -lt 4 ]; then
exit 1
fi
# ANDROID_ABI=arm64-v8a
# ANDROID_ABI=armeabi-v7a
ANDROID_ABI=$1
TARGET_ABI=$1
RUNTIME=$2
EXPORT_INCLUDE_DIR=$3
EXPORT_LIB_DIR=$4
......@@ -63,15 +61,18 @@ build_target()
bazel build --verbose_failures -c opt --strip always $BAZEL_TARGET \
--crosstool_top=//external:android/crosstool \
--host_crosstool_top=@bazel_tools//tools/cpp:toolchain \
--cpu=$ANDROID_ABI \
--cpu=$TARGET_ABI \
--copt="-std=c++11" \
--copt="-D_GLIBCXX_USE_C99_MATH_TR1" \
--copt="-Werror=return-type" \
--copt="-DMACE_OBFUSCATE_LITERALS" \
--copt="-O3" \
--define neon=true \
--define openmp=true \
$DSP_MODE_BUILD_FLAGS || exit 1
}
build_local_target()
build_host_target()
{
BAZEL_TARGET=$1
bazel build --verbose_failures -c opt --strip always $BAZEL_TARGET \
......@@ -79,7 +80,8 @@ build_local_target()
--copt="-D_GLIBCXX_USE_C99_MATH_TR1" \
--copt="-Werror=return-type" \
--copt="-DMACE_OBFUSCATE_LITERALS" \
--define openmp=true || exit -1
--copt="-O3" \
--define openmp=true || exit 1
}
merge_libs()
......@@ -132,10 +134,10 @@ bash mace/tools/git/gen_version_source.sh ${CODEGEN_DIR}/version/version.cc || e
echo "Step 3: Build libmace targets"
bazel clean
if [ x"${RUNTIME}" = x"local" ]; then
if [ x"${TARGET_ABI}" = x"host" ] || [ x"${TARGET_ABI}" = x"local" ]; then
for target in ${all_targets[*]}
do
build_local_target ${target}
build_host_target ${target}
done
else
for target in ${all_targets[*]}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册