提交 b7155afc 编写于 作者: 李寅

Add openmp as compiler opt

上级 44cdbb67
...@@ -314,6 +314,8 @@ const Argument& GetArgument(const OperatorDef& def, const string& name) { ...@@ -314,6 +314,8 @@ const Argument& GetArgument(const OperatorDef& def, const string& name) {
} }
MACE_CHECK(false, "Argument named ", name, "does not exist in operator ", MACE_CHECK(false, "Argument named ", name, "does not exist in operator ",
ProtoDebugString(def)); ProtoDebugString(def));
// should not reach here, just make compiler happy
return std::move(Argument());
} }
bool GetFlagArgument(const OperatorDef& def, bool GetFlagArgument(const OperatorDef& def,
......
...@@ -17,8 +17,8 @@ cc_library( ...@@ -17,8 +17,8 @@ cc_library(
deps = [ deps = [
"//mace/core:core", "//mace/core:core",
], ],
copts = ['-std=c++11'], copts = ['-std=c++11', "-fopenmp",],
linkopts = ["-fopenmp"] + if_android(["-lm"]), linkopts = if_android(["-lm"]),
) )
cc_test( cc_test(
......
...@@ -35,10 +35,10 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW ...@@ -35,10 +35,10 @@ void Conv2dNeonK1x1S1(const float *input, // NCHW
const index_t loop_remaining = total_pixels & 7; const index_t loop_remaining = total_pixels & 7;
// benchmark omp collapsed(2) // benchmark omp collapsed(2)
#pragma omp parallel for collapse(2)
for (index_t n = 0; n < batch; ++n) { for (index_t n = 0; n < batch; ++n) {
const float *filter_ptr = filter;
#pragma omp parallel for
for (index_t c = 0; c < channels; ++c) { for (index_t c = 0; c < channels; ++c) {
const float *filter_ptr = filter;
// TODO Will GCC opt these out? // TODO Will GCC opt these out?
float *channel_output_start = float *channel_output_start =
output + n * channels * height * width + c * height * width; output + n * channels * height * width + c * height * width;
......
...@@ -8,37 +8,6 @@ ...@@ -8,37 +8,6 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
#define KERNEL_HEAD_CODE \
int output_batch = output_shape[0]; \
int output_channels = output_shape[1]; \
int output_height = output_shape[2]; \
int output_width = output_shape[3]; \
int input_batch = input_shape[0]; \
int input_channels = input_shape[1]; \
int input_height = input_shape[2]; \
int input_width = input_shape[3]; \
int multiplier = filter_shape == nullptr ? 0 : (filter_shape[0] / input_channels); \
int filter_in_channels = filter_shape == nullptr ? input_channels : filter_shape[1]; \
for (int b = 0; b < output_batch; ++b) { \
float *output_ptr_base = output + b * output_channels * output_height * output_width; \
for (int oc = 0; oc < output_channels; ++oc) { \
const float *filter_ptr = filter + oc * filter_in_channels * kFilterSize; \
const float *input_ptr = input + b * input_channels * input_height * input_width; \
if (filter_shape != nullptr) { \
input_ptr += (oc / multiplier) * input_height * input_width; \
} \
float *output_ptr = output_ptr_base + oc * output_height * output_width; \
std::fill(output_ptr, output_ptr + output_height * output_width, bias ? bias[oc] : 0); \
for (int ic = 0; ic < filter_in_channels; ++ic) { \
float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)};
#define KERNEL_TAIL_CODE \
filter_ptr += kFilterSize; \
input_ptr += input_height * input_width; \
} \
} \
}
static const int kRegisterSize = 4; static const int kRegisterSize = 4;
static const int kFilterSize = 9; static const int kFilterSize = 9;
...@@ -52,7 +21,29 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW ...@@ -52,7 +21,29 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
int height_count = (output_shape[2] >> 1) << 1; int height_count = (output_shape[2] >> 1) << 1;
KERNEL_HEAD_CODE int output_batch = output_shape[0];
int output_channels = output_shape[1];
int output_height = output_shape[2];
int output_width = output_shape[3];
int input_batch = input_shape[0];
int input_channels = input_shape[1];
int input_height = input_shape[2];
int input_width = input_shape[3];
int multiplier = filter_shape == nullptr ? 0 : (filter_shape[0] / input_channels);
int filter_in_channels = filter_shape == nullptr ? input_channels : filter_shape[1];
#pragma omp parallel for collapse(2)
for (int b = 0; b < output_batch; ++b) {
for (int oc = 0; oc < output_channels; ++oc) {
float *output_ptr_base = output + b * output_channels * output_height * output_width;
const float *filter_ptr = filter + oc * filter_in_channels * kFilterSize;
const float *input_ptr = input + b * input_channels * input_height * input_width;
if (filter_shape != nullptr) {
input_ptr += (oc / multiplier) * input_height * input_width;
}
float *output_ptr = output_ptr_base + oc * output_height * output_width;
std::fill(output_ptr, output_ptr + output_height * output_width, bias ? bias[oc] : 0);
for (int ic = 0; ic < filter_in_channels; ++ic) {
float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)};
const float *row_ptr_v[kRegisterSize] = { const float *row_ptr_v[kRegisterSize] = {
input_ptr, input_ptr + input_width, input_ptr, input_ptr + input_width,
...@@ -212,7 +203,11 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW ...@@ -212,7 +203,11 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
} }
} }
KERNEL_TAIL_CODE filter_ptr += kFilterSize;
input_ptr += input_height * input_width;
}
}
}
} }
void Conv2dNeonK3x3S2(const float *input, // NCHW void Conv2dNeonK3x3S2(const float *input, // NCHW
...@@ -224,7 +219,30 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW ...@@ -224,7 +219,30 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW
const index_t *output_shape) { const index_t *output_shape) {
int tail_step = 2 * (input_shape[3] - output_shape[3]); int tail_step = 2 * (input_shape[3] - output_shape[3]);
KERNEL_HEAD_CODE int output_batch = output_shape[0];
int output_channels = output_shape[1];
int output_height = output_shape[2];
int output_width = output_shape[3];
int input_batch = input_shape[0];
int input_channels = input_shape[1];
int input_height = input_shape[2];
int input_width = input_shape[3];
int multiplier = filter_shape == nullptr ? 0 : (filter_shape[0] / input_channels);
int filter_in_channels = filter_shape == nullptr ? input_channels : filter_shape[1];
#pragma omp parallel for collapse(2)
for (int b = 0; b < output_batch; ++b) {
for (int oc = 0; oc < output_channels; ++oc) {
float *output_ptr_base = output + b * output_channels * output_height * output_width;
const float *filter_ptr = filter + oc * filter_in_channels * kFilterSize;
const float *input_ptr = input + b * input_channels * input_height * input_width;
if (filter_shape != nullptr) {
input_ptr += (oc / multiplier) * input_height * input_width;
}
float *output_ptr = output_ptr_base + oc * output_height * output_width;
std::fill(output_ptr, output_ptr + output_height * output_width, bias ? bias[oc] : 0);
for (int ic = 0; ic < filter_in_channels; ++ic) {
float32x4_t n_filter_v[3] = {vld1q_f32(filter_ptr), vld1q_f32(filter_ptr+3), vld1q_f32(filter_ptr+6)};
const float *row_ptr_v[3] = { const float *row_ptr_v[3] = {
input_ptr, input_ptr + input_width, input_ptr + 2 * input_width input_ptr, input_ptr + input_width, input_ptr + 2 * input_width
...@@ -291,10 +309,11 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW ...@@ -291,10 +309,11 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW
} }
} }
KERNEL_TAIL_CODE filter_ptr += kFilterSize;
input_ptr += input_height * input_width;
}
}
}
} }
#undef KERNEL_HEAD_CODE
#undef KERNEL_TAIL_CODE
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -34,7 +34,7 @@ cc_library( ...@@ -34,7 +34,7 @@ cc_library(
["*.h"], ["*.h"],
exclude = ["ops_test_util.h"], exclude = ["ops_test_util.h"],
), ),
copts = ["-std=c++11"], copts = ["-std=c++11", "-fopenmp",],
deps = [ deps = [
"//mace/core", "//mace/core",
"//mace/kernels", "//mace/kernels",
...@@ -50,7 +50,7 @@ cc_test( ...@@ -50,7 +50,7 @@ cc_test(
["*_test.cc"], ["*_test.cc"],
), ),
copts = ["-std=c++11"], copts = ["-std=c++11"],
linkopts = if_android(["-ldl"]), linkopts = ["-fopenmp",] + if_android(["-ldl"]),
linkstatic = 1, linkstatic = 1,
deps = [ deps = [
":ops", ":ops",
...@@ -64,7 +64,7 @@ cc_test( ...@@ -64,7 +64,7 @@ cc_test(
testonly = 1, testonly = 1,
srcs = glob(["*_benchmark.cc"]), srcs = glob(["*_benchmark.cc"]),
copts = ["-std=c++11"], copts = ["-std=c++11"],
linkopts = if_android(["-ldl"]), linkopts = ["-fopenmp",] + if_android(["-ldl"]),
linkstatic = 1, linkstatic = 1,
deps = [ deps = [
":ops", ":ops",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册