提交 e39003da 编写于 作者: 李寅

Reimplemented neon kernels

上级 64ee40e0
......@@ -186,7 +186,7 @@ bool Run(MaceEngine *engine,
return true;
}
DEFINE_string(device, "CPU", "Device [CPU|OPENCL]");
DEFINE_string(device, "CPU", "Device [CPU|NEON|OPENCL]");
DEFINE_string(input_node, "input_node0,input_node1",
"input nodes, separated by comma");
DEFINE_string(output_node, "output_node0,output_node1",
......@@ -264,6 +264,8 @@ int Main(int argc, char **argv) {
DeviceType device_type = CPU;
if (FLAGS_device == "OPENCL") {
device_type = OPENCL;
} else if (FLAGS_device == "NEON") {
device_type = NEON;
}
// config runtime
......@@ -271,7 +273,7 @@ int Main(int argc, char **argv) {
mace::ConfigOpenCLRuntime(
static_cast<GPUPerfHint>(FLAGS_gpu_perf_hint),
static_cast<GPUPriorityHint>(FLAGS_gpu_priority_hint));
} else if (device_type == CPU) {
} else if (device_type == CPU || device_type == NEON) {
mace::ConfigOmpThreadsAndAffinity(
FLAGS_omp_num_threads,
static_cast<CPUPowerOption>(FLAGS_cpu_power_option));
......
......@@ -41,6 +41,8 @@ class BufferBase {
virtual bool OnHost() const = 0;
virtual void Clear() = 0;
virtual index_t offset() const { return 0; }
template <typename T>
......@@ -158,6 +160,12 @@ class Buffer : public BufferBase {
bool OnHost() const { return allocator_->OnHost(); }
void Clear() {
if (buf_ != nullptr) {
memset(buf_, 0, size_);
}
}
private:
Allocator *allocator_;
void *buf_;
......@@ -242,6 +250,10 @@ class Image : public BufferBase {
bool OnHost() const { return allocator_->OnHost(); }
void Clear() {
MACE_NOT_IMPLEMENTED;
}
private:
Allocator *allocator_;
std::vector<size_t> shape_;
......@@ -322,6 +334,10 @@ class BufferSlice : public BufferBase {
bool OnHost() const { return buffer_->OnHost(); }
void Clear() {
MACE_NOT_IMPLEMENTED;
}
private:
BufferBase *buffer_;
void *mapped_buf_;
......
......@@ -93,10 +93,9 @@ extern void Register_Slice(OperatorRegistry *op_registry);
extern void Register_Softmax(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_SpaceToDepth(OperatorRegistry *op_registry);
extern void Register_Transpose(OperatorRegistry *op_registry);
extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry);
extern void Register_WinogradTransform(OperatorRegistry *op_registry);
} // namespace ops
OperatorRegistry::OperatorRegistry() {
......@@ -130,6 +129,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_Softmax(this);
ops::Register_SpaceToBatchND(this);
ops::Register_SpaceToDepth(this);
ops::Register_Transpose(this);
ops::Register_WinogradInverseTransform(this);
ops::Register_WinogradTransform(this);
}
......
......@@ -146,21 +146,26 @@ class Tensor {
template <typename T>
inline const T *data() const {
MACE_CHECK(buffer_ != nullptr, "buffer is null");
MACE_CHECK_NOTNULL(buffer_);
return buffer_->data<T>();
}
inline void *raw_mutable_data() {
MACE_CHECK(buffer_ != nullptr, "buffer is null");
MACE_CHECK_NOTNULL(buffer_);
return buffer_->raw_mutable_data();
}
template <typename T>
inline T *mutable_data() {
MACE_CHECK(buffer_ != nullptr, "buffer is null");
MACE_CHECK_NOTNULL(buffer_);
return static_cast<T *>(buffer_->raw_mutable_data());
}
inline void Clear() {
MACE_CHECK_NOTNULL(buffer_);
buffer_->Clear();
}
inline void Reshape(const std::vector<index_t> &shape) {
shape_ = shape;
MACE_CHECK(raw_size() <= buffer_->size());
......@@ -258,22 +263,19 @@ class Tensor {
inline void DebugPrint() const {
using namespace numerical_chars; // NOLINT(build/namespaces)
std::stringstream os;
os << "Tensor " << name_ << " size: [";
for (index_t i : shape_) {
os << i << ", ";
}
os << "], content:\n";
os.str("");
os.clear();
MappingGuard guard(this);
for (int i = 0; i < size(); ++i) {
if (i != 0 && i % shape_[3] == 0) {
if (i != 0 && i % shape_.back() == 0) {
os << "\n";
}
CASES(dtype_, (os << (this->data<T>()[i]) << ", "));
}
LOG(INFO) << "Tensor size: [" << dim(0) << ", " << dim(1) << ", " << dim(2)
<< ", " << dim(3) << "], content:\n"
<< os.str();
LOG(INFO) << os.str();
}
class MappingGuard {
......
......@@ -21,6 +21,7 @@ Tensor *Workspace::CreateTensor(const std::string &name,
VLOG(3) << "Creating Tensor " << name;
tensor_map_[name] =
std::move(std::unique_ptr<Tensor>(new Tensor(alloc, type)));
tensor_map_[name]->SetSourceOpName(name);
}
return GetTensor(name);
}
......
......@@ -11,13 +11,21 @@ load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled")
cc_library(
name = "kernels",
srcs = glob([
"*.cc",
"opencl/*.cc",
]),
srcs = glob(
[
"*.cc",
"opencl/*.cc",
"arm/*.cc",
],
exclude = [
"*_test.cc",
"arm/*_test.cc",
],
),
hdrs = glob([
"*.h",
"opencl/*.h",
"arm/*.h",
]),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]),
linkopts = if_android(["-lm"]),
......@@ -28,14 +36,20 @@ cc_library(
)
cc_test(
name = "kernel_test",
name = "kernels_test",
testonly = 1,
srcs = glob(["test/*.cc"]),
linkopts = if_android(["-pie"]),
srcs = glob(
[
"*_test.cc",
"arm/*_test.cc",
],
),
copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]),
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
":kernels",
"//mace/core",
"@gtest//:gtest",
"@gtest//:gtest_main",
],
)
......
......@@ -134,11 +134,20 @@ class ActivationFunctor {
};
template <>
void ActivationFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future);
class ActivationFunctor<DeviceType::NEON, float> {
public:
ActivationFunctor(ActivationType type, float relux_max_limit)
: activation_(type), relux_max_limit_(relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future);
private:
ActivationType activation_;
float relux_max_limit_;
};
template <typename T>
class ActivationFunctor<DeviceType::OPENCL, T> {
......
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/activation.h"
namespace mace {
namespace kernels {
void ActivationFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future) {
const float *input_ptr = input->data<float>();
float *output_ptr = output->mutable_data<float>();
if (activation_ == PRELU) {
MACE_CHECK_NOTNULL(alpha);
const float *alpha_ptr = alpha->data<float>();
PReLUActivation(input_ptr, output->size(), input->dim(1), alpha_ptr,
output_ptr);
} else {
DoActivation(input_ptr, output_ptr, output->size(), activation_,
relux_max_limit_);
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/batch_norm.h"
namespace mace {
namespace kernels {
void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
// ( \offset - \frac { \scale * mean } {
// \sqrt{var+\variance_epsilon} }
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
const index_t height = input->dim(2);
const index_t width = input->dim(3);
const float *input_ptr = input->data<float>();
const float *scale_ptr = scale->data<float>();
const float *offset_ptr = offset->data<float>();
float *output_ptr = output->mutable_data<float>();
std::vector<float> new_scale;
std::vector<float> new_offset;
if (!folded_constant_) {
new_scale.resize(channels);
new_offset.resize(channels);
const float *mean_ptr = mean->data<float>();
const float *var_ptr = var->data<float>();
#pragma omp parallel for
for (index_t c = 0; c < channels; ++c) {
new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon);
new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c];
}
}
const float *scale_data = folded_constant_ ? scale_ptr : new_scale.data();
const float *offset_data = folded_constant_ ? offset_ptr : new_offset.data();
index_t channel_size = height * width;
index_t batch_size = channels * channel_size;
// NEON is slower, so stick to the trivial implementaion
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
index_t offset = b * batch_size + c * channel_size;
for (index_t hw = 0; hw < height * width; ++hw) {
output_ptr[offset + hw] =
scale_data[c] * input_ptr[offset + hw] + offset_data[c];
}
}
}
DoActivation(output_ptr, output_ptr, output->size(), activation_,
relux_max_limit_);
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/arm/conv_winograd.h"
// winograd is always superior to neon impl during benchmark
#define USE_WINOGRAD 1
namespace mace {
namespace kernels {
namespace {
void Conv2dNCHW(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
float *output) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
for (index_t c = 0; c < in_channels; ++c) {
for (index_t kh = 0; kh < filter_height; ++kh) {
for (index_t kw = 0; kw < filter_width; ++kw) {
index_t ih = h * stride_h + kh * dilation_h;
index_t iw = w * stride_w + kw * dilation_w;
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t filter_offset =
(((m * in_channels) + c) * filter_height + kh) * filter_width
+ kw;
output[out_offset] += input[in_offset] * filter[filter_offset];
}
}
}
}
}
}
}
}
} // namespace
extern void Conv2dNeonK1x1S1(const float *input,
const float *filter,
const index_t batch,
const index_t height,
const index_t width,
const index_t in_channels,
const index_t out_channels,
float *output);
extern void Conv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
float *output);
extern void Conv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
float *output);
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output);
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
if (paddings_.empty()) {
CalcNCHWPaddingAndOutputSize(input->shape().data(),
filter->shape().data(),
dilations_,
strides_,
padding_type_,
output_shape.data(),
paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input->shape().data(), filter->shape().data(),
paddings_.data(), dilations_, strides_, RoundType::FLOOR,
output_shape.data());
}
output->Resize(output_shape);
output->Clear();
index_t batch = output->dim(0);
index_t channels = output->dim(1);
index_t height = output->dim(2);
index_t width = output->dim(3);
index_t input_batch = input->dim(0);
index_t input_channels = input->dim(1);
index_t input_height = input->dim(2);
index_t input_width = input->dim(3);
index_t filter_h = filter->dim(2);
index_t filter_w = filter->dim(3);
MACE_CHECK(filter->dim(0) == channels, filter->dim(0), " != ", channels);
MACE_CHECK(filter->dim(1) == input_channels, filter->dim(1), " != ",
input_channels);
index_t stride_h = strides_[0];
index_t stride_w = strides_[1];
index_t dilation_h = dilations_[0];
index_t dilation_w = dilations_[1];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
index_t padded_input_height = input_height + paddings[0];
index_t padded_input_width = input_width + paddings[1];
index_t extra_input_height = padded_input_height;
index_t extra_input_width = padded_input_width;
index_t extra_output_height = height;
index_t extra_output_width = width;
int pad_top = paddings[0] >> 1;
int pad_bottom = paddings[0] - pad_top;
int pad_left = paddings[1] >> 1;
int pad_right = paddings[1] - pad_left;
std::function<void(const float *input, float *output)> conv_func;
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto bias_data = bias == nullptr ? nullptr : bias->data<float>();
auto output_data = output->mutable_data<float>();
memset(output_data, 0, sizeof(float) * batch * channels * height * width);
if (USE_WINOGRAD && filter_h == 3 && filter_w == 3 && stride_h == 1
&& stride_w == 1
&& dilation_h == 1 && dilation_w == 1) {
extra_output_height = RoundUp<index_t>(height, 2);
extra_input_height = std::max(padded_input_height, extra_output_height + 2);
extra_output_width = RoundUp<index_t>(width, 2);
extra_input_width = std::max(padded_input_width, extra_output_width + 2);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
index_t tile_height_count = (extra_output_height + 1) / 2;
index_t tile_width_count = (extra_output_width + 1) / 2;
index_t tile_count = tile_height_count * tile_width_count;
transformed_input_.Resize({16, batch, input_channels, tile_count});
transformed_filter_.Resize({16, channels, input_channels});
transformed_output_.Resize({16, batch, channels, tile_count});
conv_func = [=](const float *pad_input, float *pad_output) {
WinoGradConv3x3s1(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
channels,
transformed_input_.mutable_data<float>(),
transformed_filter_.mutable_data<float>(),
transformed_output_.mutable_data<float>(),
is_filter_transformed_,
pad_output);
is_filter_transformed_ = true;
};
} else if (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1
&& dilation_h == 1 && dilation_w == 1) {
extra_output_height = RoundUp<index_t>(height, 2);
extra_input_height = std::max(padded_input_height, extra_output_height + 2);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width = std::max(padded_input_width, extra_output_width + 2);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK3x3S1(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
pad_output);
};
} else if (filter_h == 3 && filter_w == 3 && stride_h == 2 && stride_w == 2
&& dilation_h == 1 && dilation_w == 1) {
extra_output_height = height;
extra_input_height =
std::max(padded_input_height, (extra_output_height - 1) * 2 + 3);
extra_output_width = RoundUp<index_t>(width, 4);
extra_input_width =
std::max(padded_input_width, (extra_output_width - 1) * 2 + 3);
if (extra_input_height != padded_input_height) {
pad_bottom += (extra_input_height - padded_input_height);
}
if (extra_input_width != padded_input_width) {
pad_right += (extra_input_width - padded_input_width);
}
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK3x3S2(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
pad_output);
};
} else if (filter_h == 1 && filter_w == 1 && stride_h == 1 && stride_w == 1
&& dilation_h == 1 && dilation_w == 1) {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNeonK1x1S1(input_data,
filter_data,
batch,
height,
width,
input_channels,
channels,
output_data);
};
} else {
conv_func = [=](const float *pad_input, float *pad_output) {
Conv2dNCHW(pad_input,
filter_data,
batch,
extra_input_height,
extra_input_width,
input_channels,
extra_output_height,
extra_output_width,
channels,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_output);
};
}
const Tensor *pad_input_ptr = input;
// Keep this alive during kernel execution
if (extra_input_height != input_height || extra_input_width != input_width) {
ConstructNCHWInputWithSpecificPadding(input,
pad_top,
pad_bottom,
pad_left,
pad_right,
&padded_input_);
pad_input_ptr = &padded_input_;
}
const float *pad_input_data = pad_input_ptr->data<float>();
Tensor *pad_output_ptr = output;
// Keep this alive during kernel execution
if (extra_output_height != height || extra_output_width != width) {
std::vector<index_t> extra_output_shape
{batch, channels, extra_output_height, extra_output_width};
padded_output_.Resize(extra_output_shape);
pad_output_ptr = &padded_output_;
}
float *pad_output_data = pad_output_ptr->mutable_data<float>();
conv_func(pad_input_data, pad_output_data);
// unpack output
if (extra_output_height != height || extra_output_width != width) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t h = 0; h < height; ++h) {
memcpy(
output_data + b * channels * height * width + c * height * width
+ h * width,
pad_output_data
+ b * channels * extra_output_height * extra_output_width
+ c * extra_output_height * extra_output_width
+ h * extra_output_width,
sizeof(float) * width);
}
}
}
}
if (bias_data != nullptr) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t i = 0; i < height * width; ++i) {
output_data[(b * channels + c) * height * width + i] += bias_data[c];
}
}
}
}
DoActivation(output_data, output_data, output->size(), activation_,
relux_max_limit_);
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
#include "mace/kernels/gemm.h"
namespace mace {
namespace kernels {
void Conv2dNeonK1x1S1(const float *input,
const float *filter,
const index_t batch,
const index_t height,
const index_t width,
const index_t in_channels,
const index_t out_channels,
float *output) {
for (index_t b = 0; b < batch; ++b) {
Gemm(filter,
input + b * in_channels * height * width,
1,
out_channels,
in_channels,
height * width,
output + b * out_channels * height * width);
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
namespace mace {
namespace kernels {
// Ho = 2, Wo = 4, Co = 2
void Conv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
float *output) {
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = in_channels * in_image_size;
const index_t out_batch_size = out_channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; m += 2) {
if (m + 1 < out_channels) {
float *out_ptr0_base = output + b * out_batch_size + m * out_image_size;
float *out_ptr1_base =
output + b * out_batch_size + (m + 1) * out_image_size;
for (index_t c = 0; c < in_channels; ++c) {
float *out_ptr0 = out_ptr0_base;
float *out_ptr1 = out_ptr1_base;
const float *in_ptr0 = input + b * in_batch_size + c * in_image_size;
const float *in_ptr1 =
input + b * in_batch_size + c * in_image_size + 1 * in_width;
const float *in_ptr2 =
input + b * in_batch_size + c * in_image_size + 2 * in_width;
const float *in_ptr3 =
input + b * in_batch_size + c * in_image_size + 3 * in_width;
const float *filter_ptr0 = filter + m * in_channels * 9 + c * 9;
const float *filter_ptr1 = filter + (m + 1) * in_channels * 9 + c * 9;
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
// load filter (4 outch x 3 height x 3 width): vf_outch_height
float32x4_t vf00, vf01, vf02;
float32x4_t vf10, vf11, vf12;
vf00 = vld1q_f32(filter_ptr0);
vf01 = vld1q_f32(filter_ptr0 + 3);
vf02 = vld1q_f32(filter_ptr0 + 6);
vf10 = vld1q_f32(filter_ptr1);
vf11 = vld1q_f32(filter_ptr1 + 3);
vf12 = vld1q_f32(filter_ptr1 + 6);
for (index_t h = 0; h + 1 < out_height; h += 2) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
// input (4 height x 3 slide): vi_height_slide
float32x4_t vi00, vi01, vi02; // reg count: 14
float32x4_t vi10, vi11, vi12;
float32x4_t vi20, vi21, vi22;
float32x4_t vi30, vi31, vi32;
float32x4_t vo20, vo30; // tmp use
// output (4 outch x 2 height x 4 width): vo_outch_height
float32x4_t vo00, vo01;
float32x4_t vo10, vo11;
// load input
vi00 = vld1q_f32(in_ptr0);
vo00 = vld1q_f32(in_ptr0 + 4); // reuse vo00: vi0n
vi10 = vld1q_f32(in_ptr1);
vo10 = vld1q_f32(in_ptr1 + 4);
vi20 = vld1q_f32(in_ptr2);
vo20 = vld1q_f32(in_ptr2 + 4);
vi30 = vld1q_f32(in_ptr3);
vo30 = vld1q_f32(in_ptr3 + 4);
vi01 = vextq_f32(vi00, vo00, 1);
vi02 = vextq_f32(vi00, vo00, 2);
vi11 = vextq_f32(vi10, vo10, 1);
vi12 = vextq_f32(vi10, vo10, 2);
vi21 = vextq_f32(vi20, vo20, 1);
vi22 = vextq_f32(vi20, vo20, 2);
vi31 = vextq_f32(vi30, vo30, 1);
vi32 = vextq_f32(vi30, vo30, 2);
// load ouptut
vo00 = vld1q_f32(out_ptr0);
vo01 = vld1q_f32(out_ptr0 + out_width);
vo10 = vld1q_f32(out_ptr1);
vo11 = vld1q_f32(out_ptr1 + out_width);
// outch 0, height 0
vo00 = vfmaq_laneq_f32(vo00, vi00, vf00, 0); // reg count: 18
vo00 = vfmaq_laneq_f32(vo00, vi01, vf00, 1);
vo00 = vfmaq_laneq_f32(vo00, vi02, vf00, 2);
vo00 = vfmaq_laneq_f32(vo00, vi10, vf01, 0);
vo00 = vfmaq_laneq_f32(vo00, vi11, vf01, 1);
vo00 = vfmaq_laneq_f32(vo00, vi12, vf01, 2);
vo00 = vfmaq_laneq_f32(vo00, vi20, vf02, 0);
vo00 = vfmaq_laneq_f32(vo00, vi21, vf02, 1);
vo00 = vfmaq_laneq_f32(vo00, vi22, vf02, 2);
// outch 0, height 1
vo01 = vfmaq_laneq_f32(vo01, vi10, vf00, 0);
vo01 = vfmaq_laneq_f32(vo01, vi11, vf00, 1);
vo01 = vfmaq_laneq_f32(vo01, vi12, vf00, 2);
vo01 = vfmaq_laneq_f32(vo01, vi20, vf01, 0);
vo01 = vfmaq_laneq_f32(vo01, vi21, vf01, 1);
vo01 = vfmaq_laneq_f32(vo01, vi22, vf01, 2);
vo01 = vfmaq_laneq_f32(vo01, vi30, vf02, 0);
vo01 = vfmaq_laneq_f32(vo01, vi31, vf02, 1);
vo01 = vfmaq_laneq_f32(vo01, vi32, vf02, 2);
// outch 1, height 0
vo10 = vfmaq_laneq_f32(vo10, vi00, vf10, 0);
vo10 = vfmaq_laneq_f32(vo10, vi01, vf10, 1);
vo10 = vfmaq_laneq_f32(vo10, vi02, vf10, 2);
vo10 = vfmaq_laneq_f32(vo10, vi10, vf11, 0);
vo10 = vfmaq_laneq_f32(vo10, vi11, vf11, 1);
vo10 = vfmaq_laneq_f32(vo10, vi12, vf11, 2);
vo10 = vfmaq_laneq_f32(vo10, vi20, vf12, 0);
vo10 = vfmaq_laneq_f32(vo10, vi21, vf12, 1);
vo10 = vfmaq_laneq_f32(vo10, vi22, vf12, 2);
// outch 1, height 1
vo11 = vfmaq_laneq_f32(vo11, vi10, vf10, 0);
vo11 = vfmaq_laneq_f32(vo11, vi11, vf10, 1);
vo11 = vfmaq_laneq_f32(vo11, vi12, vf10, 2);
vo11 = vfmaq_laneq_f32(vo11, vi20, vf11, 0);
vo11 = vfmaq_laneq_f32(vo11, vi21, vf11, 1);
vo11 = vfmaq_laneq_f32(vo11, vi22, vf11, 2);
vo11 = vfmaq_laneq_f32(vo11, vi30, vf12, 0);
vo11 = vfmaq_laneq_f32(vo11, vi31, vf12, 1);
vo11 = vfmaq_laneq_f32(vo11, vi32, vf12, 2);
vst1q_f32(out_ptr0, vo00);
vst1q_f32(out_ptr0 + out_width, vo01);
vst1q_f32(out_ptr1, vo10);
vst1q_f32(out_ptr1 + out_width, vo11);
in_ptr0 += 4;
in_ptr1 += 4;
in_ptr2 += 4;
in_ptr3 += 4;
out_ptr0 += 4;
out_ptr1 += 4;
} // w
in_ptr0 += 2 + in_width;
in_ptr1 += 2 + in_width;
in_ptr2 += 2 + in_width;
in_ptr3 += 2 + in_width;
out_ptr0 += out_width;
out_ptr1 += out_width;
} // h
#else
for (index_t io = 0; io < 2; ++io) {
for (index_t ih = 0; ih < out_height; ++ih) {
for (index_t iw = 0; iw < out_width; ++iw) {
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
out_ptr0[io * out_image_size + ih * out_width + iw] +=
in_ptr0[(ih + i) * in_width + (iw + j)]
* filter_ptr0[io * in_channels * 9 + i * 3 + j];
}
}
}
}
} // for
#endif
} // c
} else {
for (index_t mm = m; mm < out_channels; ++mm) {
float
*out_ptr0_base = output + b * out_batch_size + mm * out_image_size;
for (index_t c = 0; c < in_channels; ++c) {
float *out_ptr0 = out_ptr0_base;
const float
*in_ptr0 = input + b * in_batch_size + c * in_image_size;
const float *in_ptr1 =
input + b * in_batch_size + c * in_image_size + 1 * in_width;
const float *in_ptr2 =
input + b * in_batch_size + c * in_image_size + 2 * in_width;
const float *in_ptr3 =
input + b * in_batch_size + c * in_image_size + 3 * in_width;
const float *filter_ptr0 = filter + mm * in_channels * 9 + c * 9;
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
// load filter (1 outch x 3 height x 3 width): vf_outch_height
float32x4_t vf00, vf01, vf02;
vf00 = vld1q_f32(filter_ptr0);
vf01 = vld1q_f32(filter_ptr0 + 3);
vf02 = vld1q_f32(filter_ptr0 + 6);
for (index_t h = 0; h + 1 < out_height; h += 2) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
// input (4 height x 3 slide): vi_height_slide
float32x4_t vi00, vi01, vi02, vi0n;
float32x4_t vi10, vi11, vi12, vi1n;
float32x4_t vi20, vi21, vi22, vi2n;
float32x4_t vi30, vi31, vi32, vi3n;
// output (1 outch x 2 height x 4 width): vo_outch_height
float32x4_t vo00, vo01;
// load input
vi00 = vld1q_f32(in_ptr0);
vi0n = vld1q_f32(in_ptr0 + 4);
vi10 = vld1q_f32(in_ptr1);
vi1n = vld1q_f32(in_ptr1 + 4);
vi20 = vld1q_f32(in_ptr2);
vi2n = vld1q_f32(in_ptr2 + 4);
vi30 = vld1q_f32(in_ptr3);
vi3n = vld1q_f32(in_ptr3 + 4);
vi01 = vextq_f32(vi00, vi0n, 1);
vi02 = vextq_f32(vi00, vi0n, 2);
vi11 = vextq_f32(vi10, vi1n, 1);
vi12 = vextq_f32(vi10, vi1n, 2);
vi21 = vextq_f32(vi20, vi2n, 1);
vi22 = vextq_f32(vi20, vi2n, 2);
vi31 = vextq_f32(vi30, vi3n, 1);
vi32 = vextq_f32(vi30, vi3n, 2);
// load ouptut
vo00 = vld1q_f32(out_ptr0);
vo01 = vld1q_f32(out_ptr0 + out_width);
// outch 0, height 0
vo00 = vfmaq_laneq_f32(vo00, vi00, vf00, 0);
vo00 = vfmaq_laneq_f32(vo00, vi01, vf00, 1);
vo00 = vfmaq_laneq_f32(vo00, vi02, vf00, 2);
vo00 = vfmaq_laneq_f32(vo00, vi10, vf01, 0);
vo00 = vfmaq_laneq_f32(vo00, vi11, vf01, 1);
vo00 = vfmaq_laneq_f32(vo00, vi12, vf01, 2);
vo00 = vfmaq_laneq_f32(vo00, vi20, vf02, 0);
vo00 = vfmaq_laneq_f32(vo00, vi21, vf02, 1);
vo00 = vfmaq_laneq_f32(vo00, vi22, vf02, 2);
// outch 0, height 1
vo01 = vfmaq_laneq_f32(vo01, vi10, vf00, 0);
vo01 = vfmaq_laneq_f32(vo01, vi11, vf00, 1);
vo01 = vfmaq_laneq_f32(vo01, vi12, vf00, 2);
vo01 = vfmaq_laneq_f32(vo01, vi20, vf01, 0);
vo01 = vfmaq_laneq_f32(vo01, vi21, vf01, 1);
vo01 = vfmaq_laneq_f32(vo01, vi22, vf01, 2);
vo01 = vfmaq_laneq_f32(vo01, vi30, vf02, 0);
vo01 = vfmaq_laneq_f32(vo01, vi31, vf02, 1);
vo01 = vfmaq_laneq_f32(vo01, vi32, vf02, 2);
vst1q_f32(out_ptr0, vo00);
vst1q_f32(out_ptr0 + out_width, vo01);
in_ptr0 += 4;
in_ptr1 += 4;
in_ptr2 += 4;
in_ptr3 += 4;
out_ptr0 += 4;
} // w
in_ptr0 += 2 + in_width;
in_ptr1 += 2 + in_width;
in_ptr2 += 2 + in_width;
in_ptr3 += 2 + in_width;
out_ptr0 += out_width;
} // h
#else
for (index_t ih = 0; ih < out_height; ++ih) {
for (index_t iw = 0; iw < out_width; ++iw) {
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
out_ptr0[ih * out_width + iw] +=
in_ptr0[(ih + i) * in_width + (iw + j)]
* filter_ptr0[i * 3 + j];
}
}
}
}
#endif
} // c
} // mm
} // if
} // m
} // b
}
void Conv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
float *output) {
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = in_channels * in_image_size;
const index_t out_batch_size = out_channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t c = 0; c < in_channels; ++c) {
const float *in_base = input + b * in_batch_size + c * in_image_size;
const float
*filter_ptr = filter + m * in_channels * 9 + c * 9;
float *out_base = output + b * out_batch_size + m * out_image_size;
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
// load filter (1 outch x 3 height x 3 width): vf_outch_height
float32x4_t vf00, vf01, vf02;
vf00 = vld1q_f32(filter_ptr);
vf01 = vld1q_f32(filter_ptr + 3);
vf02 = vld1q_f32(filter_ptr + 6);
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w + 3 < out_width; w += 4) {
float32x4x2_t vi0, vi1, vi2;
float32x4_t vi0n, vi1n, vi2n;
// input (3 height x 3 slide): vi_height_slide
float32x4_t vi00, vi01, vi02;
float32x4_t vi10, vi11, vi12;
float32x4_t vi20, vi21, vi22;
// output (1 outch x 1 height x 4 width): vo
float32x4_t vo;
// load input
index_t in_h = h * 2;
index_t in_w = w * 2;
index_t in_offset = in_h * in_width + in_w;
vi0 = vld2q_f32(in_base + in_offset); // [0.2.4.6, 1.3.5.7]
vi1 = vld2q_f32(in_base + in_offset + in_width);
vi2 = vld2q_f32(in_base + in_offset + 2 * in_width);
vi0n = vld1q_f32(in_base + in_offset + 8); // [8.9.10.11]
vi1n = vld1q_f32(in_base + in_offset + in_width + 8);
vi2n = vld1q_f32(in_base + in_offset + 2 * in_width + 8);
// load ouptut
index_t out_offset = h * out_width + w;
vo = vld1q_f32(out_base + out_offset);
vi00 = vi0.val[0]; // [0.2.4.6]
vi01 = vi0.val[1]; // [1.3.5.7]
vi02 = vextq_f32(vi00, vi0n, 1); // [2.4.6.8]
vi10 = vi1.val[0];
vi11 = vi1.val[1];
vi12 = vextq_f32(vi10, vi1n, 1);
vi20 = vi2.val[0];
vi21 = vi2.val[1];
vi22 = vextq_f32(vi20, vi2n, 1);
// outch 0, height 0
vo = vfmaq_laneq_f32(vo, vi00, vf00, 0);
vo = vfmaq_laneq_f32(vo, vi01, vf00, 1);
vo = vfmaq_laneq_f32(vo, vi02, vf00, 2);
vo = vfmaq_laneq_f32(vo, vi10, vf01, 0);
vo = vfmaq_laneq_f32(vo, vi11, vf01, 1);
vo = vfmaq_laneq_f32(vo, vi12, vf01, 2);
vo = vfmaq_laneq_f32(vo, vi20, vf02, 0);
vo = vfmaq_laneq_f32(vo, vi21, vf02, 1);
vo = vfmaq_laneq_f32(vo, vi22, vf02, 2);
vst1q_f32(out_base + out_offset, vo);
} // w
} // h
#else
for (index_t ih = 0; ih < out_height; ++ih) {
for (index_t iw = 0; iw < out_width; ++iw) {
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 3; ++j) {
out_base[ih * out_width + iw] +=
in_base[(ih * 2 + i) * in_width + (iw * 2 + j)]
* filter_ptr[i * 3 + j];
}
}
}
}
#endif
} // c
} // m
} // b
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include <math.h>
#include <algorithm>
#include "mace/kernels/arm/conv_winograd.h"
#include "mace/kernels/gemm.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
namespace {
// NCHW => TNCB (T: in tile pixels, B: tile indices)
void TransformInput(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t tile_count,
float *output) {
const index_t stride = batch * in_channels * tile_count;
const index_t in_height_width = in_height * in_width;
#pragma omp parallel for
for (index_t nc = 0; nc < batch * in_channels; ++nc) {
index_t tile_index = nc * tile_count;
for (index_t h = 0; h < in_height - 2; h += 2) {
for (index_t w = 0; w < in_width - 2; w += 2) {
float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14,
d15;
float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14,
s15;
// load tile data
const index_t tile_offset = nc * in_height_width + h * in_width + w;
d0 = input[tile_offset];
d1 = input[tile_offset + 1];
d2 = input[tile_offset + 2];
d3 = input[tile_offset + 3];
d4 = input[tile_offset + in_width];
d5 = input[tile_offset + in_width + 1];
d6 = input[tile_offset + in_width + 2];
d7 = input[tile_offset + in_width + 3];
d8 = input[tile_offset + 2 * in_width];
d9 = input[tile_offset + 2 * in_width + 1];
d10 = input[tile_offset + 2 * in_width + 2];
d11 = input[tile_offset + 2 * in_width + 3];
d12 = input[tile_offset + 3 * in_width];
d13 = input[tile_offset + 3 * in_width + 1];
d14 = input[tile_offset + 3 * in_width + 2];
d15 = input[tile_offset + 3 * in_width + 3];
// s = BT * d * B
s0 = (d0 - d8) - (d2 - d10);
s1 = (d1 - d9) + (d2 - d10);
s2 = (d2 - d10) - (d1 - d9);
s3 = (d1 - d9) - (d3 - d11);
s4 = (d4 + d8) - (d6 + d10);
s5 = (d5 + d9) + (d6 + d10);
s6 = (d6 + d10) - (d5 + d9);
s7 = (d5 + d9) - (d7 + d11);
s8 = (d8 - d4) - (d10 - d6);
s9 = (d9 - d5) + (d10 - d6);
s10 = (d10 - d6) - (d9 - d5);
s11 = (d9 - d5) - (d11 - d7);
s12 = (d4 - d12) - (d6 - d14);
s13 = (d5 - d13) + (d6 - d14);
s14 = (d6 - d14) - (d5 - d13);
s15 = (d5 - d13) - (d7 - d15);
// store output
output[tile_index + 0 * stride] = s0;
output[tile_index + 1 * stride] = s1;
output[tile_index + 2 * stride] = s2;
output[tile_index + 3 * stride] = s3;
output[tile_index + 4 * stride] = s4;
output[tile_index + 5 * stride] = s5;
output[tile_index + 6 * stride] = s6;
output[tile_index + 7 * stride] = s7;
output[tile_index + 8 * stride] = s8;
output[tile_index + 9 * stride] = s9;
output[tile_index + 10 * stride] = s10;
output[tile_index + 11 * stride] = s11;
output[tile_index + 12 * stride] = s12;
output[tile_index + 13 * stride] = s13;
output[tile_index + 14 * stride] = s14;
output[tile_index + 15 * stride] = s15;
++tile_index;
}
}
}
}
// OCHW => TOC
// no need to optimize, it will exist in converter
void TransformFilter(const float *filter,
const index_t in_channels,
const index_t out_channels,
float *output) {
const index_t stride = out_channels * in_channels;
#pragma omp parallel for collapse(2)
for (index_t m = 0; m < out_channels; ++m) {
for (index_t c = 0; c < in_channels; ++c) {
float g0, g1, g2, g3, g4, g5, g6, g7, g8;
float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14,
s15;
// load filter
index_t filter_offset = (m * in_channels + c) * 9;
g0 = filter[filter_offset];
g1 = filter[filter_offset + 1];
g2 = filter[filter_offset + 2];
g3 = filter[filter_offset + 3];
g4 = filter[filter_offset + 4];
g5 = filter[filter_offset + 5];
g6 = filter[filter_offset + 6];
g7 = filter[filter_offset + 7];
g8 = filter[filter_offset + 8];
// s = G * g * GT
s0 = g0;
s1 = (g0 + g2 + g1) * 0.5f;
s2 = (g0 + g2 - g1) * 0.5f;
s3 = g2;
s4 = (g0 + g6 + g3) * 0.5f;
s5 = ((g0 + g6 + g3) + (g2 + g8 + g5) + (g1 + g7 + g4)) * 0.25f;
s6 = ((g0 + g6 + g3) + (g2 + g8 + g5) - (g1 + g7 + g4)) * 0.25f;
s7 = (g2 + g8 + g5) * 0.5f;
s8 = (g0 + g6 - g3) * 0.5f;
s9 = ((g0 + g6 - g3) + (g2 + g8 - g5) + (g1 + g7 - g4)) * 0.25f;
s10 = ((g0 + g6 - g3) + (g2 + g8 - g5) - (g1 + g7 - g4)) * 0.25f;
s11 = (g2 + g8 - g5) * 0.5f;
s12 = g6;
s13 = (g6 + g8 + g7) * 0.5f;
s14 = (g6 + g8 - g7) * 0.5f;
s15 = g8;
// store output
index_t output_offset = m * in_channels + c;
output[output_offset + 0 * stride] = s0;
output[output_offset + 1 * stride] = s1;
output[output_offset + 2 * stride] = s2;
output[output_offset + 3 * stride] = s3;
output[output_offset + 4 * stride] = s4;
output[output_offset + 5 * stride] = s5;
output[output_offset + 6 * stride] = s6;
output[output_offset + 7 * stride] = s7;
output[output_offset + 8 * stride] = s8;
output[output_offset + 9 * stride] = s9;
output[output_offset + 10 * stride] = s10;
output[output_offset + 11 * stride] = s11;
output[output_offset + 12 * stride] = s12;
output[output_offset + 13 * stride] = s13;
output[output_offset + 14 * stride] = s14;
output[output_offset + 15 * stride] = s15;
}
}
}
// TOC * TNCB => TNOB
void BatchGemm(const float *input,
const float *filter,
index_t batch,
index_t in_channels,
index_t out_channels,
index_t tile_count,
float *output) {
const index_t in_stride = batch * in_channels * tile_count;
const index_t in_channels_tile_count = in_channels * tile_count;
const index_t filter_stride = out_channels * in_channels;
const index_t out_stride = batch * out_channels * tile_count;
const index_t out_channels_tile_count = out_channels * tile_count;
if (batch == 1) {
Gemm(filter, input, 16, out_channels, in_channels, tile_count, output);
} else {
for (int i = 0; i < 16; ++i) {
for (int b = 0; b < batch; ++b) {
const float
*in_ptr = input + i * in_stride + b * in_channels_tile_count;
const float *filter_ptr = filter + i * filter_stride;
float *out_ptr = output + i * out_stride + b * out_channels_tile_count;
Gemm(filter_ptr,
in_ptr,
1,
out_channels, /* rows */
in_channels, /* K */
tile_count, /* cols */
out_ptr);
}
}
}
}
// TNOB => ToNOB => NOHoWo
void TransformOutput(const float *input,
index_t batch,
index_t out_height,
index_t out_width,
index_t out_channels,
index_t tile_count,
float *output) {
const index_t in_stride = batch * out_channels * tile_count;
#pragma omp parallel for
for (index_t nm = 0; nm < batch * out_channels; ++nm) {
index_t tile_offset = nm * tile_count;
for (index_t h = 0; h < out_height; h += 2) {
for (index_t w = 0; w < out_width; w += 2) {
float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14,
d15;
float s0, s1, s2, s3, s4, s5, s6, s7;
float v0, v1, v2, v3;
d0 = input[tile_offset + 0 * in_stride];
d1 = input[tile_offset + 1 * in_stride];
d2 = input[tile_offset + 2 * in_stride];
d3 = input[tile_offset + 3 * in_stride];
d4 = input[tile_offset + 4 * in_stride];
d5 = input[tile_offset + 5 * in_stride];
d6 = input[tile_offset + 6 * in_stride];
d7 = input[tile_offset + 7 * in_stride];
d8 = input[tile_offset + 8 * in_stride];
d9 = input[tile_offset + 9 * in_stride];
d10 = input[tile_offset + 10 * in_stride];
d11 = input[tile_offset + 11 * in_stride];
d12 = input[tile_offset + 12 * in_stride];
d13 = input[tile_offset + 13 * in_stride];
d14 = input[tile_offset + 14 * in_stride];
d15 = input[tile_offset + 15 * in_stride];
s0 = d0 + d1 + d2;
s1 = d1 - d2 - d3;
s2 = d4 + d5 + d6;
s3 = d5 - d6 - d7;
s4 = d8 + d9 + d10;
s5 = d9 - d10 - d11;
s6 = d12 + d13 + d14;
s7 = d13 - d14 - d15;
v0 = s0 + s2 + s4;
v1 = s1 + s3 + s5;
v2 = s2 - s4 - s6;
v3 = s3 - s5 - s7;
index_t out_offset = nm * out_height * out_width + h * out_width + w;
output[out_offset] = v0;
output[out_offset + 1] = v1;
output[out_offset + out_width] = v2;
output[out_offset + out_width + 1] = v3;
++tile_offset;
}
}
}
}
void ConvRef3x3s1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
float *output) {
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
#pragma omp parallel for collapse(4)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
output[out_offset] = 0;
for (index_t c = 0; c < in_channels; ++c) {
for (index_t kh = 0; kh < 3; ++kh) {
for (index_t kw = 0; kw < 3; ++kw) {
index_t ih = h + kh;
index_t iw = w + kw;
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t
filter_offset = (((m * in_channels) + c) * 3 + kh) * 3 + kw;
output[out_offset] += input[in_offset] * filter[filter_offset];
}
}
}
}
}
}
}
}
} // namespace
void WinoGradConv3x3s1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
float *transformed_input,
float *transformed_filter,
float *transformed_output,
bool is_filter_transformed,
float *output) {
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
index_t tile_height_count = (out_height + 1) / 2;
index_t tile_width_count = (out_width + 1) / 2;
index_t tile_count = tile_height_count * tile_width_count;
TransformInput(input,
batch,
in_height,
in_width,
in_channels,
tile_count,
transformed_input);
// TODO(liyin): put it in model converter, but do not worry, it is fast and
// will only do once
if (!is_filter_transformed) {
TransformFilter(filter, in_channels, out_channels, transformed_filter);
}
BatchGemm(transformed_input,
transformed_filter,
batch,
in_channels,
out_channels,
tile_count,
transformed_output);
TransformOutput(transformed_output,
batch,
out_height,
out_width,
out_channels,
tile_count,
output);
}
void WinoGradConv3x3s1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
float *output) {
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
index_t tile_height_count = (out_height + 1) / 2;
index_t tile_width_count = (out_width + 1) / 2;
index_t tile_count = tile_height_count * tile_width_count;
index_t transformed_input_size = 16 * batch * in_channels * tile_count;
index_t transformed_filter_size = 16 * out_channels * in_channels;
index_t transformed_output_size = 16 * batch * out_channels * tile_count;
float *transformed_input = new float[transformed_input_size]; // TNCB
float *transformed_filter = new float[transformed_filter_size]; // TOC
float *transformed_output = new float[transformed_output_size];
WinoGradConv3x3s1(input,
filter,
batch,
in_height,
in_width,
in_channels,
out_channels,
transformed_input,
transformed_filter,
transformed_output,
false,
output);
delete[]transformed_input;
delete[]transformed_filter;
delete[]transformed_output;
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_ARM_CONV_WINOGRAD_H_
#define MACE_KERNELS_ARM_CONV_WINOGRAD_H_
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
namespace mace {
namespace kernels {
void WinoGradConv3x3s1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
float *output);
void WinoGradConv3x3s1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_channels,
float *transformed_input,
float *transformed_filter,
float *transformed_output,
bool is_filter_transformed,
float *output);
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_ARM_CONV_WINOGRAD_H_
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include <gtest/gtest.h>
#include <random>
#include <algorithm>
#include "mace/kernels/arm/conv_winograd.h"
#include "mace/core/types.h"
namespace mace {
namespace kernels {
TEST(ConvWinogradTest, winograd) {
index_t batch = 1;
index_t in_height = 32;
index_t in_width = 32;
index_t in_channels = 64;
index_t out_channels = 128;
index_t out_height = in_height - 2;
index_t out_width = in_width - 2;
index_t input_size = batch * in_channels * in_height * out_height;
index_t filter_size = 3 * 3 * in_channels * out_channels;
index_t output_size = batch * out_channels * out_height * out_width;
float *input_data = new float[input_size];
float *filter_data = new float[filter_size];
float *output_data = new float[output_size];
float *output_data_ref = new float[output_size];
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
std::generate(input_data, input_data + input_size,
[&gen, &nd] {
return std::max(-1.0f, std::min(1.0f, nd(gen)));
});
std::generate(filter_data, filter_data + filter_size,
[&gen, &nd] {
return std::max(-1.0f, std::min(1.0f, nd(gen)));
});
kernels::ConvRef3x3s1(input_data,
filter_data,
batch,
in_height,
in_width,
in_channels,
out_channels,
output_data_ref);
kernels::WinoGradConv3x3s1(input_data,
filter_data,
batch,
in_height,
in_width,
in_channels,
out_channels,
output_data);
// test
for (index_t i = 0; i < output_size; ++i) {
EXPECT_NEAR(output_data_ref[i], output_data[i], 0.1);
}
delete[]input_data;
delete[]filter_data;
delete[]output_data;
delete[]output_data_ref;
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/depthwise_conv2d.h"
#include "mace/kernels/activation.h"
namespace mace {
namespace kernels {
namespace {
void DepthwiseConv2dNCHW(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t multiplier = out_channels / in_channels;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset =
((b * out_channels + m) * out_height + h) * out_width + w;
index_t c = m / multiplier;
index_t o = m % multiplier;
float sum = 0;
for (index_t kh = 0; kh < filter_height; ++kh) {
for (index_t kw = 0; kw < filter_width; ++kw) {
index_t ih = h * stride_h + kh * dilation_h - pad_top;
index_t iw = w * stride_w + kw * dilation_w - pad_left;
if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
index_t in_offset =
((b * in_channels + c) * in_height + ih) * in_width + iw;
index_t filter_offset =
(((o * in_channels) + c) * filter_height + kh) * filter_width
+ kw;
sum += input[in_offset] * filter[filter_offset];
}
}
}
output[out_offset] = sum;
}
}
}
}
}
} // namespace
extern void DepthwiseConv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int pad_top,
const int pad_left,
const int valid_h_start,
const int valid_h_stop,
const int valid_w_start,
const int valid_w_stop,
float *output);
void DepthwiseConv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int pad_top,
const int pad_left,
const int valid_h_start,
const int valid_h_stop,
const int valid_w_start,
const int valid_w_stop,
float *output);
void DepthwiseConv2dFunctor<DeviceType::NEON,
float>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output);
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
std::vector<index_t> filter_shape
{filter->dim(0) * filter->dim(1), filter->dim(1), filter->dim(2),
filter->dim(3)};
if (paddings_.empty()) {
CalcNCHWPaddingAndOutputSize(input->shape().data(),
filter_shape.data(),
dilations_,
strides_,
padding_type_,
output_shape.data(),
paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input->shape().data(), filter_shape.data(),
paddings_.data(), dilations_, strides_, RoundType::FLOOR,
output_shape.data());
}
output->Resize(output_shape);
output->Clear();
index_t batch = output->dim(0);
index_t channels = output->dim(1);
index_t height = output->dim(2);
index_t width = output->dim(3);
index_t input_batch = input->dim(0);
index_t input_channels = input->dim(1);
index_t input_height = input->dim(2);
index_t input_width = input->dim(3);
index_t filter_h = filter_shape[2];
index_t filter_w = filter_shape[3];
MACE_CHECK(filter_shape[0] == channels, filter_shape[0], " != ", channels);
MACE_CHECK(filter_shape[1] == input_channels, filter_shape[1], " != ",
input_channels);
index_t stride_h = strides_[0];
index_t stride_w = strides_[1];
index_t dilation_h = dilations_[0];
index_t dilation_w = dilations_[1];
MACE_CHECK(batch == input_batch, "Input/Output batch size mismatch");
int pad_top = paddings[0] >> 1;
int pad_bottom = paddings[0] - pad_top;
int pad_left = paddings[1] >> 1;
int pad_right = paddings[1] - pad_left;
int valid_h_start = pad_top == 0 ? 0 : (pad_top - 1) / stride_h + 1;
int valid_h_stop = pad_bottom == 0
? height
: height - ((pad_bottom - 1) / stride_h + 1);
int valid_w_start = pad_left == 0 ? 0 : (pad_left - 1) / stride_w + 1;
int valid_w_stop = pad_right == 0
? width
: width - ((pad_right - 1) / stride_w + 1);
std::function<void(const float *input, float *output)> conv_func;
auto input_data = input->data<float>();
auto filter_data = filter->data<float>();
auto bias_data = bias == nullptr ? nullptr : bias->data<float>();
auto output_data = output->mutable_data<float>();
if (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1
&& dilation_h == 1 && dilation_w == 1) {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNeonK3x3S1(input,
filter_data,
batch,
input_height,
input_width,
input_channels,
height,
width,
channels,
pad_top,
pad_left,
valid_h_start,
valid_h_stop,
valid_w_start,
valid_w_stop,
output);
};
} else if (filter_h == 3 && filter_w == 3 && stride_h == 2 && stride_w == 2
&& dilation_h == 1 && dilation_w == 1) {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNeonK3x3S2(input,
filter_data,
batch,
input_height,
input_width,
input_channels,
height,
width,
channels,
pad_top,
pad_left,
valid_h_start,
valid_h_stop,
valid_w_start,
valid_w_stop,
output);
};
} else {
conv_func = [=](const float *input, float *output) {
DepthwiseConv2dNCHW(input,
filter_data,
batch,
input_height,
input_width,
input_channels,
height,
width,
channels,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
};
}
conv_func(input_data, output_data);
if (bias_data != nullptr) {
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
for (index_t i = 0; i < height * width; ++i) {
output_data[(b * channels + c) * height * width + i] += bias_data[c];
}
}
}
}
DoActivation(output_data, output_data, output->size(), activation_,
relux_max_limit_);
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
namespace mace {
namespace kernels {
namespace {
void DepthwiseConv2dPixel(const float *in_base,
const float *filter,
const index_t out_h,
const index_t out_w,
const index_t in_h_start,
const index_t in_w_start,
const index_t out_width,
const index_t in_height,
const index_t in_width,
int filter_height,
int filter_width,
float *out_base) {
float sum = 0;
for (int i = 0; i < filter_height; ++i) {
for (int j = 0; j < filter_width; ++j) {
index_t in_h = in_h_start + i;
index_t in_w = in_w_start + j;
if (in_h >= 0 && in_h < in_height && in_w >= 0 && in_w < in_width) {
sum += in_base[in_h * in_width + in_w] * filter[i * filter_width + j];
}
}
}
out_base[out_h * out_width + out_w] = sum;
}
} // namespace
// Ho = 2, Wo = 4, Co = 1
void DepthwiseConv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int pad_top,
const int pad_left,
const int valid_h_start,
const int valid_h_stop,
const int valid_w_start,
const int valid_w_stop,
float *output) {
const index_t multiplier = out_channels / in_channels;
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = in_channels * in_image_size;
const index_t out_batch_size = out_channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
index_t c = m / multiplier;
index_t multi_index = m % multiplier;
const float *in_base = input + b * in_batch_size + c * in_image_size;
const float *filter_ptr = filter + multi_index * in_channels * 9 + c * 9;
float *out_base = output + b * out_batch_size + m * out_image_size;
index_t h, w;
// top
for (h = 0; h < valid_h_start; ++h) {
for (w = 0; w < out_width; ++w) {
DepthwiseConv2dPixel(in_base,
filter_ptr,
h,
w,
h - pad_top,
w - pad_left,
out_width,
in_height,
in_width,
3,
3,
out_base);
}
}
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
// load filter (1 outch x 3 height x 3 width): vf_outch_height
float32x4_t vf00, vf01, vf02;
vf00 = vld1q_f32(filter_ptr);
vf01 = vld1q_f32(filter_ptr + 3);
vf02 = vld1q_f32(filter_ptr + 6);
for (h = valid_h_start; h + 1 < valid_h_stop; h += 2) {
// left
for (w = 0; w < valid_w_start; ++w) {
DepthwiseConv2dPixel(in_base,
filter_ptr,
h,
w,
h - pad_top,
w - pad_left,
out_width,
in_height,
in_width,
3,
3,
out_base);
DepthwiseConv2dPixel(in_base,
filter_ptr,
h + 1,
w,
h + 1 - pad_top,
w - pad_left,
out_width,
in_height,
in_width,
3,
3,
out_base);
}
for (w = valid_w_start; w + 3 < valid_w_stop; w += 4) {
// input (4 height x 3 slide): vi_height_slide
float32x4_t vi00, vi01, vi02, vi0n;
float32x4_t vi10, vi11, vi12, vi1n;
float32x4_t vi20, vi21, vi22, vi2n;
float32x4_t vi30, vi31, vi32, vi3n;
// output (1 outch x 2 height x 4 width): vo_outch_height
float32x4_t vo00, vo01;
// load input
index_t in_h = h - pad_top;
index_t in_w = w - pad_left;
index_t in_offset = in_h * in_width + in_w;
vi00 = vld1q_f32(in_base + in_offset);
vi0n = vld1q_f32(in_base + in_offset + 4);
vi10 = vld1q_f32(in_base + in_offset + in_width);
vi1n = vld1q_f32(in_base + in_offset + in_width + 4);
vi20 = vld1q_f32(in_base + in_offset + 2 * in_width);
vi2n = vld1q_f32(in_base + in_offset + 2 * in_width + 4);
vi30 = vld1q_f32(in_base + in_offset + 3 * in_width);
vi3n = vld1q_f32(in_base + in_offset + 3 * in_width + 4);
vi01 = vextq_f32(vi00, vi0n, 1);
vi02 = vextq_f32(vi00, vi0n, 2);
vi11 = vextq_f32(vi10, vi1n, 1);
vi12 = vextq_f32(vi10, vi1n, 2);
vi21 = vextq_f32(vi20, vi2n, 1);
vi22 = vextq_f32(vi20, vi2n, 2);
vi31 = vextq_f32(vi30, vi3n, 1);
vi32 = vextq_f32(vi30, vi3n, 2);
// load ouptut
index_t out_offset = h * out_width + w;
vo00 = vld1q_f32(out_base + out_offset);
vo01 = vld1q_f32(out_base + out_offset + out_width);
// outch 0, height 0
vo00 = vfmaq_laneq_f32(vo00, vi00, vf00, 0);
vo00 = vfmaq_laneq_f32(vo00, vi01, vf00, 1);
vo00 = vfmaq_laneq_f32(vo00, vi02, vf00, 2);
vo00 = vfmaq_laneq_f32(vo00, vi10, vf01, 0);
vo00 = vfmaq_laneq_f32(vo00, vi11, vf01, 1);
vo00 = vfmaq_laneq_f32(vo00, vi12, vf01, 2);
vo00 = vfmaq_laneq_f32(vo00, vi20, vf02, 0);
vo00 = vfmaq_laneq_f32(vo00, vi21, vf02, 1);
vo00 = vfmaq_laneq_f32(vo00, vi22, vf02, 2);
// outch 0, height 1
vo01 = vfmaq_laneq_f32(vo01, vi10, vf00, 0);
vo01 = vfmaq_laneq_f32(vo01, vi11, vf00, 1);
vo01 = vfmaq_laneq_f32(vo01, vi12, vf00, 2);
vo01 = vfmaq_laneq_f32(vo01, vi20, vf01, 0);
vo01 = vfmaq_laneq_f32(vo01, vi21, vf01, 1);
vo01 = vfmaq_laneq_f32(vo01, vi22, vf01, 2);
vo01 = vfmaq_laneq_f32(vo01, vi30, vf02, 0);
vo01 = vfmaq_laneq_f32(vo01, vi31, vf02, 1);
vo01 = vfmaq_laneq_f32(vo01, vi32, vf02, 2);
vst1q_f32(out_base + out_offset, vo00);
vst1q_f32(out_base + out_offset + out_width, vo01);
} // w
// right
for (; w < out_width; ++w) {
DepthwiseConv2dPixel(in_base,
filter_ptr,
h,
w,
h - pad_top,
w - pad_left,
out_width,
in_height,
in_width,
3,
3,
out_base);
DepthwiseConv2dPixel(in_base,
filter_ptr,
h + 1,
w,
h + 1 - pad_top,
w - pad_left,
out_width,
in_height,
in_width,
3,
3,
out_base);
}
} // h
#else
for (index_t ih = valid_h_start; ih < valid_h_stop; ++ih) {
for (index_t iw = 0; iw < out_width; ++iw) {
DepthwiseConv2dPixel(in_base,
filter_ptr,
ih,
iw,
ih - pad_top,
iw - pad_left,
out_width,
in_height,
in_width,
3,
3,
out_base);
}
}
#endif
// bottom
for (; h < out_height; ++h) {
for (w = 0; w < out_width; ++w) {
DepthwiseConv2dPixel(in_base,
filter_ptr,
h,
w,
h - pad_top,
w - pad_left,
out_width,
in_height,
in_width,
3,
3,
out_base);
}
}
} // m
} // b
}
void DepthwiseConv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t in_channels,
const index_t out_height,
const index_t out_width,
const index_t out_channels,
const int pad_top,
const int pad_left,
const int valid_h_start,
const int valid_h_stop,
const int valid_w_start,
const int valid_w_stop,
float *output) {
const index_t multiplier = out_channels / in_channels;
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = in_channels * in_image_size;
const index_t out_batch_size = out_channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t m = 0; m < out_channels; ++m) {
index_t c = m / multiplier;
index_t multi_index = m % multiplier;
const float *in_base = input + b * in_batch_size + c * in_image_size;
const float *filter_ptr = filter + multi_index * in_channels * 9 + c * 9;
float *out_base = output + b * out_batch_size + m * out_image_size;
index_t h, w;
// top
for (h = 0; h < valid_h_start; ++h) {
for (w = 0; w < out_width; ++w) {
DepthwiseConv2dPixel(in_base,
filter_ptr,
h,
w,
h * 2 - pad_top,
w * 2 - pad_left,
out_width,
in_height,
in_width,
3,
3,
out_base);
}
}
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
// load filter (1 outch x 3 height x 3 width): vf_outch_height
float32x4_t vf00, vf01, vf02;
vf00 = vld1q_f32(filter_ptr);
vf01 = vld1q_f32(filter_ptr + 3);
vf02 = vld1q_f32(filter_ptr + 6);
for (h = valid_h_start; h < valid_h_stop; ++h) {
// left
for (w = 0; w < valid_w_start; ++w) {
DepthwiseConv2dPixel(in_base,
filter_ptr,
h,
w,
h * 2 - pad_top,
w * 2 - pad_left,
out_width,
in_height,
in_width,
3,
3,
out_base);
}
for (w = valid_w_start; w + 3 < valid_w_stop; w += 4) {
float32x4x2_t vi0, vi1, vi2;
float32x4_t vi0n, vi1n, vi2n;
// input (3 height x 3 slide): vi_height_slide
float32x4_t vi00, vi01, vi02;
float32x4_t vi10, vi11, vi12;
float32x4_t vi20, vi21, vi22;
// output (1 outch x 1 height x 4 width): vo
float32x4_t vo;
// load input
index_t in_h = h * 2 - pad_top;
index_t in_w = w * 2 - pad_left;
index_t in_offset = in_h * in_width + in_w;
vi0 = vld2q_f32(in_base + in_offset); // [0.2.4.6, 1.3.5.7]
vi1 = vld2q_f32(in_base + in_offset + in_width);
vi2 = vld2q_f32(in_base + in_offset + 2 * in_width);
vi0n = vld1q_f32(in_base + in_offset + 8); // [8.9.10.11]
vi1n = vld1q_f32(in_base + in_offset + in_width + 8);
vi2n = vld1q_f32(in_base + in_offset + 2 * in_width + 8);
// load ouptut
index_t out_offset = h * out_width + w;
vo = vld1q_f32(out_base + out_offset);
vi00 = vi0.val[0]; // [0.2.4.6]
vi01 = vi0.val[1]; // [1.3.5.7]
vi02 = vextq_f32(vi00, vi0n, 1); // [2.4.6.8]
vi10 = vi1.val[0];
vi11 = vi1.val[1];
vi12 = vextq_f32(vi10, vi1n, 1);
vi20 = vi2.val[0];
vi21 = vi2.val[1];
vi22 = vextq_f32(vi20, vi2n, 1);
// outch 0, height 0
vo = vfmaq_laneq_f32(vo, vi00, vf00, 0);
vo = vfmaq_laneq_f32(vo, vi01, vf00, 1);
vo = vfmaq_laneq_f32(vo, vi02, vf00, 2);
vo = vfmaq_laneq_f32(vo, vi10, vf01, 0);
vo = vfmaq_laneq_f32(vo, vi11, vf01, 1);
vo = vfmaq_laneq_f32(vo, vi12, vf01, 2);
vo = vfmaq_laneq_f32(vo, vi20, vf02, 0);
vo = vfmaq_laneq_f32(vo, vi21, vf02, 1);
vo = vfmaq_laneq_f32(vo, vi22, vf02, 2);
vst1q_f32(out_base + out_offset, vo);
} // w
// right
for (; w < out_width; ++w) {
DepthwiseConv2dPixel(in_base,
filter_ptr,
h,
w,
h * 2 - pad_top,
w * 2 - pad_left,
out_width,
in_height,
in_width,
3,
3,
out_base);
}
} // h
#else
for (index_t ih = valid_h_start; ih < valid_h_stop; ++ih) {
for (index_t iw = 0; iw < out_width; ++iw) {
DepthwiseConv2dPixel(in_base,
filter_ptr,
ih,
iw,
ih * 2 - pad_top,
iw * 2 - pad_left,
out_width,
in_height,
in_width,
3,
3,
out_base);
}
}
#endif
// bottom
for (; h < out_height; ++h) {
for (w = 0; w < out_width; ++w) {
DepthwiseConv2dPixel(in_base,
filter_ptr,
h,
w,
h * 2 - pad_top,
w * 2 - pad_left,
out_width,
in_height,
in_width,
3,
3,
out_base);
}
}
} // m
} // b
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/pooling.h"
namespace mace {
namespace kernels {
namespace {
void MaxPooling(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t channels,
const index_t out_height,
const index_t out_width,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = channels * in_image_size;
const index_t out_batch_size = channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t out_base = b * out_batch_size + c * out_image_size;
const index_t in_base = b * in_batch_size + c * in_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
const index_t out_offset = out_base + h * out_width + w;
float res = std::numeric_limits<float>::lowest();
for (int fh = 0; fh < filter_height; ++fh) {
for (int fw = 0; fw < filter_width; ++fw) {
int inh = h * stride_h + dilation_h * fh - pad_top;
int inw = w * stride_w + dilation_w * fw - pad_left;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
index_t input_offset = in_base + inh * in_width + inw;
res = std::max(res, input[input_offset]);
}
}
}
output[out_offset] = res;
}
}
}
}
}
void AvgPooling(const float *input,
const index_t batch,
const index_t in_height,
const index_t in_width,
const index_t channels,
const index_t out_height,
const index_t out_width,
const int filter_height,
const int filter_width,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int pad_top,
const int pad_left,
float *output) {
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = channels * in_image_size;
const index_t out_batch_size = channels * out_image_size;
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t out_base = b * out_batch_size + c * out_image_size;
const index_t in_base = b * in_batch_size + c * in_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
const index_t out_offset = out_base + h * out_width + w;
float res = 0;
int block_size = 0;
for (int fh = 0; fh < filter_height; ++fh) {
for (int fw = 0; fw < filter_width; ++fw) {
int inh = h * stride_h + dilation_h * fh - pad_top;
int inw = w * stride_w + dilation_w * fw - pad_left;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
index_t input_offset = in_base + inh * in_width + inw;
res += input[input_offset];
++block_size;
}
}
}
output[out_offset] = res / block_size;
}
}
}
}
}
} // namespace
void PoolingFunctor<DeviceType::NEON,
float>::operator()(const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future) {
std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {
input_tensor->dim(1), input_tensor->dim(1), kernels_[0], kernels_[1]};
std::vector<int> paddings(2);
if (paddings_.empty()) {
kernels::CalcNCHWPaddingAndOutputSize(
input_tensor->shape().data(), filter_shape.data(), dilations_,
strides_, padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input_tensor->shape().data(), filter_shape.data(),
paddings_.data(), dilations_, strides_, RoundType::CEIL,
output_shape.data());
}
output_tensor->Resize(output_shape);
const float *input = input_tensor->data<float>();
float *output = output_tensor->mutable_data<float>();
const index_t *input_shape = input_tensor->shape().data();
index_t batch = output_shape[0];
index_t channels = output_shape[1];
index_t height = output_shape[2];
index_t width = output_shape[3];
index_t input_channels = input_shape[1];
index_t input_height = input_shape[2];
index_t input_width = input_shape[3];
index_t in_image_size = input_height * input_width;
int filter_h = kernels_[0];
int filter_w = kernels_[1];
int stride_h = strides_[0];
int stride_w = strides_[1];
int dilation_h = dilations_[0];
int dilation_w = dilations_[1];
int pad_top = paddings[0] / 2;
int pad_left = paddings[1] / 2;
if (pooling_type_ == PoolingType::MAX) {
MaxPooling(input,
batch,
input_height,
input_width,
channels,
height,
width,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
} else if (pooling_type_ == PoolingType::AVG) {
AvgPooling(input,
batch,
input_height,
input_width,
channels,
height,
width,
filter_h,
filter_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
pad_top,
pad_left,
output);
} else {
MACE_NOT_IMPLEMENTED;
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/softmax.h"
namespace mace {
namespace kernels {
void SoftmaxFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t class_count = input->dim(1);
const index_t class_size = input->dim(2) * input->dim(3);
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
for (index_t b = 0; b < batch; ++b) {
std::vector<float>
max_val(class_size, std::numeric_limits<float>::lowest());
std::vector<float> sum_val(class_size, 0.f);
// calculate max for each class
for (index_t c = 0; c < class_count; ++c) {
const float *input_ptr = input_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
max_val[k] = std::max(max_val[k], input_ptr[k]);
}
}
// calculate data - max for each class
#pragma omp parallel for
for (index_t c = 0; c < class_count; ++c) {
const float *input_ptr = input_data + (b * class_count + c) * class_size;
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
output_ptr[k] = ::exp(input_ptr[k] - max_val[k]);
}
}
// calculate sum for each class
for (index_t c = 0; c < class_count; ++c) {
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
sum_val[k] += output_ptr[k];
}
}
// calculate (data - max) / sum for each class
for (index_t c = 0; c < class_count; ++c) {
float *output_ptr = output_data + (b * class_count + c) * class_size;
for (index_t k = 0; k < class_size; ++k) {
output_ptr[k] /= sum_val[k];
}
}
}
}
} // namespace kernels
} // namespace mace
......@@ -133,14 +133,21 @@ struct BatchNormFunctor : BatchNormFunctorBase {
};
template <>
void BatchNormFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future);
struct BatchNormFunctor<DeviceType::NEON, float> : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant,
const ActivationType activation,
const float relux_max_limit)
: BatchNormFunctorBase(folded_constant, activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future);
};
template <typename T>
struct BatchNormFunctor<DeviceType::OPENCL, T> : BatchNormFunctorBase {
......
......@@ -212,6 +212,12 @@ switch (w_count) { \
case 2: \
MACE_DO_CONV2D(CC, CH, 2); \
break; \
case 3: \
MACE_DO_CONV2D(CC, CH, 3); \
break; \
case 4: \
MACE_DO_CONV2D(CC, CH, 4); \
break; \
default: \
LOG(FATAL) << "Unsupported w tile: " << w_count; \
}
......@@ -242,6 +248,42 @@ switch (c_count) { \
case 4: \
MACE_CASE_H_CONV2D(4); \
break; \
case 5: \
MACE_CASE_H_CONV2D(5); \
break; \
case 6: \
MACE_CASE_H_CONV2D(6); \
break; \
case 7: \
MACE_CASE_H_CONV2D(7); \
break; \
case 8: \
MACE_CASE_H_CONV2D(8); \
break; \
case 9: \
MACE_CASE_H_CONV2D(9); \
break; \
case 10: \
MACE_CASE_H_CONV2D(10); \
break; \
case 11: \
MACE_CASE_H_CONV2D(11); \
break; \
case 12: \
MACE_CASE_H_CONV2D(12); \
break; \
case 13: \
MACE_CASE_H_CONV2D(13); \
break; \
case 14: \
MACE_CASE_H_CONV2D(14); \
break; \
case 15: \
MACE_CASE_H_CONV2D(15); \
break; \
case 16: \
MACE_CASE_H_CONV2D(16); \
break; \
default: \
LOG(FATAL) << "Unsupported c tile: " << c_count; \
}
......@@ -373,11 +415,35 @@ struct Conv2dFunctor : Conv2dFunctorBase {
};
template <>
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
struct Conv2dFunctor<DeviceType::NEON, float> : Conv2dFunctorBase {
Conv2dFunctor(const int *strides,
const Padding &padding_type,
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit)
: Conv2dFunctorBase(strides,
padding_type,
paddings,
dilations,
activation,
relux_max_limit),
is_filter_transformed_(false) {}
void operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
// TODO(liyin): share tmp buffers among ops
Tensor padded_input_;
Tensor padded_output_;
Tensor transformed_input_;
Tensor transformed_filter_;
Tensor transformed_output_;
bool is_filter_transformed_;
};
template <typename T>
struct Conv2dFunctor<DeviceType::OPENCL, T> : Conv2dFunctorBase {
......
......@@ -9,7 +9,7 @@
namespace mace {
namespace kernels {
void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
const int *dilations,
const int *strides,
......@@ -186,6 +186,55 @@ void CalcOutputSize(const index_t *input_shape, // NHWC
output_shape[3] = filter_shape[2];
}
void CalcNCHWOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
const int *padding_size,
const int *dilations,
const int *strides,
const RoundType round_type,
index_t *output_shape) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
/*
* Convolution arithmetic:
* o = floor((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1
* Pooling arithmetic:
* o = ceil((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
output_shape[0] = input_shape[0];
if (round_type == FLOOR) {
output_shape[2] = static_cast<index_t>(
std::floor(1.0 * (input_shape[2] + padding_size[0] - filter_shape[2] -
(filter_shape[2] - 1) * (dilations[0] - 1)) /
strides[0]) +
1);
output_shape[3] = static_cast<index_t>(
std::floor(1.0 * (input_shape[3] + padding_size[1] - filter_shape[3] -
(filter_shape[3] - 1) * (dilations[1] - 1)) /
strides[1]) +
1);
} else {
output_shape[2] = static_cast<index_t>(
std::ceil(1.0 * (input_shape[2] + padding_size[0] - filter_shape[2] -
(filter_shape[2] - 1) * (dilations[0] - 1)) /
strides[0]) +
1);
output_shape[3] = static_cast<index_t>(
std::ceil(1.0 * (input_shape[3] + padding_size[1] - filter_shape[3] -
(filter_shape[3] - 1) * (dilations[1] - 1)) /
strides[1]) +
1);
}
output_shape[1] = filter_shape[0];
}
void CalPaddingSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
const int *dilations,
......@@ -230,10 +279,11 @@ void CalPaddingSize(const index_t *input_shape, // NCHW
0, (output_width - 1) * strides[1] + k_extent_width - input_shape[3]);
}
void ConstructInputWithPadding(const Tensor *input_tensor,
const int *paddings,
Tensor *output_tensor,
bool padding_same_value) {
void ConstructNCHWInputWithPadding(const Tensor *input_tensor,
const int *paddings,
Tensor *output_tensor,
bool padding_same_value) {
Tensor::MappingGuard input_mapper(input_tensor);
const float *input = input_tensor->data<float>();
const index_t *input_shape = input_tensor->shape().data();
......@@ -244,7 +294,7 @@ void ConstructInputWithPadding(const Tensor *input_tensor,
index_t width = input_shape[3];
std::vector<index_t> output_shape(
{batch, channels, paddings[0] + height, paddings[1] + width});
{batch, channels, paddings[0] + height, paddings[1] + width});
const index_t output_width = output_shape[3];
const int padded_top = paddings[0] / 2;
......@@ -268,6 +318,7 @@ void ConstructInputWithPadding(const Tensor *input_tensor,
const int padded_bottom = paddings[0] - padded_top;
const int padded_right = paddings[1] - padded_left;
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
for (int k = 0; k < padded_top; ++k) {
......@@ -301,6 +352,51 @@ void ConstructInputWithPadding(const Tensor *input_tensor,
}
}
void ConstructNCHWInputWithSpecificPadding(const Tensor *input_tensor,
const int pad_top,
const int pad_bottom,
const int pad_left,
const int pad_right,
Tensor *output_tensor) {
Tensor::MappingGuard input_mapper(input_tensor);
const float *input = input_tensor->data<float>();
const index_t *input_shape = input_tensor->shape().data();
index_t batch = input_shape[0];
index_t channels = input_shape[1];
index_t height = input_shape[2];
index_t width = input_shape[3];
const int pad_height = pad_top + pad_bottom;
const int pad_width = pad_left + pad_right;
std::vector<index_t> output_shape(
{batch, channels, height + pad_height, width + pad_width});
output_tensor->Resize(output_shape);
Tensor::MappingGuard padded_output_mapper(output_tensor);
float *output_data = output_tensor->mutable_data<float>();
const index_t output_height = output_shape[2];
const index_t output_width = output_shape[3];
const index_t in_image_size = height * width;
const index_t out_image_size = output_height * output_width;
const index_t in_batch_size = channels * in_image_size;
const index_t out_batch_size = channels * out_image_size;
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
for (int k = 0; k < height; ++k) {
memcpy(output_data + i * out_batch_size + j * out_image_size
+ (pad_top + k) * output_width + pad_left,
input + i * in_batch_size + j * in_image_size + k * width,
width * sizeof(float));
}
// Skip the padded bottom in this channel and top in the next channel
}
}
}
void ConstructNHWCInputWithPadding(const Tensor *input_tensor,
const int *paddings,
Tensor *output_tensor,
......
......@@ -22,16 +22,16 @@ enum RoundType {
namespace kernels {
void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
void CalcNCHWPaddingAndOutputSize(const index_t *input_shape,
const index_t *filter_shape,
const int *dilations,
const int *strides,
Padding padding,
index_t *output_shape,
int *padding_size);
void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
void CalcNHWCPaddingAndOutputSize(const index_t *input_shape,
const index_t *filter_shape,
const int *dilations,
const int *strides,
Padding padding,
......@@ -46,6 +46,14 @@ void CalcOutputSize(const index_t *input_shape, // NHWC
const RoundType round_type,
index_t *output_shape);
void CalcNCHWOutputSize(const index_t *input_shape,
const index_t *filter_shape,
const int *padding_size,
const int *dilations,
const int *strides,
const RoundType round_type,
index_t *output_shape);
void CalPaddingSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
const int *dilations,
......@@ -53,10 +61,15 @@ void CalPaddingSize(const index_t *input_shape, // NCHW
Padding padding,
int *padding_size);
void ConstructInputWithPadding(const Tensor *input,
const int *paddings,
Tensor *output_tensor,
bool padding_same_value = false);
void ConstructNCHWInputWithSpecificPadding(const Tensor *input,
const int pad_top, const int pad_bottom,
const int pad_left, const int pad_right,
Tensor *output_tensor);
void ConstructNCHWInputWithPadding(const Tensor *input,
const int *paddings,
Tensor *output_tensor,
bool padding_same_value = false);
void ConstructNHWCInputWithPadding(const Tensor *input,
const int *paddings,
......
......@@ -14,6 +14,7 @@
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/kernels/activation.h"
#include "mace/public/mace.h"
namespace mace {
......@@ -407,12 +408,27 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
};
template <>
void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
struct DepthwiseConv2dFunctor<DeviceType::NEON, float>
: DepthwiseConv2dFunctorBase {
DepthwiseConv2dFunctor(const int *strides,
const Padding padding_type,
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit)
: DepthwiseConv2dFunctorBase(strides,
padding_type,
paddings,
dilations,
activation,
relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
};
template <typename T>
struct DepthwiseConv2dFunctor<DeviceType::OPENCL, T>
......
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include <math.h>
#include <algorithm>
#include "mace/kernels/gemm.h"
#include "mace/utils/utils.h"
#include "mace/utils/logging.h"
namespace mace {
namespace kernels {
namespace {
void GemmRef(const float *A,
const float *B,
const index_t height,
const index_t K,
const index_t width,
float *C) {
memset(C, 0, sizeof(float) * height * width);
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j) {
for (int k = 0; k < K; ++k) {
C[i * width + j] += A[i * K + k] * B[k * width + j];
}
}
}
}
inline void GemmBlock(const float *A,
const float *B,
const index_t height,
const index_t K,
const index_t width,
const index_t stride_k,
const index_t stride_w,
float *C) {
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j) {
for (int k = 0; k < K; ++k) {
C[i * stride_w + j] += A[i * stride_k + k] * B[k * stride_w + j];
}
}
}
}
// TODO(liyin): may need implement 883 since RGB
inline void Gemm884(const float *a_ptr,
const float *b_ptr,
index_t stride_w,
index_t stride_k,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14,
a15;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1, c2, c3, c4, c5, c6, c7;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_k);
a3 = vld1q_f32(a_ptr + 1 * stride_k + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_k);
a5 = vld1q_f32(a_ptr + 2 * stride_k + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_k);
a7 = vld1q_f32(a_ptr + 3 * stride_k + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_k);
a9 = vld1q_f32(a_ptr + 4 * stride_k + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_k);
a11 = vld1q_f32(a_ptr + 5 * stride_k + 4);
a12 = vld1q_f32(a_ptr + 6 * stride_k);
a13 = vld1q_f32(a_ptr + 6 * stride_k + 4);
a14 = vld1q_f32(a_ptr + 7 * stride_k);
a15 = vld1q_f32(a_ptr + 7 * stride_k + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_w);
b2 = vld1q_f32(b_ptr + 2 * stride_w);
b3 = vld1q_f32(b_ptr + 3 * stride_w);
b4 = vld1q_f32(b_ptr + 4 * stride_w);
b5 = vld1q_f32(b_ptr + 5 * stride_w);
b6 = vld1q_f32(b_ptr + 6 * stride_w);
b7 = vld1q_f32(b_ptr + 7 * stride_w);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_w);
c2 = vld1q_f32(c_ptr + 2 * stride_w);
c3 = vld1q_f32(c_ptr + 3 * stride_w);
c4 = vld1q_f32(c_ptr + 4 * stride_w);
c5 = vld1q_f32(c_ptr + 5 * stride_w);
c6 = vld1q_f32(c_ptr + 6 * stride_w);
c7 = vld1q_f32(c_ptr + 7 * stride_w);
#define MACE_CONV_1x1_REG_CAL(RC, RA, RAN) \
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b3, a##RA, 3); \
c##RC = vfmaq_laneq_f32(c##RC, b4, a##RAN, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b5, a##RAN, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3);
MACE_CONV_1x1_REG_CAL(0, 0, 1);
MACE_CONV_1x1_REG_CAL(1, 2, 3);
MACE_CONV_1x1_REG_CAL(2, 4, 5);
MACE_CONV_1x1_REG_CAL(3, 6, 7);
MACE_CONV_1x1_REG_CAL(4, 8, 9);
MACE_CONV_1x1_REG_CAL(5, 10, 11);
MACE_CONV_1x1_REG_CAL(6, 12, 13);
MACE_CONV_1x1_REG_CAL(7, 14, 15);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_w, c1);
vst1q_f32(c_ptr + 2 * stride_w, c2);
vst1q_f32(c_ptr + 3 * stride_w, c3);
vst1q_f32(c_ptr + 4 * stride_w, c4);
vst1q_f32(c_ptr + 5 * stride_w, c5);
vst1q_f32(c_ptr + 6 * stride_w, c6);
vst1q_f32(c_ptr + 7 * stride_w, c7);
#else
GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_k, stride_w, c_ptr);
#endif
}
inline void GemmTile(const float *A,
const float *B,
const index_t height,
const index_t K,
const index_t width,
const index_t stride_k,
const index_t stride_w,
float *C) {
index_t h, w, k;
for (h = 0; h + 7 < height; h += 8) {
for (w = 0; w + 3 < width; w += 4) {
for (k = 0; k + 7 < K; k += 8) {
const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w);
Gemm884(a_ptr, b_ptr, stride_w, stride_k, c_ptr);
}
if (k < K) {
const float *a_ptr = A + (h * stride_k + k);
const float *b_ptr = B + (k * stride_w + w);
float *c_ptr = C + (h * stride_w + w);
GemmBlock(a_ptr, b_ptr, 8, K - k, 4, stride_k, stride_w, c_ptr);
}
}
if (w < width) {
const float *a_ptr = A + h * stride_k;
const float *b_ptr = B + w;
float *c_ptr = C + (h * stride_w + w);
GemmBlock(a_ptr,
b_ptr,
8,
K,
width - w,
stride_k,
stride_w,
c_ptr);
}
}
if (h < height) {
// TODO(liyin): may use Gemm444
const float *a_ptr = A + (h * stride_k);
const float *b_ptr = B;
float *c_ptr = C + h * stride_w;
GemmBlock(a_ptr,
b_ptr,
height - h,
K,
width,
stride_k,
stride_w,
c_ptr);
}
}
} // namespace
void Gemm(const float *A,
const float *B,
const index_t batch,
const index_t height,
const index_t K,
const index_t width,
float *C) {
memset(C, 0, sizeof(float) * batch * height * width);
// It is better to use large block size if it fits for fast cache.
// Assume l1 cache size is 32k, we load three blocks at a time (A, B, C),
// the block size should be sqrt(32k / sizeof(T) / 3).
const index_t block_size = 48;
const index_t block_tile_height = RoundUpDiv(height, block_size);
const index_t block_tile_width = RoundUpDiv(width, block_size);
const index_t block_tile_k = RoundUpDiv(K, block_size);
const index_t remain_height = height % block_size;
const index_t remain_width = width % block_size;
const index_t remain_k = K % block_size;
#pragma omp parallel for collapse(3)
for (index_t n = 0; n < batch; ++n) {
for (index_t bh = 0; bh < block_tile_height; ++bh) {
for (index_t bw = 0; bw < block_tile_width; ++bw) {
const float *a_base = A + n * height * K;
const float *b_base = B + n * K * width;
float *c_base = C + n * height * width;
const index_t ih_begin = bh * block_size;
const index_t ih_end =
bh * block_size + (bh == block_tile_height - 1 && remain_height > 0
? remain_height : block_size);
const index_t iw_begin = bw * block_size;
const index_t iw_end =
bw * block_size
+ (bw == block_tile_width - 1 && remain_width > 0 ? remain_width
: block_size);
for (index_t bk = 0; bk < block_tile_k; ++bk) {
const index_t ik_begin = bk * block_size;
const index_t ik_end =
bk * block_size
+ (bk == block_tile_k - 1 && remain_k > 0 ? remain_k
: block_size);
// inside block:
// calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k
GemmTile(a_base + (ih_begin * K + ik_begin),
b_base + (ik_begin * width + iw_begin),
ih_end - ih_begin,
ik_end - ik_begin,
iw_end - iw_begin,
K,
width,
c_base + (ih_begin * width + iw_begin));
} // bk
} // bw
} // bh
} // n
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_GEMM_H_
#define MACE_KERNELS_GEMM_H_
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
namespace mace {
namespace kernels {
void Gemm(const float *A,
const float *B,
const index_t batch,
const index_t height,
const index_t K,
const index_t width,
float *C);
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_GEMM_H_
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include <gtest/gtest.h>
#include <random>
#include "mace/kernels/gemm.h"
#include "mace/core/types.h"
namespace mace {
TEST(GEMMTest, gemm) {
index_t N = 17;
index_t M = 33;
index_t K = 64;
float *A = new float[N * K];
float *B = new float[K * M];
float *C = new float[N * M];
float *C_ref = new float[N * M];
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
std::generate(A, A + N * K,
[&gen, &nd] {
return nd(gen);
});
std::generate(B, B + K * M,
[&gen, &nd] {
return nd(gen);
});
kernels::Gemm(A, B, N, K, M, C);
kernels::GemmRef(A, B, N, K, M, C_ref);
for (int i = 0; i < N * M; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1);
}
delete[]A;
delete[]B;
delete[]C;
}
} // namespace mace
......@@ -16,142 +16,12 @@
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
#include "mace/kernels/gemm.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
template<typename T,
int register_tile_size,
int h_count,
int w_count,
int k_count>
inline void MatMulKernelFunc(const T *A,
const T *B,
T *C,
index_t offset_h,
index_t offset_w,
index_t offset_k,
index_t stride_h,
index_t stride_w,
index_t stride_k) {
T a_tmp[register_tile_size][register_tile_size] = {0};
T b_tmp[register_tile_size][register_tile_size] = {0};
T c_tmp[register_tile_size][register_tile_size] = {0};
for (int h = 0; h < h_count; ++h) {
for (int k = 0; k < k_count; ++k) {
a_tmp[h][k] = A[(offset_h + h) * stride_k + (offset_k + k)];
}
}
for (int k = 0; k < k_count; ++k) {
for (int w = 0; w < w_count; ++w) {
b_tmp[k][w] = B[(offset_k + k) * stride_w + (offset_w + w)];
}
}
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
static_assert(register_tile_size == 4, "register tile size must be 4");
float32x4_t a_dup;
float32x4_t b_vec[4] =
{vld1q_f32(b_tmp[0]), vld1q_f32(b_tmp[1]), vld1q_f32(b_tmp[2]),
vld1q_f32(b_tmp[3])};
float32x4_t
c_vec[4] = {vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0)};
for (int h = 0; h < register_tile_size; ++h) {
for (int k = 0; k < register_tile_size; ++k) {
a_dup = vdupq_n_f32(a_tmp[h][k]);
c_vec[h] = vfmaq_f32(c_vec[h], a_dup, b_vec[k]);
}
}
for (int h = 0; h < register_tile_size; ++h) {
vst1q_f32(c_tmp[h], c_vec[h]);
}
#else
for (int h = 0; h < register_tile_size; ++h) {
for (int w = 0; w < register_tile_size; ++w) {
for (int k = 0; k < register_tile_size; ++k) {
c_tmp[h][w] += a_tmp[h][k] * b_tmp[k][w];
}
}
}
#endif
for (int h = 0; h < h_count; ++h) {
for (int w = 0; w < w_count; ++w) {
C[(offset_h + h) * stride_w + (offset_w + w)] += c_tmp[h][w];
}
}
}
#define MACE_DO_MATMUL(HC, WC, KC) \
MatMulKernelFunc<T, register_tile_size, HC, WC, KC>(a_ptr_batch_base, \
b_ptr_batch_base, \
c_ptr_batch_base, \
ih, \
iw, \
ik, \
height, \
width, \
K);
#define MACE_CASE_K_MATMUL(HC, WC) \
switch (k_count) { \
case 1: \
MACE_DO_MATMUL(HC, WC, 1); \
break; \
case 2: \
MACE_DO_MATMUL(HC, WC, 2); \
break; \
case 3: \
MACE_DO_MATMUL(HC, WC, 3); \
break; \
case 4: \
MACE_DO_MATMUL(HC, WC, 4); \
break; \
default: \
LOG(FATAL) << "Unsupported k tile: " << k_count; \
}
#define MACE_CASE_W_MATMUL(HC) \
switch (w_count) { \
case 1: \
MACE_CASE_K_MATMUL(HC, 1); \
break; \
case 2: \
MACE_CASE_K_MATMUL(HC, 2); \
break; \
case 3: \
MACE_CASE_K_MATMUL(HC, 3); \
break; \
case 4: \
MACE_CASE_K_MATMUL(HC, 4); \
break; \
default: \
LOG(FATAL) << "Unsupported w tile: " << w_count; \
}
#define MACE_CASE_H_MATMUL \
switch (h_count) { \
case 1: \
MACE_CASE_W_MATMUL(1); \
break; \
case 2: \
MACE_CASE_W_MATMUL(2); \
break; \
case 3: \
MACE_CASE_W_MATMUL(3); \
break; \
case 4: \
MACE_CASE_W_MATMUL(4); \
break; \
default: \
LOG(FATAL) << "Unsupported h tile: " << h_count; \
}
template<DeviceType D, typename T>
struct MatMulFunctor {
void operator()(const Tensor *A,
......@@ -185,51 +55,7 @@ struct MatMulFunctor {
constexpr index_t register_tile_size = 4;
memset(c_ptr_base, 0, batch * height * width * sizeof(T));
#pragma omp parallel for collapse(3)
for (index_t n = 0; n < batch; ++n) {
// handle block
for (index_t bh = 0; bh < block_tile_height; ++bh) {
for (index_t bw = 0; bw < block_tile_width; ++bw) {
const T *a_ptr_batch_base = a_ptr_base + n * height * K;
const T *b_ptr_batch_base = b_ptr_base + n * K * width;
T *c_ptr_batch_base = c_ptr_base + n * height * width;
const index_t ih_begin = bh * block_size;
const index_t ih_end =
bh * block_size + (bh == block_tile_height - 1 && remain_height > 0
? remain_height : block_size);
const index_t iw_begin = bw * block_size;
const index_t iw_end =
bw * block_size
+ (bw == block_tile_width - 1 && remain_width > 0 ? remain_width
: block_size);
for (index_t bk = 0; bk < block_tile_k; ++bk) {
const index_t ik_begin = bk * block_size;
const index_t ik_end =
bk * block_size
+ (bk == block_tile_k - 1 && remain_k > 0 ? remain_k
: block_size);
// inside block:
// calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k
for (index_t ih = ih_begin; ih < ih_end;
ih += register_tile_size) {
for (index_t iw = iw_begin; iw < iw_end;
iw += register_tile_size) {
for (index_t ik = ik_begin; ik < ik_end;
ik += register_tile_size) {
const int h_count = std::min(register_tile_size, ih_end - ih);
const int w_count = std::min(register_tile_size, iw_end - iw);
const int k_count = std::min(register_tile_size, ik_end - ik);
MACE_CASE_H_MATMUL;
} // ik
} // iw
} // ih
} // bk
} // bw
} // bh
} // n
Gemm(a_ptr_base, b_ptr_base, batch, height, K, width, c_ptr_base);
}
};
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <arm_neon.h>
#include <float.h>
#include <limits>
namespace mace {
namespace kernels {
void PoolingAvgNeonK2x2S2x2(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape,
const int *paddings) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int padding_top = paddings[0] / 2;
int padding_bottom = paddings[0] - padding_top;
int padding_left = paddings[1] / 2;
int padding_right = paddings[1] - padding_left;
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
float avg_factors[4] = {0.25, 0.25, 0.25, 0.25};
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
float *outptr = output + output_offset;
const float *r0, *r1;
for (int h = 0; h < out_height; ++h) {
int w = 0;
int num_vectors = 0;
if (!((h == 0 && padding_top > 0) ||
(h == out_height - 1 && padding_bottom > 0))) {
r0 = input + input_offset + (h * 2 - padding_top) * in_width;
r1 = r0 + in_width;
if (padding_left > 0) {
*outptr = (r0[0] + r1[0]) * 0.25;
++r0;
++r1;
++outptr;
++w;
}
if (padding_right > 0) {
num_vectors = (out_width - w - 1) >> 2;
} else {
num_vectors = (out_width - w) >> 2;
}
}
w += num_vectors << 2;
float32x4_t factors = vld1q_f32(avg_factors);
for (; num_vectors > 0; --num_vectors) {
float32x4_t r00 = vld1q_f32(r0);
float32x4_t r10 = vld1q_f32(r1);
float32x4_t r01 = vld1q_f32(r0 + 4);
float32x4_t r11 = vld1q_f32(r1 + 4);
float32x4_t sum0 = vaddq_f32(r00, r10);
float32x4_t sum1 = vaddq_f32(r01, r11);
float32x4_t sum_result = vpaddq_f32(sum0, sum1);
float32x4_t avg_result = vmulq_f32(sum_result, factors);
vst1q_f32(outptr, avg_result);
r0 += 8;
r1 += 8;
outptr += 4;
}
for (; w < out_width; ++w) {
float sum = 0.0;
for (int kh = 0; kh < 2; ++kh) {
for (int kw = 0; kw < 2; ++kw) {
int inh = h * 2 - padding_top + kh;
int inw = w * 2 - padding_left + kw;
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
sum += input[input_offset + inh * in_width + inw];
}
}
}
*outptr = sum * 0.25;
++outptr;
}
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
// assume the input has already been padded
void PoolingAvgNeonK2x2S2x2Padded(const float *input,
const index_t *in_shape,
float *output,
const index_t *out_shape) {
index_t batch = in_shape[0];
index_t channels = in_shape[1];
index_t in_height = in_shape[2];
index_t in_width = in_shape[3];
index_t out_height = out_shape[2];
index_t out_width = out_shape[3];
int in_image_size = in_height * in_width;
int out_image_size = out_height * out_width;
index_t input_offset = 0;
index_t output_offset = 0;
float avg_factors[4] = {0.25, 0.25, 0.25, 0.25};
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
const float *img0 = input + input_offset;
float *outptr = output + output_offset;
const float *r0 = img0;
const float *r1 = img0 + in_width;
for (int h = 0; h < out_height; ++h) {
int num_vectors = out_width >> 2;
int remain = out_width - (num_vectors << 2);
float32x4_t factors = vld1q_f32(avg_factors);
for (; num_vectors > 0; --num_vectors) {
float32x4_t r00 = vld1q_f32(r0);
float32x4_t r10 = vld1q_f32(r1);
float32x4_t r01 = vld1q_f32(r0 + 4);
float32x4_t r11 = vld1q_f32(r1 + 4);
float32x4_t sum0 = vaddq_f32(r00, r10);
float32x4_t sum1 = vaddq_f32(r01, r11);
float32x4_t sum_result = vpaddq_f32(sum0, sum1);
float32x4_t avg_result = vmulq_f32(sum_result, factors);
vst1q_f32(outptr, avg_result);
r0 += 8;
r1 += 8;
outptr += 4;
}
for (; remain > 0; --remain) {
*outptr = (r0[0] + r0[1] + r1[0] + r1[1]) * 0.25;
r0 += 2;
r1 += 2;
outptr++;
}
r0 += in_width;
r1 += in_width;
}
input_offset += in_image_size;
output_offset += out_image_size;
}
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/batch_norm.h"
#include <arm_neon.h>
namespace mace {
namespace kernels {
template <>
void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const float epsilon,
Tensor *output,
StatsFuture *future) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\epsilon} } * X +
// ( \offset - \frac { \scale * mean } { \sqrt{var+\epsilon}
// }
// new_scale = \frac{ \scale } { \sqrt{var+\epsilon} }
// new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset;
const index_t n = input->dim(0);
const index_t sample_size = input->dim(1) * input->dim(2);
const index_t channel = input->dim(3);
const float *input_ptr = input->data<float>();
const float *scale_ptr = scale->data<float>();
const float *offset_ptr = offset->data<float>();
const float *mean_ptr = mean->data<float>();
const float *var_ptr = var->data<float>();
float *output_ptr = output->mutable_data<float>();
const index_t ch_blks = channel >> 2;
const index_t remain_chs = channel - (ch_blks << 2);
std::vector<float> new_scale(channel);
std::vector<float> new_offset(channel);
#pragma omp parallel for
for (index_t c = 0; c < channel; ++c) {
new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon);
new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c];
}
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < n; ++i) {
for (index_t j = 0; j < sample_size; ++j) {
const float *input_sample_ptr =
input_ptr + (i * sample_size + j) * channel;
float *output_sample_ptr = output_ptr + (i * sample_size + j) * channel;
const float *new_scale_ptr = new_scale.data();
const float *new_offset_ptr = new_offset.data();
for (index_t cb = 0; cb < ch_blks; ++cb) {
float32x4_t new_scale_f = vld1q_f32(new_scale_ptr);
float32x4_t new_offset_f = vld1q_f32(new_offset_ptr);
float32x4_t input_f = vld1q_f32(input_sample_ptr);
float32x4_t output_f = vfmaq_f32(new_offset_f, input_f, new_scale_f);
vst1q_f32(output_sample_ptr, output_f);
input_sample_ptr += 4;
output_sample_ptr += 4;
new_scale_ptr += 4;
new_offset_ptr += 4;
}
for (index_t c = (ch_blks << 2); c < channel; ++c) {
*output_sample_ptr = new_scale[c] * *input_sample_ptr + new_offset[c];
++output_sample_ptr;
++input_sample_ptr;
++new_scale_ptr;
++new_offset_ptr;
}
}
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/conv_pool_2d_util.h"
namespace mace {
namespace kernels {
extern void Conv2dNeonK1x1S1(const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
extern void Conv2dNeonK3x3S1(const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
extern void Conv2dNeonK3x3S2(const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
extern void Conv2dNeonK5x5S1(const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
template <>
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(output);
std::vector<index_t> output_shape_vec(4);
std::vector<int> paddings(2);
kernels::CalcPaddingAndOutputSize(
input->shape().data(), filter->shape().data(), dilations_, strides_,
paddings_, output_shape_vec.data(), paddings.data());
output->Resize(output_shape_vec);
typedef void (*Conv2dNeonFunction)(
const float *input, const index_t *input_shape, const float *filter,
const index_t *filter_shape, const float *bias, float *output,
const index_t *output_shape);
// Selection matrix: kernel_size x stride_size
static const Conv2dNeonFunction selector[5][2] = {
{Conv2dNeonK1x1S1, nullptr},
{nullptr, nullptr},
{Conv2dNeonK3x3S1, Conv2dNeonK3x3S2},
{nullptr, nullptr},
{Conv2dNeonK5x5S1, nullptr}};
// not implement yet
index_t kernel_h = filter->dim(2);
index_t kernel_w = filter->dim(3);
if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
LOG(WARNING) << "NEON conv2d kernel with "
<< "filter" << kernel_h << "x" << kernel_w << ","
<< " stride " << strides_[0] << "x" << strides_[1]
<< " is not implemented yet, using slow version";
Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)(
input, filter, bias, output, future);
return;
}
Tensor padded_input;
// Keep this alive during kernel execution
if (paddings[0] > 0 || paddings[1] > 0) {
ConstructInputWithPadding(input, paddings.data(), &padded_input);
input = &padded_input;
}
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard filter_mapper(filter);
Tensor::MappingGuard bias_mapper(bias);
Tensor::MappingGuard output_mapper(output);
auto input_data = input->data<float>();
auto input_shape = input->shape().data();
auto filter_data = filter->data<float>();
auto bias_data = bias == nullptr ? nullptr : bias->data<float>();
auto output_data = output->mutable_data<float>();
auto output_shape = output->shape().data();
auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_neon_func(input_data, input_shape, filter_data, nullptr, bias_data,
output_data, output_shape);
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <arm_neon.h>
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
static constexpr index_t kInputChannelBlockSize = 2;
static constexpr index_t kOutputChannelBlockSize = 4;
static __attribute__((__aligned__(64)))
int32_t mask_array[8] = {0, 0, 0, 0, -1, -1, -1, -1};
static inline void NeonConv2x4Kernel(index_t input_channels,
index_t pixel_size,
const float *input,
const float *filter,
float *output) {
const float *input0 = input;
const float *input1 = input + pixel_size;
const float32x2_t vfilter0x = vld1_f32(filter);
filter += input_channels;
const float32x2_t vfilter1x = vld1_f32(filter);
filter += input_channels;
const float32x2_t vfilter2x = vld1_f32(filter);
filter += input_channels;
const float32x2_t vfilter3x = vld1_f32(filter);
float *output0 = output;
float *output1 = output0 + pixel_size;
float *output2 = output1 + pixel_size;
float *output3 = output2 + pixel_size;
while (pixel_size >= 4) {
float32x4_t voutput0 = vld1q_f32(output0);
float32x4_t voutput1 = vld1q_f32(output1);
float32x4_t voutput2 = vld1q_f32(output2);
float32x4_t voutput3 = vld1q_f32(output3);
const float32x4_t vinput0 = vld1q_f32(input0);
input0 += 4;
voutput0 = vfmaq_lane_f32(voutput0, vinput0, vfilter0x, 0);
voutput1 = vfmaq_lane_f32(voutput1, vinput0, vfilter1x, 0);
voutput2 = vfmaq_lane_f32(voutput2, vinput0, vfilter2x, 0);
voutput3 = vfmaq_lane_f32(voutput3, vinput0, vfilter3x, 0);
const float32x4_t vinput1 = vld1q_f32(input1);
input1 += 4;
voutput0 = vfmaq_lane_f32(voutput0, vinput1, vfilter0x, 1);
voutput1 = vfmaq_lane_f32(voutput1, vinput1, vfilter1x, 1);
voutput2 = vfmaq_lane_f32(voutput2, vinput1, vfilter2x, 1);
voutput3 = vfmaq_lane_f32(voutput3, vinput1, vfilter3x, 1);
vst1q_f32(output0, voutput0);
output0 += 4;
vst1q_f32(output1, voutput1);
output1 += 4;
vst1q_f32(output2, voutput2);
output2 += 4;
vst1q_f32(output3, voutput3);
output3 += 4;
pixel_size -= 4;
}
if (pixel_size != 0) {
const int32x4_t vmask = vld1q_s32(&mask_array[pixel_size]);
output0 = output0 + pixel_size - 4;
float32x4_t voutput0 = vld1q_f32(output0);
output1 = output1 + pixel_size - 4;
float32x4_t voutput1 = vld1q_f32(output1);
output2 = output2 + pixel_size - 4;
float32x4_t voutput2 = vld1q_f32(output2);
output3 = output3 + pixel_size - 4;
float32x4_t voutput3 = vld1q_f32(output3);
const float32x4_t vinput0 = vreinterpretq_f32_s32(vandq_s32(
vmask, vreinterpretq_s32_f32(vld1q_f32(&input0[pixel_size - 4]))));
voutput0 = vfmaq_lane_f32(voutput0, vinput0, vfilter0x, 0);
voutput1 = vfmaq_lane_f32(voutput1, vinput0, vfilter1x, 0);
voutput2 = vfmaq_lane_f32(voutput2, vinput0, vfilter2x, 0);
voutput3 = vfmaq_lane_f32(voutput3, vinput0, vfilter3x, 0);
const float32x4_t vinput1 = vreinterpretq_f32_s32(vandq_s32(
vmask, vreinterpretq_s32_f32(vld1q_f32(&input1[pixel_size - 4]))));
voutput0 = vfmaq_lane_f32(voutput0, vinput1, vfilter0x, 1);
voutput1 = vfmaq_lane_f32(voutput1, vinput1, vfilter1x, 1);
voutput2 = vfmaq_lane_f32(voutput2, vinput1, vfilter2x, 1);
voutput3 = vfmaq_lane_f32(voutput3, vinput1, vfilter3x, 1);
vst1q_f32(output0, voutput0);
vst1q_f32(output1, voutput1);
vst1q_f32(output2, voutput2);
vst1q_f32(output3, voutput3);
}
}
static inline void NeonConv2x4SubBlockKernel(
index_t input_channels_subblock_size,
index_t output_channels_subblock_size,
index_t input_channels,
index_t pixel_size,
const float *input,
const float *filter,
float *output) {
const float *input0 = input;
const float *input1 = input + pixel_size;
float32x2_t vfilter0x, vfilter1x, vfilter2x, vfilter3x;
vfilter0x = vld1_dup_f32(&filter[0]);
if (input_channels_subblock_size > 1) {
vfilter0x = vld1_lane_f32(&filter[1], vfilter0x, 1);
}
if (output_channels_subblock_size > 1) {
filter += input_channels;
vfilter1x = vld1_dup_f32(&filter[0]);
if (input_channels_subblock_size > 1) {
vfilter1x = vld1_lane_f32(&filter[1], vfilter1x, 1);
}
if (output_channels_subblock_size > 2) {
filter += input_channels;
vfilter2x = vld1_dup_f32(&filter[0]);
if (input_channels_subblock_size > 1) {
vfilter2x = vld1_lane_f32(&filter[1], vfilter2x, 1);
}
if (output_channels_subblock_size > 3) {
filter += input_channels;
vfilter3x = vld1_dup_f32(&filter[0]);
if (input_channels_subblock_size > 1) {
vfilter3x = vld1_lane_f32(&filter[1], vfilter3x, 1);
}
}
}
}
float *output0 = output;
float *output1 = output0 + pixel_size;
float *output2 = output1 + pixel_size;
float *output3 = output2 + pixel_size;
while (pixel_size >= 4) {
float32x4_t voutput0, voutput1, voutput2, voutput3;
voutput0 = vld1q_f32(output0);
if (output_channels_subblock_size > 1) {
voutput1 = vld1q_f32(output1);
if (output_channels_subblock_size > 2) {
voutput2 = vld1q_f32(output2);
if (output_channels_subblock_size > 3) {
voutput3 = vld1q_f32(output3);
}
}
}
const float32x4_t vinput0 = vld1q_f32(input0);
input0 += 4;
voutput0 = vfmaq_lane_f32(voutput0, vinput0, vfilter0x, 0);
voutput1 = vfmaq_lane_f32(voutput1, vinput0, vfilter1x, 0);
voutput2 = vfmaq_lane_f32(voutput2, vinput0, vfilter2x, 0);
voutput3 = vfmaq_lane_f32(voutput3, vinput0, vfilter3x, 0);
if (input_channels_subblock_size > 1) {
const float32x4_t vinput1 = vld1q_f32(input1);
input1 += 4;
voutput0 = vfmaq_lane_f32(voutput0, vinput1, vfilter0x, 1);
voutput1 = vfmaq_lane_f32(voutput1, vinput1, vfilter1x, 1);
voutput2 = vfmaq_lane_f32(voutput2, vinput1, vfilter2x, 1);
voutput3 = vfmaq_lane_f32(voutput3, vinput1, vfilter3x, 1);
}
vst1q_f32(output0, voutput0);
output0 += 4;
if (output_channels_subblock_size > 1) {
vst1q_f32(output1, voutput1);
output1 += 4;
if (output_channels_subblock_size > 2) {
vst1q_f32(output2, voutput2);
output2 += 4;
if (output_channels_subblock_size > 3) {
vst1q_f32(output3, voutput3);
output3 += 4;
}
}
}
pixel_size -= 4;
}
if (pixel_size != 0) {
const int32x4_t vmask = vld1q_s32(&mask_array[pixel_size]);
float32x4_t voutput0, voutput1, voutput2, voutput3;
output0 += pixel_size - 4;
voutput0 = vld1q_f32(output0);
if (output_channels_subblock_size > 1) {
output1 += pixel_size - 4;
voutput1 = vld1q_f32(output1);
if (output_channels_subblock_size > 2) {
output2 += pixel_size - 4;
voutput2 = vld1q_f32(output2);
if (output_channels_subblock_size > 3) {
output3 += pixel_size - 4;
voutput3 = vld1q_f32(output3);
}
}
}
const float32x4_t vinput0 = vreinterpretq_f32_s32(vandq_s32(
vmask, vreinterpretq_s32_f32(vld1q_f32(&input0[pixel_size - 4]))));
voutput0 = vfmaq_lane_f32(voutput0, vinput0, vfilter0x, 0);
voutput1 = vfmaq_lane_f32(voutput1, vinput0, vfilter1x, 0);
voutput2 = vfmaq_lane_f32(voutput2, vinput0, vfilter2x, 0);
voutput3 = vfmaq_lane_f32(voutput3, vinput0, vfilter3x, 0);
if (input_channels_subblock_size > 1) {
const float32x4_t vinput1 = vreinterpretq_f32_s32(vandq_s32(
vmask, vreinterpretq_s32_f32(vld1q_f32(&input1[pixel_size - 4]))));
voutput0 = vfmaq_lane_f32(voutput0, vinput1, vfilter0x, 1);
voutput1 = vfmaq_lane_f32(voutput1, vinput1, vfilter1x, 1);
voutput2 = vfmaq_lane_f32(voutput2, vinput1, vfilter2x, 1);
voutput3 = vfmaq_lane_f32(voutput3, vinput1, vfilter3x, 1);
}
vst1q_f32(output0, voutput0);
if (output_channels_subblock_size > 1) {
vst1q_f32(output1, voutput1);
if (output_channels_subblock_size > 2) {
vst1q_f32(output2, voutput2);
if (output_channels_subblock_size > 3) {
vst1q_f32(output3, voutput3);
}
}
}
}
}
void Conv2dNeonK1x1S1(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, filter_h, filter_w
const index_t *filter_shape,
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
const index_t batch = output_shape[0];
const index_t channels = output_shape[1];
const index_t height = output_shape[2];
const index_t width = output_shape[3];
const index_t input_batch = input_shape[0];
const index_t input_channels = input_shape[1];
const index_t input_height = input_shape[2];
const index_t input_width = input_shape[3];
MACE_CHECK(input_batch == batch && input_height == height &&
input_width == width);
const index_t total_pixels = height * width;
const index_t round_up_channels = RoundUp(channels, kOutputChannelBlockSize);
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (int i = 0; i < channels; ++i) {
float *output_ptr_base =
output + n * channels * total_pixels + i * total_pixels;
std::fill(output_ptr_base, output_ptr_base + total_pixels,
bias ? bias[i] : 0);
}
}
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (index_t c = 0; c < round_up_channels; c += kOutputChannelBlockSize) {
const float *input_ptr = input + n * input_channels * total_pixels;
const float *filter_ptr = filter + c * input_channels;
float *output_ptr =
output + n * channels * total_pixels + c * total_pixels;
const index_t output_channel_block_size =
std::min(channels - c, kOutputChannelBlockSize);
index_t remain_input_channels = input_channels;
if (c + kOutputChannelBlockSize <= channels) {
while (remain_input_channels >= kInputChannelBlockSize) {
NeonConv2x4Kernel(input_channels, total_pixels, input_ptr, filter_ptr,
output_ptr);
input_ptr += kInputChannelBlockSize * total_pixels;
filter_ptr += kInputChannelBlockSize;
remain_input_channels -= kInputChannelBlockSize;
}
}
while (remain_input_channels != 0) {
const index_t input_channel_block_size =
std::min(remain_input_channels, kInputChannelBlockSize);
NeonConv2x4SubBlockKernel(
input_channel_block_size, output_channel_block_size, input_channels,
total_pixels, input_ptr, filter_ptr, output_ptr);
input_ptr += kInputChannelBlockSize * total_pixels;
filter_ptr += kInputChannelBlockSize;
remain_input_channels -= input_channel_block_size;
}
}
}
}
void Conv2dNeonPixelK1x1S1(
const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
const index_t batch = output_shape[0];
const index_t channels = output_shape[1];
const index_t height = output_shape[2];
const index_t width = output_shape[3];
const index_t input_batch = input_shape[0];
const index_t input_channels = input_shape[1];
const index_t input_height = input_shape[2];
const index_t input_width = input_shape[3];
MACE_CHECK(input_batch == batch && input_height == height &&
input_width == width);
const index_t total_pixels = height * width;
// Process 4 * 2 = 8 pixels for each innermost loop
// TODO(heliangliang): Does 64 bit v.s. 32 bit index matters? need benchmark
const index_t total_loops = total_pixels >> 3;
const index_t loop_remaining = total_pixels & 7;
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (index_t c = 0; c < channels; ++c) {
const float *filter_ptr = filter + c * input_channels;
// TODO(heliangliang): Will GCC opt these out?
float *channel_output_start =
output + n * channels * height * width + c * height * width;
const float *input_ptr =
input + n * input_channels * input_height * input_width;
// Fill with bias
float *output_ptr = channel_output_start;
std::fill(output_ptr, output_ptr + total_pixels, bias ? bias[c] : 0);
index_t inc = 0;
// Process 4 input channels in batch
for (; inc + 3 < input_channels; inc += 4) {
float *output_ptr = channel_output_start;
// The begining of each input feature map channel
MACE_ASSERT(input_ptr ==
input + n * input_channels * input_height * input_width +
inc * input_height * input_width);
const float *input_ptr1 = input_ptr + total_pixels;
const float *input_ptr2 = input_ptr1 + total_pixels;
const float *input_ptr3 = input_ptr2 + total_pixels;
// filter is in c_out, c_in, 1, 1 order
MACE_ASSERT(filter_ptr == filter + c * input_channels + inc);
const float k0 = filter_ptr[0];
const float k1 = filter_ptr[1];
const float k2 = filter_ptr[2];
const float k3 = filter_ptr[3];
filter_ptr += 4;
const float32x4_t vk0 = vdupq_n_f32(k0);
const float32x4_t vk1 = vdupq_n_f32(k1);
const float32x4_t vk2 = vdupq_n_f32(k2);
const float32x4_t vk3 = vdupq_n_f32(k3);
index_t loop_itr = total_loops;
for (; loop_itr > 0; --loop_itr) {
// Process 2 group of 4 floats
float32x4_t out0 = vld1q_f32(output_ptr);
float32x4_t out4 = vld1q_f32(output_ptr + 4);
const float32x4_t in00 = vld1q_f32(input_ptr);
const float32x4_t in04 = vld1q_f32(input_ptr + 4);
out0 = vfmaq_f32(out0, in00, vk0);
out4 = vfmaq_f32(out4, in04, vk0);
const float32x4_t in10 = vld1q_f32(input_ptr1);
const float32x4_t in14 = vld1q_f32(input_ptr1 + 4);
out0 = vfmaq_f32(out0, in10, vk1);
out4 = vfmaq_f32(out4, in14, vk1);
const float32x4_t in20 = vld1q_f32(input_ptr2);
const float32x4_t in24 = vld1q_f32(input_ptr2 + 4);
out0 = vfmaq_f32(out0, in20, vk2);
out4 = vfmaq_f32(out4, in24, vk2);
const float32x4_t in30 = vld1q_f32(input_ptr3);
const float32x4_t in34 = vld1q_f32(input_ptr3 + 4);
out0 = vfmaq_f32(out0, in30, vk3);
out4 = vfmaq_f32(out4, in34, vk3);
float prev_output = output_ptr[0];
// Save output
vst1q_f32(output_ptr, out0);
vst1q_f32(output_ptr + 4, out4);
output_ptr += 8;
input_ptr += 8;
input_ptr1 += 8;
input_ptr2 += 8;
input_ptr3 += 8;
}
// Process the remaining pixels
index_t remaining_pixels = loop_remaining;
for (; remaining_pixels > 0; --remaining_pixels) {
const float mul = *input_ptr * k0;
const float mul1 = *input_ptr1 * k1;
const float mul2 = *input_ptr2 * k2;
const float mul3 = *input_ptr3 * k3;
float prev_output = output_ptr[0];
*output_ptr += mul + mul1 + mul2 + mul3;
++output_ptr;
++input_ptr;
++input_ptr1;
++input_ptr2;
++input_ptr3;
}
// Skip these 4 feature maps
input_ptr += 3 * total_pixels;
}
// Process the remaining channels
for (; inc < input_channels; ++inc) {
float *output_ptr = channel_output_start;
MACE_ASSERT(input_ptr ==
input + n * input_channels * input_height * input_width +
inc * input_height * input_width);
MACE_ASSERT(filter_ptr == filter + c * input_channels + inc);
const float k0 = filter_ptr[0];
++filter_ptr;
const float32x4_t vk0 = vdupq_n_f32(k0);
index_t loop_itr = total_loops;
for (; loop_itr > 0; --loop_itr) {
float32x4_t out0 = vld1q_f32(output_ptr);
float32x4_t out4 = vld1q_f32(output_ptr + 4);
const float32x4_t in0 = vld1q_f32(input_ptr);
const float32x4_t in4 = vld1q_f32(input_ptr + 4);
out0 = vfmaq_f32(out0, in0, vk0);
out4 = vfmaq_f32(out4, in4, vk0);
// Save output
vst1q_f32(output_ptr, out0);
vst1q_f32(output_ptr + 4, out4);
output_ptr += 8;
input_ptr += 8;
}
// Process the remaining pixels
index_t remaining_pixels = loop_remaining;
for (; remaining_pixels > 0; --remaining_pixels) {
const float mul = *input_ptr * k0;
*output_ptr += mul;
++output_ptr;
++input_ptr;
}
}
}
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <arm_neon.h>
namespace mace {
namespace kernels {
static const int kRegisterSize = 4;
static const int kFilterSize = 9;
void Conv2dNeonK3x3S1(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
int height_count = (output_shape[2] >> 1) << 1;
int output_batch = output_shape[0];
int output_channels = output_shape[1];
int output_height = output_shape[2];
int output_width = output_shape[3];
int input_batch = input_shape[0];
int input_channels = input_shape[1];
int input_height = input_shape[2];
int input_width = input_shape[3];
int multiplier = filter_shape == nullptr ? 0 : filter_shape[0];
int filter_in_channels = filter_shape == nullptr ? input_channels : 1;
#pragma omp parallel for collapse(2)
for (int b = 0; b < output_batch; ++b) {
for (int oc = 0; oc < output_channels; ++oc) {
float *output_ptr_base =
output + b * output_channels * output_height * output_width;
const float *filter_ptr = filter + oc * filter_in_channels * kFilterSize;
const float *input_ptr =
input + b * input_channels * input_height * input_width;
if (filter_shape != nullptr) {
input_ptr += (oc / multiplier) * input_height * input_width;
}
float *output_ptr = output_ptr_base + oc * output_height * output_width;
std::fill(output_ptr, output_ptr + output_height * output_width,
bias ? bias[oc] : 0);
for (int ic = 0; ic < filter_in_channels; ++ic) {
float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr),
vld1q_f32(filter_ptr + 3),
vld1q_f32(filter_ptr + 6)};
const float *row_ptr_v[kRegisterSize] = {
input_ptr, input_ptr + input_width, input_ptr + 2 * input_width,
input_ptr + 3 * input_width};
float *output_ptr_v[] = {output_ptr, output_ptr + output_width};
for (int h = 0; h < height_count; h += 2) {
int count = output_width >> 2;
int remain_count = output_width & 3;
for (; count > 0; --count) {
float32x4_t n_sum0 = vdupq_n_f32(.0f);
float32x4_t n_row_former = vld1q_f32(row_ptr_v[0]);
float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + kRegisterSize);
float32x4_t n_row_ext0 = vextq_f32(n_row_former, n_row_latter, 1);
float32x4_t n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 2);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_former, n_filter_v[0], 0);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext0, n_filter_v[0], 1);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext1, n_filter_v[0], 2);
float32x4_t n_row1_former = vld1q_f32(row_ptr_v[1]);
float32x4_t n_row1_latter = vld1q_f32(row_ptr_v[1] + kRegisterSize);
float32x4_t n_row1_ext0 =
vextq_f32(n_row1_former, n_row1_latter, 1);
float32x4_t n_row1_ext1 =
vextq_f32(n_row1_former, n_row1_latter, 2);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_former, n_filter_v[1], 0);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_ext0, n_filter_v[1], 1);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row1_ext1, n_filter_v[1], 2);
n_row_former = vld1q_f32(row_ptr_v[2]);
n_row_latter = vld1q_f32(row_ptr_v[2] + kRegisterSize);
n_row_ext0 = vextq_f32(n_row_former, n_row_latter, 1);
n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 2);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_former, n_filter_v[2], 0);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext0, n_filter_v[2], 1);
n_sum0 = vfmaq_laneq_f32(n_sum0, n_row_ext1, n_filter_v[2], 2);
// second row
float32x4_t n_sum1 = vdupq_n_f32(.0f);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_former, n_filter_v[0], 0);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext0, n_filter_v[0], 1);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext1, n_filter_v[0], 2);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row_former, n_filter_v[1], 0);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row_ext0, n_filter_v[1], 1);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row_ext1, n_filter_v[1], 2);
n_row1_former = vld1q_f32(row_ptr_v[3]);
n_row1_latter = vld1q_f32(row_ptr_v[3] + kRegisterSize);
n_row1_ext0 = vextq_f32(n_row1_former, n_row1_latter, 1);
n_row1_ext1 = vextq_f32(n_row1_former, n_row1_latter, 2);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_former, n_filter_v[2], 0);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext0, n_filter_v[2], 1);
n_sum1 = vfmaq_laneq_f32(n_sum1, n_row1_ext1, n_filter_v[2], 2);
float32x4_t n_output_row = vld1q_f32(output_ptr_v[0]);
float32x4_t n_output_row1 = vld1q_f32(output_ptr_v[1]);
n_output_row = vaddq_f32(n_output_row, n_sum0);
n_output_row1 = vaddq_f32(n_output_row1, n_sum1);
vst1q_f32(output_ptr_v[0], n_output_row);
vst1q_f32(output_ptr_v[1], n_output_row1);
output_ptr_v[0] += kRegisterSize;
output_ptr_v[1] += kRegisterSize;
for (int i = 0; i < kRegisterSize; ++i) {
row_ptr_v[i] += kRegisterSize;
}
}
for (; remain_count > 0; --remain_count) {
float32x4_t n_row_v[] = {vld1q_f32(row_ptr_v[0]),
vld1q_f32(row_ptr_v[1]),
vld1q_f32(row_ptr_v[2])};
float32x4_t n_sum0 = vmulq_f32(n_row_v[0], n_filter_v[0]);
n_sum0 = vmlaq_f32(n_sum0, n_row_v[1], n_filter_v[1]);
n_sum0 = vmlaq_f32(n_sum0, n_row_v[2], n_filter_v[2]);
n_sum0 = vsetq_lane_f32(*output_ptr_v[0], n_sum0, 3);
*output_ptr_v[0] = vaddvq_f32(n_sum0);
float32x4_t n_row3 = vld1q_f32(row_ptr_v[3]);
float32x4_t n_sum1 = vmulq_f32(n_row_v[1], n_filter_v[0]);
n_sum1 = vmlaq_f32(n_sum1, n_row_v[2], n_filter_v[1]);
n_sum1 = vmlaq_f32(n_sum1, n_row3, n_filter_v[2]);
n_sum1 = vsetq_lane_f32(*output_ptr_v[1], n_sum1, 3);
*output_ptr_v[1] = vaddvq_f32(n_sum1);
++output_ptr_v[0];
++output_ptr_v[1];
for (int i = 0; i < kRegisterSize; ++i) {
row_ptr_v[i] += 1;
}
}
output_ptr_v[0] += output_width;
output_ptr_v[1] += output_width;
for (int i = 0; i < kRegisterSize; ++i) {
row_ptr_v[i] += 2 + input_width;
}
}
if (output_height != height_count) {
int count = output_width >> 2;
int remain_count = output_width & 3;
for (; count > 0; --count) {
float32x4_t n_sum = vdupq_n_f32(.0f);
float32x4_t n_row_former = vld1q_f32(row_ptr_v[0]);
float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + kRegisterSize);
float32x4_t n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 1);
float32x4_t n_row_ext2 = vextq_f32(n_row_former, n_row_latter, 2);
n_sum = vfmaq_laneq_f32(n_sum, n_row_former, n_filter_v[0], 0);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext1, n_filter_v[0], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext2, n_filter_v[0], 2);
n_row_former = vld1q_f32(row_ptr_v[1]);
n_row_latter = vld1q_f32(row_ptr_v[1] + kRegisterSize);
n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 1);
n_row_ext2 = vextq_f32(n_row_former, n_row_latter, 2);
n_sum = vfmaq_laneq_f32(n_sum, n_row_former, n_filter_v[1], 0);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext1, n_filter_v[1], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext2, n_filter_v[1], 2);
n_row_former = vld1q_f32(row_ptr_v[2]);
n_row_latter = vld1q_f32(row_ptr_v[2] + kRegisterSize);
n_row_ext1 = vextq_f32(n_row_former, n_row_latter, 1);
n_row_ext2 = vextq_f32(n_row_former, n_row_latter, 2);
n_sum = vfmaq_laneq_f32(n_sum, n_row_former, n_filter_v[2], 0);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext1, n_filter_v[2], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext2, n_filter_v[2], 2);
float32x4_t n_output_row = vld1q_f32(output_ptr_v[0]);
n_output_row = vaddq_f32(n_output_row, n_sum);
vst1q_f32(output_ptr_v[0], n_output_row);
output_ptr_v[0] += kRegisterSize;
for (int i = 0; i < 3; ++i) {
row_ptr_v[i] += kRegisterSize;
}
}
for (; remain_count > 0; --remain_count) {
float32x4_t n_row_v[] = {
vld1q_f32(row_ptr_v[0]), vld1q_f32(row_ptr_v[1]),
vld1q_f32(row_ptr_v[2]),
};
float32x4_t n_sum = vmulq_f32(n_row_v[0], n_filter_v[0]);
n_sum = vmlaq_f32(n_sum, n_row_v[1], n_filter_v[1]);
n_sum = vmlaq_f32(n_sum, n_row_v[2], n_filter_v[2]);
n_sum = vsetq_lane_f32(*output_ptr_v[0], n_sum, 3);
*output_ptr_v[0] = vaddvq_f32(n_sum);
++output_ptr_v[0];
for (int i = 0; i < 3; ++i) {
row_ptr_v[i] += 1;
}
}
}
filter_ptr += kFilterSize;
input_ptr += input_height * input_width;
}
}
}
}
void Conv2dNeonK3x3S2(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
int tail_step = 2 * (input_shape[3] - output_shape[3]);
int output_batch = output_shape[0];
int output_channels = output_shape[1];
int output_height = output_shape[2];
int output_width = output_shape[3];
int input_batch = input_shape[0];
int input_channels = input_shape[1];
int input_height = input_shape[2];
int input_width = input_shape[3];
int multiplier = filter_shape == nullptr ? 0 : filter_shape[0];
int filter_in_channels = filter_shape == nullptr ? input_channels : 1;
#pragma omp parallel for collapse(2)
for (int b = 0; b < output_batch; ++b) {
for (int oc = 0; oc < output_channels; ++oc) {
float *output_ptr_base =
output + b * output_channels * output_height * output_width;
const float *filter_ptr = filter + oc * filter_in_channels * kFilterSize;
const float *input_ptr =
input + b * input_channels * input_height * input_width;
if (filter_shape != nullptr) {
input_ptr += (oc / multiplier) * input_height * input_width;
}
float *output_ptr = output_ptr_base + oc * output_height * output_width;
std::fill(output_ptr, output_ptr + output_height * output_width,
bias ? bias[oc] : 0);
for (int ic = 0; ic < filter_in_channels; ++ic) {
float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr),
vld1q_f32(filter_ptr + 3),
vld1q_f32(filter_ptr + 6)};
const float *row_ptr_v[3] = {input_ptr, input_ptr + input_width,
input_ptr + 2 * input_width};
float *output_ptr_inner = output_ptr;
for (int h = 0; h < output_height; ++h) {
int count = output_width >> 2;
int remain_count = output_width & 3;
for (; count > 0; --count) {
float32x4_t n_sum = vdupq_n_f32(.0f);
float32x4x2_t n_row_former = vld2q_f32(row_ptr_v[0]);
float32x4_t n_row_latter = vld1q_f32(row_ptr_v[0] + 8);
float32x4_t n_row_ext =
vextq_f32(n_row_former.val[0], n_row_latter, 1);
n_sum =
vfmaq_laneq_f32(n_sum, n_row_former.val[0], n_filter_v[0], 0);
n_sum =
vfmaq_laneq_f32(n_sum, n_row_former.val[1], n_filter_v[0], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row_ext, n_filter_v[0], 2);
float32x4x2_t n_row1_former = vld2q_f32(row_ptr_v[1]);
float32x4_t n_row1_latter = vld1q_f32(row_ptr_v[1] + 8);
float32x4_t n_row1_ext =
vextq_f32(n_row1_former.val[0], n_row1_latter, 1);
n_sum =
vfmaq_laneq_f32(n_sum, n_row1_former.val[0], n_filter_v[1], 0);
n_sum =
vfmaq_laneq_f32(n_sum, n_row1_former.val[1], n_filter_v[1], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row1_ext, n_filter_v[1], 2);
float32x4x2_t n_row2_former = vld2q_f32(row_ptr_v[2]);
float32x4_t n_row2_latter = vld1q_f32(row_ptr_v[2] + 8);
float32x4_t n_row2_ext =
vextq_f32(n_row2_former.val[0], n_row2_latter, 1);
n_sum =
vfmaq_laneq_f32(n_sum, n_row2_former.val[0], n_filter_v[2], 0);
n_sum =
vfmaq_laneq_f32(n_sum, n_row2_former.val[1], n_filter_v[2], 1);
n_sum = vfmaq_laneq_f32(n_sum, n_row2_ext, n_filter_v[2], 2);
float32x4_t n_output_row = vld1q_f32(output_ptr_inner);
n_output_row = vaddq_f32(n_output_row, n_sum);
vst1q_f32(output_ptr_inner, n_output_row);
output_ptr_inner += kRegisterSize;
for (int i = 0; i < 3; ++i) {
row_ptr_v[i] += 2 * kRegisterSize;
}
}
for (; remain_count > 0; --remain_count) {
float32x4_t n_row_v[] = {vld1q_f32(row_ptr_v[0]),
vld1q_f32(row_ptr_v[1]),
vld1q_f32(row_ptr_v[2])};
float32x4_t n_sum = vmulq_f32(n_row_v[0], n_filter_v[0]);
n_sum = vmlaq_f32(n_sum, n_row_v[1], n_filter_v[1]);
n_sum = vmlaq_f32(n_sum, n_row_v[2], n_filter_v[2]);
n_sum = vsetq_lane_f32(*output_ptr_inner, n_sum, 3);
*output_ptr_inner = vaddvq_f32(n_sum);
++output_ptr_inner;
for (int i = 0; i < 3; ++i) {
row_ptr_v[i] += 2;
}
}
for (int i = 0; i < 3; ++i) {
row_ptr_v[i] += tail_step;
}
}
filter_ptr += kFilterSize;
input_ptr += input_height * input_width;
}
}
}
}
} // namespace kernels
} // namespace mace
此差异已折叠。
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/depthwise_conv2d.h"
namespace mace {
namespace kernels {
extern void Conv2dNeonK3x3S1(const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
extern void Conv2dNeonK3x3S2(const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
template <>
void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
typedef void (*Conv2dNeonFunction)(
const float *input, const index_t *input_shape, const float *filter,
const index_t *filter_shape, const float *bias, float *output,
const index_t *output_shape);
// Selection matrix: kernel_size x stride_size
static const Conv2dNeonFunction selector[5][2] = {
{nullptr, nullptr},
{nullptr, nullptr},
{Conv2dNeonK3x3S1, Conv2dNeonK3x3S2},
{nullptr, nullptr},
{nullptr, nullptr}};
// not implement yet
index_t kernel_h = filter->dim(2);
index_t kernel_w = filter->dim(3);
if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
LOG(WARNING) << "Depthwise-Conv2d NEON kernel with "
<< "filter" << kernel_h << "x" << kernel_w << ","
<< " stride " << strides_[0] << "x" << strides_[1]
<< " is not implemented yet, using slow version";
DepthwiseConv2dFunctor<DeviceType::CPU, float>(
strides_, paddings_, dilations_)(input, filter, bias, output, future);
return;
}
const float *input_ptr = input->data<float>();
const index_t *input_shape = input->shape().data();
const float *filter_ptr = filter->data<float>();
const index_t *filter_shape = filter->shape().data();
const float *bias_ptr = bias->data<float>();
float *output_ptr = output->mutable_data<float>();
const index_t *output_shape = output->shape().data();
// Keep this alive during kernel execution
Tensor padded_input;
if (paddings_[0] > 0 || paddings_[1] > 0) {
ConstructInputWithPadding(input, paddings_.data(), &padded_input);
input_ptr = padded_input.data<float>();
input_shape = padded_input.shape().data();
}
auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_neon_func(input_ptr, input_shape, filter_ptr, filter_shape, bias_ptr,
output_ptr, output_shape);
}
} // namespace kernels
} // namespace mace
......@@ -166,8 +166,20 @@ struct PoolingFunctor : PoolingFunctorBase {
};
template <>
void PoolingFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input_tensor, Tensor *output_tensor, StatsFuture *future);
struct PoolingFunctor<DeviceType::NEON, float> : PoolingFunctorBase {
PoolingFunctor(const PoolingType pooling_type,
const int *kernels,
const int *strides,
const Padding padding_type,
const std::vector<int> &paddings,
const int *dilations)
: PoolingFunctorBase(
pooling_type, kernels, strides, padding_type, paddings, dilations) {
}
void operator()(const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future);
};
template <typename T>
struct PoolingFunctor<DeviceType::OPENCL, T> : PoolingFunctorBase {
......
......@@ -56,6 +56,11 @@ struct SoftmaxFunctor {
}
};
template <>
struct SoftmaxFunctor<DeviceType::NEON, float> {
void operator()(const Tensor *logits, Tensor *output, StatsFuture *future);
};
template <typename T>
struct SoftmaxFunctor<DeviceType::OPENCL, T> {
void operator()(const Tensor *logits, Tensor *output, StatsFuture *future);
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_TRANSPOSE_H_
#define MACE_KERNELS_TRANSPOSE_H_
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
template<DeviceType D, typename T>
struct TransposeFunctor {
explicit TransposeFunctor(const std::vector<int> &dims) : dims_(dims) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const std::vector<index_t> &input_shape = input->shape();
const std::vector<index_t> &output_shape = output->shape();
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
std::vector<index_t>
in_stride{input_shape[1] * input_shape[2] * input_shape[3],
input_shape[2] * input_shape[3], input_shape[3], 1};
std::vector<index_t>
out_stride{output_shape[1] * output_shape[2] * output_shape[3],
output_shape[2] * output_shape[3], output_shape[3], 1};
std::vector<index_t> idim(4, 0);
std::vector<index_t> odim(4, 0);
for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) {
for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) {
for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) {
for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) {
idim[dims_[0]] = odim[0];
idim[dims_[1]] = odim[1];
idim[dims_[2]] = odim[2];
idim[dims_[3]] = odim[3];
output_data[odim[0] * out_stride[0] + odim[1] * out_stride[1]
+ odim[2] * out_stride[2] + odim[3]] =
input_data[idim[0] * in_stride[0] + idim[1] * in_stride[1]
+ idim[2] * in_stride[2] + idim[3]];
}
}
}
}
}
std::vector<int> dims_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_TRANSPOSE_H_
......@@ -25,6 +25,11 @@ void Register_Activation(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
ActivationOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Activation")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
ActivationOp<DeviceType::NEON, float>);
}
} // namespace ops
......
......@@ -25,6 +25,11 @@ void Register_BatchNorm(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
BatchNormOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchNorm")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
BatchNormOp<DeviceType::NEON, float>);
}
} // namespace ops
......
此差异已折叠。
......@@ -25,6 +25,12 @@ void Register_Conv2D(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
Conv2dOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Conv2D")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
Conv2dOp<DeviceType::NEON, float>);
}
} // namespace ops
......
......@@ -30,10 +30,19 @@ static void Conv2d(int iters,
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
net.AddRandomInput<D, float>("Filter",
{kernel_h, kernel_w, output_channels, channels});
net.AddRandomInput<D, float>("Bias", {output_channels});
if (D == DeviceType::NEON) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
net.AddRandomInput<D, float>("Filter",
{output_channels, channels, kernel_h,
kernel_w});
net.AddRandomInput<D, float>("Bias", {output_channels});
} else {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
net.AddRandomInput<D, float>("Filter",
{kernel_h, kernel_w, output_channels,
channels});
net.AddRandomInput<D, float>("Bias", {output_channels});
}
if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage",
......@@ -65,15 +74,17 @@ static void Conv2d(int iters,
.Finalize(net.NewOperatorDef());
}
net.Setup(D);
// Warm-up
for (int i = 0; i < 2; ++i) {
net.RunOp(D);
net.Run();
net.Sync();
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Run();
net.Sync();
}
}
......@@ -112,7 +123,8 @@ static void Conv2d(int iters,
#define BM_CONV_2D(N, C, H, W, KH, KW, S, D, P, OC) \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, CPU); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, OPENCL); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, half, OPENCL);
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, half, OPENCL); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, NEON);
BM_CONV_2D(1, 256, 64, 64, 3, 3, 1, 1, VALID, 256);
......@@ -133,6 +145,8 @@ BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, 1, VALID, 128);
BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, 1, VALID, 128); // Test bad alignments
BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, 1, SAME, 128);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, 1, SAME, 128);
BM_CONV_2D(1, 3, 224, 224, 3, 3, 2, 1, SAME, 32);
BM_CONV_2D(1, 3, 224, 224, 3, 3, 2, 1, VALID, 32);
BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, 1, SAME, 128);
BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, 1, SAME, 128);
......
此差异已折叠。
......@@ -25,6 +25,12 @@ void Register_DepthwiseConv2d(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
DepthwiseConv2dOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthwiseConv2d")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
DepthwiseConv2dOp<DeviceType::NEON, float>);
}
} // namespace ops
......
......@@ -29,10 +29,19 @@ static void DepthwiseConv2d(int iters,
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, float>(
if (D == DeviceType::NEON) {
net.AddRandomInput<D, float>("Input",
{batch, input_channels, height, width});
net.AddRandomInput<D, float>(
"Filter", {multiplier, input_channels, kernel_h, kernel_w});
net.AddRandomInput<D, float>("Bias", {input_channels * multiplier});
} else {
net.AddRandomInput<D, float>("Input",
{batch, height, width, input_channels});
net.AddRandomInput<D, float>(
"Filter", {kernel_h, kernel_w, input_channels, multiplier});
net.AddRandomInput<D, float>("Bias", {input_channels * multiplier});
net.AddRandomInput<D, float>("Bias", {input_channels * multiplier});
}
if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage",
......@@ -64,15 +73,17 @@ static void DepthwiseConv2d(int iters,
.Finalize(net.NewOperatorDef());
}
net.Setup(D);
// Warm-up
for (int i = 0; i < 2; ++i) {
net.RunOp(D);
net.Run();
net.Sync();
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
net.Run();
net.Sync();
}
}
......@@ -108,10 +119,16 @@ static void DepthwiseConv2d(int iters,
#define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, M) \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, CPU); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, OPENCL); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, half, OPENCL);
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, half, OPENCL); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, NEON);
BM_DEPTHWISE_CONV_2D(1, 32, 112, 112, 3, 3, 1, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 32, 112, 112, 3, 3, 2, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 32, 56, 56, 3, 3, 2, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 32, 112, 112, 3, 3, 2, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 32, 224, 224, 3, 3, 2, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 56, 56, 3, 3, 2, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 112, 112, 3, 3, 2, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 224, 224, 3, 3, 2, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 1);
......@@ -124,6 +141,10 @@ BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 3, 112, 112, 3, 3, 2, VALID, 1);
BM_DEPTHWISE_CONV_2D(1, 3, 224, 224, 3, 3, 2, SAME, 1);
BM_DEPTHWISE_CONV_2D(1, 8, 224, 224, 3, 3, 2, SAME, 1);
} // namespace test
} // namespace ops
......
此差异已折叠。
......@@ -25,6 +25,11 @@ void Register_FoldedBatchNorm(OperatorRegistry *op_registry) {
.TypeConstraint<half>("T")
.Build(),
FoldedBatchNormOp<DeviceType::OPENCL, half>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
FoldedBatchNormOp<DeviceType::NEON, float>);
}
} // namespace ops
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册