提交 780f5a60 编写于 作者: L Liangliang He

Refactor conv2d NEON implementations

上级 07f8ff18
......@@ -288,7 +288,8 @@ class Tensor {
}
CASES(dtype_, (os << (this->data<T>()[i]) << ", "));
}
LOG(INFO) << os.str();
LOG(INFO) << "Tensor size: [" << dim(0) << ", " << dim(1) << ", "
<< dim(2) << ", " << dim(3) << "], content:\n" << os.str();
}
inline size_t SizeOfType() const {
......
......@@ -15,15 +15,14 @@ cc_library(
"*.cc",
"opencl/*.cc",
]) + if_neon_enabled(glob([
"neon/*.cc",
"neon/addn_neon.cc",
"neon/batch_norm_neon.cc",
])),
hdrs = glob([
"*.h",
"opencl/*.h",
]) + if_neon_enabled(glob([
"neon/*.h",
])),
copts = if_openmp_enabled(["-fopenmp"]),
]),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]),
linkopts = if_android(["-lm"]),
deps = [
"//mace/core",
......
......@@ -5,6 +5,10 @@
#ifndef MACE_KERNELS_ADDN_H_
#define MACE_KERNELS_ADDN_H_
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/core/runtime/opencl/cl2_header.h"
......@@ -12,7 +16,6 @@
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct AddNFunctor {
void operator()(const std::vector<const Tensor *> &input_tensors,
......@@ -47,7 +50,7 @@ struct AddNFunctor<DeviceType::OPENCL, T> {
cl::Kernel kernel_;
};
} // namespace kernels
} // namespace mace
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_ADDN_H_
......@@ -136,7 +136,7 @@ struct BatchNormFunctor<DeviceType::OPENCL, T> : BatchNormFunctorBase {
cl::Kernel kernel_;
};
} // namepsace kernels
} // namespace mace
} // namepsace kernels
} // namespace mace
#endif // MACE_KERNELS_BATCH_NORM_H_
#endif // MACE_KERNELS_BATCH_NORM_H_
此差异已折叠。
......@@ -17,7 +17,7 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
......@@ -51,7 +51,8 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1;
output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1;
break;
default:MACE_CHECK(false, "Unsupported padding type: ", padding);
default:
MACE_CHECK(false, "Unsupported padding type: ", padding);
}
// Note: TensorFlow may padded one more on the right/bottom side
......@@ -59,12 +60,10 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
// utilize the more centered features. We need to benchmark
// based on the model accuracy.
padding_size[0] =
std::max<int>(0, (output_height - 1) * strides[0]
+ k_extent_height - input_shape[2]);
padding_size[1] =
std::max<int>(0, (output_width - 1) * strides[1]
+ k_extent_width - input_shape[3]);
padding_size[0] = std::max<int>(
0, (output_height - 1) * strides[0] + k_extent_height - input_shape[2]);
padding_size[1] = std::max<int>(
0, (output_width - 1) * strides[1] + k_extent_width - input_shape[3]);
output_shape[0] = input_shape[0];
output_shape[1] = output_channels;
......@@ -73,7 +72,7 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
}
void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
const index_t *filter_shape, // HWIO
const index_t *filter_shape, // HWOI
const int *dilations,
const int *strides,
Padding padding,
......@@ -82,7 +81,7 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
......@@ -98,7 +97,7 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
index_t output_height = 0, output_width = 0;
index_t kernel_height = filter_shape[0];
index_t kernel_width = filter_shape[1];
index_t output_channels = filter_shape[3];
index_t output_channels = filter_shape[2];
index_t k_extent_height = (kernel_height - 1) * dilations[0] + 1;
index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1;
......@@ -116,7 +115,8 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
output_height = (input_shape[1] + k_extent_height - 2) / strides[0] + 1;
output_width = (input_shape[2] + k_extent_width - 2) / strides[1] + 1;
break;
default:MACE_CHECK(false, "Unsupported padding type: ", padding);
default:
MACE_CHECK(false, "Unsupported padding type: ", padding);
}
// Note: TensorFlow may padded one more on the right/bottom side
......@@ -124,12 +124,10 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
// utilize the more centered features. We need to benchmark
// based on the model accuracy.
padding_size[0] =
std::max<int>(0, (output_height - 1) * strides[0]
+ k_extent_height - input_shape[1]);
padding_size[1] =
std::max<int>(0, (output_width - 1) * strides[1]
+ k_extent_width - input_shape[2]);
padding_size[0] = std::max<int>(
0, (output_height - 1) * strides[0] + k_extent_height - input_shape[1]);
padding_size[1] = std::max<int>(
0, (output_width - 1) * strides[1] + k_extent_width - input_shape[2]);
output_shape[0] = input_shape[0];
output_shape[1] = output_height;
......@@ -146,7 +144,7 @@ void CalPaddingSize(const index_t *input_shape, // NCHW
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(padding_size);
......@@ -167,19 +165,18 @@ void CalPaddingSize(const index_t *input_shape, // NCHW
output_height = (input_shape[2] + k_extent_height - 2) / strides[0] + 1;
output_width = (input_shape[3] + k_extent_width - 2) / strides[1] + 1;
break;
default:MACE_CHECK(false, "Unsupported padding type: ", padding);
default:
MACE_CHECK(false, "Unsupported padding type: ", padding);
}
// Note: TensorFlow may padded one more on the right/bottom side
// TODO may be it's better to also truncate the left/top to
// utilize the more centered features. We need to benchmark
// based on the model accuracy.
padding_size[0] =
std::max<int>(0, (output_height - 1) * strides[0]
+ k_extent_height - input_shape[2]);
padding_size[1] =
std::max<int>(0, (output_width - 1) * strides[1]
+ k_extent_width - input_shape[3]);
padding_size[0] = std::max<int>(
0, (output_height - 1) * strides[0] + k_extent_height - input_shape[2]);
padding_size[1] = std::max<int>(
0, (output_width - 1) * strides[1] + k_extent_width - input_shape[3]);
}
void ConstructInputWithPadding(const Tensor *input_tensor,
......@@ -206,18 +203,18 @@ void ConstructInputWithPadding(const Tensor *input_tensor,
output_tensor->Resize(output_shape);
Tensor::MappingGuard padded_output_mapper(output_tensor);
float *output_ptr = output_tensor->mutable_data<float>();
memset(output_ptr, 0, output_tensor->size() * sizeof(float));
float *output_data = output_tensor->mutable_data<float>();
memset(output_data, 0, output_tensor->size() * sizeof(float));
// Skip the padded top rows
if (padding_same_value) {
#define COPY_INPUT \
std::fill(output_ptr, output_ptr+padded_left, input[0]); \
output_ptr += padded_left; \
memcpy(output_ptr, input, width * sizeof(float)); \
output_ptr += width; \
std::fill(output_ptr , output_ptr + padded_right, input[width-1]); \
output_ptr += padded_right;
#define COPY_INPUT \
std::fill(output_data, output_data + padded_left, input[0]); \
output_data += padded_left; \
memcpy(output_data, input, width * sizeof(float)); \
output_data += width; \
std::fill(output_data, output_data + padded_right, input[width - 1]); \
output_data += padded_right;
const int padded_bottom = paddings[0] - padded_top;
const int padded_right = paddings[1] - padded_left;
......@@ -239,19 +236,69 @@ void ConstructInputWithPadding(const Tensor *input_tensor,
}
#undef COPY_INPUT
} else {
output_ptr += padded_top * output_width;
output_data += padded_top * output_width;
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
for (int k = 0; k < height; ++k) {
memcpy(output_ptr + padded_left, input, width * sizeof(float));
memcpy(output_data + padded_left, input, width * sizeof(float));
input += width;
output_ptr += output_width;
output_data += output_width;
}
// Skip the padded bottom in this channel and top in the next channel
output_ptr += paddings[0] * output_width;
output_data += paddings[0] * output_width;
}
}
}
}
} // namespace kernels
} // namespace mace
void ConstructNHWCInputWithPadding(const Tensor *input_tensor,
const int *paddings,
Tensor *output_tensor,
bool padding_same_value) {
VLOG(1) << "input: " << input_tensor->NumElements();
Tensor::MappingGuard input_mapper(input_tensor);
const float *input = input_tensor->data<float>();
const index_t *input_shape = input_tensor->shape().data();
index_t batch = input_shape[0];
index_t height = input_shape[1];
index_t width = input_shape[2];
index_t channels = input_shape[3];
std::vector<index_t> output_shape(
{batch, paddings[0] + height, paddings[1] + width, channels});
const int output_height = output_shape[1];
const int output_width = output_shape[2];
const int padded_top = paddings[0] / 2;
const int padded_left = paddings[1] / 2;
output_tensor->Resize(output_shape);
Tensor::MappingGuard padded_output_mapper(output_tensor);
float *output_data = output_tensor->mutable_data<float>();
memset(output_data, 0, output_tensor->size() * sizeof(float));
// Skip the padded top rows
if (padding_same_value) {
LOG(FATAL) << "Not implemented";
} else {
#pragma omp parallel for collapse(3)
for (int n = 0; n < batch; ++n) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
const float *input_ptr =
input + ((n * height + h) * width + w) * channels;
float *output_ptr =
output_data +
((n * output_height + h + padded_top) * output_width + w +
padded_left) *
channels;
memcpy(output_ptr, input_ptr, channels * sizeof(float));
}
}
}
}
}
} // namespace kernels
} // namespace mace
......@@ -44,6 +44,12 @@ void ConstructInputWithPadding(const Tensor *input,
const int *paddings,
Tensor *output_tensor,
bool padding_same_value = false);
void ConstructNHWCInputWithPadding(const Tensor *input,
const int *paddings,
Tensor *output_tensor,
bool padding_same_value = false);
} // namespace kernels
} // namespace mace
......
......@@ -64,8 +64,8 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
std::vector<index_t> fake_filter_shape(4);
fake_filter_shape[0] = filter->shape()[0];
fake_filter_shape[1] = filter->shape()[1];
fake_filter_shape[3] = filter->shape()[2] * filter->shape()[3];
fake_filter_shape[2] = 1;
fake_filter_shape[2] = filter->shape()[2] * filter->shape()[3];
fake_filter_shape[3] = 1;
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
......
......@@ -10,9 +10,11 @@ namespace kernels {
template <>
void AddNFunctor<DeviceType::NEON, float>::operator()(
const std::vector<const Tensor *> &input_tensors, Tensor *output_tensor,
const std::vector<const Tensor *> &input_tensors,
Tensor *output_tensor,
StatsFuture *future) {
// TODO: neon mem copy
output_tensor->ResizeLike(input_tensors[0]);
index_t size = output_tensor->size();
float *output_ptr = output_tensor->mutable_data<float>();
memset(output_ptr, 0, size * sizeof(float));
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <arm_neon.h>
#include <float.h>
#include <limits>
#include "mace/core/common.h"
namespace mace {
namespace kernels {
void PoolingAvgNeonK3x3S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int padding_top = paddings[0] / 2;
int padding_bottom = paddings[0] - padding_top;
int padding_left = paddings[1] / 2;
int padding_right = paddings[1] - padding_left;
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
float avg_factors[4] = {1.0 / 9.0, 1.0 / 9.0, 1.0 / 9.0, 1.0 / 9.0};
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
float *outptr = output + output_offset;
for (int h = 0; h < out_height; ++h) {
int w = 0;
int num_vectors = 0;
const float *r0, *r1, *r2;
if (!((h == 0 && padding_top > 0) ||
(h == out_height - 1 && padding_bottom > 0))) {
r0 = input + input_offset + (h * 2 - padding_top) * in_width;
r1 = r0 + in_width;
r2 = r1 + in_width;
if (padding_left > 0) {
if (padding_left == 1) {
float sum0 = std::max(r0[0], r0[1]);
float sum1 = std::max(r1[0], r1[1]);
float max2 = std::max(r2[0], r2[1]);
*outptr = (r0[0] + r0[1] + r1[0] + r1[1] + r2[0] + r2[1]) / 9.0;
++r0;
++r1;
} else { // padding_left == 2
*outptr = (r0[0] + r1[0] + r2[0]) / 9.0;
}
++outptr;
++w;
}
if (padding_right > 0) {
num_vectors = (out_width - w - 1) >> 2;
} else {
num_vectors = (out_width - w) >> 2;
}
}
w += num_vectors << 2;
float32x4_t factors = vld1q_f32(avg_factors);
float32x4x2_t row0 = vld2q_f32(r0);
float32x4x2_t row1 = vld2q_f32(r1);
float32x4x2_t row2 = vld2q_f32(r2);
for (; num_vectors > 0; --num_vectors) {
float32x4x2_t row0_next = vld2q_f32(r0 + 8);
float32x4x2_t row1_next = vld2q_f32(r1 + 8);
float32x4x2_t row2_next = vld2q_f32(r2 + 8);
float32x4_t sum0 = vaddq_f32(row0.val[0], row0.val[1]);
float32x4_t sum1 = vaddq_f32(row1.val[0], row1.val[1]);
float32x4_t sum2 = vaddq_f32(row2.val[0], row2.val[1]);
float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1);
float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1);
float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1);
sum0 = vaddq_f32(sum0, row02);
sum1 = vaddq_f32(sum1, row12);
sum2 = vaddq_f32(sum2, row22);
float32x4_t sum_result = vaddq_f32(vaddq_f32(sum0, sum1), sum2);
float32x4_t avg_result = vmulq_f32(sum_result, factors);
vst1q_f32(outptr, avg_result);
row0 = row0_next;
row1 = row1_next;
row2 = row2_next;
r0 += 8;
r1 += 8;
r2 += 8;
outptr += 4;
}
for (; w < out_width; ++w) {
float sum = 0.0;
for (int kh = 0; kh < 3; ++kh) {
for (int kw = 0; kw < 3; ++kw) {
int inh = h * 2 - padding_top + kh;
int inw = w * 2 - padding_left + kw;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
sum += input[input_offset + inh * in_width + inw];
}
}
}
*outptr = sum / 9.0;
++outptr;
}
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
// assume the input has already been padded
void PoolingAvgNeonK3x3S2x2Padded(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
float avg_factors[4] = {1.0 / 9.0, 1.0 / 9.0, 1.0 / 9.0, 1.0 / 9.0};
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
const float *img0 = input + input_offset;
float *outptr = output + output_offset;
const float *r0 = img0;
const float *r1 = r0 + in_width;
const float *r2 = r1 + in_width;
for (int h = 0; h < out_height; h++) {
int num_vectors = out_width >> 2;
int remain = out_width - (num_vectors << 2);
float32x4_t factors = vld1q_f32(avg_factors);
float32x4x2_t row0 = vld2q_f32(r0);
float32x4x2_t row1 = vld2q_f32(r1);
float32x4x2_t row2 = vld2q_f32(r2);
for (; num_vectors > 0; --num_vectors) {
float32x4x2_t row0_next = vld2q_f32(r0 + 8);
float32x4x2_t row1_next = vld2q_f32(r1 + 8);
float32x4x2_t row2_next = vld2q_f32(r2 + 8);
float32x4_t sum0 = vaddq_f32(row0.val[0], row0.val[1]);
float32x4_t sum1 = vaddq_f32(row1.val[0], row1.val[1]);
float32x4_t sum2 = vaddq_f32(row2.val[0], row2.val[1]);
float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1);
float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1);
float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1);
sum0 = vaddq_f32(sum0, row02);
sum1 = vaddq_f32(sum1, row12);
sum2 = vaddq_f32(sum2, row22);
float32x4_t sum_result = vaddq_f32(vaddq_f32(sum0, sum1), sum2);
float32x4_t avg_result = vmulq_f32(sum_result, factors);
vst1q_f32(outptr, avg_result);
row0 = row0_next;
row1 = row1_next;
row2 = row2_next;
r0 += 8;
r1 += 8;
r2 += 8;
outptr += 4;
}
for (; remain > 0; remain--) {
*outptr = (r0[0] + r0[1] + r0[2] + r1[0] + r1[1] + r1[2] + r2[0] +
r2[1] + r2[2]) /
9.0;
r0 += 2;
r1 += 2;
r2 += 2;
outptr++;
}
r0 += 1 + in_width;
r1 += 1 + in_width;
r2 += 1 + in_width;
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
} // namespace kernels
} // namespace mace
......@@ -15,6 +15,7 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
......@@ -26,8 +27,8 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
const index_t n = input->dim(0);
const index_t channel = input->dim(1);
const index_t sample_size = input->dim(2) * input->dim(3);
const index_t sample_size = input->dim(1) * input->dim(2);
const index_t channel = input->dim(3);
const float *input_ptr = input->data<float>();
const float *scale_ptr = scale->data<float>();
......@@ -36,36 +37,47 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const float *var_ptr = var->data<float>();
float *output_ptr = output->mutable_data<float>();
index_t count = sample_size >> 2;
index_t remain_count = sample_size - (count << 2);
const index_t ch_blks = channel >> 2;
const index_t remain_chs = channel - (ch_blks << 2);
std::vector<float> new_scale(channel);
std::vector<float> new_offset(channel);
#pragma omp parallel for
for (index_t c = 0; c < channel; ++c) {
float new_scale = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon_);
float new_offset = offset_ptr[c] - mean_ptr[c] * new_scale;
index_t pos = c * sample_size;
float32x4_t new_scale_f = vdupq_n_f32(new_scale);
float32x4_t new_offset_f = vdupq_n_f32(new_offset);
for (index_t i = 0; i < n; ++i) {
const float *input_sample_ptr = input_ptr + pos;
float *output_sample_ptr = output_ptr + pos;
new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon);
new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c];
}
for (index_t j = 0; j < count; ++j) {
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < n; ++i) {
for (index_t j = 0; j < sample_size; ++j) {
const float *input_sample_ptr = input_ptr + (i * sample_size + j) * channel;
float *output_sample_ptr = output_ptr + (i * sample_size + j) * channel;
const float *new_scale_ptr = new_scale.data();
const float *new_offset_ptr = new_offset.data();
for (index_t cb = 0; cb < ch_blks; ++cb) {
float32x4_t new_scale_f = vld1q_f32(new_scale_ptr);
float32x4_t new_offset_f = vld1q_f32(new_offset_ptr);
float32x4_t input_f = vld1q_f32(input_sample_ptr);
float32x4_t output_f = vfmaq_f32(new_offset_f, input_f, new_scale_f);
vst1q_f32(output_sample_ptr, output_f);
input_sample_ptr += 4;
output_sample_ptr += 4;
new_scale_ptr += 4;
new_offset_ptr += 4;
}
for (index_t j = 0; j < remain_count; ++j) {
*output_sample_ptr = new_scale * *input_sample_ptr + new_offset;
for (index_t c = (ch_blks << 2); c < channel; ++c) {
*output_sample_ptr = new_scale[c] * *input_sample_ptr + new_offset[c];
++output_sample_ptr;
++input_sample_ptr;
++new_scale_ptr;
++new_offset_ptr;
}
pos += channel * sample_size;
}
}
};
} // namespace kernels
} // namespace mace
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/global_avg_pooling.h"
#include <arm_neon.h>
namespace mace {
namespace kernels {
template <>
void GlobalAvgPoolingFunctor<DeviceType::NEON, float>::operator()(
const float *input, const index_t *input_shape,
float *output, StatsFuture *future) {
index_t batch = input_shape[0];
index_t channels = input_shape[1];
index_t height = input_shape[2];
index_t width = input_shape[3];
index_t image_size = height * width;
index_t input_offset = 0;
index_t total_channels = batch * channels;
#pragma omp parallel for
for (int c = 0; c < total_channels; ++c) {
const float *inptr = input + c * image_size;
float sum = 0.0;
int num_vectors = image_size >> 2;
int remain = image_size - (num_vectors << 2);
if (num_vectors > 0) {
float sum_out[4] = {0.0, 0.0, 0.0, 0.0};
float32x4_t sum_vector = vld1q_f32(inptr);
inptr += 4;
for (int n = 1; n < num_vectors; ++n) {
float32x4_t vector = vld1q_f32(inptr);
sum_vector = vaddq_f32(sum_vector, vector);
inptr += 4;
}
vst1q_f32(sum_out, sum_vector);
sum = sum_out[0] + sum_out[1] + sum_out[2] + sum_out[3];
}
for (int i = 0; i < remain; ++i) {
sum += *inptr;
++inptr;
}
output[c] = sum / image_size;
}
};
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <arm_neon.h>
#include <float.h>
#include <limits>
#include "mace/core/common.h"
namespace mace {
namespace kernels {
void PoolingMaxNeonK2x2S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int padding_top = paddings[0] / 2;
int padding_bottom = paddings[0] - padding_top;
int padding_left = paddings[1] / 2;
int padding_right = paddings[1] - padding_left;
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
float *outptr = output + output_offset;
const float *r0, *r1;
for (int h = 0; h < out_height; ++h) {
int w = 0;
int num_vectors = 0;
if (!((h == 0 && padding_top > 0) ||
(h == out_height - 1 && padding_bottom > 0))) {
r0 = input + input_offset + (h * 2 - padding_top) * in_width;
r1 = r0 + in_width;
if (padding_left > 0) {
*outptr = std::max(r0[0], r1[0]);
++r0;
++r1;
++outptr;
++w;
}
if (padding_right > 0) {
num_vectors = (out_width - w - 1) >> 2;
} else {
num_vectors = (out_width - w) >> 2;
}
}
w += num_vectors << 2;
for (; num_vectors > 0; --num_vectors) {
float32x4_t r00 = vld1q_f32(r0);
float32x4_t r10 = vld1q_f32(r1);
float32x4_t r01 = vld1q_f32(r0 + 4);
float32x4_t r11 = vld1q_f32(r1 + 4);
float32x4_t max0 = vmaxq_f32(r00, r10);
float32x4_t max1 = vmaxq_f32(r01, r11);
float32x4_t max_result = vpmaxq_f32(max0, max1);
vst1q_f32(outptr, max_result);
r0 += 8;
r1 += 8;
outptr += 4;
}
for (; w < out_width; ++w) {
float max = std::numeric_limits<float>::lowest();
for (int kh = 0; kh < 2; ++kh) {
for (int kw = 0; kw < 2; ++kw) {
int inh = h * 2 - padding_top + kh;
int inw = w * 2 - padding_left + kw;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
max = std::max(max, input[input_offset + inh * in_width + inw]);
}
}
}
*outptr = max;
++outptr;
}
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
// assume the input has already been padded
void PoolingMaxNeonK2x2S2x2Padded(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
const float *img0 = input + input_offset;
float *outptr = output + output_offset;
const float *r0 = img0;
const float *r1 = img0 + in_width;
for (int h = 0; h < out_height; ++h) {
int num_vectors = out_width >> 2;
int remain = out_width - (num_vectors << 2);
for (; num_vectors > 0; --num_vectors) {
float32x4_t r00 = vld1q_f32(r0);
float32x4_t r10 = vld1q_f32(r1);
float32x4_t r01 = vld1q_f32(r0 + 4);
float32x4_t r11 = vld1q_f32(r1 + 4);
float32x4_t max0 = vmaxq_f32(r00, r10);
float32x4_t max1 = vmaxq_f32(r01, r11);
float32x4_t max_result = vpmaxq_f32(max0, max1);
vst1q_f32(outptr, max_result);
r0 += 8;
r1 += 8;
outptr += 4;
}
for (; remain > 0; --remain) {
float max0 = std::max(r0[0], r0[1]);
float max1 = std::max(r1[0], r1[1]);
*outptr = std::max(max0, max1);
r0 += 2;
r1 += 2;
outptr++;
}
r0 += in_width;
r1 += in_width;
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <arm_neon.h>
#include <limits>
#include "mace/core/common.h"
namespace mace {
namespace kernels {
void PoolingMaxNeonK3x3S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int padding_top = paddings[0] / 2;
int padding_bottom = paddings[0] - padding_top;
int padding_left = paddings[1] / 2;
int padding_right = paddings[1] - padding_left;
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
float *outptr = output + output_offset;
for (int h = 0; h < out_height; ++h) {
int w = 0;
int num_vectors = 0;
const float *r0, *r1, *r2;
if (!((h == 0 && padding_top > 0) ||
(h == out_height - 1 && padding_bottom > 0))) {
r0 = input + input_offset + (h * 2 - padding_top) * in_width;
r1 = r0 + in_width;
r2 = r1 + in_width;
if (padding_left > 0) {
if (padding_left == 1) {
float max0 = std::max(r0[0], r0[1]);
float max1 = std::max(r1[0], r1[1]);
float max2 = std::max(r2[0], r2[1]);
*outptr = std::max(std::max(max0, max1), max2);
++r0;
++r1;
} else { // padding_left == 2
float max_tmp = std::max(r0[0], r1[0]);
*outptr = std::max(max_tmp, r2[0]);
}
++outptr;
++w;
}
if (padding_right > 0) {
num_vectors = (out_width - w - 1) >> 2;
} else {
num_vectors = (out_width - w) >> 2;
}
}
w += num_vectors << 2;
float32x4x2_t row0 = vld2q_f32(r0);
float32x4x2_t row1 = vld2q_f32(r1);
float32x4x2_t row2 = vld2q_f32(r2);
for (; num_vectors > 0; --num_vectors) {
float32x4x2_t row0_next = vld2q_f32(r0 + 8);
float32x4x2_t row1_next = vld2q_f32(r1 + 8);
float32x4x2_t row2_next = vld2q_f32(r2 + 8);
float32x4_t max0 = vmaxq_f32(row0.val[0], row0.val[1]);
float32x4_t max1 = vmaxq_f32(row1.val[0], row1.val[1]);
float32x4_t max2 = vmaxq_f32(row2.val[0], row2.val[1]);
float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1);
float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1);
float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1);
max0 = vmaxq_f32(max0, row02);
max1 = vmaxq_f32(max1, row12);
max2 = vmaxq_f32(max2, row22);
float32x4_t max_result = vmaxq_f32(vmaxq_f32(max0, max1), max2);
vst1q_f32(outptr, max_result);
row0 = row0_next;
row1 = row1_next;
row2 = row2_next;
r0 += 8;
r1 += 8;
r2 += 8;
outptr += 4;
}
for (; w < out_width; ++w) {
float max = std::numeric_limits<float>::lowest();
for (int kh = 0; kh < 3; ++kh) {
for (int kw = 0; kw < 3; ++kw) {
int inh = h * 2 - padding_top + kh;
int inw = w * 2 - padding_left + kw;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
max = std::max(max, input[input_offset + inh * in_width + inw]);
}
}
}
*outptr = max;
++outptr;
}
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
// assume the input has already been padded
void PoolingMaxNeonK3x3S2x2Padded(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
const float *img0 = input + input_offset;
float *outptr = output + output_offset;
const float *r0 = img0;
const float *r1 = r0 + in_width;
const float *r2 = r1 + in_width;
for (int h = 0; h < out_height; h++) {
int num_vectors = out_width >> 2;
int remain = out_width - (num_vectors << 2);
float32x4x2_t row0 = vld2q_f32(r0);
float32x4x2_t row1 = vld2q_f32(r1);
float32x4x2_t row2 = vld2q_f32(r2);
for (; num_vectors > 0; num_vectors--) {
float32x4x2_t row0_next = vld2q_f32(r0 + 8);
float32x4x2_t row1_next = vld2q_f32(r1 + 8);
float32x4x2_t row2_next = vld2q_f32(r2 + 8);
float32x4_t max0 = vmaxq_f32(row0.val[0], row0.val[1]);
float32x4_t max1 = vmaxq_f32(row1.val[0], row1.val[1]);
float32x4_t max2 = vmaxq_f32(row2.val[0], row2.val[1]);
float32x4_t row02 = vextq_f32(row0.val[0], row0_next.val[0], 1);
float32x4_t row12 = vextq_f32(row1.val[0], row1_next.val[0], 1);
float32x4_t row22 = vextq_f32(row2.val[0], row2_next.val[0], 1);
max0 = vmaxq_f32(max0, row02);
max1 = vmaxq_f32(max1, row12);
max2 = vmaxq_f32(max2, row22);
float32x4_t max_result = vmaxq_f32(vmaxq_f32(max0, max1), max2);
vst1q_f32(outptr, max_result);
row0 = row0_next;
row1 = row1_next;
row2 = row2_next;
r0 += 8;
r1 += 8;
r2 += 8;
outptr += 4;
}
for (; remain > 0; remain--) {
float max0 = std::max(std::max(r0[0], r0[1]), r0[2]);
float max1 = std::max(std::max(r1[0], r1[1]), r1[2]);
float max2 = std::max(std::max(r2[0], r2[1]), r2[2]);
*outptr = std::max(std::max(max0, max1), max2);
r0 += 2;
r1 += 2;
r2 += 2;
outptr++;
}
r0 += 1 + in_width;
r1 += 1 + in_width;
r2 += 1 + in_width;
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/pooling.h"
namespace mace {
namespace kernels {
extern void PoolingMaxNeonK2x2S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings);
extern void PoolingAvgNeonK2x2S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings);
extern void PoolingMaxNeonK3x3S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings);
extern void PoolingAvgNeonK3x3S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings);
#ifdef __COPY_MAKE_PADDING
extern void PoolingMaxNeonK2x2S2x2Padded(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape);
extern void PoolingAvgNeonK2x2S2x2Padded(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape);
extern void PoolingMaxNeonK3x3S2x2Padded(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape);
extern void PoolingAvgNeonK3x3S2x2Padded(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape);
#endif
template <>
void PoolingFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future) {
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
std::vector<index_t> filter_shape(4);
filter_shape[0] = input_tensor->shape()[1];
filter_shape[1] = input_tensor->shape()[1];
filter_shape[2] = kernels_[0];
filter_shape[3] = kernels_[1];
kernels::CalcPaddingAndOutputSize(
input_tensor->shape().data(), filter_shape.data(), this->dilations_,
strides_, this->padding_, output_shape.data(),
paddings.data());
output_tensor->Resize(output_shape);
const float *input = input_tensor->data<float>();
float *output = output_tensor->mutable_data<float>();
const index_t *input_shape = input_tensor->shape().data();
#ifdef __COPY_MAKE_PADDING
Tensor padded_input;
ConstructInputWithPadding(input_tensor, paddings.data(), &padded_input);
input = padded_input.data<float>();
input_shape = padded_input.shape().data();
#endif
if (kernels_[0] == 2 && kernels_[1] == 2 && strides_[0] == 2 &&
strides_[1] == 2) {
// kernel_size: 2x2, strides: 2x2
if (pooling_type_ == MAX) { // MAX_POOL_2x2s2x2
#ifdef __COPY_MAKE_PADDING
PoolingMaxNeonK2x2S2x2Padded(input, input_shape, output, output_shape.data());
#else
PoolingMaxNeonK2x2S2x2(input, input_shape, output, output_shape.data(),
paddings.data());
#endif
} else { // AVG_POOL_2x2s2x2
#ifdef __COPY_MAKE_PADDING
PoolingAvgNeonK2x2S2x2Padded(input, input_shape, output, output_shape.data());
#else
PoolingAvgNeonK2x2S2x2(input, input_shape, output, output_shape.data(),
paddings.data());
#endif
}
} else if (kernels_[0] == 3 && kernels_[1] == 3 && strides_[0] == 2 &&
strides_[1] == 2) {
// kernel_size: 3x3, strides: 2x2
if (pooling_type_ == MAX) { // MAX_POOL_3x3s2x2
#ifdef __COPY_MAKE_PADDING
PoolingMaxNeonK3x3S2x2Padded(input, input_shape, output, output_shape.data());
#else
PoolingMaxNeonK3x3S2x2(input, input_shape, output, output_shape.data(),
paddings.data());
#endif
} else { // AVG_POOL_3x3s2x2
#ifdef __COPY_MAKE_PADDING
PoolingAvgNeonK3x3S2x2Padded(input, input_shape, output, output_shape.data());
#else
PoolingAvgNeonK3x3S2x2(input, input_shape, output, output_shape.data(),
paddings.data());
#endif
}
} else { // not implement yet
PoolingFunctor<DeviceType::CPU, float>(pooling_type_, kernels_, strides_,
padding_, dilations_)(
input_tensor, output_tensor, future);
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/relu.h"
#include <arm_neon.h>
namespace mace {
namespace kernels {
template <>
void ActivationFunctor<DeviceType::NEON, float>::operator()(const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future) {
const float *input = input_tensor->data<float>();
float *output = output_tensor->mutable_data<float>();
index_t size = input_tensor->size();
if (max_limit_ < 0) {
#pragma omp parallel for
for (int64_t i = 0; i < size; i += kCostPerGroup) {
int64_t count = std::min(static_cast<int64_t>(kCostPerGroup), size - i);
int block = count >> 2;
int remain = count - (block << 2);
const float *inptr = input + i;
float *outptr = output + i;
float32x4_t zero = vdupq_n_f32(0.f);
for (; block > 0; --block) {
float32x4_t in = vld1q_f32(inptr);
float32x4_t out = vmaxq_f32(in, zero);
vst1q_f32(outptr, out);
inptr += 4;
outptr += 4;
}
for (; remain > 0; --remain) {
*outptr = std::max(*inptr, 0.f);
++inptr;
++outptr;
}
}
} else {
#pragma omp parallel for
for (int64_t i = 0; i < size; i += kCostPerGroup) {
int64_t count = std::min(static_cast<int64_t>(kCostPerGroup), size - i);
int block = count >> 2;
int remain = count - (block << 2);
const float *inptr = input + i;
float *outptr = output + i;
float32x4_t zero = vdupq_n_f32(0.f);
float32x4_t vmax = vdupq_n_f32(max_limit_);
for (; block > 0; --block) {
float32x4_t in = vld1q_f32(inptr);
float32x4_t out = vmaxq_f32(in, zero);
out = vminq_f32(out, vmax);
vst1q_f32(outptr, out);
inptr += 4;
outptr += 4;
}
for (; remain > 0; --remain) {
*outptr = std::min(std::max(*inptr, 0.f), max_limit_);
++inptr;
++outptr;
}
}
}
};
} // namespace kernels
} // namespace mace
#include <common.h>
__kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, ic, oc */
__kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, oc, ic */
__private const int filter_w,
__private const int in_channel,
__private const int out_channel,
__private const int in_channel,
__write_only image2d_t output) {
int w = get_global_id(0);
int h = get_global_id(1);
......@@ -13,23 +13,26 @@ __kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, i
const int in_channel_idx = w % rounded_in_channel;
const int h_idx = hw_idx / filter_w;
const int w_idx = hw_idx % filter_w;
const int offset = ((h_idx * filter_w + w_idx) * in_channel + in_channel_idx) * out_channel
+ out_channel_idx;
const int offset = ((h_idx * filter_w + w_idx) * out_channel + out_channel_idx) * in_channel
+ in_channel_idx;
const int size = out_channel - out_channel_idx;
VEC_DATA_TYPE(DATA_TYPE, 4) values = 0;
if (in_channel_idx < in_channel) {
if (out_channel_idx < out_channel) {
const int size = out_channel - out_channel_idx;
if (size < 4) {
switch(size) {
switch (size) {
case 3:
values.z = *(input + offset + 2);
values.z = *(input + offset + 2 * in_channel);
case 2:
values.y = *(input + offset + 1);
values.y = *(input + offset + 1 * in_channel);
case 1:
values.x = *(input + offset);
}
} else {
values = vload4(0, input + offset);
values.w = *(input + offset + 3 * in_channel);
values.z = *(input + offset + 2 * in_channel);
values.y = *(input + offset + 1 * in_channel);
values.x = *(input + offset);
}
}
......@@ -37,10 +40,10 @@ __kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, i
CMD_TYPE(write_image, CMD_DATA_TYPE)(output, coord, values);
}
__kernel void filter_image_to_buffer(__global DATA_TYPE *output, /* h, w, ic, oc */
__kernel void filter_image_to_buffer(__global DATA_TYPE *output, /* h, w, oc, ic */
__private const int filter_w,
__private const int in_channel,
__private const int out_channel,
__private const int in_channel,
__read_only image2d_t input) {
int w = get_global_id(0);
int h = get_global_id(1);
......@@ -50,29 +53,31 @@ __kernel void filter_image_to_buffer(__global DATA_TYPE *output, /* h, w, ic, oc
const int in_channel_idx = w % rounded_in_channel;
const int h_idx = hw_idx / filter_w;
const int w_idx = hw_idx % filter_w;
const int offset = ((h_idx * filter_w + w_idx) * in_channel + in_channel_idx) * out_channel
+ out_channel_idx;
const int offset = ((h_idx * filter_w + w_idx) * out_channel + out_channel_idx) * in_channel
+ in_channel_idx;
if (in_channel_idx < in_channel) {
if (out_channel_idx < out_channel) {
int2 coord = (int2)(w, h);
VEC_DATA_TYPE(DATA_TYPE, 4) values = CMD_TYPE(read_image, CMD_DATA_TYPE)(input, SAMPLER, coord);
const int size = (out_channel - out_channel_idx);
if (size < 4) {
switch (size) {
case 3:
output[offset+2] = values.s2;
output[offset + 2 * in_channel] = values.z;
case 2:
output[offset+1] = values.s1;
output[offset + 1 * in_channel] = values.y;
case 1:
output[offset] = values.s0;
output[offset] = values.x;
}
} else {
vstore4(values, 0, output + offset);
output[offset + 3 * in_channel] = values.w;
output[offset + 2 * in_channel] = values.z;
output[offset + 1 * in_channel] = values.y;
output[offset] = values.x;
}
}
}
__kernel void dw_filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, ic, m */
__private const int filter_w,
__private const int in_channel,
......
......@@ -149,8 +149,8 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()(
std::vector<index_t> fake_filter_shape(4);
fake_filter_shape[0] = filter->shape()[0];
fake_filter_shape[1] = filter->shape()[1];
fake_filter_shape[3] = filter->shape()[2] * filter->shape()[3];
fake_filter_shape[2] = 1;
fake_filter_shape[2] = filter->shape()[2] * filter->shape()[3];
fake_filter_shape[3] = 1;
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
......
......@@ -19,12 +19,12 @@ void CalInOutputImageShape(const std::vector<index_t> &shape, /* NHWC */
}
// [RoundUp<4>(Ic) * H * W, (Oc + 3) / 4]
void CalConv2dFilterImageShape(const std::vector<index_t> &shape, /* HWIO */
void CalConv2dFilterImageShape(const std::vector<index_t> &shape, /* HWOI */
std::vector<size_t> &image_shape) {
MACE_CHECK(shape.size() == 4);
image_shape.resize(2);
image_shape[0] = shape[0] * shape[1] * RoundUp<index_t>(shape[2], 4);
image_shape[1] = RoundUpDiv4(shape[3]);
image_shape[0] = shape[0] * shape[1] * RoundUp<index_t>(shape[3], 4);
image_shape[1] = RoundUpDiv4(shape[2]);
}
// [H * W * M, (Ic + 3) / 4]
......@@ -179,6 +179,7 @@ void TuningOrRun3DKernel(cl::Kernel &kernel,
local_ws[2] = std::min<uint32_t>(gws[2],
kwg_size / (local_ws[0] * local_ws[1]));
return {
// TODO tuning these magic numbers
{local_ws[0], local_ws[1], local_ws[2], 1},
{kwg_size / 16, 4, 4, 1},
{kwg_size / 32, 4, 8, 1},
......@@ -200,7 +201,7 @@ void TuningOrRun3DKernel(cl::Kernel &kernel,
{9, 7, 15, 1},
{15, 7, 9, 1},
{1, kwg_size, 1, 1},
{4, 15, 8, 1}, // SNPE size
{4, 15, 8, 1},
};
};
cl::Event event;
......
......@@ -13,14 +13,6 @@ void Register_Activation(OperatorRegistry *op_registry) {
.Build(),
ActivationOp<DeviceType::CPU, float>);
#if MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
ActivationOp<DeviceType::NEON, float>);
#endif // MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
......
......@@ -298,14 +298,15 @@ static void SigmoidBenchmark(
} \
BENCHMARK(BM_SIGMOID_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_SIGMOID(N, C, H, W, TYPE) \
BM_SIGMOID_MACRO(N, C, H, W, TYPE, CPU); \
BM_SIGMOID_MACRO(N, C, H, W, TYPE, OPENCL);
BM_SIGMOID(1, 1, 512, 512, float);
BM_SIGMOID(1, 3, 128, 128, float);
BM_SIGMOID(1, 3, 512, 512, float);
BM_SIGMOID(1, 32, 112, 112, float);
BM_SIGMOID(1, 64, 256, 256, float);
#define BM_SIGMOID(N, C, H, W) \
BM_SIGMOID_MACRO(N, C, H, W, float, CPU); \
BM_SIGMOID_MACRO(N, C, H, W, float, OPENCL); \
BM_SIGMOID_MACRO(N, C, H, W, half, OPENCL);
BM_SIGMOID(1, 1, 512, 512);
BM_SIGMOID(1, 3, 128, 128);
BM_SIGMOID(1, 3, 512, 512);
BM_SIGMOID(1, 32, 112, 112);
BM_SIGMOID(1, 64, 256, 256);
} // namespace mace
......@@ -53,10 +53,6 @@ void TestSimpleRelu() {
TEST_F(ActivationOpTest, CPUSimpleRelu) { TestSimpleRelu<DeviceType::CPU>(); }
#if __ARM_NEON
TEST_F(ActivationOpTest, NEONSimpleRelu) { TestSimpleRelu<DeviceType::NEON>(); }
#endif
TEST_F(ActivationOpTest, OPENCLSimpleRelu) {
TestSimpleRelu<DeviceType::OPENCL>();
}
......@@ -104,12 +100,6 @@ TEST_F(ActivationOpTest, CPUUnalignedSimpleRelu) {
TestUnalignedSimpleRelu<DeviceType::CPU>();
}
#if __ARM_NEON
TEST_F(ActivationOpTest, NEONUnalignedSimpleRelu) {
TestUnalignedSimpleRelu<DeviceType::NEON>();
}
#endif
TEST_F(ActivationOpTest, OPENCLUnalignedSimpleRelu) {
TestUnalignedSimpleRelu<DeviceType::OPENCL>();
}
......@@ -160,10 +150,6 @@ void TestSimpleRelux() {
TEST_F(ActivationOpTest, CPUSimple) { TestSimpleRelux<DeviceType::CPU>(); }
#if __ARM_NEON
TEST_F(ActivationOpTest, NEONSimple) { TestSimpleRelux<DeviceType::NEON>(); }
#endif
TEST_F(ActivationOpTest, OPENCLSimple) {
TestSimpleRelux<DeviceType::OPENCL>();
}
......@@ -216,12 +202,6 @@ TEST_F(ActivationOpTest, CPUSimpleRelux) {
TestSimpleReluRelux<DeviceType::CPU>();
}
#if __ARM_NEON
TEST_F(ActivationOpTest, NEONSimpleRelux) {
TestSimpleReluRelux<DeviceType::NEON>();
}
#endif
TEST_F(ActivationOpTest, OPENCLSimpleRelux) {
TestSimpleReluRelux<DeviceType::OPENCL>();
}
......@@ -272,12 +252,6 @@ void TestSimplePrelu() {
TEST_F(ActivationOpTest, CPUSimplePrelu) { TestSimplePrelu<DeviceType::CPU>(); }
#if __ARM_NEON
TEST_F(ActivationOpTest, NEONSimplePrelu) {
TestSimplePrelu<DeviceType::NEON>();
}
#endif
TEST_F(ActivationOpTest, OPENCLSimplePrelu) {
TestSimplePrelu<DeviceType::OPENCL>();
}
......@@ -329,10 +303,6 @@ void TestSimpleTanh() {
TEST_F(ActivationOpTest, CPUSimpleTanh) { TestSimpleTanh<DeviceType::CPU>(); }
#if __ARM_NEON
TEST_F(ActivationOpTest, NEONSimpleTanh) { TestSimpleTanh<DeviceType::NEON>(); }
#endif
TEST_F(ActivationOpTest, OPENCLSimpleTanh) {
TestSimpleTanh<DeviceType::OPENCL>();
}
......@@ -387,12 +357,6 @@ TEST_F(ActivationOpTest, CPUSimpleSigmoid) {
TestSimpleSigmoid<DeviceType::CPU>();
}
#if __ARM_NEON
TEST_F(ActivationOpTest, NEONSimpleSigmoid) {
TestSimpleSigmoid<DeviceType::NEON>();
}
#endif
TEST_F(ActivationOpTest, OPENCLSimpleSigmoid) {
TestSimpleSigmoid<DeviceType::OPENCL>();
}
......
......@@ -65,16 +65,16 @@ static void AddNBenchmark(int iters, int inputs, int n, int h, int w, int c) {
} \
BENCHMARK(BM_ADDN_##INPUTS##_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE)
#define BM_ADDN(INPUTS, N, H, W, C, TYPE) \
BM_ADDN_MACRO(INPUTS, N, H, W, C, TYPE, CPU); \
BM_ADDN_MACRO(INPUTS, N, H, W, C, TYPE, OPENCL);
#define BM_ADDN(INPUTS, N, H, W, C) \
BM_ADDN_MACRO(INPUTS, N, H, W, C, float, CPU); \
BM_ADDN_MACRO(INPUTS, N, H, W, C, float, NEON); \
BM_ADDN_MACRO(INPUTS, N, H, W, C, float, OPENCL); \
BM_ADDN_MACRO(INPUTS, N, H, W, C, half, OPENCL);
BM_ADDN(2, 1, 256, 256, 32, float);
BM_ADDN(2, 1, 128, 128, 32, float);
// BM_ADDN(2, 1, 240, 240, 256, half);
BM_ADDN(4, 1, 128, 128, 3, float);
BM_ADDN(2, 1, 256, 256, 3, float);
BM_ADDN(2, 1, 512, 512, 3, float);
// BM_ADDN(4, 1, 240, 240, 256, half);
BM_ADDN(2, 1, 256, 256, 32);
BM_ADDN(2, 1, 128, 128, 32);
BM_ADDN(4, 1, 128, 128, 3);
BM_ADDN(2, 1, 256, 256, 3);
BM_ADDN(2, 1, 512, 512, 3);
} // namespace mace
} // namespace mace
......@@ -33,12 +33,8 @@ void SimpleAdd2() {
TEST_F(AddnOpTest, CPUSimpleAdd2) { SimpleAdd2<DeviceType::CPU>(); }
/*
TEST_F(AddnOpTest, NEONSimpleAdd2) { SimpleAdd2<DeviceType::NEON>(); }
TEST_F(AddnOpTest, OPENCLSimpleAdd2) { SimpleAdd2<DeviceType::OPENCL>(); }
*/
template <DeviceType D>
void SimpleAdd3() {
// Construct graph
......@@ -65,9 +61,7 @@ void SimpleAdd3() {
TEST_F(AddnOpTest, CPUSimpleAdd3) { SimpleAdd3<DeviceType::CPU>(); }
/*
TEST_F(AddnOpTest, NEONSimpleAdd3) { SimpleAdd3<DeviceType::NEON>(); }
*/
template <DeviceType D>
void RandomTest() {
......
......@@ -82,21 +82,24 @@ static void BatchNorm(
} \
BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_BATCH_NORM(N, C, H, W, TYPE) \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU); \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, OPENCL);
#define BM_BATCH_NORM(N, C, H, W) \
BM_BATCH_NORM_MACRO(N, C, H, W, float, CPU); \
BM_BATCH_NORM_MACRO(N, C, H, W, float, NEON); \
BM_BATCH_NORM_MACRO(N, C, H, W, float, OPENCL); \
BM_BATCH_NORM_MACRO(N, C, H, W, half, OPENCL);
BM_BATCH_NORM(1, 1, 512, 512, float);
BM_BATCH_NORM(1, 3, 128, 128, float);
BM_BATCH_NORM(1, 3, 512, 512, float);
BM_BATCH_NORM(1, 32, 112, 112, float);
BM_BATCH_NORM(1, 64, 256, 256, float);
BM_BATCH_NORM(1, 64, 512, 512, float);
BM_BATCH_NORM(1, 128, 56, 56, float);
BM_BATCH_NORM(1, 128, 256, 256, float);
BM_BATCH_NORM(1, 256, 14, 14, float);
BM_BATCH_NORM(1, 512, 14, 14, float);
BM_BATCH_NORM(1, 1024, 7, 7, float);
BM_BATCH_NORM(32, 1, 256, 256, float);
BM_BATCH_NORM(32, 3, 256, 256, float);
} // namespace mace
BM_BATCH_NORM(1, 1, 512, 512);
BM_BATCH_NORM(1, 3, 128, 128);
BM_BATCH_NORM(1, 3, 512, 512);
BM_BATCH_NORM(1, 32, 112, 112);
BM_BATCH_NORM(1, 64, 256, 256);
BM_BATCH_NORM(1, 64, 512, 512);
BM_BATCH_NORM(1, 128, 56, 56);
BM_BATCH_NORM(1, 128, 256, 256);
BM_BATCH_NORM(1, 256, 14, 14);
BM_BATCH_NORM(1, 512, 14, 14);
BM_BATCH_NORM(1, 1024, 7, 7);
BM_BATCH_NORM(32, 1, 256, 256);
BM_BATCH_NORM(32, 3, 256, 256);
} // namespace mace
......@@ -72,23 +72,18 @@ void Simple() {
TEST_F(BatchNormOpTest, SimpleCPU) { Simple<DeviceType::CPU>(); }
/*
TEST_F(BatchNormOpTest, SimpleNEON) {
Simple<DeviceType::NEON>();
}
*/
TEST_F(BatchNormOpTest, SimpleNEON) { Simple<DeviceType::NEON>(); }
TEST_F(BatchNormOpTest, SimpleOPENCL) { Simple<DeviceType::OPENCL>(); }
/*
TEST_F(BatchNormOpTest, SimpleRandomNeon) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 64;
index_t width = 64;
index_t channels = 3 + rand() % 50;
// Construct graph
OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest")
......@@ -97,18 +92,17 @@ TEST_F(BatchNormOpTest, SimpleRandomNeon) {
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", {batch, channels, height,
width});
net.AddRandomInput<DeviceType::CPU, float>("Input",
{batch, height, width, channels});
net.AddRandomInput<DeviceType::CPU, float>("Scale", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Offset", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Mean", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::CPU, float>("Epsilon", {}, {1e-3});
// run cpu
net.RunOp();
......@@ -139,18 +133,17 @@ TEST_F(BatchNormOpTest, ComplexRandomNeon) {
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", {batch, channels, height,
width});
net.AddRandomInput<DeviceType::CPU, float>("Input",
{batch, height, width, channels});
net.AddRandomInput<DeviceType::CPU, float>("Scale", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Offset", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Mean", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::CPU, float>("Epsilon", {}, {1e-3});
// run cpu
net.RunOp();
......@@ -164,7 +157,6 @@ width});
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-2);
}
*/
TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
srand(time(NULL));
......
......@@ -47,10 +47,10 @@ static void BMBatchToSpace(
} \
BENCHMARK(BM_BATCH_TO_SPACE_##N##_##H##_##W##_##C##_##ARG##_##TYPE##_##DEVICE)
#define BM_BATCH_TO_SPACE(N, H, W, C, ARG, TYPE) \
BM_BATCH_TO_SPACE_MACRO(N, H, W, C, ARG, TYPE, OPENCL);
#define BM_BATCH_TO_SPACE(N, H, W, C, ARG) \
BM_BATCH_TO_SPACE_MACRO(N, H, W, C, ARG, float, OPENCL);
BM_BATCH_TO_SPACE(128, 8, 8, 128, 2, float);
BM_BATCH_TO_SPACE(4, 128, 128, 32, 2, float);
BM_BATCH_TO_SPACE(16, 64, 64, 32, 4, float);
} // namespace mace
\ No newline at end of file
BM_BATCH_TO_SPACE(128, 8, 8, 128, 2);
BM_BATCH_TO_SPACE(4, 128, 128, 32, 2);
BM_BATCH_TO_SPACE(16, 64, 64, 32, 4);
} // namespace mace
......@@ -13,16 +13,6 @@ void Register_BiasAdd(OperatorRegistry *op_registry) {
.Build(),
BiasAddOp<DeviceType::CPU, float>);
/*
#if __ARM_NEON
REGISTER_OPERATOR(op_registry,OpKeyBuilder("BiasAdd")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
BiasAddOp<DeviceType::NEON, float>);
#endif // __ARM_NEON
*/
REGISTER_OPERATOR(op_registry, OpKeyBuilder("BiasAdd")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
......
......@@ -59,21 +59,22 @@ static void BiasAdd(int iters, int batch, int channels, int height, int width) {
} \
BENCHMARK(BM_BIAS_ADD_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_BIAS_ADD(N, C, H, W, TYPE) \
BM_BIAS_ADD_MACRO(N, C, H, W, TYPE, CPU); \
BM_BIAS_ADD_MACRO(N, C, H, W, TYPE, OPENCL);
#define BM_BIAS_ADD(N, C, H, W) \
BM_BIAS_ADD_MACRO(N, C, H, W, float, CPU); \
BM_BIAS_ADD_MACRO(N, C, H, W, float, OPENCL); \
BM_BIAS_ADD_MACRO(N, C, H, W, half, OPENCL);
BM_BIAS_ADD(1, 1, 512, 512, float);
BM_BIAS_ADD(1, 3, 128, 128, float);
BM_BIAS_ADD(1, 3, 512, 512, float);
BM_BIAS_ADD(1, 32, 112, 112, float);
BM_BIAS_ADD(1, 64, 256, 256, float);
BM_BIAS_ADD(1, 64, 512, 512, float);
BM_BIAS_ADD(1, 128, 56, 56, float);
BM_BIAS_ADD(1, 128, 256, 256, float);
BM_BIAS_ADD(1, 256, 14, 14, float);
BM_BIAS_ADD(1, 512, 14, 14, float);
BM_BIAS_ADD(1, 1024, 7, 7, float);
BM_BIAS_ADD(32, 1, 256, 256, float);
BM_BIAS_ADD(32, 3, 256, 256, float);
} // namespace mace
BM_BIAS_ADD(1, 1, 512, 512);
BM_BIAS_ADD(1, 3, 128, 128);
BM_BIAS_ADD(1, 3, 512, 512);
BM_BIAS_ADD(1, 32, 112, 112);
BM_BIAS_ADD(1, 64, 256, 256);
BM_BIAS_ADD(1, 64, 512, 512);
BM_BIAS_ADD(1, 128, 56, 56);
BM_BIAS_ADD(1, 128, 256, 256);
BM_BIAS_ADD(1, 256, 14, 14);
BM_BIAS_ADD(1, 512, 14, 14);
BM_BIAS_ADD(1, 1024, 7, 7);
BM_BIAS_ADD(32, 1, 256, 256);
BM_BIAS_ADD(32, 3, 256, 256);
} // namespace mace
......@@ -20,4 +20,4 @@ void Register_BufferToImage(OperatorRegistry *op_registry) {
BufferToImageOp<DeviceType::OPENCL, half>);
}
} // namespace mace
} // namespace mace
......@@ -35,5 +35,5 @@ class BufferToImageOp: public Operator<D, T> {
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace mace
#endif // MACE_OPS_BUFFER_TO_IMAGE_H_
} // namespace mace
#endif // MACE_OPS_BUFFER_TO_IMAGE_H_
......@@ -13,20 +13,6 @@ void Register_Conv2D(OperatorRegistry *op_registry) {
.Build(),
Conv2dOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D")
.Device(DeviceType::CPU)
.TypeConstraint<half>("T")
.Build(),
Conv2dOp<DeviceType::CPU, half>);
#if MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
Conv2dOp<DeviceType::NEON, float>);
#endif // MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
......
......@@ -29,7 +29,7 @@ static void Conv2d(int iters,
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
net.AddRandomInput<D, float>("Filter",
{kernel_h, kernel_w, channels, output_channels});
{kernel_h, kernel_w, output_channels, channels});
net.AddRandomInput<D, float>("Bias", {output_channels});
if (D == DeviceType::OPENCL) {
......@@ -92,50 +92,46 @@ static void Conv2d(int iters,
BENCHMARK( \
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE)
#define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, OPENCL);
// ICNet
BM_CONV_2D(1, 512, 15, 15, 1, 1, 1, VALID, 1024, half);
//// SNPE GPU ExecutionDuration = 448us, % ALU Utilization = 105
BM_CONV_2D(1, 64, 60, 60, 1, 1, 1, VALID, 128, half);
//// SNPE GPU ExecutionDuration = 258us, % ALU Utilization = 108
BM_CONV_2D(1, 32, 60, 60, 1, 1, 1, VALID, 128, half);
BM_CONV_2D(1, 128, 60, 60, 3, 3, 1, VALID, 128, half);
//// SNPE GPU ExecutionDuration = 506us, % ALU Utilization = 106.8
BM_CONV_2D(1, 32, 60, 60, 3, 3, 1, SAME, 32, half);
BM_CONV_2D(1, 3, 512, 512, 7, 7, 2, SAME, 64, half);
BM_CONV_2D(1, 512, 64, 64, 1, 1, 1, SAME, 256, half);
BM_CONV_2D(1, 128, 16, 16, 3, 3, 1, VALID, 32, half);
BM_CONV_2D(1, 128, 64, 64, 3, 3, 1, VALID, 32, half);
BM_CONV_2D(1, 128, 128, 128, 3, 3, 1, VALID, 32, half);
// Test RGB <-> YUV
// BM_CONV_2D(1, 3, 2160, 1080, 1, 1, 1, VALID, 3, float);
// BM_CONV_2D(1, 3, 480, 480, 1, 1, 1, VALID, 3, float);
//
// BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float);
// BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128, float); // Test bad
// alignments
// BM_CONV_2D(1, 3, 512, 512, 1, 1, 1, VALID, 3, float);
// BM_CONV_2D(1, 32, 112, 112, 1, 1, 1, VALID, 64, float);
// BM_CONV_2D(1, 64, 56, 56, 1, 1, 1, VALID, 128, float);
// BM_CONV_2D(1, 256, 28, 28, 1, 1, 1, VALID, 256, float);
// BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, VALID, 1024, float);
// BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float);
// BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float);
// BM_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 3, float);
// BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float);
// BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float);
// BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 128, float);
// BM_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 3, float);
// BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 128, float);
// BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 128, float);
// BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 128, float);
// BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, VALID, 128, float);
// BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128, float);
// BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, SAME, 128, float);
// BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, SAME, 128, float);
} // namespace mace
#define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC) \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, float, CPU); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, float, OPENCL); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, half, OPENCL);
BM_CONV_2D(1, 512, 15, 15, 1, 1, 1, VALID, 1024);
BM_CONV_2D(1, 64, 60, 60, 1, 1, 1, VALID, 128);
BM_CONV_2D(1, 32, 60, 60, 1, 1, 1, VALID, 128);
BM_CONV_2D(1, 128, 60, 60, 3, 3, 1, VALID, 128);
BM_CONV_2D(1, 32, 60, 60, 3, 3, 1, SAME, 32);
BM_CONV_2D(1, 3, 512, 512, 7, 7, 2, SAME, 64);
BM_CONV_2D(1, 512, 64, 64, 1, 1, 1, SAME, 256);
BM_CONV_2D(1, 128, 16, 16, 3, 3, 1, VALID, 32);
BM_CONV_2D(1, 128, 64, 64, 3, 3, 1, VALID, 32);
BM_CONV_2D(1, 128, 128, 128, 3, 3, 1, VALID, 32);
BM_CONV_2D(1, 3, 480, 480, 1, 1, 1, VALID, 3);
BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128);
BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128); // Test bad alignments
BM_CONV_2D(1, 3, 512, 512, 1, 1, 1, VALID, 3);
BM_CONV_2D(1, 32, 112, 112, 1, 1, 1, VALID, 64);
BM_CONV_2D(1, 64, 56, 56, 1, 1, 1, VALID, 128);
BM_CONV_2D(1, 256, 28, 28, 1, 1, 1, VALID, 256);
BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, VALID, 1024);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128);
BM_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 3);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 128);
BM_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 3);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 128);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 128);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 128);
BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, VALID, 128);
BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128);
BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, SAME, 128);
BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, SAME, 128);
} // namespace mace
......@@ -10,81 +10,6 @@ using namespace mace;
class Conv2dOpTest : public OpsTestBase {};
template <DeviceType D>
void TestSimple3x3VALID() {
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Add args
// Add input data
net.AddInputFromArray<D, float>(
"Input", {1, 2, 3, 3},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, float>(
"Filter", {1, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<D, float>("Bias", {1}, {0.1f});
// Run
net.RunOp(D);
// Check
auto expected = CreateTensor<float>({1, 1, 1, 1}, {18.1f});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
template <DeviceType D>
void TestSimple3x3SAME() {
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<D, float>(
"Input", {1, 2, 3, 3},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, float>(
"Filter", {1, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<D, float>("Bias", {1}, {0.1f});
// Run
net.RunOp(D);
// Check
auto expected = CreateTensor<float>(
{1, 1, 3, 3},
{8.1f, 12.1f, 8.1f, 12.1f, 18.1f, 12.1f, 8.1f, 12.1f, 8.1f});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
#if __ARM_NEON
TEST_F(Conv2dOpTest, NEONSimple) {
TestSimple3x3VALID<DeviceType::NEON>();
TestSimple3x3SAME<DeviceType::NEON>();
}
#endif
template <DeviceType D, typename T>
void TestNHWCSimple3x3VALID() {
OpsTestNet net;
......@@ -93,7 +18,7 @@ void TestNHWCSimple3x3VALID() {
"Input", {1, 3, 3, 2},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, T>(
"Filter", {3, 3, 2, 1},
"Filter", {3, 3, 1, 2},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<D, T>("Bias", {1}, {0.1f});
......@@ -150,7 +75,7 @@ void TestNHWCSimple3x3SAME() {
"Input", {1, 3, 3, 2},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, T>(
"Filter", {3, 3, 2, 1},
"Filter", {3, 3, 1, 2},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<D, T>("Bias", {1}, {0.1f});
......@@ -211,42 +136,6 @@ TEST_F(Conv2dOpTest, OPENCLSimple) {
TestNHWCSimple3x3SAME<DeviceType::OPENCL, float>();
}
template <DeviceType D>
void TestSimple3x3WithoutBias() {
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<D, float>(
"Input", {1, 2, 3, 3},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, float>(
"Filter", {1, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
// Run
net.RunOp(D);
// Check
auto expected = CreateTensor<float>({1, 1, 1, 1}, {18.0f});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
#ifdef __ARM_NEON
TEST_F(Conv2dOpTest, NEONWithouBias) {
TestSimple3x3WithoutBias<DeviceType::NEON>();
}
#endif
template <DeviceType D, typename T>
void TestNHWCSimple3x3WithoutBias() {
OpsTestNet net;
......@@ -256,7 +145,7 @@ void TestNHWCSimple3x3WithoutBias() {
"Input", {1, 3, 3, 2},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, T>(
"Filter", {3, 3, 2, 1},
"Filter", {3, 3, 1, 2},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
......@@ -309,47 +198,6 @@ TEST_F(Conv2dOpTest, OPENCLWithoutBias) {
TestNHWCSimple3x3WithoutBias<DeviceType::OPENCL, float>();
}
template <DeviceType D>
static void TestCombined3x3() {
// Construct graph
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<D, float>(
"Input", {1, 2, 5, 5}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, float>(
"Filter", {2, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f});
// Run
net.RunOp(D);
// Check
auto expected = CreateTensor<float>(
{1, 2, 3, 3}, {8.1f, 12.1f, 8.1f, 12.1f, 18.1f, 12.1f, 8.1f, 12.1f, 8.1f,
4.2f, 6.2f, 4.2f, 6.2f, 9.2f, 6.2f, 4.2f, 6.2f, 4.2f});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
#ifdef __ARM_NEON
TEST_F(Conv2dOpTest, NEONCombined) { TestCombined3x3<DeviceType::NEON>(); }
#endif
template <DeviceType D, typename T>
static void TestNHWCCombined3x3() {
// Construct graph
......@@ -362,9 +210,9 @@ static void TestNHWCCombined3x3() {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, T>(
"Filter", {3, 3, 2, 2},
{1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f,
1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f,
1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f});
{1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f,
1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f,
1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f});
net.AddInputFromArray<D, T>("Bias", {2}, {0.1f, 0.2f});
if (D == DeviceType::OPENCL) {
......@@ -436,8 +284,8 @@ void TestConv1x1() {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, float>(
"Filter", {1, 1, 5, 2},
{1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f});
"Filter", {1, 1, 2, 5},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f});
net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f});
if (D == DeviceType::OPENCL) {
......@@ -522,7 +370,7 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, input_channels, output_channels});
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu
......@@ -606,7 +454,7 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape,
float_input_data);
std::vector<float> float_filter_data;
GenerateRandomRealTypeData(
{kernel_h, kernel_w, input_channels, output_channels},
{kernel_h, kernel_w, output_channels, input_channels},
float_filter_data);
std::vector<float> float_bias_data;
GenerateRandomRealTypeData({output_channels}, float_bias_data);
......@@ -614,7 +462,7 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape,
net.AddInputFromArray<D, float>(
"Input", {batch, height, width, input_channels}, float_input_data);
net.AddInputFromArray<D, float>(
"Filter", {kernel_h, kernel_w, input_channels, output_channels},
"Filter", {kernel_h, kernel_w, output_channels, input_channels},
float_filter_data);
net.AddInputFromArray<D, float>("Bias", {output_channels}, float_bias_data);
......@@ -748,7 +596,7 @@ static void TestDilationConvNxN(const std::vector<index_t> &shape, const int dil
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, input_channels, output_channels});
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu
......
......@@ -13,14 +13,6 @@ void Register_DepthwiseConv2d(OperatorRegistry *op_registry) {
.Build(),
DepthwiseConv2dOp<DeviceType::CPU, float>);
#if MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthwiseConv2d")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
DepthwiseConv2dOp<DeviceType::NEON, float>);
#endif // MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthwiseConv2d")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
......
......@@ -288,12 +288,6 @@ void TestNxNS12(const index_t height, const index_t width) {
}
}
#if __ARM_NEON
TEST_F(DepthwiseConv2dOpTest, NeonSimpleNxNS12) {
TestNxNS12<DeviceType::NEON, float>(4, 4);
}
#endif
TEST_F(DepthwiseConv2dOpTest, OpenCLSimpleNxNS12) {
TestNxNS12<DeviceType::OPENCL, float>(4, 4);
}
......@@ -302,13 +296,6 @@ TEST_F(DepthwiseConv2dOpTest, OpenCLSimpleNxNS12Half) {
TestNxNS12<DeviceType::OPENCL, half>(4, 4);
}
#if __ARM_NEON
TEST_F(DepthwiseConv2dOpTest, NeonAlignedNxNS12) {
TestNxNS12<DeviceType::NEON, float>(64, 64);
TestNxNS12<DeviceType::NEON, float>(128, 128);
}
#endif
TEST_F(DepthwiseConv2dOpTest, OpenCLAlignedNxNS12) {
TestNxNS12<DeviceType::OPENCL, float>(64, 64);
TestNxNS12<DeviceType::OPENCL, float>(128, 128);
......@@ -319,12 +306,6 @@ TEST_F(DepthwiseConv2dOpTest, OpenCLAlignedNxNS12Half) {
TestNxNS12<DeviceType::OPENCL, half>(128, 128);
}
#if __ARM_NEON
TEST_F(DepthwiseConv2dOpTest, NeonUnalignedNxNS12) {
TestNxNS12<DeviceType::NEON, float>(107, 113);
}
#endif
TEST_F(DepthwiseConv2dOpTest, OpenCLUnalignedNxNS12) {
TestNxNS12<DeviceType::OPENCL, float>(107, 113);
}
......
......@@ -89,21 +89,22 @@ static void DepthwiseConv2d(int iters,
BENCHMARK( \
BM_DEPTHWISE_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE)
#define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, OPENCL);
#define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, OC) \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, float, CPU); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, float, OPENCL); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, half, OPENCL);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 1, float);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, SAME, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 1, float);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 1, float);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 1);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 1);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 1);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 1);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 1, SAME, 1);
//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 1);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 1);
//BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 1);
//BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 1);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 1);
//BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 1);
} // namespace mace
......@@ -14,14 +14,6 @@ void Register_FoldedBatchNorm(OperatorRegistry *op_registry) {
.Build(),
FoldedBatchNormOp<DeviceType::CPU, float>);
#if MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
FoldedBatchNormOp<DeviceType::NEON, float>);
#endif // MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry,
OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::OPENCL)
......
......@@ -13,12 +13,6 @@ void Register_FusedConv2D(OperatorRegistry *op_registry) {
.Build(),
FusedConv2dOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FusedConv2D")
.Device(DeviceType::CPU)
.TypeConstraint<half>("T")
.Build(),
FusedConv2dOp<DeviceType::CPU, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FusedConv2D")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
......
......@@ -298,7 +298,7 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, input_channels, output_channels});
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu
......@@ -375,7 +375,7 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
float_input_data);
std::vector<float> float_filter_data;
GenerateRandomRealTypeData(
{kernel_h, kernel_w, input_channels, output_channels},
{kernel_h, kernel_w, output_channels, input_channels},
float_filter_data);
std::vector<float> float_bias_data;
GenerateRandomRealTypeData({output_channels}, float_bias_data);
......@@ -383,7 +383,7 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
net.AddInputFromArray<D, float>(
"Input", {batch, height, width, input_channels}, float_input_data);
net.AddInputFromArray<D, float>(
"Filter", {kernel_h, kernel_w, input_channels, output_channels},
"Filter", {kernel_h, kernel_w, output_channels, input_channels},
float_filter_data);
net.AddInputFromArray<D, float>("Bias", {output_channels}, float_bias_data);
......@@ -462,7 +462,7 @@ static void TestGeneralConvNxNS12(const std::vector<index_t> &image_shape,
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, input_channels, output_channels});
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu
......@@ -540,7 +540,7 @@ static void TestAtrousConvNxN(const std::vector<index_t> &shape, const int dilat
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, input_channels, output_channels});
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu
......@@ -622,7 +622,7 @@ static void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape,
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, float>(
"Filter", {kernel_h, kernel_w, input_channels, output_channels});
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, float>("Bias", {output_channels});
// run on cpu
......
......@@ -12,14 +12,6 @@ void Register_GlobalAvgPooling(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T")
.Build(),
GlobalAvgPoolingOp<DeviceType::CPU, float>);
#if MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("GlobalAvgPooling")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
GlobalAvgPoolingOp<DeviceType::NEON, float>);
#endif // MACE_ENABLE_NEON
}
} // namespace mace
......@@ -31,29 +31,3 @@ TEST_F(GlobalAvgPoolingOpTest, 3x7x7_CPU) {
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
#if __ARM_NEON
TEST_F(GlobalAvgPoolingOpTest, 3x7x7_NEON) {
// Construct graph
OpsTestNet net;
OpDefBuilder("GlobalAvgPooling", "GlobalAvgPoolingTest")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
std::vector<float> input(147);
for (int i = 0; i < 147; ++i) {
input[i] = i / 49 + 1;
}
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 7, 7}, input);
// Run
net.RunOp(DeviceType::NEON);
// Check
auto expected = CreateTensor<float>({1, 3, 1, 1}, {1, 2, 3});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
#endif
......@@ -61,10 +61,10 @@ static void MatMulBenchmark(
} \
BENCHMARK(BM_MATMUL_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE)
#define BM_MATMUL(N, H, C, W, TYPE) \
BM_MATMUL_MACRO(N, H, C, W, TYPE, OPENCL);
#define BM_MATMUL(N, H, C, W) \
BM_MATMUL_MACRO(N, H, C, W, half, OPENCL);
BM_MATMUL(16, 32, 128, 49, half);
BM_MATMUL(16, 32, 128, 961, half);
BM_MATMUL(16, 32, 128, 3969, half);
} // namespace mace
BM_MATMUL(16, 32, 128, 49);
BM_MATMUL(16, 32, 128, 961);
BM_MATMUL(16, 32, 128, 3969);
} // namespace mace
......@@ -18,14 +18,6 @@ void Register_Pooling(OperatorRegistry *op_registry) {
.Build(),
PoolingOp<DeviceType::CPU, half>);
#if MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
PoolingOp<DeviceType::NEON, float>);
#endif // MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Pooling")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
......
......@@ -13,14 +13,6 @@ void Register_ResizeBilinear(OperatorRegistry *op_registry) {
.Build(),
ResizeBilinearOp<DeviceType::CPU, float>);
#if MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBilinear")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
ResizeBilinearOp<DeviceType::NEON, float>);
#endif // MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("ResizeBilinear")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
......
......@@ -69,18 +69,18 @@ static void ResizeBilinearBenchmark(int iters,
BENCHMARK( \
BM_RESIZE_BILINEAR_##N##_##C##_##H0##_##W0##_##H1##_##W1##_##TYPE##_##DEVICE)
#define BM_RESIZE_BILINEAR(N, C, H0, W0, H1, W1, TYPE) \
BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, CPU); \
BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, TYPE, OPENCL);
#define BM_RESIZE_BILINEAR(N, C, H0, W0, H1, W1) \
BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, float, CPU); \
BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, float, OPENCL); \
BM_RESIZE_BILINEAR_MACRO(N, C, H0, W0, H1, W1, half, OPENCL);
// SNPE 835 GPU: 6870us
BM_RESIZE_BILINEAR(1, 128, 120, 120, 480, 480, float);
BM_RESIZE_BILINEAR(1, 128, 120, 120, 480, 480);
BM_RESIZE_BILINEAR(1, 256, 7, 7, 15, 15, float);
BM_RESIZE_BILINEAR(1, 256, 15, 15, 30, 30, float);
BM_RESIZE_BILINEAR(1, 128, 30, 30, 60, 60, float);
BM_RESIZE_BILINEAR(1, 128, 240, 240, 480, 480, float);
BM_RESIZE_BILINEAR(1, 3, 4032, 3016, 480, 480, float);
BM_RESIZE_BILINEAR(1, 3, 480, 480, 4032, 3016, float);
BM_RESIZE_BILINEAR(1, 256, 7, 7, 15, 15);
BM_RESIZE_BILINEAR(1, 256, 15, 15, 30, 30);
BM_RESIZE_BILINEAR(1, 128, 30, 30, 60, 60);
BM_RESIZE_BILINEAR(1, 128, 240, 240, 480, 480);
BM_RESIZE_BILINEAR(1, 3, 4032, 3016, 480, 480);
BM_RESIZE_BILINEAR(1, 3, 480, 480, 4032, 3016);
} // namespace mace
} // namespace mace
......@@ -55,13 +55,14 @@ static void SoftmaxBenchmark(
} \
BENCHMARK(BM_SOFTMAX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_SOFTMAX(N, C, H, W, TYPE) \
BM_SOFTMAX_MACRO(N, C, H, W, TYPE, CPU); \
BM_SOFTMAX_MACRO(N, C, H, W, TYPE, OPENCL);
#define BM_SOFTMAX(N, C, H, W) \
BM_SOFTMAX_MACRO(N, C, H, W, float, CPU); \
BM_SOFTMAX_MACRO(N, C, H, W, float, OPENCL); \
BM_SOFTMAX_MACRO(N, C, H, W, half, OPENCL);
BM_SOFTMAX(1, 1, 512, 512, float);
BM_SOFTMAX(1, 3, 128, 128, float);
BM_SOFTMAX(1, 3, 512, 512, float);
BM_SOFTMAX(1, 32, 112, 112, float);
BM_SOFTMAX(1, 64, 256, 256, float);
BM_SOFTMAX(1, 1, 512, 512);
BM_SOFTMAX(1, 3, 128, 128);
BM_SOFTMAX(1, 3, 512, 512);
BM_SOFTMAX(1, 32, 112, 112);
BM_SOFTMAX(1, 64, 256, 256);
} // namespace mace
......@@ -49,10 +49,10 @@ static void BMSpaceToBatch(
BENCHMARK( \
BM_SPACE_TO_BATCH_##N##_##H##_##W##_##C##_##SHAPE##_##TYPE##_##DEVICE)
#define BM_SPACE_TO_BATCH(N, H, W, C, SHAPE, TYPE) \
BM_SPACE_TO_BATCH_MACRO(N, H, W, C, SHAPE, TYPE, OPENCL);
#define BM_SPACE_TO_BATCH(N, H, W, C, SHAPE) \
BM_SPACE_TO_BATCH_MACRO(N, H, W, C, SHAPE, float, OPENCL);
BM_SPACE_TO_BATCH(128, 16, 16, 128, 2, float);
BM_SPACE_TO_BATCH(1, 256, 256, 32, 2, float);
BM_SPACE_TO_BATCH(1, 256, 256, 32, 4, float);
} // namespace mace
\ No newline at end of file
BM_SPACE_TO_BATCH(128, 16, 16, 128, 2);
BM_SPACE_TO_BATCH(1, 256, 256, 32, 2);
BM_SPACE_TO_BATCH(1, 256, 256, 32, 4);
} // namespace mace
......@@ -19,9 +19,9 @@ void TransposeFilter(const std::vector<float> &input,
const float *input_ptr = input.data();
for (index_t h = 0; h < input_shape[0]; ++h) {
for (index_t w = 0; w < input_shape[1]; ++w) {
for (index_t ic = 0; ic < input_shape[2]; ++ic) {
for (index_t oc = 0; oc < input_shape[3]; ++oc) {
int offset = ((oc * input_shape[2] + ic) * input_shape[0] + h) * input_shape[1] + w;
for (index_t oc = 0; oc < input_shape[2]; ++oc) {
for (index_t ic = 0; ic < input_shape[3]; ++ic) {
int offset = ((oc * input_shape[3] + ic) * input_shape[0] + h) * input_shape[1] + w;
output[offset] = *input_ptr;
++input_ptr;
}
......@@ -43,7 +43,7 @@ void WinogradConvolution(const index_t batch,
OpsTestNet net;
// Add input data
std::vector<float> filter_data;
std::vector<index_t> filter_shape = {3, 3, in_channels, out_channels};
std::vector<index_t> filter_shape = {3, 3, out_channels, in_channels};
GenerateRandomRealTypeData<float>(filter_shape, filter_data);
net.AddRandomInput<D, float>("Input", {batch, height, width, in_channels});
net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data);
......
......@@ -48,12 +48,12 @@ static void BMWinogradTransform(
BENCHMARK( \
BM_WINOGRAD_TRANSFORM_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE)
#define BM_WINOGRAD_TRANSFORM(N, H, W, C, TYPE) \
BM_WINOGRAD_TRANSFORM_MACRO(N, H, W, C, TYPE, OPENCL);
#define BM_WINOGRAD_TRANSFORM(N, H, W, C) \
BM_WINOGRAD_TRANSFORM_MACRO(N, H, W, C, half, OPENCL);
BM_WINOGRAD_TRANSFORM(1, 16, 16, 128, half);
BM_WINOGRAD_TRANSFORM(1, 64, 64, 128, half);
BM_WINOGRAD_TRANSFORM(1, 128, 128, 128, half);
BM_WINOGRAD_TRANSFORM(1, 16, 16, 128);
BM_WINOGRAD_TRANSFORM(1, 64, 64, 128);
BM_WINOGRAD_TRANSFORM(1, 128, 128, 128);
template <DeviceType D, typename T>
static void BMWinogradInverseTransform(
......@@ -100,11 +100,11 @@ static void BMWinogradInverseTransform(
BENCHMARK( \
BM_WINOGRAD_INVERSE_TRANSFORM_##N##_##H##_##W##_##C##_##TYPE##_##DEVICE)
#define BM_WINOGRAD_INVERSE_TRANSFORM(N, H, W, C, TYPE) \
BM_WINOGRAD_INVERSE_TRANSFORM_MACRO(N, H, W, C, TYPE, OPENCL);
#define BM_WINOGRAD_INVERSE_TRANSFORM(N, H, W, C) \
BM_WINOGRAD_INVERSE_TRANSFORM_MACRO(N, H, W, C, half, OPENCL);
BM_WINOGRAD_INVERSE_TRANSFORM(1, 14, 14, 32, half);
BM_WINOGRAD_INVERSE_TRANSFORM(1, 62, 62, 32, half);
BM_WINOGRAD_INVERSE_TRANSFORM(1, 126, 126, 32, half);
BM_WINOGRAD_INVERSE_TRANSFORM(1, 14, 14, 32);
BM_WINOGRAD_INVERSE_TRANSFORM(1, 62, 62, 32);
BM_WINOGRAD_INVERSE_TRANSFORM(1, 126, 126, 32);
} // namespace mace
\ No newline at end of file
} // namespace mace
......@@ -30,6 +30,11 @@ Integer RoundUpDiv8(Integer i) {
return (i + 7) >> 3;
}
template <typename Integer>
Integer RoundUpDiv(Integer i, Integer factor) {
return (i + factor - 1) / factor;
}
template <typename Integer>
Integer CeilQuotient(Integer a, Integer b) {
return (a + b - 1) / b;
......
......@@ -18,8 +18,8 @@ BAZEL_BIN_PATH=${BAZEL_BIN_PATH#//}
BAZEL_BIN_PATH=bazel-bin/$BAZEL_BIN_PATH
BIN_NAME=`echo $BAZEL_TARGET | cut -d: -f2`
ANDROID_ABI=arm64-v8a
ANDROID_ABI=armeabi-v7a
ANDROID_ABI=arm64-v8a
STRIP="--strip always"
VLOG_LEVEL=0
PROFILING="1"
......@@ -43,7 +43,8 @@ bazel build -c opt $STRIP --verbose_failures $BAZEL_TARGET \
--copt="-D_GLIBCXX_USE_C99_MATH_TR1" \
--copt="-DMACE_DISABLE_NO_TUNING_WARNING" \
--copt="-Werror=return-type" \
--define neon=false \
--copt="-O3" \
--define neon=true \
--define openmp=true
if [ $? -ne 0 ]; then
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册