提交 00a2b756 编写于 作者: L liuruilong

Merge branch 'develop' of https://github.com/codeWorm2015/paddle-mobile into develop

...@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.0) ...@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.0)
project(paddle-mobile) project(paddle-mobile)
option(DEBUGING "enable debug mode" ON) option(DEBUGING "enable debug mode" ON)
option(USE_OPENMP "openmp support" OFF) option(USE_OPENMP "openmp support" ON)
option(USE_EXCEPTION "use std exception" ON) option(USE_EXCEPTION "use std exception" ON)
option(LOG_PROFILE "log profile" ON) option(LOG_PROFILE "log profile" ON)
# select the platform to build # select the platform to build
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "io/executor.h" #include "io/executor.h"
#include <operators/math/gemm.h>
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "common/enforce.h" #include "common/enforce.h"
...@@ -25,6 +26,9 @@ limitations under the License. */ ...@@ -25,6 +26,9 @@ limitations under the License. */
#include "framework/program/var_desc.h" #include "framework/program/var_desc.h"
#include "framework/scope.h" #include "framework/scope.h"
#include "framework/tensor.h" #include "framework/tensor.h"
#ifdef _OPENMP
#include <omp.h>
#endif // _OPENMP
#ifdef PADDLE_EXECUTOR_MULTITHREAD #ifdef PADDLE_EXECUTOR_MULTITHREAD
#include <queue> #include <queue>
#include <utility> #include <utility>
...@@ -403,6 +407,14 @@ std::vector<typename Executor<Dtype, P>::Ptype> Executor<Dtype, P>::Predict( ...@@ -403,6 +407,14 @@ std::vector<typename Executor<Dtype, P>::Ptype> Executor<Dtype, P>::Predict(
return result_vector; return result_vector;
} }
template <typename Dtype, Precision P>
void Executor<Dtype, P>::SetThreadNum(int num) {
#ifdef _OPENMP
// omp_set_dynamic(0);
omp_set_num_threads(num);
#endif
}
template class Executor<CPU, Precision::FP32>; template class Executor<CPU, Precision::FP32>;
template class Executor<FPGA, Precision::FP32>; template class Executor<FPGA, Precision::FP32>;
template class Executor<GPU_MALI, Precision::FP32>; template class Executor<GPU_MALI, Precision::FP32>;
......
...@@ -58,6 +58,8 @@ class Executor { ...@@ -58,6 +58,8 @@ class Executor {
std::vector<Ptype> Predict(const std::vector<Ptype> &input, std::vector<Ptype> Predict(const std::vector<Ptype> &input,
const std::vector<int64_t> &dims); const std::vector<int64_t> &dims);
void SetThreadNum(int num);
protected: protected:
Executor() = default; Executor() = default;
void InitMemory(); void InitMemory();
......
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef LRN_OP #ifdef LRN_OP
#ifdef _OPENMP
#include <omp.h>
#endif
#include "framework/operator.h" #include "framework/operator.h"
#include "operators/op_param.h" #include "operators/op_param.h"
...@@ -47,6 +49,7 @@ struct LRNFunctor { ...@@ -47,6 +49,7 @@ struct LRNFunctor {
std::fill(sqr_buffer_ptr, sqr_buffer_ptr + sqr_buffer.numel(), 0.0); std::fill(sqr_buffer_ptr, sqr_buffer_ptr + sqr_buffer.numel(), 0.0);
for (int a = 0; a < N; a++) { for (int a = 0; a < N; a++) {
#pragma parallel for
for (int b = 0; b < C; b++) { for (int b = 0; b < C; b++) {
for (int index = start; index < end; index++) { for (int index = start; index < end; index++) {
int channel = b + index; int channel = b + index;
......
...@@ -1209,12 +1209,12 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { ...@@ -1209,12 +1209,12 @@ void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
// C = A * B, batchnorm(C) // C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale, void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
float *bias) { float *bias) {
int nc1 = nc / 16; int volatile nc1 = nc / 16;
int _nc1 = nc % 16; int _nc1 = nc % 16;
int nc2 = _nc1 / 4; int volatile nc2 = _nc1 / 4;
int nc3 = 16 - 4 * (_nc1 % 4); int volatile nc3 = 16 - 4 * (_nc1 % 4);
int step = 4 * (ldc - nc); int volatile step = 4 * (ldc - nc);
int step1 = 4 * (NC - nc); int volatile step1 = 4 * (NC - nc);
asm volatile( asm volatile(
"subs %[mc], %[mc], #1 \n\t" "subs %[mc], %[mc], #1 \n\t"
......
...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef POOL_OP #ifdef POOL_OP
#include "operators/math/pool_3x3.h" #ifdef _OPENMP
#include <omp.h>
#endif
#include "framework/tensor.h" #include "framework/tensor.h"
#include "pool_3x3.h"
#if __ARM_NEON #if __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif // __ARM_NEON #endif // __ARM_NEON
...@@ -40,46 +43,52 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) { ...@@ -40,46 +43,52 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
const int w_out = output->dims()[3]; const int w_out = output->dims()[3];
const int outputdata_channel_stride = h_out * w_out; const int outputdata_channel_stride = h_out * w_out;
const int inputdata_channel_stride = h_in * w_in; const int inputdata_channel_stride = h_in * w_in;
const int input_batch_stride = output_channels * inputdata_channel_stride;
const int output_batch_stride = output_channels * outputdata_channel_stride;
float *out_data = output->data<float>(); float *out_data = output->data<float>();
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float coef = 1.0 / 9.0; const float coef = 1.0 / 9.0;
for (int k = 0; k < batch_size; ++k) { for (int k = 0; k < batch_size; ++k) {
#pragma omp parallel for
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
const float *input_seg = input_data + c * inputdata_channel_stride;
float *output_seg = out_data + c * outputdata_channel_stride;
// four corner point // four corner point
out_data[0] = (input_data[0] + input_data[1] + input_data[w_in] + output_seg[0] = (input_seg[0] + input_seg[1] + input_seg[w_in] +
input_data[w_in + 1]) * input_seg[w_in + 1]) *
coef; coef;
out_data[w_out - 1] = output_seg[w_out - 1] =
(input_data[w_in - 2] + input_data[w_in - 1] + (input_seg[w_in - 2] + input_seg[w_in - 1] + input_seg[w_in * 2 - 2] +
input_data[w_in * 2 - 2] + input_data[2 * w_in - 1]) * input_seg[2 * w_in - 1]) *
coef; coef;
out_data[(h_out - 1) * w_out] = output_seg[(h_out - 1) * w_out] =
(input_data[(h_in - 2) * w_in] + input_data[(h_in - 2) * w_in + 1] + (input_seg[(h_in - 2) * w_in] + input_seg[(h_in - 2) * w_in + 1] +
input_data[(h_in - 1) * w_in] + input_data[(h_in - 1) * w_in + 1]) * input_seg[(h_in - 1) * w_in] + input_seg[(h_in - 1) * w_in + 1]) *
coef; coef;
out_data[h_out * w_out - 1] = output_seg[h_out * w_out - 1] =
(input_data[h_in * w_in - 1] + input_data[h_in * w_in - 2] + (input_seg[h_in * w_in - 1] + input_seg[h_in * w_in - 2] +
input_data[(h_in - 1) * w_in - 1] + input_seg[(h_in - 1) * w_in - 1] +
input_data[(h_in - 1) * w_in - 2]) * input_seg[(h_in - 1) * w_in - 2]) *
coef; coef;
// left side & right side // left side & right side
for (int i = 1; i < h_in - 1; ++i) { for (int i = 1; i < h_in - 1; ++i) {
out_data[i * w_out] = output_seg[i * w_out] =
(input_data[i * w_in - w_in] + input_data[i * w_in - w_in + 1] + (input_seg[i * w_in - w_in] + input_seg[i * w_in - w_in + 1] +
input_data[i * w_in] + input_data[i * w_in + 1] + input_seg[i * w_in] + input_seg[i * w_in + 1] +
input_data[i * w_in + w_in] + input_data[i * w_in + w_in + 1]) * input_seg[i * w_in + w_in] + input_seg[i * w_in + w_in + 1]) *
coef; coef;
out_data[i * w_out + w_out - 1] = output_seg[i * w_out + w_out - 1] =
(input_data[i * w_in - w_in + w_in - 2] + (input_seg[i * w_in - w_in + w_in - 2] +
input_data[i * w_in - w_in + 1 + w_in - 2] + input_seg[i * w_in - w_in + 1 + w_in - 2] +
input_data[i * w_in + w_in - 2] + input_seg[i * w_in + w_in - 2] +
input_data[i * w_in + 1 + w_in - 2] + input_seg[i * w_in + 1 + w_in - 2] +
input_data[i * w_in + w_in + w_in - 2] + input_seg[i * w_in + w_in + w_in - 2] +
input_data[i * w_in + w_in + 1 + w_in - 2]) * input_seg[i * w_in + w_in + 1 + w_in - 2]) *
coef; coef;
} }
// top 1 row & bottom 1 row // top 1 row & bottom 1 row
const float *input_tmp = input_data; const float *input_tmp = input_seg;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, sum, out0; tmp3, tmp4, tmp5, sum, out0;
...@@ -90,7 +99,7 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) { ...@@ -90,7 +99,7 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
in4 = vld1q_f32(input_tmp_end); in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + w_in); in6 = vld1q_f32(input_tmp_end + w_in);
int c_mid = w_out - 2; int c_mid = w_out - 2;
auto output_ptr = out_data + 1; auto output_ptr = output_seg + 1;
for (; c_mid > 3; c_mid -= 4) { for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4); in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w_in + 4); in3 = vld1q_f32(input_tmp + w_in + 4);
...@@ -135,8 +144,8 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) { ...@@ -135,8 +144,8 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
in6 = in7; in6 = in7;
} }
// top right remain // top right remain
float32x4_t pad0 = vdupq_n_f32(input_data[w_in - 1]); float32x4_t pad0 = vdupq_n_f32(input_seg[w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w_in - 1]); float32x4_t pad1 = vdupq_n_f32(input_seg[2 * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1); tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2); tmp1 = vextq_f32(in0, pad0, 2);
...@@ -163,8 +172,8 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) { ...@@ -163,8 +172,8 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
} }
// bottom_right remain // bottom_right remain
float32x4_t pad2 = vdupq_n_f32(input_data[(h_in - 1) * w_in - 1]); float32x4_t pad2 = vdupq_n_f32(input_seg[(h_in - 1) * w_in - 1]);
float32x4_t pad3 = vdupq_n_f32(input_data[h_in * w_in - 1]); float32x4_t pad3 = vdupq_n_f32(input_seg[h_in * w_in - 1]);
tmp0 = vextq_f32(in4, pad2, 1); tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2); tmp1 = vextq_f32(in4, pad2, 2);
...@@ -191,8 +200,8 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) { ...@@ -191,8 +200,8 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
} }
// mid // mid
for (int j = 0; j < h_out - 2; ++j) { for (int j = 0; j < h_out - 2; ++j) {
output_ptr = out_data + w_out * (j + 1) + 1; output_ptr = output_seg + w_out * (j + 1) + 1;
input_tmp = input_data + j * w_in; input_tmp = input_seg + j * w_in;
in0 = vld1q_f32(input_tmp); in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w_in); in2 = vld1q_f32(input_tmp + w_in);
...@@ -228,9 +237,9 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) { ...@@ -228,9 +237,9 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
in4 = in5; in4 = in5;
} }
// mid remain // mid remain
float32x4_t pad0 = vdupq_n_f32(input_data[(j + 1) * w_in - 1]); float32x4_t pad0 = vdupq_n_f32(input_seg[(j + 1) * w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]); float32x4_t pad1 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]);
float32x4_t pad2 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]); float32x4_t pad2 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1); tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2); tmp1 = vextq_f32(in0, pad0, 2);
...@@ -261,9 +270,11 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) { ...@@ -261,9 +270,11 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
} }
} }
} }
input_data += inputdata_channel_stride; // input_data += inputdata_channel_stride;
out_data += outputdata_channel_stride; // out_data += outputdata_channel_stride;
} }
input_data += input_batch_stride;
out_data += output_batch_stride;
} }
#endif #endif
} }
...@@ -282,44 +293,50 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) { ...@@ -282,44 +293,50 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
const int w_out = output->dims()[3]; const int w_out = output->dims()[3];
const int outputdata_channel_stride = h_out * w_out; const int outputdata_channel_stride = h_out * w_out;
const int inputdata_channel_stride = h_in * w_in; const int inputdata_channel_stride = h_in * w_in;
const int input_batch_stride = output_channels * inputdata_channel_stride;
const int output_batch_stride = output_channels * outputdata_channel_stride;
float *out_data = output->data<float>(); float *out_data = output->data<float>();
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
for (int k = 0; k < batch_size; ++k) { for (int k = 0; k < batch_size; ++k) {
#pragma omp parallel for
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
const float *input_seg = input_data + c * inputdata_channel_stride;
float *output_seg = out_data + c * outputdata_channel_stride;
// four corner point // four corner point
out_data[0] = std::max(std::max(input_data[0], input_data[1]), output_seg[0] = std::max(std::max(input_seg[0], input_seg[1]),
std::max(input_data[w_in], input_data[w_in + 1])); std::max(input_seg[w_in], input_seg[w_in + 1]));
out_data[w_out - 1] = std::max( output_seg[w_out - 1] =
std::max(input_data[w_in - 2], input_data[w_in - 1]), std::max(std::max(input_seg[w_in - 2], input_seg[w_in - 1]),
std::max(input_data[w_in * 2 - 2], input_data[2 * w_in - 1])); std::max(input_seg[w_in * 2 - 2], input_seg[2 * w_in - 1]));
out_data[(h_out - 1) * w_out] = output_seg[(h_out - 1) * w_out] =
std::max(std::max(input_data[(h_in - 2) * w_in], std::max(std::max(input_seg[(h_in - 2) * w_in],
input_data[(h_in - 2) * w_in + 1]), input_seg[(h_in - 2) * w_in + 1]),
std::max(input_data[(h_in - 1) * w_in], std::max(input_seg[(h_in - 1) * w_in],
input_data[(h_in - 1) * w_in + 1])); input_seg[(h_in - 1) * w_in + 1]));
out_data[h_out * w_out - 1] = std::max( output_seg[h_out * w_out - 1] = std::max(
std::max(input_data[(h_in - 1) * w_in - 1], std::max(input_seg[(h_in - 1) * w_in - 1],
input_data[(h_in - 1) * w_in - 2]), input_seg[(h_in - 1) * w_in - 2]),
std::max(input_data[h_in * w_in - 1], input_data[h_in * w_in - 2])); std::max(input_seg[h_in * w_in - 1], input_seg[h_in * w_in - 2]));
// left side & right side // left side & right side
for (int i = 1; i < h_in - 1; ++i) { for (int i = 1; i < h_in - 1; ++i) {
float max1 = std::max(input_data[i * w_in - w_in], float max1 = std::max(input_seg[i * w_in - w_in],
input_data[i * w_in - w_in + 1]); input_seg[i * w_in - w_in + 1]);
float max2 = std::max(input_data[i * w_in], input_data[i * w_in + 1]); float max2 = std::max(input_seg[i * w_in], input_seg[i * w_in + 1]);
float max3 = std::max(input_data[i * w_in + w_in], float max3 = std::max(input_seg[i * w_in + w_in],
input_data[i * w_in + w_in + 1]); input_seg[i * w_in + w_in + 1]);
out_data[i * w_out] = std::max(std::max(max1, max2), max3); output_seg[i * w_out] = std::max(std::max(max1, max2), max3);
max1 = std::max(input_data[i * w_in - w_in + w_in - 2], max1 = std::max(input_seg[i * w_in - w_in + w_in - 2],
input_data[i * w_in - w_in + 1 + w_in - 2]); input_seg[i * w_in - w_in + 1 + w_in - 2]);
max2 = std::max(input_data[i * w_in + w_in - 2], max2 = std::max(input_seg[i * w_in + w_in - 2],
input_data[i * w_in + 1 + w_in - 2]); input_seg[i * w_in + 1 + w_in - 2]);
max3 = std::max(input_data[i * w_in + w_in + w_in - 2], max3 = std::max(input_seg[i * w_in + w_in + w_in - 2],
input_data[i * w_in + w_in + 1 + w_in - 2]); input_seg[i * w_in + w_in + 1 + w_in - 2]);
out_data[i * w_out + w_out - 1] = std::max(std::max(max1, max2), max3); output_seg[i * w_out + w_out - 1] =
std::max(std::max(max1, max2), max3);
} }
// top 1 row & bottom 1 row // top 1 row & bottom 1 row
const float *input_tmp = input_data; const float *input_tmp = input_seg;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, max; tmp3, tmp4, tmp5, max;
...@@ -329,7 +346,7 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) { ...@@ -329,7 +346,7 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
in4 = vld1q_f32(input_tmp_end); in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + w_in); in6 = vld1q_f32(input_tmp_end + w_in);
int c_mid = w_out - 2; int c_mid = w_out - 2;
auto output_ptr = out_data + 1; auto output_ptr = output_seg + 1;
for (; c_mid > 3; c_mid -= 4) { for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4); in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w_in + 4); in3 = vld1q_f32(input_tmp + w_in + 4);
...@@ -373,8 +390,8 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) { ...@@ -373,8 +390,8 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
in6 = in7; in6 = in7;
} }
// top right remain // top right remain
float32x4_t pad0 = vdupq_n_f32(input_data[w_in - 1]); float32x4_t pad0 = vdupq_n_f32(input_seg[w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w_in - 1]); float32x4_t pad1 = vdupq_n_f32(input_seg[2 * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1); tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2); tmp1 = vextq_f32(in0, pad0, 2);
...@@ -400,8 +417,8 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) { ...@@ -400,8 +417,8 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
} }
// bottom_right remain // bottom_right remain
float32x4_t pad2 = vdupq_n_f32(input_data[(h_in - 1) * w_in - 1]); float32x4_t pad2 = vdupq_n_f32(input_seg[(h_in - 1) * w_in - 1]);
float32x4_t pad3 = vdupq_n_f32(input_data[h_in * w_in - 1]); float32x4_t pad3 = vdupq_n_f32(input_seg[h_in * w_in - 1]);
tmp0 = vextq_f32(in4, pad2, 1); tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2); tmp1 = vextq_f32(in4, pad2, 2);
...@@ -427,8 +444,8 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) { ...@@ -427,8 +444,8 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
} }
// mid // mid
for (int j = 0; j < h_out - 2; ++j) { for (int j = 0; j < h_out - 2; ++j) {
output_ptr = out_data + (j + 1) * w_out + 1; output_ptr = output_seg + (j + 1) * w_out + 1;
input_tmp = input_data + j * w_in; input_tmp = input_seg + j * w_in;
in0 = vld1q_f32(input_tmp); in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w_in); in2 = vld1q_f32(input_tmp + w_in);
...@@ -463,9 +480,9 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) { ...@@ -463,9 +480,9 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
in4 = in5; in4 = in5;
} }
// mid remain // mid remain
float32x4_t pad0 = vdupq_n_f32(input_data[(j + 1) * w_in - 1]); float32x4_t pad0 = vdupq_n_f32(input_seg[(j + 1) * w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]); float32x4_t pad1 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]);
float32x4_t pad2 = vdupq_n_f32(input_data[(j + 3) * w_in - 1]); float32x4_t pad2 = vdupq_n_f32(input_seg[(j + 3) * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1); tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2); tmp1 = vextq_f32(in0, pad0, 2);
...@@ -495,9 +512,11 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) { ...@@ -495,9 +512,11 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
} }
} }
} }
input_data += inputdata_channel_stride; // input_data += inputdata_channel_stride;
out_data += outputdata_channel_stride; // out_data += outputdata_channel_stride;
} }
input_data += input_batch_stride;
out_data += output_batch_stride;
} }
#endif #endif
} }
...@@ -515,11 +534,11 @@ void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input, ...@@ -515,11 +534,11 @@ void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
const int output_height = output->dims()[2]; const int output_height = output->dims()[2];
const int output_width = output->dims()[3]; const int output_width = output->dims()[3];
const int _kernel_size = 3; // const int _kernel_size = 3;
const int stride_height = strides[0]; const int stride = strides[0];
const int stride_width = strides[1]; // const int stride_width = strides[1];
const int padding_height = paddings[0]; const int padding = paddings[0];
const int padding_width = paddings[1]; // const int padding_width = paddings[1];
const float negative_max = -INT_MAX; const float negative_max = -INT_MAX;
const int input_channel_stride = input_height * input_width; const int input_channel_stride = input_height * input_width;
const int output_channel_stride = output_height * output_width; const int output_channel_stride = output_height * output_width;
...@@ -529,36 +548,39 @@ void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input, ...@@ -529,36 +548,39 @@ void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
const int input_batch_stride = output_channels * input_channel_stride; const int input_batch_stride = output_channels * input_channel_stride;
const int output_batch_stride = output_channels * output_channel_stride; const int output_batch_stride = output_channels * output_channel_stride;
const float *pos1, *pos2, *pos3, *output_ptr; const float *pos1, *output_ptr;
int hstart, wstart, hend, wend; int hstart, wstart, hend, wend;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
#pragma omp parallel for
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
const float *input_seg = input_data + c * input_channel_stride;
float *output_seg = output_data + c * output_channel_stride;
for (int ph = 0; ph < output_height; ph++) { for (int ph = 0; ph < output_height; ph++) {
for (int pw = 0; pw < output_width; pw++) { for (int pw = 0; pw < output_width; pw++) {
hstart = ph * stride_height - padding_height; int hstart = ph * stride - padding;
wstart = pw * stride_width - padding_width; int wstart = pw * stride - padding;
hend = min(hstart + _kernel_size, input_height + padding_height); int hend = min(hstart + 3, input_height + padding);
wend = min(wstart + _kernel_size, input_width + padding_width); int wend = min(wstart + 3, input_width + padding);
hstart = max(hstart, 0); hstart = max(hstart, 0);
wstart = max(wstart, 0); wstart = max(wstart, 0);
hend = min(hend, input_height); hend = min(hend, input_height);
wend = min(wend, input_width); wend = min(wend, input_width);
pos1 = input_data + hstart * input_width + wstart; const float *pos1 = input_seg + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart; const float *pos2 = input_seg + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart; const float *pos3 = input_seg + (hstart + 2) * input_width + wstart;
output_ptr = output_data + ph * output_width + pw; output_ptr = output_seg + ph * output_width + pw;
if (hend - hstart != 3 || wend - wstart != 3) { if (hend - hstart != 3 || wend - wstart != 3) {
float max_value = -INT_MAX; float max_value = -INT_MAX;
for (int h = hstart; h < hend; h++) { for (int h = hstart; h < hend; h++) {
for (int w = wstart; w < wend; w++) { for (int w = wstart; w < wend; w++) {
float value = input_data[h * input_width + w]; float value = input_seg[h * input_width + w];
if (value > max_value) { if (value > max_value) {
max_value = value; max_value = value;
} }
} }
} }
output_data[ph * output_width + pw] = max_value; output_seg[ph * output_width + pw] = max_value;
} else { } else {
#if defined(ARMV7) #if defined(ARMV7)
asm volatile( asm volatile(
...@@ -572,27 +594,25 @@ void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input, ...@@ -572,27 +594,25 @@ void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
"vpmax.f32 d7, d6, d6 \n\t" "vpmax.f32 d7, d6, d6 \n\t"
"vst1.32 {d7[0]},[%[output_ptr]] \n\t" "vst1.32 {d7[0]},[%[output_ptr]] \n\t"
: :
: [input_data] "r"(input_data), [pos1] "r"(pos1), : [input_seg] "r"(input_seg), [pos1] "r"(pos1),
[pos2] "r"(pos2), [pos3] "r"(pos3), [pos2] "r"(pos2), [pos3] "r"(pos3),
[output_ptr] "r"(output_ptr), [negative_max] "r"(negative_max) [output_ptr] "r"(output_ptr), [negative_max] "r"(negative_max)
: "memory", "q1", "q2", "q3", "q4"); : "memory", "q1", "q2", "q3", "q4");
#else #else
const float32x4_t data1 = vld1q_f32(pos1); const float32x4_t data1 = vld1q_f32(pos1);
const float32x4_t data2 = vld1q_f32(pos2); const float32x4_t data2 = vld1q_f32(pos1 + input_width);
const float32x4_t data3 = vld1q_f32(pos3); const float32x4_t data3 = vld1q_f32(pos1 + 2 * input_width);
const float32x4_t max_data = const float32x4_t max_data =
vmaxq_f32(vmaxq_f32(data1, data3), data2); vmaxq_f32(vmaxq_f32(data1, data2), data3);
float32x2_t res = float32x2_t res =
vpmax_f32(vget_high_f32(vsetq_lane_f32(-INT_MAX, max_data, 3)), vpmax_f32(vget_high_f32(vsetq_lane_f32(-INT_MAX, max_data, 3)),
vget_low_f32(max_data)); vget_low_f32(max_data));
res = vpmax_f32(res, res); res = vpmax_f32(res, res);
output_data[ph * output_width + pw] = vget_lane_f32(res, 0); output_seg[ph * output_width + pw] = vget_lane_f32(res, 0);
#endif #endif
} }
} }
} }
input_data += input_channel_stride;
output_data += output_channel_stride;
} }
input_data += input_batch_stride; input_data += input_batch_stride;
output_data += output_batch_stride; output_data += output_batch_stride;
...@@ -613,11 +633,8 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input, ...@@ -613,11 +633,8 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
const int output_height = output->dims()[2]; const int output_height = output->dims()[2];
const int output_width = output->dims()[3]; const int output_width = output->dims()[3];
const int _kernel_size = 3; const int stride = strides[0];
const int stride_height = strides[0]; const int padding = paddings[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const int input_channel_stride = input_height * input_width; const int input_channel_stride = input_height * input_width;
const int output_channel_stride = output_height * output_width; const int output_channel_stride = output_height * output_width;
...@@ -631,30 +648,33 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input, ...@@ -631,30 +648,33 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
const int input_batch_stride = output_channels * input_channel_stride; const int input_batch_stride = output_channels * input_channel_stride;
const int output_batch_stride = output_channels * output_channel_stride; const int output_batch_stride = output_channels * output_channel_stride;
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
#pragma omp parallel for
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
const float *input_seg = input_data + c * input_channel_stride;
float *output_seg = output_data + c * output_channel_stride;
for (int ph = 0; ph < output_height; ph++) { for (int ph = 0; ph < output_height; ph++) {
for (int pw = 0; pw < output_width; pw++) { for (int pw = 0; pw < output_width; pw++) {
int hstart = ph * stride_height - padding_height; int hstart = ph * stride - padding;
int wstart = pw * stride_width - padding_width; int wstart = pw * stride - padding;
int hend = min(hstart + _kernel_size, input_height + padding_height); int hend = min(hstart + 3, input_height + padding);
int wend = min(wstart + _kernel_size, input_width + padding_width); int wend = min(wstart + 3, input_width + padding);
hstart = max(hstart, 0); hstart = max(hstart, 0);
wstart = max(wstart, 0); wstart = max(wstart, 0);
hend = min(hend, input_height); hend = min(hend, input_height);
wend = min(wend, input_width); wend = min(wend, input_width);
const float *pos1 = input_data + hstart * input_width + wstart; const float *pos1 = input_seg + hstart * input_width + wstart;
const float *pos2 = input_data + (hstart + 1) * input_width + wstart; const float *pos2 = input_seg + (hstart + 1) * input_width + wstart;
const float *pos3 = input_data + (hstart + 2) * input_width + wstart; const float *pos3 = input_seg + (hstart + 2) * input_width + wstart;
const float *output_ptr = output_data + ph * output_width + pw; float *output_ptr = output_seg + ph * output_width + pw;
if (hend - hstart != 3 || wend - wstart != 3) { if (hend - hstart != 3 || wend - wstart != 3) {
float sum = 0; float sum = 0;
for (int h = hstart; h < hend; h++) { for (int h = hstart; h < hend; h++) {
for (int w = wstart; w < wend; w++) { for (int w = wstart; w < wend; w++) {
sum += input_data[h * input_width + w]; sum += input_seg[h * input_width + w];
} }
} }
output_data[ph * output_width + pw] = sum / 9.0; output_seg[ph * output_width + pw] = sum / 9.0;
} else { } else {
#if defined(ARMV7) #if defined(ARMV7)
...@@ -671,7 +691,7 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input, ...@@ -671,7 +691,7 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
"vmul.f32 d6,d7 \n\t" "vmul.f32 d6,d7 \n\t"
"vst1.32 {d6[0]},[%[output_ptr]] \n\t" "vst1.32 {d6[0]},[%[output_ptr]] \n\t"
: :
: [input_data] "r"(input_data), [pos1] "r"(pos1), : [input_seg] "r"(input_seg), [pos1] "r"(pos1),
[pos2] "r"(pos2), [pos3] "r"(pos3), [pos2] "r"(pos2), [pos3] "r"(pos3),
[output_ptr] "r"(output_ptr), [zero] "r"(zero), [output_ptr] "r"(output_ptr), [zero] "r"(zero),
[nine_ptr] "r"(nine_ptr) [nine_ptr] "r"(nine_ptr)
...@@ -686,13 +706,11 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input, ...@@ -686,13 +706,11 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
vpadd_f32(vget_high_f32(vsetq_lane_f32(0, sum_data, 3)), vpadd_f32(vget_high_f32(vsetq_lane_f32(0, sum_data, 3)),
vget_low_f32(sum_data)); vget_low_f32(sum_data));
res = vpadd_f32(res, res); res = vpadd_f32(res, res);
output_data[ph * output_width + pw] = vget_lane_f32(res, 0) / 9.0; output_seg[ph * output_width + pw] = vget_lane_f32(res, 0) / 9.0;
#endif #endif
} }
} }
} }
input_data += input_channel_stride;
output_data += output_channel_stride;
} }
input_data += input_batch_stride; input_data += input_batch_stride;
output_data += output_batch_stride; output_data += output_batch_stride;
......
...@@ -15,10 +15,13 @@ limitations under the License. */ ...@@ -15,10 +15,13 @@ limitations under the License. */
#ifdef POOL_OP #ifdef POOL_OP
#pragma once #pragma once
#ifdef _OPENMP
#include <omp.h>
#endif
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "framework/tensor.h" #include "framework/tensor.h"
#ifdef __ARM_NEON #if __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#endif // __ARM_NEON #endif // __ARM_NEON
......
...@@ -14,10 +14,11 @@ limitations under the License. */ ...@@ -14,10 +14,11 @@ limitations under the License. */
#ifdef POOL_OP #ifdef POOL_OP
#include "operators/math/pooling.h" #include "pooling.h"
#include <algorithm>
#include <vector>
#include "common/types.h" #include "common/types.h"
#ifdef _OPENMP
#include <omp.h>
#endif
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -59,7 +60,7 @@ class PoolFunctor<CPU, PoolProcess, T> { ...@@ -59,7 +60,7 @@ class PoolFunctor<CPU, PoolProcess, T> {
T *output_data = output->mutable_data<T>(); T *output_data = output->mutable_data<T>();
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
// #pragma omp parallel for #pragma omp parallel for
for (int c = 0; c < output_channels; ++c) { for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < output_height; ++ph) { for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height; int hstart = ph * stride_height - padding_height;
......
...@@ -26,16 +26,17 @@ int main() { ...@@ -26,16 +26,17 @@ int main() {
auto time2 = time(); auto time2 = time();
DLOG << "load cost :" << time_diff(time1, time2) << "ms\n"; DLOG << "load cost :" << time_diff(time1, time2) << "ms\n";
paddle_mobile::Executor<paddle_mobile::CPU> executor(program, 1, optimize); paddle_mobile::Executor<paddle_mobile::CPU> executor(program, 1, optimize);
executor.SetThreadNum(4);
std::vector<float> input; std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224}; std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims); GetInput<float>(g_test_image_1x3x224x224, &input, dims);
auto time3 = time(); auto time3 = time();
int count = 1;
for (int i = 0; i < 10; ++i) { for (int i = 0; i < count; ++i) {
executor.Predict(input, dims); executor.Predict(input, dims);
} }
auto time4 = time(); auto time4 = time();
DLOG << "predict cost :" << time_diff(time3, time4) << "ms\n"; DLOG << "predict cost :" << time_diff(time3, time4) / count << "ms\n";
return 0; return 0;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册