提交 cab2d143 编写于 作者: W WangLiu 提交者: GitHub

Merge pull request #552 from smilejames/develop

Revert: accelerate with openmp
......@@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.0)
project(paddle-mobile)
option(DEBUGING "enable debug mode" ON)
option(USE_OPENMP "openmp support" ON)
option(USE_OPENMP "openmp support" OFF)
option(USE_EXCEPTION "use std exception" ON)
option(LOG_PROFILE "log profile" ON)
# select the platform to build
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "io/executor.h"
#include <operators/math/gemm.h>
#include <algorithm>
#include <vector>
#include "common/enforce.h"
......@@ -26,9 +25,6 @@ limitations under the License. */
#include "framework/program/var_desc.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#ifdef _OPENMP
#include <omp.h>
#endif // _OPENMP
#ifdef PADDLE_EXECUTOR_MULTITHREAD
#include <queue>
#include <utility>
......@@ -407,17 +403,6 @@ std::vector<typename Executor<Dtype, P>::Ptype> Executor<Dtype, P>::Predict(
return result_vector;
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::SetThreadNum(int num) {
for (int k = 0; k < std::max(num, 3); ++k) {
operators::math::Gemmer::gemmers.push_back(new operators::math::Gemmer());
}
#ifdef _OPENMP
// omp_set_dynamic(0);
omp_set_num_threads(num);
#endif
}
template class Executor<CPU, Precision::FP32>;
template class Executor<FPGA, Precision::FP32>;
template class Executor<GPU_MALI, Precision::FP32>;
......
......@@ -58,8 +58,6 @@ class Executor {
std::vector<Ptype> Predict(const std::vector<Ptype> &input,
const std::vector<int64_t> &dims);
void SetThreadNum(int num);
protected:
Executor() = default;
void InitMemory();
......
......@@ -14,14 +14,10 @@ limitations under the License. */
#ifdef FUSION_CONVADD_OP
#pragma once
#if _OPENMP
#include <omp.h>
#endif
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/gemm.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
......@@ -110,33 +106,9 @@ void ConvAddBasic(const FusionConvAddParam &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
auto dim_a = filter_slice.dims();
auto dim_b = col_matrix.dims();
auto dim_out = out_slice.dims();
int m = dim_out[0];
int n = dim_out[1];
int k = dim_a[1];
float *output_data = out_slice.data<float>();
int thread_num = 4;
int m1 = m / thread_num;
int m2 = m % thread_num;
#pragma omp parallel for
for (int j = 0; j < thread_num; ++j) {
int row_count = m1;
if (j == thread_num - 1) {
row_count = m1 + m2;
}
math::Gemmer::gemmers[j]->Sgemm(
row_count, n, k, 1, filter_slice.data<float>() + j * m1 * k, k,
col_matrix.data<float>(), n, 1, output_data + j * m1 * n, n, false);
}
// math::matmul<float>(filter_slice, false, col_matrix, false,
// static_cast<float>(1), &out_slice,
// static_cast<float>(1));
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1));
}
}
}
......
......@@ -13,9 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef LRN_OP
#ifdef _OPENMP
#include <omp.h>
#endif
#include "framework/operator.h"
#include "operators/op_param.h"
......@@ -49,7 +47,6 @@ struct LRNFunctor {
std::fill(sqr_buffer_ptr, sqr_buffer_ptr + sqr_buffer.numel(), 0.0);
for (int a = 0; a < N; a++) {
#pragma parallel for
for (int b = 0; b < C; b++) {
for (int index = start; index < end; index++) {
int channel = b + index;
......
......@@ -22,10 +22,16 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
namespace math {
std::vector<Gemmer *> Gemmer::gemmers;
int MC = 0;
int KC = 0;
int NC = 0;
float *packedA;
float *packedB;
float *packedC;
float *zero;
// 将A矩阵分块复制到连续内存(ColMajor)
void Gemmer::PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
int i, j;
const float *Aij;
......@@ -52,7 +58,7 @@ void Gemmer::PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
}
// 将A矩阵分块复制到连续内存(RowMajor)
void Gemmer::PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
float *buffer) {
const float *a0, *a1, *a2, *a3;
for (int i = 0; i < m - m_tail; i += MR) {
......@@ -92,7 +98,7 @@ void Gemmer::PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
}
// 将B矩阵分块复制到连续内存(ColMajor)
void Gemmer::PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
int i, j;
const float *Bj, *Bj1, *Bj2, *Bj3;
......@@ -121,7 +127,7 @@ void Gemmer::PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
}
// 将B矩阵分块复制到连续内存(RowMajor)
void Gemmer::PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) {
const float *b0;
for (int j = 0; j < n - n_tail; j += NR) {
......@@ -150,9 +156,8 @@ void Gemmer::PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
}
// 分块矩阵乘法
void Gemmer::InnerKernel(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu) {
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu) {
for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) {
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
......@@ -179,10 +184,9 @@ void Gemmer::InnerKernel(int mc, int nc, float alpha, const float *a,
}
// 分块矩阵乘法
void Gemmer::InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *new_scale,
float *new_bias) {
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc,
bool relu, float *new_scale, float *new_bias) {
for (int j = 0; j < nc; j += NR) {
for (int i = 0; i < mc; i += MR) {
// AddDot4x4(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
......@@ -198,8 +202,7 @@ void Gemmer::InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
}
#if defined(IOS)
void Gemmer::AddDot4x4(int k, const float *a, const float *b, float *C,
int ldc) {
void AddDot4x4(int k, const float *a, const float *b, float *C, int ldc) {
// init C
float32x4_t cv0 = vdupq_n_f32(0.0);
float32x4_t cv1 = vdupq_n_f32(0.0);
......@@ -250,8 +253,7 @@ void Gemmer::AddDot4x4(int k, const float *a, const float *b, float *C,
} // namespace math
#elif defined(ARMV7)
void Gemmer::AddDot4x4(int k, const float *a, const float *b, float *c,
int ldc) {
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
......@@ -322,8 +324,7 @@ void Gemmer::AddDot4x4(int k, const float *a, const float *b, float *c,
}
#else
void Gemmer::AddDot4x4(int k, const float *a, const float *b, float *c,
int ldc) {
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) {
float *c0, *c1, *c2, *c3;
c0 = c;
c1 = c + ldc;
......@@ -362,9 +363,8 @@ void Gemmer::AddDot4x4(int k, const float *a, const float *b, float *c,
#endif
// 32位 float 矩阵乘法
void Gemmer::Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu) {
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 30 * 1024;
......@@ -415,10 +415,9 @@ void Gemmer::Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
paddle_mobile::memory::Free(zero);
}
void Gemmer::SgemmWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu, float *new_scale,
float *new_bias) {
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias) {
// L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73)
// L2 cache is 0.5~4 Mib (Contex-A72 cluster)
int L1 = 30 * 1024;
......@@ -469,9 +468,9 @@ void Gemmer::SgemmWithBn(int m, int n, int k, float alpha, const float *A,
paddle_mobile::memory::Free(zero);
}
void Gemmer::VectorKernel(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta,
float *C, int ldc, bool relu) {
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu) {
float *bufferC = static_cast<float *>(memory::Alloc(sizeof(float) * n));
const float *a0, *b0, *b1, *b2, *b3;
......@@ -691,10 +690,9 @@ void Gemmer::VectorKernel(int m, int n, int k, float alpha, const float *A,
}
}
void Gemmer::VectorKernelWithBn(int m, int n, int k, float alpha,
const float *A, int lda, const float *B,
int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias) {
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu, float *new_scale, float *new_bias) {
float *bufferC = static_cast<float *>(memory::Alloc(sizeof(float) * n));
const float *a0, *b0, *b1, *b2, *b3;
......@@ -903,8 +901,7 @@ void Gemmer::VectorKernelWithBn(int m, int n, int k, float alpha,
}
}
void Gemmer::AddDot4x8(int k, const float *a, const float *b, float *c,
int ldc) {
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) {
const float *a_ptr, *b_ptr;
a_ptr = a;
b_ptr = b;
......@@ -1012,7 +1009,7 @@ void Gemmer::AddDot4x8(int k, const float *a, const float *b, float *c,
}
// C = A * B
void Gemmer::WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
void WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16;
int _nc1 = nc % 16;
int step = 4 * ldc;
......@@ -1069,10 +1066,10 @@ void Gemmer::WriteBasic(int mc, int nc, float *c, float *C, int ldc) {
}
// C = alpha * A * B + beta * C
void Gemmer::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {}
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {}
// C = A * B + C
void Gemmer::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16;
int _nc1 = nc % 16;
int step = 4 * ldc;
......@@ -1136,7 +1133,7 @@ void Gemmer::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {
}
// C = A * B + C, relu(C)
void Gemmer::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
int nc1 = nc / 16;
int _nc1 = nc % 16;
int step = 4 * ldc;
......@@ -1210,8 +1207,8 @@ void Gemmer::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {
}
// C = A * B, batchnorm(C)
void Gemmer::WriteWithBn(int mc, int nc, float *c, float *C, int ldc,
float *scale, float *bias) {
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *scale,
float *bias) {
int nc1 = nc / 16;
int _nc1 = nc % 16;
int nc2 = _nc1 / 4;
......@@ -1296,8 +1293,8 @@ void Gemmer::WriteWithBn(int mc, int nc, float *c, float *C, int ldc,
}
// C = A * B, batchnorm(C), relu(C)
void Gemmer::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *scale, float *bias) {
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, float *scale,
float *bias) {
int nc1 = nc / 16;
int _nc1 = nc % 16;
int nc2 = _nc1 / 4;
......@@ -1389,7 +1386,7 @@ void Gemmer::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
}
// C = A * B
void Gemmer::VecWriteBasic(int n, float *c, float *C, int ldc) {
void VecWriteBasic(int n, float *c, float *C, int ldc) {
int nc1 = n / 16;
int _nc1 = n % 16;
int nc2 = _nc1 / 4;
......@@ -1435,10 +1432,10 @@ void Gemmer::VecWriteBasic(int n, float *c, float *C, int ldc) {
}
// C = alpha * A * B + beta * C
void Gemmer::VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {}
void VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {}
// C = A * B + C
void Gemmer::VecWriteWithAdd(int n, float *c, float *C, int ldc) {
void VecWriteWithAdd(int n, float *c, float *C, int ldc) {
int nc1 = n / 16;
int _nc1 = n % 16;
......@@ -1476,7 +1473,7 @@ void Gemmer::VecWriteWithAdd(int n, float *c, float *C, int ldc) {
}
// C = A * B + C, relu(C)
void Gemmer::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
int nc1 = n / 16;
int _nc1 = n % 16;
......@@ -1524,7 +1521,7 @@ void Gemmer::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) {
}
// C = A * B, batchnorm(C)
void Gemmer::VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale,
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale,
float *bias) {
int nc1 = n / 16;
int _nc1 = n % 16;
......@@ -1591,8 +1588,8 @@ void Gemmer::VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale,
}
// C = A * B, batchnorm(C), relu(C)
void Gemmer::VecWriteWithBnRelu(int n, float *c, float *C, int ldc,
float *scale, float *bias) {
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *scale,
float *bias) {
int nc1 = n / 16;
int _nc1 = n % 16;
int nc2 = _nc1 / 4;
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
// 矩阵取值运算宏,假设矩阵按行存储
#define A(i, j) A[(i)*lda + (j)]
......@@ -28,111 +27,88 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
namespace math {
struct Gemmer {
int MC = 0;
int KC = 0;
int NC = 0;
float *packedA;
float *packedB;
float *packedC;
float *zero;
static std::vector<Gemmer *> gemmers;
// 将 A 矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
// 将 A 矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
// 将 B 矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 将 A 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
// 将 A 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
// 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
float *buffer);
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
// 分块矩阵乘法
void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b,
float beta, float *c, float *C, int ldc, bool relu);
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C,
int ldc, bool relu, float *new_scale, float *new_bias);
void InnerKernelWithBn(int mc, int nc, float alpha, const float *a,
const float *b, float beta, float *c, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
// 向量矩阵乘法 (M = 1)
void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu);
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta,
float *C, int ldc, bool relu, float *new_scale,
float *new_bias);
// 计算一个更小的 C 矩阵分块
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
// 分块矩阵乘法结果回写
// C = A * B
void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias);
void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A,
int lda, const float *B, int ldb, float beta, float *C,
int ldc, bool relu, float *new_scale, float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
// 计算一个更小的 C 矩阵分块
void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc);
void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc);
// 分块矩阵乘法结果回写
// C = A * B
void WriteBasic(int mc, int nc, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C
void WriteWithAdd(int mc, int nc, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void WriteWithBn(int mc, int nc, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc,
float *new_scale, float *new_bias);
// 向量矩阵乘法结果回写
// C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc);
// C = A * B + C
void VecWriteWithAdd(int n, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
// 向量矩阵乘法结果回写
// C = A * B
void VecWriteBasic(int n, float *c, float *C, int ldc);
// C = alpha * A * B + beta * C
void VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc);
// C = A * B + C
void VecWriteWithAdd(int n, float *c, float *C, int ldc);
// C = A * B + C, relu(C)
void VecWriteWithAddRelu(int n, float *c, float *C, int ldc);
// C = A * B, batchnorm(C)
void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
// C = A * B, batchnorm(C), relu(C)
void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale,
float *new_bias);
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
// 32位 float 矩阵乘法
void Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, bool relu);
// 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
// 32位 float 矩阵乘法, 并对结果进行 batchnrom
void SgemmWithBn(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc,
bool relu, float *new_scale, float *new_bias);
// 64位 double 矩阵乘法
void dgemm(int m, int n, int k, float alpha, const double *A, int lda,
// 64位 double 矩阵乘法
void dgemm(int m, int n, int k, float alpha, const double *A, int lda,
const double *B, int ldb, float beta, double *C, int ldc);
};
} // namespace math
} // namespace operators
......
......@@ -26,14 +26,23 @@ void matmul<float>(const framework::Tensor &matrix_a, bool trans_a,
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
// PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 &&
// dim_out.size() ==
// 2,
// "The input and output of matmul be matrix");
//
// PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
// platform::is_cpu_place(matrix_b.place())
// &&
// platform::is_cpu_place(matrix_out->place()),
// "Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
Gemmer::gemmers[0]->Sgemm(M, N, K, alpha, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu);
Sgemm(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu);
}
template <>
......@@ -45,15 +54,24 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
// PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 &&
// dim_out.size() ==
// 2,
// "The input and output of matmul be matrix");
//
// PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
// platform::is_cpu_place(matrix_b.place())
// &&
// platform::is_cpu_place(matrix_out->place()),
// "Matrix must all be in CPUPlace");
int M = dim_out[0];
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
Gemmer::gemmers[0]->SgemmWithBn(
M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(), N,
beta, matrix_out->data<float>(), N, relu, new_scale->data<float>(),
new_bias->data<float>());
SgemmWithBn(M, N, K, alpha, matrix_a.data<float>(), K, matrix_b.data<float>(),
N, beta, matrix_out->data<float>(), N, relu,
new_scale->data<float>(), new_bias->data<float>());
}
} // namespace math
......
......@@ -13,12 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef POOL_OP
#ifdef _OPENMP
#include <omp.h>
#endif
#include "operators/math/pool_3x3.h"
#include "framework/tensor.h"
#include "pool_3x3.h"
#ifdef __ARM_NEON
#if __ARM_NEON
#include <arm_neon.h>
#endif // __ARM_NEON
#include <climits>
......@@ -30,7 +27,7 @@ using std::max;
using std::min;
using std::vector;
void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
#ifdef __ARM_NEON
#if __ARM_NEON
const int batch_size = input->dims()[0];
const int h_in = input->dims()[2];
......@@ -43,52 +40,46 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
const int w_out = output->dims()[3];
const int outputdata_channel_stride = h_out * w_out;
const int inputdata_channel_stride = h_in * w_in;
const int input_batch_stride = output_channels * inputdata_channel_stride;
const int output_batch_stride = output_channels * outputdata_channel_stride;
float *out_data = output->data<float>();
const float *input_data = input->data<float>();
const float coef = 1.0 / 9.0;
for (int k = 0; k < batch_size; ++k) {
#pragma omp parallel for
for (int c = 0; c < output_channels; ++c) {
const float *input_seg = input_data + c * inputdata_channel_stride;
float *output_seg = out_data + c * outputdata_channel_stride;
// four corner point
output_seg[0] = (input_seg[0] + input_seg[1] + input_seg[w_in] +
input_seg[w_in + 1]) *
out_data[0] = (input_data[0] + input_data[1] + input_data[w_in] +
input_data[w_in + 1]) *
coef;
output_seg[w_out - 1] =
(input_seg[w_in - 2] + input_seg[w_in - 1] + input_seg[w_in * 2 - 2] +
input_seg[2 * w_in - 1]) *
out_data[w_out - 1] =
(input_data[w_in - 2] + input_data[w_in - 1] +
input_data[w_in * 2 - 2] + input_data[2 * w_in - 1]) *
coef;
output_seg[(h_out - 1) * w_out] =
(input_seg[(h_in - 2) * w_in] + input_seg[(h_in - 2) * w_in + 1] +
input_seg[(h_in - 1) * w_in] + input_seg[(h_in - 1) * w_in + 1]) *
out_data[(h_out - 1) * w_out] =
(input_data[(h_in - 2) * w_in] + input_data[(h_in - 2) * w_in + 1] +
input_data[(h_in - 1) * w_in] + input_data[(h_in - 1) * w_in + 1]) *
coef;
output_seg[h_out * w_out - 1] =
(input_seg[h_in * w_in - 1] + input_seg[h_in * w_in - 2] +
input_seg[(h_in - 1) * w_in - 1] +
input_seg[(h_in - 1) * w_in - 2]) *
out_data[h_out * w_out - 1] =
(input_data[h_in * w_in - 1] + input_data[h_in * w_in - 2] +
input_data[(h_in - 1) * w_in - 1] +
input_data[(h_in - 1) * w_in - 2]) *
coef;
// left side & right side
for (int i = 1; i < h_in - 1; ++i) {
output_seg[i * w_out] =
(input_seg[i * w_in - w_in] + input_seg[i * w_in - w_in + 1] +
input_seg[i * w_in] + input_seg[i * w_in + 1] +
input_seg[i * w_in + w_in] + input_seg[i * w_in + w_in + 1]) *
out_data[i * w_out] =
(input_data[i * w_in - w_in] + input_data[i * w_in - w_in + 1] +
input_data[i * w_in] + input_data[i * w_in + 1] +
input_data[i * w_in + w_in] + input_data[i * w_in + w_in + 1]) *
coef;
output_seg[i * w_out + w_out - 1] =
(input_seg[i * w_in - w_in + w_in - 2] +
input_seg[i * w_in - w_in + 1 + w_in - 2] +
input_seg[i * w_in + w_in - 2] +
input_seg[i * w_in + 1 + w_in - 2] +
input_seg[i * w_in + w_in + w_in - 2] +
input_seg[i * w_in + w_in + 1 + w_in - 2]) *
out_data[i * w_out + w_out - 1] =
(input_data[i * w_in - w_in + w_in - 2] +
input_data[i * w_in - w_in + 1 + w_in - 2] +
input_data[i * w_in + w_in - 2] +
input_data[i * w_in + 1 + w_in - 2] +
input_data[i * w_in + w_in + w_in - 2] +
input_data[i * w_in + w_in + 1 + w_in - 2]) *
coef;
}
// top 1 row & bottom 1 row
const float *input_tmp = input_seg;
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, sum, out0;
......@@ -99,7 +90,7 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + w_in);
int c_mid = w_out - 2;
auto output_ptr = output_seg + 1;
auto output_ptr = out_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w_in + 4);
......@@ -144,8 +135,8 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
in6 = in7;
}
// top right remain
float32x4_t pad0 = vdupq_n_f32(input_seg[w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_seg[2 * w_in - 1]);
float32x4_t pad0 = vdupq_n_f32(input_data[w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
......@@ -172,8 +163,8 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
}
// bottom_right remain
float32x4_t pad2 = vdupq_n_f32(input_seg[(h_in - 1) * w_in - 1]);
float32x4_t pad3 = vdupq_n_f32(input_seg[h_in * w_in - 1]);
float32x4_t pad2 = vdupq_n_f32(input_data[(h_in - 1) * w_in - 1]);
float32x4_t pad3 = vdupq_n_f32(input_data[h_in * w_in - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
......@@ -200,8 +191,8 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
}
// mid
for (int j = 0; j < h_out - 2; ++j) {
output_ptr = output_seg + w_out * (j + 1) + 1;
input_tmp = input_seg + j * w_in;
output_ptr = out_data + w_out * (j + 1) + 1;
input_tmp = input_data + j * w_in;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w_in);
......@@ -237,9 +228,9 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
in4 = in5;
}
// mid remain
float32x4_t pad0 = vdupq_n_f32(input_seg[(j + 1) * w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]);
float32x4_t pad2 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]);
float32x4_t pad0 = vdupq_n_f32(input_data[(j + 1) * w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]);
float32x4_t pad2 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
......@@ -270,17 +261,15 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
}
}
}
// input_data += inputdata_channel_stride;
// out_data += outputdata_channel_stride;
input_data += inputdata_channel_stride;
out_data += outputdata_channel_stride;
}
input_data += input_batch_stride;
out_data += output_batch_stride;
}
#endif
}
void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
#ifdef __ARM_NEON
#if __ARM_NEON
const int batch_size = input->dims()[0];
const int h_in = input->dims()[2];
......@@ -293,50 +282,44 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
const int w_out = output->dims()[3];
const int outputdata_channel_stride = h_out * w_out;
const int inputdata_channel_stride = h_in * w_in;
const int input_batch_stride = output_channels * inputdata_channel_stride;
const int output_batch_stride = output_channels * outputdata_channel_stride;
float *out_data = output->data<float>();
const float *input_data = input->data<float>();
for (int k = 0; k < batch_size; ++k) {
#pragma omp parallel for
for (int c = 0; c < output_channels; ++c) {
const float *input_seg = input_data + c * inputdata_channel_stride;
float *output_seg = out_data + c * outputdata_channel_stride;
// four corner point
output_seg[0] = std::max(std::max(input_seg[0], input_seg[1]),
std::max(input_seg[w_in], input_seg[w_in + 1]));
output_seg[w_out - 1] =
std::max(std::max(input_seg[w_in - 2], input_seg[w_in - 1]),
std::max(input_seg[w_in * 2 - 2], input_seg[2 * w_in - 1]));
output_seg[(h_out - 1) * w_out] =
std::max(std::max(input_seg[(h_in - 2) * w_in],
input_seg[(h_in - 2) * w_in + 1]),
std::max(input_seg[(h_in - 1) * w_in],
input_seg[(h_in - 1) * w_in + 1]));
output_seg[h_out * w_out - 1] = std::max(
std::max(input_seg[(h_in - 1) * w_in - 1],
input_seg[(h_in - 1) * w_in - 2]),
std::max(input_seg[h_in * w_in - 1], input_seg[h_in * w_in - 2]));
out_data[0] = std::max(std::max(input_data[0], input_data[1]),
std::max(input_data[w_in], input_data[w_in + 1]));
out_data[w_out - 1] = std::max(
std::max(input_data[w_in - 2], input_data[w_in - 1]),
std::max(input_data[w_in * 2 - 2], input_data[2 * w_in - 1]));
out_data[(h_out - 1) * w_out] =
std::max(std::max(input_data[(h_in - 2) * w_in],
input_data[(h_in - 2) * w_in + 1]),
std::max(input_data[(h_in - 1) * w_in],
input_data[(h_in - 1) * w_in + 1]));
out_data[h_out * w_out - 1] = std::max(
std::max(input_data[(h_in - 1) * w_in - 1],
input_data[(h_in - 1) * w_in - 2]),
std::max(input_data[h_in * w_in - 1], input_data[h_in * w_in - 2]));
// left side & right side
for (int i = 1; i < h_in - 1; ++i) {
float max1 = std::max(input_seg[i * w_in - w_in],
input_seg[i * w_in - w_in + 1]);
float max2 = std::max(input_seg[i * w_in], input_seg[i * w_in + 1]);
float max3 = std::max(input_seg[i * w_in + w_in],
input_seg[i * w_in + w_in + 1]);
output_seg[i * w_out] = std::max(std::max(max1, max2), max3);
max1 = std::max(input_seg[i * w_in - w_in + w_in - 2],
input_seg[i * w_in - w_in + 1 + w_in - 2]);
max2 = std::max(input_seg[i * w_in + w_in - 2],
input_seg[i * w_in + 1 + w_in - 2]);
max3 = std::max(input_seg[i * w_in + w_in + w_in - 2],
input_seg[i * w_in + w_in + 1 + w_in - 2]);
output_seg[i * w_out + w_out - 1] =
std::max(std::max(max1, max2), max3);
float max1 = std::max(input_data[i * w_in - w_in],
input_data[i * w_in - w_in + 1]);
float max2 = std::max(input_data[i * w_in], input_data[i * w_in + 1]);
float max3 = std::max(input_data[i * w_in + w_in],
input_data[i * w_in + w_in + 1]);
out_data[i * w_out] = std::max(std::max(max1, max2), max3);
max1 = std::max(input_data[i * w_in - w_in + w_in - 2],
input_data[i * w_in - w_in + 1 + w_in - 2]);
max2 = std::max(input_data[i * w_in + w_in - 2],
input_data[i * w_in + 1 + w_in - 2]);
max3 = std::max(input_data[i * w_in + w_in + w_in - 2],
input_data[i * w_in + w_in + 1 + w_in - 2]);
out_data[i * w_out + w_out - 1] = std::max(std::max(max1, max2), max3);
}
// top 1 row & bottom 1 row
const float *input_tmp = input_seg;
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, max;
......@@ -346,7 +329,7 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + w_in);
int c_mid = w_out - 2;
auto output_ptr = output_seg + 1;
auto output_ptr = out_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w_in + 4);
......@@ -390,8 +373,8 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
in6 = in7;
}
// top right remain
float32x4_t pad0 = vdupq_n_f32(input_seg[w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_seg[2 * w_in - 1]);
float32x4_t pad0 = vdupq_n_f32(input_data[w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
......@@ -417,8 +400,8 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
}
// bottom_right remain
float32x4_t pad2 = vdupq_n_f32(input_seg[(h_in - 1) * w_in - 1]);
float32x4_t pad3 = vdupq_n_f32(input_seg[h_in * w_in - 1]);
float32x4_t pad2 = vdupq_n_f32(input_data[(h_in - 1) * w_in - 1]);
float32x4_t pad3 = vdupq_n_f32(input_data[h_in * w_in - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
......@@ -444,8 +427,8 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
}
// mid
for (int j = 0; j < h_out - 2; ++j) {
output_ptr = output_seg + (j + 1) * w_out + 1;
input_tmp = input_seg + j * w_in;
output_ptr = out_data + (j + 1) * w_out + 1;
input_tmp = input_data + j * w_in;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w_in);
......@@ -480,9 +463,9 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
in4 = in5;
}
// mid remain
float32x4_t pad0 = vdupq_n_f32(input_seg[(j + 1) * w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]);
float32x4_t pad2 = vdupq_n_f32(input_seg[(j + 3) * w_in - 1]);
float32x4_t pad0 = vdupq_n_f32(input_data[(j + 1) * w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]);
float32x4_t pad2 = vdupq_n_f32(input_data[(j + 3) * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
......@@ -512,18 +495,16 @@ void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
}
}
}
// input_data += inputdata_channel_stride;
// out_data += outputdata_channel_stride;
input_data += inputdata_channel_stride;
out_data += outputdata_channel_stride;
}
input_data += input_batch_stride;
out_data += output_batch_stride;
}
#endif
}
void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
Tensor *output) {
#ifdef __ARM_NEON
#if __ARM_NEON
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
......@@ -534,11 +515,11 @@ void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
// const int _kernel_size = 3;
const int stride = strides[0];
// const int stride_width = strides[1];
const int padding = paddings[0];
// const int padding_width = paddings[1];
const int _kernel_size = 3;
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const float negative_max = -INT_MAX;
const int input_channel_stride = input_height * input_width;
const int output_channel_stride = output_height * output_width;
......@@ -548,41 +529,38 @@ void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
const int input_batch_stride = output_channels * input_channel_stride;
const int output_batch_stride = output_channels * output_channel_stride;
const float *pos1, *output_ptr;
const float *pos1, *pos2, *pos3, *output_ptr;
int hstart, wstart, hend, wend;
for (int i = 0; i < batch_size; ++i) {
#pragma omp parallel for
for (int c = 0; c < output_channels; ++c) {
const float *input_seg = input_data + c * input_channel_stride;
float *output_seg = output_data + c * output_channel_stride;
for (int ph = 0; ph < output_height; ph++) {
for (int pw = 0; pw < output_width; pw++) {
int hstart = ph * stride - padding;
int wstart = pw * stride - padding;
int hend = min(hstart + 3, input_height + padding);
int wend = min(wstart + 3, input_width + padding);
hstart = ph * stride_height - padding_height;
wstart = pw * stride_width - padding_width;
hend = min(hstart + _kernel_size, input_height + padding_height);
wend = min(wstart + _kernel_size, input_width + padding_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, input_height);
wend = min(wend, input_width);
const float *pos1 = input_seg + hstart * input_width + wstart;
const float *pos2 = input_seg + (hstart + 1) * input_width + wstart;
const float *pos3 = input_seg + (hstart + 2) * input_width + wstart;
output_ptr = output_seg + ph * output_width + pw;
pos1 = input_data + hstart * input_width + wstart;
pos2 = input_data + (hstart + 1) * input_width + wstart;
pos3 = input_data + (hstart + 2) * input_width + wstart;
output_ptr = output_data + ph * output_width + pw;
if (hend - hstart != 3 || wend - wstart != 3) {
float max_value = -INT_MAX;
for (int h = hstart; h < hend; h++) {
for (int w = wstart; w < wend; w++) {
float value = input_seg[h * input_width + w];
float value = input_data[h * input_width + w];
if (value > max_value) {
max_value = value;
}
}
}
output_seg[ph * output_width + pw] = max_value;
output_data[ph * output_width + pw] = max_value;
} else {
#ifdef ARMV7
#if defined(ARMV7)
asm volatile(
"vld1.32 {q1}, [%[pos1]] \n\t"
"vld1.32 {q2}, [%[pos2]] \n\t"
......@@ -594,25 +572,27 @@ void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
"vpmax.f32 d7, d6, d6 \n\t"
"vst1.32 {d7[0]},[%[output_ptr]] \n\t"
:
: [input_seg] "r"(input_seg), [pos1] "r"(pos1),
: [input_data] "r"(input_data), [pos1] "r"(pos1),
[pos2] "r"(pos2), [pos3] "r"(pos3),
[output_ptr] "r"(output_ptr), [negative_max] "r"(negative_max)
: "memory", "q1", "q2", "q3", "q4");
#else
const float32x4_t data1 = vld1q_f32(pos1);
const float32x4_t data2 = vld1q_f32(pos1 + input_width);
const float32x4_t data3 = vld1q_f32(pos1 + 2 * input_width);
const float32x4_t data2 = vld1q_f32(pos2);
const float32x4_t data3 = vld1q_f32(pos3);
const float32x4_t max_data =
vmaxq_f32(vmaxq_f32(data1, data2), data3);
vmaxq_f32(vmaxq_f32(data1, data3), data2);
float32x2_t res =
vpmax_f32(vget_high_f32(vsetq_lane_f32(-INT_MAX, max_data, 3)),
vget_low_f32(max_data));
res = vpmax_f32(res, res);
output_seg[ph * output_width + pw] = vget_lane_f32(res, 0);
output_data[ph * output_width + pw] = vget_lane_f32(res, 0);
#endif
}
}
}
input_data += input_channel_stride;
output_data += output_channel_stride;
}
input_data += input_batch_stride;
output_data += output_batch_stride;
......@@ -622,7 +602,7 @@ void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
Tensor *output) {
#ifdef __ARM_NEON
#if __ARM_NEON
const int batch_size = input->dims()[0];
const int input_height = input->dims()[2];
......@@ -633,8 +613,11 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
const int output_height = output->dims()[2];
const int output_width = output->dims()[3];
const int stride = strides[0];
const int padding = paddings[0];
const int _kernel_size = 3;
const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
const int input_channel_stride = input_height * input_width;
const int output_channel_stride = output_height * output_width;
......@@ -648,35 +631,32 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
const int input_batch_stride = output_channels * input_channel_stride;
const int output_batch_stride = output_channels * output_channel_stride;
for (int i = 0; i < batch_size; ++i) {
#pragma omp parallel for
for (int c = 0; c < output_channels; ++c) {
const float *input_seg = input_data + c * input_channel_stride;
float *output_seg = output_data + c * output_channel_stride;
for (int ph = 0; ph < output_height; ph++) {
for (int pw = 0; pw < output_width; pw++) {
int hstart = ph * stride - padding;
int wstart = pw * stride - padding;
int hend = min(hstart + 3, input_height + padding);
int wend = min(wstart + 3, input_width + padding);
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
int hend = min(hstart + _kernel_size, input_height + padding_height);
int wend = min(wstart + _kernel_size, input_width + padding_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
hend = min(hend, input_height);
wend = min(wend, input_width);
const float *pos1 = input_seg + hstart * input_width + wstart;
const float *pos2 = input_seg + (hstart + 1) * input_width + wstart;
const float *pos3 = input_seg + (hstart + 2) * input_width + wstart;
float *output_ptr = output_seg + ph * output_width + pw;
const float *pos1 = input_data + hstart * input_width + wstart;
const float *pos2 = input_data + (hstart + 1) * input_width + wstart;
const float *pos3 = input_data + (hstart + 2) * input_width + wstart;
const float *output_ptr = output_data + ph * output_width + pw;
if (hend - hstart != 3 || wend - wstart != 3) {
float sum = 0;
for (int h = hstart; h < hend; h++) {
for (int w = wstart; w < wend; w++) {
sum += input_seg[h * input_width + w];
sum += input_data[h * input_width + w];
}
}
output_seg[ph * output_width + pw] = sum / 9.0;
output_data[ph * output_width + pw] = sum / 9.0;
} else {
#ifdef ARMV7
#if defined(ARMV7)
asm volatile(
"vld1.32 {q1}, [%[pos1]] \n\t"
......@@ -691,7 +671,7 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
"vmul.f32 d6,d7 \n\t"
"vst1.32 {d6[0]},[%[output_ptr]] \n\t"
:
: [input_seg] "r"(input_seg), [pos1] "r"(pos1),
: [input_data] "r"(input_data), [pos1] "r"(pos1),
[pos2] "r"(pos2), [pos3] "r"(pos3),
[output_ptr] "r"(output_ptr), [zero] "r"(zero),
[nine_ptr] "r"(nine_ptr)
......@@ -706,11 +686,13 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
vpadd_f32(vget_high_f32(vsetq_lane_f32(0, sum_data, 3)),
vget_low_f32(sum_data));
res = vpadd_f32(res, res);
output_seg[ph * output_width + pw] = vget_lane_f32(res, 0) / 9.0;
output_data[ph * output_width + pw] = vget_lane_f32(res, 0) / 9.0;
#endif
}
}
}
input_data += input_channel_stride;
output_data += output_channel_stride;
}
input_data += input_batch_stride;
output_data += output_batch_stride;
......
......@@ -15,9 +15,6 @@ limitations under the License. */
#ifdef POOL_OP
#pragma once
#ifdef _OPENMP
#include <omp.h>
#endif
#include <algorithm>
#include <vector>
#include "framework/tensor.h"
......
......@@ -14,11 +14,10 @@ limitations under the License. */
#ifdef POOL_OP
#include "pooling.h"
#include "operators/math/pooling.h"
#include <algorithm>
#include <vector>
#include "common/types.h"
#ifdef _OPENMP
#include <omp.h>
#endif
namespace paddle_mobile {
namespace operators {
......@@ -60,8 +59,8 @@ class PoolFunctor<CPU, PoolProcess, T> {
T *output_data = output->mutable_data<T>();
for (int i = 0; i < batch_size; i++) {
// #pragma omp parallel for
for (int c = 0; c < output_channels; ++c) {
#pragma omp parallel for
for (int ph = 0; ph < output_height; ++ph) {
int hstart = ph * stride_height - padding_height;
int hend = std::min(hstart + ksize_height, input_height);
......
......@@ -26,17 +26,16 @@ int main() {
auto time2 = time();
DLOG << "load cost :" << time_diff(time1, time2) << "ms\n";
paddle_mobile::Executor<paddle_mobile::CPU> executor(program, 1, optimize);
executor.SetThreadNum(4);
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims);
auto time3 = time();
int count = 1;
for (int i = 0; i < count; ++i) {
for (int i = 0; i < 10; ++i) {
executor.Predict(input, dims);
}
auto time4 = time();
DLOG << "predict cost :" << time_diff(time3, time4) / count << "ms\n";
DLOG << "predict cost :" << time_diff(time3, time4) << "ms\n";
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册