提交 a9fa945d 编写于 作者: L liuqi

Optimize conv1x1 with 2x4 block kernel.

上级 8c00b57e
......@@ -6,26 +6,26 @@ cc_binary(
srcs = [
"helloworld.cc",
],
copts = ["-std=c++11"],
linkopts = ["-fopenmp"] + if_android(["-ldl"]),
deps = [
"//mace/core",
"//mace/ops",
"@org_tensorflow//tensorflow/core:android_tensorflow_lib",
],
copts = ["-std=c++11"],
linkopts = ["-fopenmp",] + if_android(["-ldl"]),
)
cc_test(
name = "benchmark_example",
testonly = 1,
srcs = ["benchmark_example.cc"],
copts = ["-std=c++11"],
linkopts = ["-fopenmp"] + if_android(["-ldl"]),
linkstatic = 1,
deps = [
"//mace/core",
"//mace/core:test_benchmark_main",
],
copts = ["-std=c++11"],
linkopts = ["-fopenmp",] + if_android(["-ldl"]),
linkstatic = 1,
testonly = 1,
)
cc_binary(
......@@ -33,12 +33,12 @@ cc_binary(
srcs = [
"mace_run.cc",
],
copts = ["-std=c++11"],
linkopts = ["-fopenmp"] + if_android(["-ldl"]),
linkstatic = 1,
deps = [
"//mace/core",
"//mace/utils",
"//mace/ops",
"//mace/utils:command_line_flags",
],
copts = ["-std=c++11",],
linkopts = ["-fopenmp",] + if_android(["-ldl"]),
linkstatic = 1,
)
......@@ -5,7 +5,6 @@ package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
load("//mace:mace.bzl", "if_android")
......@@ -14,36 +13,40 @@ cc_library(
name = "kernels",
srcs = glob(["*.cc"]) + if_android(glob(["neon/*.cc"])),
hdrs = glob(["*.h"]) + if_android(glob(["neon/*.h"])),
deps = [
"//mace/core:core",
copts = [
"-std=c++11",
"-fopenmp",
],
copts = ['-std=c++11', "-fopenmp",],
linkopts = if_android(["-lm"]),
deps = [
"//mace/core",
"//mace/utils:utils",
],
)
cc_test(
name = "kernel_test",
testonly = 1,
srcs = glob(["test/*.cc"]),
copts = ["-std=c++11"],
linkopts = if_android(["-pie"]),
linkstatic = 1,
deps = [
"@gtest//:gtest_main",
":kernels",
"//mace/core:core",
"//mace/core",
"@gtest//:gtest_main",
],
copts = ['-std=c++11'],
linkopts = if_android(["-pie"]),
linkstatic = 1,
testonly = 1,
)
cc_test(
name = "benchmark",
testonly = 1,
srcs = glob(["benchmark/*.cc"]),
copts = ["-std=c++11"],
linkstatic = 1,
deps = [
":kernels",
"//mace/core:core",
"//mace/core",
"//mace/core:test_benchmark_main",
],
copts = ['-std=c++11'],
linkstatic = 1,
testonly = 1,
)
......@@ -19,7 +19,7 @@ struct ConcatFunctor {
T *output) {
const size_t input_count = input_list.size();
for (int inner_idx = 0; inner_idx < inner_dim; ++inner_idx) {
for (int i = 0; i < input_count; ++i) {
for (size_t i = 0; i < input_count; ++i) {
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
memcpy(output, input_list[i], outer_dims[i] * sizeof(T));
output += outer_dims[i];
......
......@@ -4,11 +4,297 @@
#include <arm_neon.h>
#include "mace/core/common.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
static constexpr index_t kInputChannelBlockSize = 2;
static constexpr index_t kOutputChannelBlockSize = 4;
static __attribute__((__aligned__(64))) int32_t mask_array[8] = {
0, 0, 0, 0, -1, -1, -1, -1
};
static inline void NeonConv2x4Kernel(index_t input_channels,
index_t pixel_size,
const float *input,
const float *filter,
float *output) {
const float *input0 = input;
const float *input1 = input + pixel_size;
const float32x2_t vfilter0x = vld1_f32(filter);
filter += input_channels;
const float32x2_t vfilter1x = vld1_f32(filter);
filter += input_channels;
const float32x2_t vfilter2x = vld1_f32(filter);
filter += input_channels;
const float32x2_t vfilter3x = vld1_f32(filter);
float *output0 = output;
float *output1 = output0 + pixel_size;
float *output2 = output1 + pixel_size;
float *output3 = output2 + pixel_size;
while (pixel_size >= 4) {
float32x4_t voutput0 = vld1q_f32(output0);
float32x4_t voutput1 = vld1q_f32(output1);
float32x4_t voutput2 = vld1q_f32(output2);
float32x4_t voutput3 = vld1q_f32(output3);
const float32x4_t vinput0 = vld1q_f32(input0);
input0 += 4;
voutput0 = vfmaq_lane_f32(voutput0, vinput0, vfilter0x, 0);
voutput1 = vfmaq_lane_f32(voutput1, vinput0, vfilter1x, 0);
voutput2 = vfmaq_lane_f32(voutput2, vinput0, vfilter2x, 0);
voutput3 = vfmaq_lane_f32(voutput3, vinput0, vfilter3x, 0);
const float32x4_t vinput1 = vld1q_f32(input1);
input1 += 4;
voutput0 = vfmaq_lane_f32(voutput0, vinput1, vfilter0x, 1);
voutput1 = vfmaq_lane_f32(voutput1, vinput1, vfilter1x, 1);
voutput2 = vfmaq_lane_f32(voutput2, vinput1, vfilter2x, 1);
voutput3 = vfmaq_lane_f32(voutput3, vinput1, vfilter3x, 1);
vst1q_f32(output0, voutput0);
output0 += 4;
vst1q_f32(output1, voutput1);
output1 += 4;
vst1q_f32(output2, voutput2);
output2 += 4;
vst1q_f32(output3, voutput3);
output3 += 4;
pixel_size -= 4;
}
if (pixel_size != 0) {
const int32x4_t vmask = vld1q_s32(&mask_array[pixel_size]);
output0 = output0 + pixel_size - 4;
float32x4_t voutput0 = vld1q_f32(output0);
output1 = output1 + pixel_size - 4;
float32x4_t voutput1 = vld1q_f32(output1);
output2 = output2 + pixel_size - 4;
float32x4_t voutput2 = vld1q_f32(output2);
output3 = output3 + pixel_size - 4;
float32x4_t voutput3 = vld1q_f32(output3);
const float32x4_t vinput0 = vreinterpretq_f32_s32(
vandq_s32(vmask, vreinterpretq_s32_f32(vld1q_f32(&input0[pixel_size - 4]))));
voutput0 = vfmaq_lane_f32(voutput0, vinput0, vfilter0x, 0);
voutput1 = vfmaq_lane_f32(voutput1, vinput0, vfilter1x, 0);
voutput2 = vfmaq_lane_f32(voutput2, vinput0, vfilter2x, 0);
voutput3 = vfmaq_lane_f32(voutput3, vinput0, vfilter3x, 0);
const float32x4_t vinput1 = vreinterpretq_f32_s32(
vandq_s32(vmask, vreinterpretq_s32_f32(vld1q_f32(&input1[pixel_size - 4]))));
voutput0 = vfmaq_lane_f32(voutput0, vinput1, vfilter0x, 1);
voutput1 = vfmaq_lane_f32(voutput1, vinput1, vfilter1x, 1);
voutput2 = vfmaq_lane_f32(voutput2, vinput1, vfilter2x, 1);
voutput3 = vfmaq_lane_f32(voutput3, vinput1, vfilter3x, 1);
vst1q_f32(output0, voutput0);
vst1q_f32(output1, voutput1);
vst1q_f32(output2, voutput2);
vst1q_f32(output3, voutput3);
}
}
static inline void NeonConv2x4SubBlockKernel(index_t input_channels_subblock_size,
index_t output_channels_subblock_size,
index_t input_channels,
index_t pixel_size,
const float *input,
const float *filter,
float *output) {
const float *input0 = input;
const float *input1 = input + pixel_size;
float32x2_t vfilter0x, vfilter1x, vfilter2x, vfilter3x;
vfilter0x = vld1_dup_f32(&filter[0]);
if (input_channels_subblock_size > 1) {
vfilter0x = vld1_lane_f32(&filter[1], vfilter0x, 1);
}
if (output_channels_subblock_size > 1) {
filter += input_channels;
vfilter1x = vld1_dup_f32(&filter[0]);
if (input_channels_subblock_size > 1) {
vfilter1x = vld1_lane_f32(&filter[1], vfilter1x, 1);
}
if (output_channels_subblock_size > 2) {
filter += input_channels;
vfilter2x = vld1_dup_f32(&filter[0]);
if (input_channels_subblock_size > 1) {
vfilter2x = vld1_lane_f32(&filter[1], vfilter2x, 1);
}
if (output_channels_subblock_size > 3) {
filter += input_channels;
vfilter3x = vld1_dup_f32(&filter[0]);
if (input_channels_subblock_size > 1) {
vfilter3x = vld1_lane_f32(&filter[1], vfilter3x, 1);
}
}
}
}
float *output0 = output;
float *output1 = output0 + pixel_size;
float *output2 = output1 + pixel_size;
float *output3 = output2 + pixel_size;
while (pixel_size >= 4) {
float32x4_t voutput0, voutput1, voutput2, voutput3;
voutput0 = vld1q_f32(output0);
if (output_channels_subblock_size > 1) {
voutput1 = vld1q_f32(output1);
if (output_channels_subblock_size > 2) {
voutput2 = vld1q_f32(output2);
if (output_channels_subblock_size > 3) {
voutput3 = vld1q_f32(output3);
}
}
}
const float32x4_t vinput0 = vld1q_f32(input0);
input0 += 4;
voutput0 = vfmaq_lane_f32(voutput0, vinput0, vfilter0x, 0);
voutput1 = vfmaq_lane_f32(voutput1, vinput0, vfilter1x, 0);
voutput2 = vfmaq_lane_f32(voutput2, vinput0, vfilter2x, 0);
voutput3 = vfmaq_lane_f32(voutput3, vinput0, vfilter3x, 0);
if (input_channels_subblock_size > 1) {
const float32x4_t vinput1 = vld1q_f32(input1);
input1 += 4;
voutput0 = vfmaq_lane_f32(voutput0, vinput1, vfilter0x, 1);
voutput1 = vfmaq_lane_f32(voutput1, vinput1, vfilter1x, 1);
voutput2 = vfmaq_lane_f32(voutput2, vinput1, vfilter2x, 1);
voutput3 = vfmaq_lane_f32(voutput3, vinput1, vfilter3x, 1);
}
vst1q_f32(output0, voutput0);
output0 += 4;
if (output_channels_subblock_size > 1) {
vst1q_f32(output1, voutput1);
output1 += 4;
if (output_channels_subblock_size > 2) {
vst1q_f32(output2, voutput2);
output2 += 4;
if (output_channels_subblock_size > 3) {
vst1q_f32(output3, voutput3);
output3 += 4;
}
}
}
pixel_size -= 4;
}
if (pixel_size != 0) {
const int32x4_t vmask = vld1q_s32(&mask_array[pixel_size]);
float32x4_t voutput0, voutput1, voutput2, voutput3;
output0 += pixel_size - 4;
voutput0 = vld1q_f32(output0);
if (output_channels_subblock_size > 1) {
output1 += pixel_size - 4;
voutput1 = vld1q_f32(output1);
if (output_channels_subblock_size > 2) {
output2 += pixel_size - 4;
voutput2 = vld1q_f32(output2);
if (output_channels_subblock_size > 3) {
output3 += pixel_size - 4;
voutput3 = vld1q_f32(output3);
}
}
}
const float32x4_t vinput0 = vreinterpretq_f32_s32(
vandq_s32(vmask, vreinterpretq_s32_f32(vld1q_f32(&input0[pixel_size - 4]))));
voutput0 = vfmaq_lane_f32(voutput0, vinput0, vfilter0x, 0);
voutput1 = vfmaq_lane_f32(voutput1, vinput0, vfilter1x, 0);
voutput2 = vfmaq_lane_f32(voutput2, vinput0, vfilter2x, 0);
voutput3 = vfmaq_lane_f32(voutput3, vinput0, vfilter3x, 0);
if (input_channels_subblock_size > 1) {
const float32x4_t vinput1 = vreinterpretq_f32_s32(
vandq_s32(vmask, vreinterpretq_s32_f32(vld1q_f32(&input1[pixel_size - 4]))));
voutput0 = vfmaq_lane_f32(voutput0, vinput1, vfilter0x, 1);
voutput1 = vfmaq_lane_f32(voutput1, vinput1, vfilter1x, 1);
voutput2 = vfmaq_lane_f32(voutput2, vinput1, vfilter2x, 1);
voutput3 = vfmaq_lane_f32(voutput3, vinput1, vfilter3x, 1);
}
vst1q_f32(output0, voutput0);
if (output_channels_subblock_size > 1) {
vst1q_f32(output1, voutput1);
if (output_channels_subblock_size > 2) {
vst1q_f32(output2, voutput2);
if (output_channels_subblock_size > 3) {
vst1q_f32(output3, voutput3);
}
}
}
}
}
void Conv2dNeonK1x1S1(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, filter_h, filter_w
const index_t *filter_shape,
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
const index_t batch = output_shape[0];
const index_t channels = output_shape[1];
const index_t height = output_shape[2];
const index_t width = output_shape[3];
const index_t input_batch = input_shape[0];
const index_t input_channels = input_shape[1];
const index_t input_height = input_shape[2];
const index_t input_width = input_shape[3];
MACE_CHECK(input_batch == batch && input_height == height &&
input_width == width);
const index_t total_pixels = height * width;
const index_t round_up_channels = RoundUp(channels, kOutputChannelBlockSize);
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (int i = 0; i < channels; ++i) {
float *output_ptr_base = output + n * channels * total_pixels + i * total_pixels;
std::fill(output_ptr_base, output_ptr_base + total_pixels, bias ? bias[i] : 0);
}
}
// benchmark omp collapsed(2)
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) {
for (index_t c = 0; c < round_up_channels; c += kOutputChannelBlockSize) {
const float *input_ptr = input + n * input_channels * total_pixels;
const float *filter_ptr = filter + c * input_channels;
float *output_ptr = output + n * channels * total_pixels + c * total_pixels;
const index_t output_channel_block_size = std::min(channels - c, kOutputChannelBlockSize);
index_t remain_input_channels = input_channels;
if (c + kOutputChannelBlockSize <= channels) {
while (remain_input_channels >= kInputChannelBlockSize) {
NeonConv2x4Kernel(input_channels, total_pixels, input_ptr, filter_ptr, output_ptr);
input_ptr += kInputChannelBlockSize * total_pixels;
filter_ptr += kInputChannelBlockSize;
remain_input_channels -= kInputChannelBlockSize;
}
}
while (remain_input_channels != 0) {
const index_t input_channel_block_size = std::min(remain_input_channels, kInputChannelBlockSize);
NeonConv2x4SubBlockKernel(input_channel_block_size, output_channel_block_size,
input_channels, total_pixels, input_ptr, filter_ptr, output_ptr);
input_ptr += kInputChannelBlockSize * total_pixels;
filter_ptr += kInputChannelBlockSize;
remain_input_channels -= input_channel_block_size;
}
}
}
};
void Conv2dNeonPixelK1x1S1(const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
......
......@@ -34,7 +34,10 @@ cc_library(
["*.h"],
exclude = ["ops_test_util.h"],
),
copts = ["-std=c++11", "-fopenmp",],
copts = [
"-std=c++11",
"-fopenmp",
],
deps = [
"//mace/core",
"//mace/kernels",
......@@ -50,7 +53,7 @@ cc_test(
["*_test.cc"],
),
copts = ["-std=c++11"],
linkopts = ["-fopenmp",] + if_android(["-ldl"]),
linkopts = ["-fopenmp"] + if_android(["-ldl"]),
linkstatic = 1,
deps = [
":ops",
......@@ -64,7 +67,7 @@ cc_test(
testonly = 1,
srcs = glob(["*_benchmark.cc"]),
copts = ["-std=c++11"],
linkopts = ["-fopenmp",] + if_android(["-ldl"]),
linkopts = ["-fopenmp"] + if_android(["-ldl"]),
linkstatic = 1,
deps = [
":ops",
......
......@@ -72,6 +72,11 @@ static void Conv2d(int iters,
BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128, float); // Test bad alignments
BM_CONV_2D(1, 3, 512, 512, 1, 1, 1, VALID, 3, float);
BM_CONV_2D(1, 32, 112, 112, 1, 1, 1, VALID, 64, float);
BM_CONV_2D(1, 64, 56, 56, 1, 1, 1, VALID, 128, float);
BM_CONV_2D(1, 256, 28, 28, 1, 1, 1, VALID, 256, float);
BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, VALID, 1024, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 3, float);
......@@ -86,5 +91,4 @@ BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, SAME, 128, float);
BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, SAME, 128, float);
} // namespace mace
......@@ -165,18 +165,69 @@ TEST_F(Conv2dOpTest, Conv1x1) {
}
// TODO we need more tests
TEST_F(Conv2dOpTest, ConvNxNS12) {
TEST_F(Conv2dOpTest, IdleConvNxNS12) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t input_channels = 1 + rand() % 10;
index_t batch = 3 ;
index_t input_channels = 64;
index_t height = 32;
index_t width = 32;
index_t output_channels = 128;
// Construct graph
auto& net = test_net();
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.Finalize(net.operator_def());
// Add args
net.AddIntsArg("strides", {stride_h, stride_w});
net.AddIntArg("padding", type);
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddRandomInput<float>("Input", {batch, input_channels, height, width});
net.AddRandomInput<float>(
"Filter", {output_channels, input_channels, kernel_h, kernel_w});
net.AddRandomInput<float>("Bias", {output_channels});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run NEON
net.RunOp(DeviceType::NEON);
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.001);
};
for (int kernel_size : {1}) {
for (int stride : {1}) {
func(kernel_size, kernel_size, stride, stride, VALID);
func(kernel_size, kernel_size, stride, stride, SAME);
}
}
}
TEST_F(Conv2dOpTest, DisgustConvNxNS12) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) {
srand(time(NULL));
// generate random input
index_t batch = 3 + rand() % 10;
index_t input_channels = 3 + rand() % 10;
index_t height = 107;
index_t width = 113;
index_t output_channels = 1 + rand() % 10;
index_t output_channels = 3 + rand() % 10;
// Construct graph
auto& net = test_net();
OpDefBuilder("Conv2D", "Conv2dTest")
......
......@@ -8,15 +8,23 @@ package(
licenses(["notice"]) # Apache 2.0
cc_library(
name = "utils",
srcs = glob([
"*.cc",
]),
hdrs = glob([
"*.h",
]),
name = "command_line_flags",
srcs = [
"command_line_flags.cc",
],
hdrs = [
"command_line_flags.h",
],
copts = ["-std=c++11"],
deps = [
"//mace/core:core",
"//mace/core",
],
)
cc_library(
name = "utils",
hdrs = [
"utils.h",
],
)
\ No newline at end of file
copts = ["-std=c++11"],
)
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_UTILS_UTILS_H_
#define MACE_UTILS_UTILS_H_
namespace mace {
template <typename Integer>
Integer RoundUp(Integer i, Integer factor) {
return (i + factor - 1) / factor * factor;
}
template <typename Integer>
Integer CeilQuotient(Integer a, Integer b) {
return (a + b - 1) / b;
}
} // namespace mace
#endif // MACE_UTILS_UTILS_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册