未验证 提交 9567fcad 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #503 from cocodark/develop

optimize pool 3x3s1p1
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#### 以下是 paddle-mobile 代码的执行流程图: #### 以下是 paddle-mobile 代码的执行流程图:
![执行流程图](http://otkwwi4x8.bkt.clouddn.com/2018-07-02-15305189473720.png) ![执行流程图](http://otkwwi4x8.bkt.clouddn.com/2018-07-02-15305189473720.png)
...@@ -15,7 +14,6 @@ ...@@ -15,7 +14,6 @@
先来看一下模型, 模型分为两种结构: 先来看一下模型, 模型分为两种结构:
一种为参数文件是散开的, 如下图, 红框为模型结构的 protobuf 文件, 其余为参数文件 一种为参数文件是散开的, 如下图, 红框为模型结构的 protobuf 文件, 其余为参数文件
![模型描述](http://otkwwi4x8.bkt.clouddn.com/2018-07-02-15305190629577.png) ![模型描述](http://otkwwi4x8.bkt.clouddn.com/2018-07-02-15305190629577.png)
...@@ -23,6 +21,7 @@ ...@@ -23,6 +21,7 @@
![模型描述combined](http://otkwwi4x8.bkt.clouddn.com/2018-07-02-15305191057130.png) ![模型描述combined](http://otkwwi4x8.bkt.clouddn.com/2018-07-02-15305191057130.png)
loader 模块的作用是将模型结构信息 load 进内存, 将红框内的 protobuf 文件 load 进内存, 并对模型结构进行优化(如将几个细粒度的 op 融合成 粗粒度的 op, 如将 conv、 add、 batchnorm、 relu 融合为 conv\_add\_batchnorm\_relu). loader 模块的作用是将模型结构信息 load 进内存, 将红框内的 protobuf 文件 load 进内存, 并对模型结构进行优化(如将几个细粒度的 op 融合成 粗粒度的 op, 如将 conv、 add、 batchnorm、 relu 融合为 conv\_add\_batchnorm\_relu).
方便进行算法优化. 方便进行算法优化.
......
### iOS&Android开发文档
# iOS开发文档 # iOS开发文档
## 编译 ## 编译
...@@ -72,14 +74,154 @@ PaddleMobile.h ...@@ -72,14 +74,154 @@ PaddleMobile.h
``` ```
#Android开发文档
用户可通过如下两种方式,交叉编译Android平台上适用的paddle-mobile库:
- 基于Docker容器编译
- 基于Linux交叉编译
## 基于Docker容器编译
### 1. 安装 docker
安装 docker 的方式,参考官方文档 [https://docs.docker.com/install/](https://docs.docker.com/install/)
### 2. 使用 docker 搭建构建环境
首先进入 paddle-mobile 的目录下,执行 `docker build`
以 Linux/Mac 为例 (windows 建议在 'Docker Quickstart Terminal' 中执行)
```
$ docker build -t paddle-mobile:dev - < Dockerfile
```
使用 `docker images` 可以看到我们新建的 image
```
$ docker images
REPOSITORY TAG IMAGE ID CREATED SIZE
paddle-mobile dev 33b146787711 45 hours ago 372MB
```
### 3. 使用 docker 构建
进入 paddle-mobile 目录,执行 docker run
```
$ docker run -it --mount type=bind,source=$PWD,target=/paddle-mobile paddle-mobile:dev
root@5affd29d4fc5:/ # cd /paddle-mobile
# 生成构建 android 产出的 Makefile
root@5affd29d4fc5:/ # rm CMakeCache.txt
root@5affd29d4fc5:/ # cmake -DCMAKE_TOOLCHAIN_FILE=tools/toolchains/arm-android-neon.cmake
# 生成构建 linux 产出的 Makefile
root@5affd29d4fc5:/ # rm CMakeCache.txt
root@5affd29d4fc5:/ # cmake -DCMAKE_TOOLCHAIN_FILE=tools/toolchains/arm-linux-gnueabi.cmake
```
### 4. 设置编译选项
可以通过 ccmake 设置编译选项
```
root@5affd29d4fc5:/ # ccmake .
Page 1 of 1
CMAKE_ASM_FLAGS
CMAKE_ASM_FLAGS_DEBUG
CMAKE_ASM_FLAGS_RELEASE
CMAKE_BUILD_TYPE
CMAKE_INSTALL_PREFIX /usr/local
CMAKE_TOOLCHAIN_FILE /paddle-mobile/tools/toolchains/arm-android-neon.cmake
CPU ON
DEBUGING ON
FPGA OFF
LOG_PROFILE ON
MALI_GPU OFF
NET googlenet
USE_EXCEPTION ON
USE_OPENMP OFF
```
修改选项后,按 `c`, `g` 更新 Makefile
### 5. 构建
使用 make 命令进行构建
```
root@5affd29d4fc5:/ # make
```
### 6. 查看构建产出
构架产出可以在 host 机器上查看,在 paddle-mobile 的目录下,build 以及 test/build 下,可以使用 adb 指令或者 scp 传输到 device 上执行
## 基于Linux交叉编译
### 交叉编译环境准备
##### 下载Android NDK
从源码交叉编译paddle-mobile,用户需要提前准备好交叉编译环境。Android平台使用的C/C++交叉编译工具链是[Android NDK](https://developer.android.com/ndk/),用户可以自行前往下载,也可以通过以下命令获取:
```
wget https://dl.google.com/android/repository/android-ndk-r17b-darwin-x86_64.zip
unzip android-ndk-r17b-darwin-x86_64.zip
```
##### 设置环境变量
工程中自带的独立工具链会根据环境变量NDK_ROOT查找NDK,因此需要配置环境变量:
```
export NDK_ROOT = "path to ndk"
```
### 执行编译
在paddle-mobile根目录中,执行以下命令:
```
cd tools
sh build.sh android
```
执行完毕后,生成的so位于build目录中,单测可执行文件位于test/build目录中。
##### Tips:
如果想要获得体积更小的库,可选择编译支持指定模型结构的库。
如执行如下命令:
```
sh build.sh android googlenet
```
会得到一个支持googlnet的体积更小的库。
##测试
在编译完成后,我们提供了自动化的测试脚本,帮助用户将运行单测文件所需要的模型及库文件push到Android设备中,执行以下命令:
```
cd tools/android-debug-script
sh run_on_android.sh (npm) 可选参数npm,用于选择是否传输模型文件到手机上
```
出现如下提示:
```
**** choose OP or NET to test ****
which to test :
```
输入名称即可运行对应的测试文件。
##部署
Android应用可通过JNI接口调用底层C/C++,paddle-mobile对外提供的JNI接口如下:
##### 1 load接口 加载模型参数
```
/*
*@param modelPath 模型文件路径
*@return jboolean
*/
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env,
jclass thiz,
jstring modelPath);
```
##### 2 predict接口 执行预测
```
/**
*@param buf 输入数据
*@return 输出数据
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predict(
JNIEnv *env, jclass thiz, jfloatArray buf);
```
##### 3 clear接口 销毁实例、清理内存操作
```
JNIEXPORT void JNICALL Java_com_baidu_paddle_PMLL_clear(JNIEnv *env,
jclass thiz);
```
...@@ -23,7 +23,17 @@ namespace framework { ...@@ -23,7 +23,17 @@ namespace framework {
class Scope { class Scope {
public: public:
Scope() = default; Scope() = default;
~Scope() = default;
~Scope() {
for (auto &var : vars_) {
delete var.second;
}
vars_.clear();
for (auto kid : kids_) {
delete kid;
}
kids_.clear();
}
Scope &NewScope() const; Scope &NewScope() const;
......
...@@ -54,13 +54,14 @@ string jstring2cppstring(JNIEnv *env, jstring jstr) { ...@@ -54,13 +54,14 @@ string jstring2cppstring(JNIEnv *env, jstring jstr) {
JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env, JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env,
jclass thiz, jclass thiz,
jstring modelPath) { jstring modelPath) {
ANDROIDLOGI("load invoked");
bool optimize = true; bool optimize = true;
return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath),
optimize); optimize);
} }
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( JNIEXPORT jfloatArray JNICALL
JNIEnv *env, jclass thiz, jfloatArray buf) { Java_com_baidu_paddle_PML_predict(JNIEnv *env, jclass thiz, jfloatArray buf) {
jfloatArray result = NULL; jfloatArray result = NULL;
int count = 0; int count = 0;
float *dataPointer = nullptr; float *dataPointer = nullptr;
...@@ -78,6 +79,7 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( ...@@ -78,6 +79,7 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
count = output->numel(); count = output->numel();
result = env->NewFloatArray(count); result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>()); env->SetFloatArrayRegion(result, 0, count, output->data<float>());
ANDROIDLOGI("predict finished");
return result; return result;
} }
......
...@@ -31,8 +31,8 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env, ...@@ -31,8 +31,8 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env,
/** /**
* object detection for anroid * object detection for anroid
*/ */
JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( JNIEXPORT jfloatArray JNICALL
JNIEnv *env, jclass thiz, jfloatArray buf); Java_com_baidu_paddle_PML_predict(JNIEnv *env, jclass thiz, jfloatArray buf);
/** /**
* clear data of the net when destroy for android * clear data of the net when destroy for android
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#ifdef FUSION_CONVADD_OP #ifdef FUSION_CONVADD_OP
#include "operators/kernel/conv_add_kernel.h" #include "operators/kernel/conv_add_kernel.h"
#include "../central-arm-func/conv_add_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -23,111 +24,9 @@ bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam *param) { ...@@ -23,111 +24,9 @@ bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam *param) {
return true; return true;
} }
void ConvAddBasic(const FusionConvAddParam &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
int axis = param.Axis();
Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims());
output->ShareDataWith(bias);
int groups = param.Groups();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1));
}
}
}
template <> template <>
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam &param) const { void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam &param) const {
if (param.Groups() == param.Input()->dims()[1] && ConvAddCompute<float>(param);
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), param.Bias(), param.Output(), true);
} else {
ConvAddBasic(param);
}
} }
template class ConvAddKernel<CPU, float>; template class ConvAddKernel<CPU, float>;
......
...@@ -14,27 +14,11 @@ limitations under the License. */ ...@@ -14,27 +14,11 @@ limitations under the License. */
#ifdef POOL_OP #ifdef POOL_OP
#include <operators/kernel/pool_kernel.h> #include "operators/kernel/pool_kernel.h"
#include "common/log.h" #include "../central-arm-func/pool_arm_func.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
inline void PoolBasic(std::string pooling_type, std::vector<int> ksize,
std::vector<int> strides, std::vector<int> paddings,
const Tensor *in_x, Tensor *out) {
if (pooling_type == "max") {
math::PoolFunctor<CPU, math::MaxPool<float>, float> pool2d_forward;
math::MaxPool<float> pool_process;
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
} else if (pooling_type == "avg") {
math::PoolFunctor<CPU, math::AvgPool<float>, float> pool2d_forward;
math::AvgPool<float> pool_process;
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
}
}
template <> template <>
bool PoolKernel<CPU, float>::Init(PoolParam *param) { bool PoolKernel<CPU, float>::Init(PoolParam *param) {
return true; return true;
...@@ -42,42 +26,7 @@ bool PoolKernel<CPU, float>::Init(PoolParam *param) { ...@@ -42,42 +26,7 @@ bool PoolKernel<CPU, float>::Init(PoolParam *param) {
template <> template <>
void PoolKernel<CPU, float>::Compute(const PoolParam &param) const { void PoolKernel<CPU, float>::Compute(const PoolParam &param) const {
const Tensor *in_x = param.Input(); PoolCompute<float>(param);
Tensor *out = param.Output();
std::string pooling_type = param.PoolingType();
std::vector<int> ksize = param.Ksize();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
if (ksize.size() != 2) {
LOG(paddle_mobile::LogLevel::kLOG_ERROR)
<< "Pool op only supports 2D and 3D input.";
}
if (param.isGlobalPooling()) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
} else if (ksize[0] == 3 && ksize[0] == ksize[1]) {
if (pooling_type == "max") {
math::Pool3x3Max(strides, paddings, in_x, out);
} else if (pooling_type == "avg") {
math::Pool3x3Avg(strides, paddings, in_x, out);
}
} else if (ksize[0] == 2 && ksize[0] == ksize[1]) {
if (pooling_type == "max") {
math::Pool2x2Max(strides, paddings, in_x, out);
} else if (pooling_type == "avg") {
math::Pool2x2Avg(strides, paddings, in_x, out);
}
} else {
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
}
} }
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef SIGMOID_OP #ifdef SIGMOID_OP
#include "../sigmoid_kernel.h" #include "../sigmoid_kernel.h"
#include "../central-arm-func/sigmoid_arm_func.h"
#if __ARM_NEON #if __ARM_NEON
#include "../../math/math_func_neon.h" #include "../../math/math_func_neon.h"
#endif #endif
...@@ -25,52 +26,6 @@ namespace operators { ...@@ -25,52 +26,6 @@ namespace operators {
using framework::DDim; using framework::DDim;
using framework::Tensor; using framework::Tensor;
void sigmoid(const Tensor *X, Tensor *Y) {
#if __ARM_NEON
const float *input = X->data<float>();
float *output = Y->mutable_data<float>();
const DDim &dDim = X->dims();
int axis_index = 1;
if (dDim.size() < 4) {
axis_index = 0;
}
DDim outer_ddim =
paddle_mobile::framework::slice_ddim(dDim, 0, axis_index + 1);
DDim inner_ddim =
paddle_mobile::framework::slice_ddim(dDim, axis_index + 1, dDim.size());
int out_size = paddle_mobile::framework::product(outer_ddim);
int inner_size = paddle_mobile::framework::product(inner_ddim);
DLOG << "outsize=" << out_size;
DLOG << "innersize=" << inner_size;
#pragma omp parallel for
for (int i = 0; i < out_size; ++i) {
const float *input_outer_ptr = input + i * inner_size;
float *output_outer_ptr = output + i * inner_size;
int nn = inner_size >> 2;
int remain = inner_size - (nn << 2);
float32x4_t _one = vdupq_n_f32(1.f);
for (; nn > 0; nn--) {
float32x4_t data = vld1q_f32(input_outer_ptr);
data = vnegq_f32(data);
data = exp_ps(data);
data = vaddq_f32(data, _one);
float32x4_t out_data = vrecpeq_f32(data);
out_data = vmulq_f32(vrecpsq_f32(data, out_data), out_data);
vst1q_f32(output_outer_ptr, out_data);
input_outer_ptr += 4;
output_outer_ptr += 4;
}
for (; remain > 0; remain--) {
*output_outer_ptr = 1.f / (1.f + exp(-*input_outer_ptr));
output_outer_ptr++;
input_outer_ptr++;
}
}
#endif
}
template <> template <>
bool SigmoidKernel<CPU, float>::Init(SigmoidParam *param) { bool SigmoidKernel<CPU, float>::Init(SigmoidParam *param) {
return true; return true;
...@@ -78,11 +33,7 @@ bool SigmoidKernel<CPU, float>::Init(SigmoidParam *param) { ...@@ -78,11 +33,7 @@ bool SigmoidKernel<CPU, float>::Init(SigmoidParam *param) {
template <> template <>
void SigmoidKernel<CPU, float>::Compute(const SigmoidParam &param) const { void SigmoidKernel<CPU, float>::Compute(const SigmoidParam &param) const {
const Tensor *in_x = param.InputX(); SigmoidCompute<float>(param);
Tensor *out = param.Out();
auto x_dims = in_x->dims();
out->Resize(x_dims);
sigmoid(in_x, out);
} }
template class SigmoidKernel<CPU, float>; template class SigmoidKernel<CPU, float>;
......
...@@ -15,7 +15,8 @@ limitations under the License. */ ...@@ -15,7 +15,8 @@ limitations under the License. */
#ifdef SOFTMAX_OP #ifdef SOFTMAX_OP
#include "../softmax_kernel.h" #include "../softmax_kernel.h"
#include "../../math/softmax.h" #include "../central-arm-func/softmax_arm_func.h"
#include "operators/math/softmax.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -26,11 +27,7 @@ bool SoftmaxKernel<CPU, float>::Init(SoftmaxParam *param) { ...@@ -26,11 +27,7 @@ bool SoftmaxKernel<CPU, float>::Init(SoftmaxParam *param) {
template <> template <>
void SoftmaxKernel<CPU, float>::Compute(const SoftmaxParam &param) const { void SoftmaxKernel<CPU, float>::Compute(const SoftmaxParam &param) const {
const Tensor *in_x = param.InputX(); SoftmaxCompute<float>(param);
Tensor *out = param.Out();
auto x_dims = in_x->dims();
out->Resize(x_dims);
math::SoftmaxFuntor<CPU, float>()(in_x, out);
} }
template class SoftmaxKernel<CPU, float>; template class SoftmaxKernel<CPU, float>;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADD_OP
#pragma once
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void ConvAddBasic(const FusionConvAddParam &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
int axis = param.Axis();
Tensor *output = param.Output();
math::expand_bias(bias, axis, output->dims());
output->ShareDataWith(bias);
int groups = param.Groups();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
std::vector<int> dilations = param.Dilations();
const int batch_size = static_cast<int>(input->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = input->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
bool is_expand =
math::IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col.mutable_data<float>(col_shape);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
framework::DDim input_shape = framework::slice_ddim(
input->dims(), 1, static_cast<int>(input->dims().size()));
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
output->dims()[1],
output->numel() / (output->dims()[0] * output->dims()[1])};
// convolution operator: im2col(or vol2col) + gemm
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Vol2ColFunctor<CPU, float> vol2col;
math::Im2ColFunctor<math::ColFormat::kCFO, CPU, float> im2col;
for (int i = 0; i < batch_size; i++) {
Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape);
Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; g++) {
Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(in_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(in_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(in_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1));
}
}
}
template <typename P>
void ConvAddCompute(const FusionConvAddParam &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
param.Bias(), true);
} else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), param.Bias(), param.Output(), true);
} else {
ConvAddBasic(param);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -15,19 +15,21 @@ limitations under the License. */ ...@@ -15,19 +15,21 @@ limitations under the License. */
#ifdef CONV_OP #ifdef CONV_OP
#pragma once #pragma once
#include <operators/math/depthwise_conv_3x3.h>
#include <vector> #include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/depthwise_conv_3x3.h"
#include "operators/math/im2col.h"
#include "operators/math/math_function.h"
#include "operators/math/vol2col.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
inline void ConvBasic(const ConvParam &param) { inline void ConvBasic(const ConvParam &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor *output = param.Output(); Tensor *output = param.Output();
output->mutable_data<float>();
int groups = param.Groups(); int groups = param.Groups();
std::vector<int> strides = param.Strides(); std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings(); std::vector<int> paddings = param.Paddings();
...@@ -111,20 +113,18 @@ inline void ConvBasic(const ConvParam &param) { ...@@ -111,20 +113,18 @@ inline void ConvBasic(const ConvParam &param) {
template <typename P> template <typename P>
void ConvCompute(const ConvParam &param) { void ConvCompute(const ConvParam &param) {
Tensor Bias;
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) {
math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(),
&Bias, false); nullptr, false);
} else if (param.Groups() == param.Input()->dims()[1] && } else if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { param.Filter()->dims()[2] == 3) {
math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(),
param.Filter(), &Bias, param.Output(), false); param.Filter(), nullptr, param.Output(), false);
} else { } else {
ConvBasic(param); ConvBasic(param);
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef POOL_OP
#pragma once
#include <string>
#include <vector>
#include "operators/math/pooling.h"
namespace paddle_mobile {
namespace operators {
using framework::Tensor;
inline void PoolBasic(std::string pooling_type, std::vector<int> ksize,
std::vector<int> strides, std::vector<int> paddings,
const Tensor *in_x, Tensor *out) {
if (pooling_type == "max") {
math::PoolFunctor<CPU, math::MaxPool<float>, float> pool2d_forward;
math::MaxPool<float> pool_process;
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
} else if (pooling_type == "avg") {
math::PoolFunctor<CPU, math::AvgPool<float>, float> pool2d_forward;
math::AvgPool<float> pool_process;
pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out);
}
}
template <typename P>
void PoolCompute(const PoolParam &param) {
const Tensor *in_x = param.Input();
Tensor *out = param.Output();
std::string pooling_type = param.PoolingType();
std::vector<int> ksize = param.Ksize();
std::vector<int> strides = param.Strides();
std::vector<int> paddings = param.Paddings();
if (ksize.size() != 2) {
LOG(paddle_mobile::LogLevel::kLOG_ERROR)
<< "Pool op only supports 2D and 3D input.";
}
if (param.isGlobalPooling()) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
} else if (ksize[0] == 3 && ksize[0] == ksize[1]) {
if (pooling_type == "max") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Maxs1p1(in_x, out);
} else {
math::Pool3x3Max(strides, paddings, in_x, out);
}
} else if (pooling_type == "avg") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
math::Pool3x3Avgs1p1(in_x, out);
} else {
math::Pool3x3Avg(strides, paddings, in_x, out);
}
}
} else if (ksize[0] == 2 && ksize[0] == ksize[1]) {
if (pooling_type == "max") {
math::Pool2x2Max(strides, paddings, in_x, out);
} else if (pooling_type == "avg") {
math::Pool2x2Avg(strides, paddings, in_x, out);
}
} else {
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SIGMOID_OP
#pragma once
#include "operators/op_param.h"
#if __ARM_NEON
#include <arm_neon.h>
#include "operators/math/math_func_neon.h"
#endif
namespace paddle_mobile {
namespace operators {
using framework::DDim;
void sigmoid(const Tensor *X, Tensor *Y) {
#if __ARM_NEON
const float *input = X->data<float>();
float *output = Y->mutable_data<float>();
const DDim &dDim = X->dims();
int axis_index = 1;
if (dDim.size() < 4) {
axis_index = 0;
}
DDim outer_ddim =
paddle_mobile::framework::slice_ddim(dDim, 0, axis_index + 1);
DDim inner_ddim =
paddle_mobile::framework::slice_ddim(dDim, axis_index + 1, dDim.size());
int out_size = paddle_mobile::framework::product(outer_ddim);
int inner_size = paddle_mobile::framework::product(inner_ddim);
DLOG << "outsize=" << out_size;
DLOG << "innersize=" << inner_size;
#pragma omp parallel for
for (int i = 0; i < out_size; ++i) {
const float *input_outer_ptr = input + i * inner_size;
float *output_outer_ptr = output + i * inner_size;
int nn = inner_size >> 2;
int remain = inner_size - (nn << 2);
float32x4_t _one = vdupq_n_f32(1.f);
for (; nn > 0; nn--) {
float32x4_t data = vld1q_f32(input_outer_ptr);
data = vnegq_f32(data);
data = exp_ps(data);
data = vaddq_f32(data, _one);
float32x4_t out_data = vrecpeq_f32(data);
out_data = vmulq_f32(vrecpsq_f32(data, out_data), out_data);
vst1q_f32(output_outer_ptr, out_data);
input_outer_ptr += 4;
output_outer_ptr += 4;
}
for (; remain > 0; remain--) {
*output_outer_ptr = 1.f / (1.f + exp(-*input_outer_ptr));
output_outer_ptr++;
input_outer_ptr++;
}
}
#endif
}
template <typename P>
void SigmoidCompute(const SigmoidParam &param) {
const Tensor *in_x = param.InputX();
Tensor *out = param.Out();
auto x_dims = in_x->dims();
out->Resize(x_dims);
sigmoid(in_x, out);
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SOFTMAX_OP
#pragma once
#include "../../math/softmax.h"
namespace paddle_mobile {
namespace operators {
template <typename P>
void SoftmaxCompute(const SoftmaxParam &param) {
const Tensor *in_x = param.InputX();
Tensor *out = param.Out();
auto x_dims = in_x->dims();
out->Resize(x_dims);
math::SoftmaxFuntor<CPU, float>()(in_x, out);
}
} // namespace operators
} // namespace paddle_mobile
#endif
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#pragma once #pragma once
#include "framework/operator.h" #include "framework/operator.h"
#include "operators/math/pooling.h"
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
......
...@@ -23,8 +23,6 @@ namespace paddle_mobile { ...@@ -23,8 +23,6 @@ namespace paddle_mobile {
namespace operators { namespace operators {
using framework::OpKernelBase; using framework::OpKernelBase;
void simoid(Tensor *X, Tensor *Y);
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class SoftmaxKernel : public OpKernelBase<DeviceType, SoftmaxParam> { class SoftmaxKernel : public OpKernelBase<DeviceType, SoftmaxParam> {
public: public:
......
...@@ -245,7 +245,10 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, ...@@ -245,7 +245,10 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter,
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
float *output_data = output->data<float>(); float *output_data = output->data<float>();
const float *bias_data = bias->data<float>(); const float *bias_data;
if (if_bias) {
bias_data = bias->data<float>();
}
const int h = static_cast<int>(input->dims()[2]); const int h = static_cast<int>(input->dims()[2]);
const int w = static_cast<int>(input->dims()[3]); const int w = static_cast<int>(input->dims()[3]);
......
...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef POOL_OP #ifdef POOL_OP
#define __ARM_NEON true
#include "pool_3x3.h" #include "pool_3x3.h"
#include "framework/tensor.h" #include "framework/tensor.h"
#if __ARM_NEON #if __ARM_NEON
...@@ -27,6 +26,481 @@ using framework::Tensor; ...@@ -27,6 +26,481 @@ using framework::Tensor;
using std::max; using std::max;
using std::min; using std::min;
using std::vector; using std::vector;
void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
#if __ARM_NEON
const int batch_size = input->dims()[0];
const int h_in = input->dims()[2];
const int w_in = input->dims()[3];
const int output_channels = output->dims()[1];
const int h_out = output->dims()[2];
const int w_out = output->dims()[3];
const int outputdata_channel_stride = h_out * w_out;
const int inputdata_channel_stride = h_in * w_in;
float *out_data = output->data<float>();
const float *input_data = input->data<float>();
const float coef = 1.0 / 9.0;
for (int k = 0; k < batch_size; ++k) {
for (int c = 0; c < output_channels; ++c) {
// four corner point
out_data[0] = (input_data[0] + input_data[1] + input_data[w_in] +
input_data[w_in + 1]) *
coef;
out_data[w_out - 1] =
(input_data[w_in - 2] + input_data[w_in - 1] +
input_data[w_in * 2 - 2] + input_data[2 * w_in - 1]) *
coef;
out_data[(h_out - 1) * w_out] =
(input_data[(h_in - 2) * w_in] + input_data[(h_in - 2) * w_in + 1] +
input_data[(h_in - 1) * w_in] + input_data[(h_in - 1) * w_in + 1]) *
coef;
out_data[h_out * w_out - 1] =
(input_data[h_in * w_in - 1] + input_data[h_in * w_in - 2] +
input_data[(h_in - 1) * w_in - 1] +
input_data[(h_in - 1) * w_in - 2]) *
coef;
// left side & right side
for (int i = 1; i < h_in - 1; ++i) {
out_data[i * w_out] =
(input_data[i * w_in - w_in] + input_data[i * w_in - w_in + 1] +
input_data[i * w_in] + input_data[i * w_in + 1] +
input_data[i * w_in + w_in] + input_data[i * w_in + w_in + 1]) *
coef;
out_data[i * w_out + w_out - 1] =
(input_data[i * w_in - w_in + w_in - 2] +
input_data[i * w_in - w_in + 1 + w_in - 2] +
input_data[i * w_in + w_in - 2] +
input_data[i * w_in + 1 + w_in - 2] +
input_data[i * w_in + w_in + w_in - 2] +
input_data[i * w_in + w_in + 1 + w_in - 2]) *
coef;
}
// top 1 row & bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, sum, out0;
float32x4_t v_coef = vdupq_n_f32(coef);
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w_in);
const float *input_tmp_end = input_tmp + (h_in - 2) * w_in;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + w_in);
int c_mid = w_out - 2;
auto output_ptr = out_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w_in + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
sum = vaddq_f32(in0, tmp0);
sum = vaddq_f32(sum, tmp1);
sum = vaddq_f32(sum, in2);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3);
vst1q_f32(output_ptr, vmulq_f32(sum, v_coef));
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + w_in + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
sum = vaddq_f32(in0, tmp0);
sum = vaddq_f32(sum, tmp1);
sum = vaddq_f32(sum, in2);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3);
vst1q_f32(output_ptr + (h_out - 1) * w_out, vmulq_f32(sum, v_coef));
// can optimize to each 8 stride.
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right remain
float32x4_t pad0 = vdupq_n_f32(input_data[w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 2);
tmp3 = vextq_f32(in2, pad1, 2);
sum = vaddq_f32(in0, tmp0);
sum = vaddq_f32(sum, tmp1);
sum = vaddq_f32(sum, in2);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3);
out0 = vmulq_f32(sum, v_coef);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
// bottom_right remain
float32x4_t pad2 = vdupq_n_f32(input_data[(h_in - 1) * w_in - 1]);
float32x4_t pad3 = vdupq_n_f32(input_data[h_in * w_in - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 2);
tmp3 = vextq_f32(in6, pad3, 2);
sum = vaddq_f32(in4, tmp0);
sum = vaddq_f32(sum, tmp1);
sum = vaddq_f32(sum, in6);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3);
out0 = vmulq_f32(sum, v_coef);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 2);
}
}
// mid
for (int j = 0; j < h_out - 2; ++j) {
output_ptr = out_data + w_out * (j + 1) + 1;
input_tmp = input_data + j * w_in;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w_in);
in4 = vld1q_f32(input_tmp + 2 * w_in);
c_mid = w_out - 2;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w_in + 4);
in5 = vld1q_f32(input_tmp + 2 * w_in + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
tmp4 = vextq_f32(in4, in5, 1);
tmp5 = vextq_f32(in4, in5, 2);
sum = vaddq_f32(in0, tmp0);
sum = vaddq_f32(sum, tmp1);
sum = vaddq_f32(sum, in2);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3);
sum = vaddq_f32(sum, in4);
sum = vaddq_f32(sum, tmp4);
sum = vaddq_f32(sum, tmp5);
out0 = vmulq_f32(sum, v_coef);
vst1q_f32(output_ptr, out0);
output_ptr += 4;
input_tmp += 4;
in0 = in1;
in2 = in3;
in4 = in5;
}
// mid remain
float32x4_t pad0 = vdupq_n_f32(input_data[(j + 1) * w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]);
float32x4_t pad2 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
tmp4 = vextq_f32(in4, pad2, 1);
tmp5 = vextq_f32(in4, pad2, 2);
sum = vaddq_f32(in0, tmp0);
sum = vaddq_f32(sum, tmp1);
sum = vaddq_f32(sum, in2);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3);
sum = vaddq_f32(sum, in4);
sum = vaddq_f32(sum, tmp4);
sum = vaddq_f32(sum, tmp5);
out0 = vmulq_f32(sum, v_coef);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
}
}
input_data += inputdata_channel_stride;
out_data += outputdata_channel_stride;
}
}
#endif
}
void Pool3x3Maxs1p1(const Tensor *input, Tensor *output) {
#if __ARM_NEON
const int batch_size = input->dims()[0];
const int h_in = input->dims()[2];
const int w_in = input->dims()[3];
const int output_channels = output->dims()[1];
const int h_out = output->dims()[2];
const int w_out = output->dims()[3];
const int outputdata_channel_stride = h_out * w_out;
const int inputdata_channel_stride = h_in * w_in;
float *out_data = output->data<float>();
const float *input_data = input->data<float>();
for (int k = 0; k < batch_size; ++k) {
for (int c = 0; c < output_channels; ++c) {
// four corner point
out_data[0] = std::max(std::max(input_data[0], input_data[1]),
std::max(input_data[w_in], input_data[w_in + 1]));
out_data[w_out - 1] = std::max(
std::max(input_data[w_in - 2], input_data[w_in - 1]),
std::max(input_data[w_in * 2 - 2], input_data[2 * w_in - 1]));
out_data[(h_out - 1) * w_out] =
std::max(std::max(input_data[(h_in - 2) * w_in],
input_data[(h_in - 2) * w_in + 1]),
std::max(input_data[(h_in - 1) * w_in],
input_data[(h_in - 1) * w_in + 1]));
out_data[h_out * w_out - 1] = std::max(
std::max(input_data[(h_in - 1) * w_in - 1],
input_data[(h_in - 1) * w_in - 2]),
std::max(input_data[h_in * w_in - 1], input_data[h_in * w_in - 2]));
// left side & right side
for (int i = 1; i < h_in - 1; ++i) {
float max1 = std::max(input_data[i * w_in - w_in],
input_data[i * w_in - w_in + 1]);
float max2 = std::max(input_data[i * w_in], input_data[i * w_in + 1]);
float max3 = std::max(input_data[i * w_in + w_in],
input_data[i * w_in + w_in + 1]);
out_data[i * w_out] = std::max(std::max(max1, max2), max3);
max1 = std::max(input_data[i * w_in - w_in + w_in - 2],
input_data[i * w_in - w_in + 1 + w_in - 2]);
max2 = std::max(input_data[i * w_in + w_in - 2],
input_data[i * w_in + 1 + w_in - 2]);
max3 = std::max(input_data[i * w_in + w_in + w_in - 2],
input_data[i * w_in + w_in + 1 + w_in - 2]);
out_data[i * w_out + w_out - 1] = std::max(std::max(max1, max2), max3);
}
// top 1 row & bottom 1 row
const float *input_tmp = input_data;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, max;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w_in);
const float *input_tmp_end = input_tmp + (h_in - 2) * w_in;
in4 = vld1q_f32(input_tmp_end);
in6 = vld1q_f32(input_tmp_end + w_in);
int c_mid = w_out - 2;
auto output_ptr = out_data + 1;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w_in + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
max = vmaxq_f32(in0, tmp0);
max = vmaxq_f32(max, tmp1);
max = vmaxq_f32(max, in2);
max = vmaxq_f32(max, tmp2);
max = vmaxq_f32(max, tmp3);
vst1q_f32(output_ptr, max);
in5 = vld1q_f32(input_tmp_end + 4);
in7 = vld1q_f32(input_tmp_end + w_in + 4);
tmp0 = vextq_f32(in4, in5, 1);
tmp1 = vextq_f32(in4, in5, 2);
tmp2 = vextq_f32(in6, in7, 1);
tmp3 = vextq_f32(in6, in7, 2);
max = vmaxq_f32(in4, tmp0);
max = vmaxq_f32(max, tmp1);
max = vmaxq_f32(max, in6);
max = vmaxq_f32(max, tmp2);
max = vmaxq_f32(max, tmp3);
vst1q_f32(output_ptr + (h_out - 1) * w_out, max);
input_tmp += 4;
input_tmp_end += 4;
output_ptr += 4;
in0 = in1;
in2 = in3;
in4 = in5;
in6 = in7;
}
// top right remain
float32x4_t pad0 = vdupq_n_f32(input_data[w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[2 * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
max = vmaxq_f32(in0, tmp0);
max = vmaxq_f32(max, tmp1);
max = vmaxq_f32(max, in2);
max = vmaxq_f32(max, tmp2);
max = vmaxq_f32(max, tmp3);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, max, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, max, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, max, 2);
}
}
// bottom_right remain
float32x4_t pad2 = vdupq_n_f32(input_data[(h_in - 1) * w_in - 1]);
float32x4_t pad3 = vdupq_n_f32(input_data[h_in * w_in - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 1);
tmp3 = vextq_f32(in6, pad3, 2);
max = vmaxq_f32(in4, tmp0);
max = vmaxq_f32(max, tmp1);
max = vmaxq_f32(max, in6);
max = vmaxq_f32(max, tmp2);
max = vmaxq_f32(max, tmp3);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, max, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, max, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, max, 2);
}
}
// mid
for (int j = 0; j < h_out - 2; ++j) {
output_ptr = out_data + (j + 1) * w_out + 1;
input_tmp = input_data + j * w_in;
in0 = vld1q_f32(input_tmp);
in2 = vld1q_f32(input_tmp + w_in);
in4 = vld1q_f32(input_tmp + 2 * w_in);
c_mid = w_out - 2;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w_in + 4);
in5 = vld1q_f32(input_tmp + 2 * w_in + 4);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
tmp4 = vextq_f32(in4, in5, 1);
tmp5 = vextq_f32(in4, in5, 2);
max = vmaxq_f32(in0, tmp0);
max = vmaxq_f32(max, tmp1);
max = vmaxq_f32(max, in2);
max = vmaxq_f32(max, tmp2);
max = vmaxq_f32(max, tmp3);
max = vmaxq_f32(max, in4);
max = vmaxq_f32(max, tmp4);
max = vmaxq_f32(max, tmp5);
vst1q_f32(output_ptr, max);
output_ptr += 4;
input_tmp += 4;
in0 = in1;
in2 = in3;
in4 = in5;
}
// mid remain
float32x4_t pad0 = vdupq_n_f32(input_data[(j + 1) * w_in - 1]);
float32x4_t pad1 = vdupq_n_f32(input_data[(j + 2) * w_in - 1]);
float32x4_t pad2 = vdupq_n_f32(input_data[(j + 3) * w_in - 1]);
tmp0 = vextq_f32(in0, pad0, 1);
tmp1 = vextq_f32(in0, pad0, 2);
tmp2 = vextq_f32(in2, pad1, 1);
tmp3 = vextq_f32(in2, pad1, 2);
tmp4 = vextq_f32(in4, pad2, 1);
tmp5 = vextq_f32(in4, pad2, 2);
max = vmaxq_f32(in0, tmp0);
max = vmaxq_f32(max, tmp1);
max = vmaxq_f32(max, in2);
max = vmaxq_f32(max, tmp2);
max = vmaxq_f32(max, tmp3);
max = vmaxq_f32(max, in4);
max = vmaxq_f32(max, tmp4);
max = vmaxq_f32(max, tmp5);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + i, max, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, max, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, max, 2);
}
}
}
input_data += inputdata_channel_stride;
out_data += outputdata_channel_stride;
}
}
#endif
}
void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input, void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
Tensor *output) { Tensor *output) {
......
...@@ -15,7 +15,8 @@ limitations under the License. */ ...@@ -15,7 +15,8 @@ limitations under the License. */
#ifdef POOL_OP #ifdef POOL_OP
#pragma once #pragma once
#include <algorithm>
#include <vector>
#include "framework/tensor.h" #include "framework/tensor.h"
#if __ARM_NEON #if __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
...@@ -26,7 +27,8 @@ namespace operators { ...@@ -26,7 +27,8 @@ namespace operators {
namespace math { namespace math {
using framework::Tensor; using framework::Tensor;
using std::vector; using std::vector;
void Pool3x3Avgs1p1(const Tensor *input, Tensor *output);
void Pool3x3Maxs1p1(const Tensor *input, Tensor *output);
void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input, void Pool3x3Max(vector<int> strides, vector<int> paddings, const Tensor *input,
Tensor *output); Tensor *output);
......
...@@ -195,8 +195,7 @@ class OpParam { ...@@ -195,8 +195,7 @@ class OpParam {
class ConvParam : OpParam { class ConvParam : OpParam {
public: public:
ConvParam(const VariableNameMap &inputs, const VariableNameMap &outputs, ConvParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const AttributeMap &attrs, const Scope &scope) {
const framework::Scope &scope) {
filter_ = FilterFrom<LoDTensor>(inputs, scope); filter_ = FilterFrom<LoDTensor>(inputs, scope);
input_ = InputFrom<LoDTensor>(inputs, scope); input_ = InputFrom<LoDTensor>(inputs, scope);
output_ = OutputFrom<LoDTensor>(outputs, scope); output_ = OutputFrom<LoDTensor>(outputs, scope);
...@@ -237,12 +236,11 @@ Print &operator<<(Print &printer, const ConvParam &conv_param); ...@@ -237,12 +236,11 @@ Print &operator<<(Print &printer, const ConvParam &conv_param);
class ElementwiseAddParam : OpParam { class ElementwiseAddParam : OpParam {
public: public:
ElementwiseAddParam(const VariableNameMap &inputs, ElementwiseAddParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs, const AttributeMap &attrs,
const framework::AttributeMap &attrs, const Scope &scope) {
const framework::Scope &scope) { input_x_ = InputXFrom<LoDTensor>(inputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope); input_y_ = InputYFrom<LoDTensor>(inputs, scope);
input_y_ = InputYFrom<framework::LoDTensor>(inputs, scope); out_ = OutFrom<LoDTensor>(outputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope);
axis_ = GetAttr<int>("axis", attrs); axis_ = GetAttr<int>("axis", attrs);
} }
...@@ -267,11 +265,10 @@ class ElementwiseAddParam : OpParam { ...@@ -267,11 +265,10 @@ class ElementwiseAddParam : OpParam {
class MulParam : OpParam { class MulParam : OpParam {
public: public:
MulParam(const VariableNameMap &inputs, const VariableNameMap &outputs, MulParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const AttributeMap &attrs, const Scope &scope) {
const framework::Scope &scope) { input_x_ = InputXFrom<LoDTensor>(inputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope); input_y_ = InputYFrom<LoDTensor>(inputs, scope);
input_y_ = InputYFrom<framework::LoDTensor>(inputs, scope); out_ = OutFrom<LoDTensor>(outputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope);
x_num_col_dims_ = GetAttr<int>("x_num_col_dims", attrs); x_num_col_dims_ = GetAttr<int>("x_num_col_dims", attrs);
y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs); y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs);
} }
...@@ -299,10 +296,9 @@ class MulParam : OpParam { ...@@ -299,10 +296,9 @@ class MulParam : OpParam {
class ConcatParam : public OpParam { class ConcatParam : public OpParam {
public: public:
ConcatParam(const VariableNameMap &inputs, const VariableNameMap &outputs, ConcatParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const AttributeMap &attrs, const Scope &scope) {
const framework::Scope &scope) {
inputs_ = InputMultiFrom<LoDTensor>(inputs, scope); inputs_ = InputMultiFrom<LoDTensor>(inputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope); out_ = OutFrom<LoDTensor>(outputs, scope);
axis_ = GetAttr<int>("axis", attrs); axis_ = GetAttr<int>("axis", attrs);
} }
...@@ -323,11 +319,10 @@ class ConcatParam : public OpParam { ...@@ -323,11 +319,10 @@ class ConcatParam : public OpParam {
class LrnParam : public OpParam { class LrnParam : public OpParam {
public: public:
LrnParam(const VariableNameMap &inputs, const VariableNameMap &outputs, LrnParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const AttributeMap &attrs, const Scope &scope) {
const framework::Scope &scope) { input_x_ = InputXFrom<LoDTensor>(inputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope); out_ = OutFrom<LoDTensor>(outputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope); mid_out_ = MidOutFrom<LoDTensor>(outputs, scope);
mid_out_ = MidOutFrom<framework::LoDTensor>(outputs, scope);
n_ = GetAttr<int>("n", attrs); n_ = GetAttr<int>("n", attrs);
alpha_ = GetAttr<float>("alpha", attrs); alpha_ = GetAttr<float>("alpha", attrs);
beta_ = GetAttr<float>("beta", attrs); beta_ = GetAttr<float>("beta", attrs);
...@@ -367,14 +362,13 @@ class LrnParam : public OpParam { ...@@ -367,14 +362,13 @@ class LrnParam : public OpParam {
class BatchNormParam : OpParam { class BatchNormParam : OpParam {
public: public:
BatchNormParam(const VariableNameMap &inputs, const VariableNameMap &outputs, BatchNormParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const AttributeMap &attrs, const Scope &scope) {
const framework::Scope &scope) { input_x_ = InputXFrom<LoDTensor>(inputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope); output_y_ = OutputYFrom<LoDTensor>(outputs, scope);
output_y_ = OutputYFrom<framework::LoDTensor>(outputs, scope); input_bias_ = InputBiasFrom<LoDTensor>(inputs, scope);
input_bias_ = InputBiasFrom<framework::LoDTensor>(inputs, scope); input_mean_ = InputMeanFrom<LoDTensor>(inputs, scope);
input_mean_ = InputMeanFrom<framework::LoDTensor>(inputs, scope); input_scale_ = InputScaleFrom<LoDTensor>(inputs, scope);
input_scale_ = InputScaleFrom<framework::LoDTensor>(inputs, scope); input_variance_ = InputVarianceFrom<LoDTensor>(inputs, scope);
input_variance_ = InputVarianceFrom<framework::LoDTensor>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs); epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs); momentum_ = GetAttr<float>("momentum", attrs);
is_test_ = GetAttr<bool>("is_test", attrs); is_test_ = GetAttr<bool>("is_test", attrs);
...@@ -418,11 +412,10 @@ class BatchNormParam : OpParam { ...@@ -418,11 +412,10 @@ class BatchNormParam : OpParam {
class PoolParam : public OpParam { class PoolParam : public OpParam {
public: public:
PoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs, PoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const AttributeMap &attrs, const Scope &scope) {
const framework::Scope &scope) { input_ = InputXFrom<LoDTensor>(inputs, scope);
input_ = InputXFrom<framework::LoDTensor>(inputs, scope);
output_ = OutFrom<framework::LoDTensor>(outputs, scope); output_ = OutFrom<LoDTensor>(outputs, scope);
pooling_type_ = GetAttr<string>("pooling_type", attrs); pooling_type_ = GetAttr<string>("pooling_type", attrs);
ksize_ = GetAttr<vector<int>>("ksize", attrs); ksize_ = GetAttr<vector<int>>("ksize", attrs);
strides_ = GetAttr<vector<int>>("strides", attrs); strides_ = GetAttr<vector<int>>("strides", attrs);
...@@ -464,13 +457,11 @@ class PoolParam : public OpParam { ...@@ -464,13 +457,11 @@ class PoolParam : public OpParam {
class PriorBoxParam : public OpParam { class PriorBoxParam : public OpParam {
public: public:
PriorBoxParam(const VariableNameMap &inputs, const VariableNameMap &outputs, PriorBoxParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const AttributeMap &attrs, const Scope &scope) {
const framework::Scope &scope) { input_ = InputFrom<LoDTensor>(inputs, scope);
input_ = InputFrom<framework::LoDTensor>(inputs, scope); input_image_ = InputImageFrom<LoDTensor>(inputs, scope);
input_image_ = InputImageFrom<framework::LoDTensor>(inputs, scope); output_boxes_ = OutputBoxesFrom<LoDTensor>(outputs, scope);
output_boxes_ = OutputBoxesFrom<framework::LoDTensor>(outputs, scope); output_variances_ = OutputVariancesFrom<LoDTensor>(outputs, scope);
output_variances_ =
OutputVariancesFrom<framework::LoDTensor>(outputs, scope);
min_sizes_ = GetAttr<vector<float>>("min_sizes", attrs); min_sizes_ = GetAttr<vector<float>>("min_sizes", attrs);
max_sizes_ = GetAttr<vector<float>>("max_sizes", attrs); max_sizes_ = GetAttr<vector<float>>("max_sizes", attrs);
aspect_ratios_ = GetAttr<vector<float>>("aspect_ratios", attrs); aspect_ratios_ = GetAttr<vector<float>>("aspect_ratios", attrs);
...@@ -528,13 +519,11 @@ class PriorBoxParam : public OpParam { ...@@ -528,13 +519,11 @@ class PriorBoxParam : public OpParam {
class BoxCoderParam : public OpParam { class BoxCoderParam : public OpParam {
public: public:
BoxCoderParam(const VariableNameMap &inputs, const VariableNameMap &outputs, BoxCoderParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const AttributeMap &attrs, const Scope &scope) {
const framework::Scope &scope) { input_priorbox_ = InputPriorBoxFrom<LoDTensor>(inputs, scope);
input_priorbox_ = InputPriorBoxFrom<framework::LoDTensor>(inputs, scope); input_priorboxvar_ = InputPriorBoxVarFrom<LoDTensor>(inputs, scope);
input_priorboxvar_ = input_targetbox_ = InputTargetBoxFrom<LoDTensor>(inputs, scope);
InputPriorBoxVarFrom<framework::LoDTensor>(inputs, scope); output_box_ = OutputBoxFrom<LoDTensor>(outputs, scope);
input_targetbox_ = InputTargetBoxFrom<framework::LoDTensor>(inputs, scope);
output_box_ = OutputBoxFrom<framework::LoDTensor>(outputs, scope);
code_type_ = GetAttr<std::string>("code_type", attrs); code_type_ = GetAttr<std::string>("code_type", attrs);
} }
const Tensor *InputPriorBox() const { return input_priorbox_; } const Tensor *InputPriorBox() const { return input_priorbox_; }
...@@ -560,10 +549,9 @@ class BoxCoderParam : public OpParam { ...@@ -560,10 +549,9 @@ class BoxCoderParam : public OpParam {
class SoftmaxParam : public OpParam { class SoftmaxParam : public OpParam {
public: public:
SoftmaxParam(const VariableNameMap &inputs, const VariableNameMap &outputs, SoftmaxParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const AttributeMap &attrs, const Scope &scope) {
const framework::Scope &scope) { input_x_ = InputXFrom<LoDTensor>(inputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope); out_ = OutFrom<LoDTensor>(outputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope);
} }
const Tensor *InputX() const { return input_x_; } const Tensor *InputX() const { return input_x_; }
Tensor *Out() const { return out_; } Tensor *Out() const { return out_; }
...@@ -578,10 +566,9 @@ class SoftmaxParam : public OpParam { ...@@ -578,10 +566,9 @@ class SoftmaxParam : public OpParam {
class SigmoidParam : public OpParam { class SigmoidParam : public OpParam {
public: public:
SigmoidParam(const VariableNameMap &inputs, const VariableNameMap &outputs, SigmoidParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const AttributeMap &attrs, const Scope &scope) {
const framework::Scope &scope) { input_x_ = InputXFrom<LoDTensor>(inputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope); out_ = OutFrom<LoDTensor>(outputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope);
} }
const Tensor *InputX() const { return input_x_; } const Tensor *InputX() const { return input_x_; }
Tensor *Out() const { return out_; } Tensor *Out() const { return out_; }
...@@ -643,9 +630,9 @@ class MultiClassNMSParam : public OpParam { ...@@ -643,9 +630,9 @@ class MultiClassNMSParam : public OpParam {
class FeedParam : public OpParam { class FeedParam : public OpParam {
public: public:
FeedParam(const VariableNameMap &inputs, const VariableNameMap &outputs, FeedParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, framework::Scope &scope) { const AttributeMap &attrs, Scope &scope) {
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope); input_x_ = InputXFrom<LoDTensor>(inputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope); out_ = OutFrom<LoDTensor>(outputs, scope);
auto var = scope.Var("batch_size"); auto var = scope.Var("batch_size");
batch_size = var->GetValue<int>(); batch_size = var->GetValue<int>();
} }
...@@ -662,10 +649,9 @@ class FeedParam : public OpParam { ...@@ -662,10 +649,9 @@ class FeedParam : public OpParam {
class FetchParam : public OpParam { class FetchParam : public OpParam {
public: public:
FetchParam(const VariableNameMap &inputs, const VariableNameMap &outputs, FetchParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const AttributeMap &attrs, const Scope &scope) {
const framework::Scope &scope) { input_x_ = InputXFrom<LoDTensor>(inputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope); out_ = OutFrom<LoDTensor>(outputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope);
} }
const Tensor *InputX() const { return input_x_; } const Tensor *InputX() const { return input_x_; }
Tensor *Out() const { return out_; } Tensor *Out() const { return out_; }
...@@ -863,10 +849,10 @@ class FusionConvAddBNReluParam : public OpParam { ...@@ -863,10 +849,10 @@ class FusionConvAddBNReluParam : public OpParam {
paddings_ = GetAttr<vector<int>>("paddings", attrs); paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs); dilations_ = GetAttr<vector<int>>("dilations", attrs);
groups = GetAttr<int>("groups", attrs); groups = GetAttr<int>("groups", attrs);
input_bias_ = InputBiasFrom<framework::LoDTensor>(inputs, scope); input_bias_ = InputBiasFrom<LoDTensor>(inputs, scope);
input_mean_ = InputMeanFrom<framework::LoDTensor>(inputs, scope); input_mean_ = InputMeanFrom<LoDTensor>(inputs, scope);
input_scale_ = InputScaleFrom<framework::LoDTensor>(inputs, scope); input_scale_ = InputScaleFrom<LoDTensor>(inputs, scope);
input_variance_ = InputVarianceFrom<framework::LoDTensor>(inputs, scope); input_variance_ = InputVarianceFrom<LoDTensor>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs); epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs); momentum_ = GetAttr<float>("momentum", attrs);
is_test_ = GetAttr<bool>("is_test", attrs); is_test_ = GetAttr<bool>("is_test", attrs);
......
...@@ -17,25 +17,25 @@ limitations under the License. */ ...@@ -17,25 +17,25 @@ limitations under the License. */
#include "../test_include.h" #include "../test_include.h"
int main() { int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile; paddle_mobile::Loader<paddle_mobile::CPU> loader;
bool optimize = true; bool optimize = true;
auto time1 = time(); auto time1 = time();
// auto program = loader.Load(g_googlenet, optimize); // auto program = loader.Load(g_googlenet, optimize);
if (paddle_mobile.Load(g_googlenet_combine + "/model", auto program = loader.Load(g_googlenet_combine + "/model",
g_googlenet_combine + "/params", optimize)) { g_googlenet_combine + "/params", optimize);
auto time2 = time(); auto time2 = time();
DLOG << "load cost :" << time_diff(time1, time2) << "ms\n"; DLOG << "load cost :" << time_diff(time1, time2) << "ms\n";
paddle_mobile::Executor<paddle_mobile::CPU> executor(program, 1, optimize);
std::vector<float> input; std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224}; std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224, &input, dims); GetInput<float>(g_test_image_1x3x224x224, &input, dims);
auto time3 = time(); auto time3 = time();
for (int i = 0; i < 10; ++i) { for (int i = 0; i < 10; ++i) {
paddle_mobile.Predict(input, dims); executor.Predict(input, dims);
} }
auto time4 = time(); auto time4 = time();
DLOG << "predict cost :" << time_diff(time3, time4) << "ms\n"; DLOG << "predict cost :" << time_diff(time3, time4) << "ms\n";
}
return 0; return 0;
} }
...@@ -32,8 +32,8 @@ build_for_mac() { ...@@ -32,8 +32,8 @@ build_for_mac() {
build_for_android() { build_for_android() {
#rm -rf "../build" #rm -rf "../build"
if [ -z "${ANDROID_NDK}" ]; then if [ -z "${NDK_ROOT}" ]; then
echo "ANDROID_NDK not found!" echo "NDK_ROOT not found!"
exit -1 exit -1
fi fi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册