提交 ff39c7a5 编写于 作者: L Liangliang He

Add conv2d k1x1s1 NEON kernel

上级 b9492125
......@@ -31,7 +31,7 @@ class Allocator {
template <typename T>
T* New(size_t num_elements) {
if (num_elements > (std::numeric_limits<size_t>::max() / sizeof(T))) {
return NULL;
return nullptr;
}
void* p = New(sizeof(T) * num_elements);
T* typed_p = reinterpret_cast<T*>(p);
......
......@@ -106,16 +106,27 @@ class LogMessageFatal : public LogMessage {
if (VLOG_IS_ON(lvl)) \
::mace::internal::LogMessage(__FILE__, __LINE__, mace::INFO)
// MACE_CHECK dies with a fatal error if condition is not true. It is *not*
// controlled by NDEBUG, so the check will be executed regardless of
// compilation mode. Therefore, it is safe to do things like:
// MACE_CHECK/MACE_ASSERT dies with a fatal error if condition is not true.
// MACE_ASSERT is controlled by NDEBUG ('-c opt' for bazel) while MACE_CHECK
// will be executed regardless of compilation mode.
// Therefore, it is safe to do things like:
// MACE_CHECK(fp->Write(x) == 4)
// MACE_CHECK(fp->Write(x) == 4, "Write failed")
// which are not correct for MACE_ASSERT.
#define MACE_CHECK(condition, ...) \
if (!(condition)) \
LOG(FATAL) << "Check failed: " #condition " " \
<< ::mace::internal::MakeString(__VA_ARGS__)
#ifndef NDEBUG
#define MACE_ASSERT(condition, ...) \
if (!(condition)) \
LOG(FATAL) << "Assert failed: " #condition " " \
<< ::mace::internal::MakeString(__VA_ARGS__)
#else
#define MACE_ASSERT(condition, ...) ((void)0)
#endif
template <typename T>
T&& CheckNotNull(const char* file, int line, const char* exprtext, T&& t) {
if (t == nullptr) {
......
......@@ -6,7 +6,9 @@
#include <cstdlib>
#include <algorithm>
#include <regex>
#include <vector>
#include "mace/core/logging.h"
#include "mace/core/testing/env_time.h"
#include "mace/core/testing/test_benchmark.h"
......@@ -52,12 +54,23 @@ Benchmark* Benchmark::ArgPair(int x, int y) {
// Run all benchmarks
void Benchmark::Run() {
Run("all");
}
void Benchmark::Run(const char* pattern) {
if (!all_benchmarks) return;
if (std::string(pattern) == "all") {
pattern = ".*";
}
std::regex regex(pattern);
// Compute name width.
int width = 10;
char name[100];
std::smatch match;
for (auto b : *all_benchmarks) {
if (!std::regex_match(b->name_, match, regex)) continue;
for (auto arg : b->args_) {
strcpy(name, b->name_.c_str());
if (arg.first >= 0) {
......@@ -74,7 +87,7 @@ void Benchmark::Run() {
printf("%-*s %10s %10s\n", width, "Benchmark", "Time(ns)", "Iterations");
printf("%s\n", string(width + 22, '-').c_str());
for (auto b : *all_benchmarks) {
if (!std::regex_match(b->name_, match, regex)) continue;
for (auto arg : b->args_) {
strcpy(name, b->name_.c_str());
if (arg.first >= 0) {
......
......@@ -28,6 +28,7 @@ class Benchmark {
Benchmark* ArgPair(int x, int y);
static void Run();
static void Run(const char* pattern);
private:
string name_;
......
......@@ -9,7 +9,12 @@
int main(int argc, char** argv) {
std::cout << "Running main() from test_main.cc\n";
mace::testing::Benchmark::Run();
// TODO Use gflags
if (argc == 2) {
mace::testing::Benchmark::Run(argv[1]);
} else {
mace::testing::Benchmark::Run("all");
}
return 0;
}
......@@ -108,14 +108,14 @@ class Conv2dFunctor {
const int* dilations_; // [dilation_h, dilation_w]
};
template<>
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const index_t* filter_shape,
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape);
template <>
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input,
const index_t* input_shape,
const float* filter,
const index_t* filter_shape,
const float* bias,
float* output,
const index_t* output_shape);
} // namespace kernels
} // namespace mace
......
......@@ -9,39 +9,48 @@
namespace mace {
namespace kernels {
static inline void ConstructInputWithPadding(const float* input, const index_t* input_shape,
static inline void ConstructInputWithPadding(const float* input,
const index_t* input_shape,
const int* paddings,
Tensor& output_tensor,
std::vector<index_t>& output_shape) {
Tensor* output_tensor) {
index_t batch = input_shape[0];
index_t channels = input_shape[1];
index_t height = input_shape[2];
index_t width = input_shape[3];
output_shape[0] = batch;
output_shape[1] = channels;
output_shape[2] = paddings[0] + height;
output_shape[3] = paddings[1] + width;
index_t output_width = output_shape[3];
int padded_left = paddings[1] / 2;
std::vector<index_t> output_shape({batch,
channels,
paddings[0] + height,
paddings[1] + width});
output_tensor.Resize(output_shape);
float* output_ptr = output_tensor.mutable_data<float>();
memset(output_ptr, 0, output_tensor.size() * sizeof(float));
output_ptr += paddings[0] / 2 * output_width;
const index_t output_width = output_shape[3];
const int padded_top = paddings[0] / 2;
const int padded_left = paddings[1] / 2;
output_tensor->Resize(output_shape);
float* output_ptr = output_tensor->mutable_data<float>();
memset(output_ptr, 0, output_tensor->size() * sizeof(float));
// Skip the padded top rows
output_ptr += padded_top * output_width;
for (; batch > 0; --batch) {
for (; channels > 0; --channels) {
for(; height > 0; --height) {
memcpy(output_ptr+padded_left, input, width*sizeof(float));
memcpy(output_ptr + padded_left, input, width * sizeof(float));
input += width;
output_ptr += output_width;
}
// Skip the padded bottom in this channel and top in the next channel
output_ptr += paddings[0] * output_width;
}
}
}
extern void Conv2dNeonK1x1S1(const float* input, const index_t* input_shape,
const float* filter, const float* bias,
float* output, const index_t* output_shape);
template<>
void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // NCHW
const index_t* input_shape,
......@@ -57,9 +66,10 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // N
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape);
// Selection matrix: kernel_size x stride_size
static const Conv2dNeonFunction selector[5][2] = {
{
nullptr,
Conv2dNeonK1x1S1,
nullptr
},
{
......@@ -80,10 +90,13 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // N
}
};
// not implement yet
if (paddings_[0] != paddings_[1] || paddings_[0] > 5 ||
strides_[0] != strides_[1] || strides_[0] > 4 ||
dilations_[0] != 1 || dilations_[1] != 1 ||
selector[paddings_[0]-1][strides_[0]-1] == nullptr) {
index_t kernel_h = filter_shape[2];
index_t kernel_w = filter_shape[3];
if (kernel_h != kernel_w || kernel_h > 5 ||
strides_[0] != strides_[1] || strides_[0] > 2 ||
dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
LOG(WARNING) << "NEON conv2d kernel not implementated, using slow vesion";
Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)(
input,
input_shape,
......@@ -94,19 +107,22 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const float* input, // N
output_shape
);
}
// Keep this alive during kernel execution
Tensor padded_input;
std::vector<index_t> padded_input_shape(4);
ConstructInputWithPadding(input, input_shape, paddings_, padded_input, padded_input_shape);
auto conv2d_neon_func = selector[paddings_[0] - 1][strides_[0] - 1];
conv2d_neon_func(
padded_input.data<float>(),
padded_input_shape.data(),
filter,
bias,
output,
output_shape
);
if (paddings_[0] > 0 || paddings_[1] > 0) {
ConstructInputWithPadding(input, input_shape, paddings_, &padded_input);
input = padded_input.data<float>();
input_shape = padded_input.shape().data();
}
auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_neon_func(input,
input_shape,
filter,
bias,
output,
output_shape);
}
} // namespace kernels
} // namespace mace
\ No newline at end of file
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <arm_neon.h>
#include "mace/kernels/conv_2d.h"
namespace mace {
namespace kernels {
void Conv2dNeonK1x1S1(const float* input, // NCHW
const index_t* input_shape,
const float* filter, // c_out, c_in, kernel_h, kernel_w
const float* bias, // c_out
float* output, // NCHW
const index_t* output_shape) {
const index_t batch = output_shape[0];
const index_t channels = output_shape[1];
const index_t height = output_shape[2];
const index_t width = output_shape[3];
const index_t input_batch = input_shape[0];
const index_t input_channels = input_shape[1];
const index_t input_height = input_shape[2];
const index_t input_width = input_shape[3];
MACE_CHECK(input_batch == batch &&
input_height == height &&
input_width == width);
const index_t total_pixels = height * width;
// Process 4 * 2 = 8 pixels for each innermost loop
// TODO Does 64 bit v.s. 32 bit index matters? need benchmark
const index_t total_loops = total_pixels >> 3;
const index_t loop_remaining = total_pixels & 7;
// benchmark omp collapsed(2)
for (index_t n = 0; n < batch; ++n) {
const float* filter_ptr = filter;
#pragma omp parallel for
for (index_t c = 0; c < channels; ++c) {
// TODO Will GCC opt these out?
float* channel_output_start =
output + n * channels * height * width + c * height * width;
const float* input_ptr = input + n * input_channels * input_height * input_width;
// Fill with bias
float* output_ptr = channel_output_start;
for (index_t ptr = 0; ptr < total_pixels; ++ptr) {
output_ptr[ptr] = bias[c]; // TODO can we avoid this?
}
index_t inc = 0;
// Process 4 input channels in batch
for (; inc + 3 < input_channels; inc += 4) {
float* output_ptr = channel_output_start;
// The begining of each input feature map channel
MACE_ASSERT(input_ptr == input + n * input_channels *
input_height * input_width +
inc * input_height * input_width);
const float* input_ptr1 = input_ptr + total_pixels;
const float* input_ptr2 = input_ptr1 + total_pixels;
const float* input_ptr3 = input_ptr2 + total_pixels;
// filter is in c_out, c_in, 1, 1 order
MACE_ASSERT(filter_ptr == filter + c * input_channels + inc);
const float k0 = filter_ptr[0];
const float k1 = filter_ptr[1];
const float k2 = filter_ptr[2];
const float k3 = filter_ptr[3];
filter_ptr += 4;
const float32x4_t vk0 = vdupq_n_f32(k0);
const float32x4_t vk1 = vdupq_n_f32(k1);
const float32x4_t vk2 = vdupq_n_f32(k2);
const float32x4_t vk3 = vdupq_n_f32(k3);
index_t loop_itr = total_loops;
for (; loop_itr > 0; --loop_itr) {
// Process 2 group of 4 floats
float32x4_t out0 = vld1q_f32(output_ptr);
float32x4_t out4 = vld1q_f32(output_ptr + 4);
const float32x4_t in00 = vld1q_f32(input_ptr);
const float32x4_t in04 = vld1q_f32(input_ptr + 4);
out0 = vfmaq_f32(out0, in00, vk0);
out4 = vfmaq_f32(out4, in04, vk0);
const float32x4_t in10 = vld1q_f32(input_ptr1);
const float32x4_t in14 = vld1q_f32(input_ptr1 + 4);
out0 = vfmaq_f32(out0, in10, vk1);
out4 = vfmaq_f32(out4, in14, vk1);
const float32x4_t in20 = vld1q_f32(input_ptr2);
const float32x4_t in24 = vld1q_f32(input_ptr2 + 4);
out0 = vfmaq_f32(out0, in20, vk2);
out4 = vfmaq_f32(out4, in24, vk2);
const float32x4_t in30 = vld1q_f32(input_ptr3);
const float32x4_t in34 = vld1q_f32(input_ptr3 + 4);
out0 = vfmaq_f32(out0, in30, vk3);
out4 = vfmaq_f32(out4, in34, vk3);
float prev_output = output_ptr[0];
// Save output
vst1q_f32(output_ptr, out0);
vst1q_f32(output_ptr + 4, out4);
output_ptr += 8;
input_ptr += 8;
input_ptr1 += 8;
input_ptr2 += 8;
input_ptr3 += 8;
}
// Process the remaining pixels
index_t remaining_pixels = loop_remaining;
for (; remaining_pixels > 0; --remaining_pixels) {
const float mul = *input_ptr * k0;
const float mul1 = *input_ptr1 * k1;
const float mul2 = *input_ptr2 * k2;
const float mul3 = *input_ptr3 * k3;
float prev_output = output_ptr[0];
*output_ptr += mul + mul1 + mul2 + mul3;
++output_ptr;
++input_ptr;
++input_ptr1;
++input_ptr2;
++input_ptr3;
}
// Skip these 4 feature maps
input_ptr += 3 * total_pixels;
}
// Process the remaining channels
for (; inc < input_channels; ++inc) {
float* output_ptr = channel_output_start;
MACE_ASSERT(input_ptr == input + n * input_channels *
input_height * input_width +
inc * input_height * input_width);
MACE_ASSERT(filter_ptr == filter + c * input_channels + inc);
const float k0 = filter_ptr[0];
++filter_ptr;
const float32x4_t vk0 = vdupq_n_f32(k0);
index_t loop_itr = total_loops;
for (; loop_itr > 0; --loop_itr) {
float32x4_t out0 = vld1q_f32(output_ptr);
float32x4_t out4 = vld1q_f32(output_ptr + 4);
const float32x4_t in0 = vld1q_f32(input_ptr);
const float32x4_t in4 = vld1q_f32(input_ptr + 4);
out0 = vfmaq_f32(out0, in0, vk0);
out4 = vfmaq_f32(out4, in4, vk0);
// Save output
vst1q_f32(output_ptr, out0);
vst1q_f32(output_ptr + 4, out4);
output_ptr += 8;
input_ptr += 8;
}
// Process the remaining pixels
index_t remaining_pixels = loop_remaining;
for (; remaining_pixels > 0; --remaining_pixels) {
const float mul = *input_ptr * k0;
*output_ptr += mul;
++output_ptr;
++input_ptr;
}
}
}
}
};
} // namespace kernels
} // namespace mace
......@@ -25,7 +25,7 @@ cc_library(
name = "ops",
srcs = glob(
["*.cc"],
exclude = ["*_test.cc"],
exclude = ["*_test.cc", "*_benchmark.cc"],
),
hdrs = glob(
["*.h"],
......@@ -46,11 +46,6 @@ cc_test(
["*_test.cc"],
),
copts = ["-std=c++11"],
linkopts = if_android([
"-pie",
"-llog",
"-latomic",
]),
linkstatic = 1,
deps = [
":ops",
......@@ -58,3 +53,16 @@ cc_test(
"@gtest//:gtest_main",
],
)
cc_test(
name = "ops_benchmark",
srcs = glob(["*_benchmark.cc"]),
deps = [
":ops",
"//mace/core:core",
"//mace/core:test_benchmark_main",
],
copts = ['-std=c++11'],
linkstatic = 1,
testonly = 1,
)
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/conv_2d.h"
namespace mace {
template <DeviceType D, typename T>
static void Conv2d(int iters, int batch, int channels, int height, int width,
int kernel_h, int kernel_w, int stride,
Padding padding, int output_channels) {
mace::testing::StopTiming();
mace::testing::StartTiming();
while(iters--) {
}
}
#define BM_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, DEVICE) \
static void BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \
Conv2d<DEVICE, TYPE>(iters, N, C, H, W, KH, KW, STRIDE, mace::Padding::P, OC); \
} \
BENCHMARK(BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_OC##_##TYPE##_##DEVICE)
#define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON);
BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float);
} // namespace mace
......@@ -136,4 +136,55 @@ TEST_F(Conv2dOpTest, Combined) {
ExpectTensorNear<float>(expected, *GetOutput("Output"), 0.001);
}
TEST_F(Conv2dOpTest, Conv1x1) {
// Construct graph
OpDefBuilder("Conv2d", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.Finalize(operator_def());
// Add args
AddIntsArg("strides", {1, 1});
AddIntArg("padding", Padding::VALID);
AddIntsArg("dilations", {1, 1});
// Add input data
AddInputFromArray<float>("Input", {1, 5, 3, 10},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
AddInputFromArray<float>("Filter", {2, 5, 1, 1},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
2.0f, 2.0f, 2.0f, 2.0f, 2.0f});
AddInputFromArray<float>("Bias", {2}, {0.1f, 0.2f});
// Run
RunOp(DeviceType::NEON);
// Check
Tensor expected = CreateTensor<float>({1, 2, 3, 10},
{5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f,
5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f,
5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f, 5.1f,
10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f,
10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f,
10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f, 10.2f});
ExpectTensorNear<float>(expected, *GetOutput("Output"), 0.001);
}
// TODO we need more tests
......@@ -9,8 +9,8 @@
#include "gtest/gtest.h"
#include "mace/core/common.h"
#include "mace/core/tensor.h"
#include "mace/core/net.h"
#include "mace/core/tensor.h"
namespace mace {
......@@ -29,7 +29,7 @@ class OpDefBuilder {
return *this;
}
void Finalize(OperatorDef* op_def) const {
MACE_CHECK(op_def != NULL, "input should not be null.");
MACE_CHECK(op_def != nullptr, "input should not be null.");
*op_def = op_def_;
}
OperatorDef op_def_;
......@@ -49,6 +49,7 @@ class OpsTestBase : public ::testing::Test {
Tensor* input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v());
input->Resize(shape);
float* input_data = input->mutable_data<float>();
// TODO check the dims
memcpy(input_data, data.data(), data.size() * sizeof(T));
}
......@@ -96,14 +97,18 @@ class OpsTestBase : public ::testing::Test {
OperatorDef* operator_def() { return &op_def_; }
bool RunOp() {
bool RunOp(DeviceType device) {
NetDef net_def;
net_def.add_op()->CopyFrom(op_def_);
VLOG(0) << net_def.DebugString();
auto net = CreateNet(net_def, &ws_, DeviceType::CPU);
auto net = CreateNet(net_def, &ws_, device);
return net->Run();
}
bool RunOp() {
return RunOp(DeviceType::CPU);
}
Tensor* GetOutput(const char* output_name) {
return ws_.GetTensor(output_name);
}
......@@ -209,6 +214,6 @@ void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) {
Expector<T>::Near(x, y ,abs_err);
}
} // namespace mace
} // namespace mace
#endif // MACE_OPS_TEST_UTIL_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册