diff --git a/doc/design_doc.md b/doc/design_doc.md index 3407c78443de0f0c7d9ebab848122c2e089e9e41..bf5f78e8d805465418cad8989945f2afa7ab5587 100644 --- a/doc/design_doc.md +++ b/doc/design_doc.md @@ -3,7 +3,6 @@ #### 以下是 paddle-mobile 代码的执行流程图: - ![执行流程图](http://otkwwi4x8.bkt.clouddn.com/2018-07-02-15305189473720.png) @@ -15,7 +14,6 @@ 先来看一下模型, 模型分为两种结构: 一种为参数文件是散开的, 如下图, 红框为模型结构的 protobuf 文件, 其余为参数文件 - ![模型描述](http://otkwwi4x8.bkt.clouddn.com/2018-07-02-15305190629577.png) @@ -23,6 +21,7 @@ ![模型描述combined](http://otkwwi4x8.bkt.clouddn.com/2018-07-02-15305191057130.png) + loader 模块的作用是将模型结构信息 load 进内存, 将红框内的 protobuf 文件 load 进内存, 并对模型结构进行优化(如将几个细粒度的 op 融合成 粗粒度的 op, 如将 conv、 add、 batchnorm、 relu 融合为 conv\_add\_batchnorm\_relu). 方便进行算法优化. diff --git a/ios/PaddleMobile.xcworkspace/xcuserdata/liuruilong.xcuserdatad/UserInterfaceState.xcuserstate b/ios/PaddleMobile.xcworkspace/xcuserdata/liuruilong.xcuserdatad/UserInterfaceState.xcuserstate index ff9a9abc7211d4c390fca9535743ee452515390e..a74810d22b023830a6e44d19984ff92302eb84a3 100644 Binary files a/ios/PaddleMobile.xcworkspace/xcuserdata/liuruilong.xcuserdatad/UserInterfaceState.xcuserstate and b/ios/PaddleMobile.xcworkspace/xcuserdata/liuruilong.xcuserdatad/UserInterfaceState.xcuserstate differ diff --git a/src/framework/scope.h b/src/framework/scope.h index d714f61af3bd443c09fcef7aacee2416b90b5e02..054f141ff68895e0879fd31e15d90c76ea038135 100644 --- a/src/framework/scope.h +++ b/src/framework/scope.h @@ -23,7 +23,17 @@ namespace framework { class Scope { public: Scope() = default; - ~Scope() = default; + + ~Scope() { + for (auto &var : vars_) { + delete var.second; + } + vars_.clear(); + for (auto kid : kids_) { + delete kid; + } + kids_.clear(); + } Scope &NewScope() const; diff --git a/src/jni/paddle_mobile_jni.cpp b/src/jni/paddle_mobile_jni.cpp index 323d1e37c1c4420f8b0d91cafefa83a98ef9328b..01d4e52a4b1308a7ff97bc672d1a15d329dbf318 100644 --- a/src/jni/paddle_mobile_jni.cpp +++ b/src/jni/paddle_mobile_jni.cpp @@ -54,13 +54,14 @@ string jstring2cppstring(JNIEnv *env, jstring jstr) { JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env, jclass thiz, jstring modelPath) { + ANDROIDLOGI("load invoked"); bool optimize = true; return getPaddleMobileInstance()->Load(jstring2cppstring(env, modelPath), optimize); } -JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( - JNIEnv *env, jclass thiz, jfloatArray buf) { +JNIEXPORT jfloatArray JNICALL +Java_com_baidu_paddle_PML_predict(JNIEnv *env, jclass thiz, jfloatArray buf) { jfloatArray result = NULL; int count = 0; float *dataPointer = nullptr; @@ -78,6 +79,7 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( count = output->numel(); result = env->NewFloatArray(count); env->SetFloatArrayRegion(result, 0, count, output->data()); + ANDROIDLOGI("predict finished"); return result; } diff --git a/src/jni/paddle_mobile_jni.h b/src/jni/paddle_mobile_jni.h index 3497144999b028c927aad9a0ffa079044c3bcdf0..86caa9a273ab11124f6ea67efe27dc3529cea69f 100644 --- a/src/jni/paddle_mobile_jni.h +++ b/src/jni/paddle_mobile_jni.h @@ -31,8 +31,8 @@ JNIEXPORT jboolean JNICALL Java_com_baidu_paddle_PML_load(JNIEnv *env, /** * object detection for anroid */ -JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage( - JNIEnv *env, jclass thiz, jfloatArray buf); +JNIEXPORT jfloatArray JNICALL +Java_com_baidu_paddle_PML_predict(JNIEnv *env, jclass thiz, jfloatArray buf); /** * clear data of the net when destroy for android diff --git a/src/operators/kernel/arm/conv_add_kernel.cpp b/src/operators/kernel/arm/conv_add_kernel.cpp index 64d6dfa64dc3feae5b73a17ae5b148053df34a0b..88f839f611f1ed7f46c11a1b24feb6e29ff07ec7 100644 --- a/src/operators/kernel/arm/conv_add_kernel.cpp +++ b/src/operators/kernel/arm/conv_add_kernel.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #ifdef FUSION_CONVADD_OP #include "operators/kernel/conv_add_kernel.h" +#include "../central-arm-func/conv_add_arm_func.h" namespace paddle_mobile { namespace operators { @@ -23,111 +24,9 @@ bool ConvAddKernel::Init(FusionConvAddParam *param) { return true; } -void ConvAddBasic(const FusionConvAddParam ¶m) { - const Tensor *input = param.Input(); - Tensor filter = *param.Filter(); - Tensor bias = *param.Bias(); - int axis = param.Axis(); - Tensor *output = param.Output(); - math::expand_bias(bias, axis, output->dims()); - output->ShareDataWith(bias); - int groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); - - const int batch_size = static_cast(input->dims()[0]); - - std::vector filter_shape_vec(framework::vectorize(filter.dims())); - - std::vector output_shape_vec(framework::vectorize(output->dims())); - size_t data_dim = filter_shape_vec.size() - 2; - std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = input->dims()[1] / groups; - for (size_t j = 0; j < data_dim; ++j) { - col_shape_vec[j + 1] = filter_shape_vec[j + 2]; - col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - - framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, data_dim + 1); - - bool is_expand = - math::IsExpand(filter_shape_vec, strides, paddings, dilations); - Tensor col; - Tensor col_matrix; - if (is_expand) { - col.mutable_data(col_shape); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } - - framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); - - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - framework::DDim output_matrix_shape = { - output->dims()[1], - output->numel() / (output->dims()[0] * output->dims()[1])}; - - // convolution operator: im2col(or vol2col) + gemm - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output->dims()[1]) / groups; - - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; - - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (!is_expand) { - col.ShareDataWith(in_slice); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } else if (data_dim == 2U) { - // im2col - im2col(in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &col); - } else if (data_dim == 3U) { - // vol2col - vol2col(in_slice, dilations, strides, paddings, &col); - } - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(1)); - } - } -} - template <> void ConvAddKernel::Compute(const FusionConvAddParam ¶m) const { - if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { - math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), - param.Bias(), true); - } else if (param.Groups() == param.Input()->dims()[1] && - param.Input()->dims()[1] == param.Output()->dims()[1] && - param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3) { - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), param.Bias(), param.Output(), true); - } else { - ConvAddBasic(param); - } + ConvAddCompute(param); } template class ConvAddKernel; diff --git a/src/operators/kernel/arm/pool_kernel.cpp b/src/operators/kernel/arm/pool_kernel.cpp index 38e6c5f3f071d8eb0385d742fb819564309eeef6..be2189340f480bef80fd00a612cf32e71ea10a1c 100644 --- a/src/operators/kernel/arm/pool_kernel.cpp +++ b/src/operators/kernel/arm/pool_kernel.cpp @@ -14,27 +14,11 @@ limitations under the License. */ #ifdef POOL_OP -#include -#include "common/log.h" - +#include "operators/kernel/pool_kernel.h" +#include "../central-arm-func/pool_arm_func.h" namespace paddle_mobile { namespace operators { -inline void PoolBasic(std::string pooling_type, std::vector ksize, - std::vector strides, std::vector paddings, - const Tensor *in_x, Tensor *out) { - if (pooling_type == "max") { - math::PoolFunctor, float> pool2d_forward; - math::MaxPool pool_process; - pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out); - - } else if (pooling_type == "avg") { - math::PoolFunctor, float> pool2d_forward; - math::AvgPool pool_process; - pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out); - } -} - template <> bool PoolKernel::Init(PoolParam *param) { return true; @@ -42,54 +26,7 @@ bool PoolKernel::Init(PoolParam *param) { template <> void PoolKernel::Compute(const PoolParam ¶m) const { - const Tensor *in_x = param.Input(); - Tensor *out = param.Output(); - std::string pooling_type = param.PoolingType(); - - std::vector ksize = param.Ksize(); - - std::vector strides = param.Strides(); - - std::vector paddings = param.Paddings(); - if (ksize.size() != 2) { - LOG(paddle_mobile::LogLevel::kLOG_ERROR) - << "Pool op only supports 2D and 3D input."; - } - - if (param.isGlobalPooling()) { - for (size_t i = 0; i < ksize.size(); ++i) { - paddings[i] = 0; - ksize[i] = static_cast(in_x->dims()[i + 2]); - } - } else if (ksize[0] == 3 && ksize[0] == ksize[1]) { - if (pooling_type == "max") { - if (strides[0] == strides[1] && strides[0] == 1 && - paddings[0] == paddings[1] && paddings[1] == 1) { - math::Pool3x3Maxs1p1(in_x, out); - } else { - math::Pool3x3Max(strides, paddings, in_x, out); - } - math::Pool3x3Max(strides, paddings, in_x, out); - } else if (pooling_type == "avg") { - if (strides[0] == strides[1] && strides[0] == 1 && - paddings[0] == paddings[1] && paddings[1] == 1) { - math::Pool3x3Avgs1p1(in_x, out); - } else { - math::Pool3x3Avg(strides, paddings, in_x, out); - } - math::Pool3x3Avg(strides, paddings, in_x, out); - } - - } else if (ksize[0] == 2 && ksize[0] == ksize[1]) { - if (pooling_type == "max") { - math::Pool2x2Max(strides, paddings, in_x, out); - } else if (pooling_type == "avg") { - math::Pool2x2Avg(strides, paddings, in_x, out); - } - - } else { - PoolBasic(pooling_type, ksize, strides, paddings, in_x, out); - } + PoolCompute(param); } } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/arm/sigmoid_kernel.cpp b/src/operators/kernel/arm/sigmoid_kernel.cpp index 5eb65cd6cebf453e46dc16c4982f81cb679bbc72..eb67de153ddb13fb48e42c28d6ec2270b0bc59b4 100644 --- a/src/operators/kernel/arm/sigmoid_kernel.cpp +++ b/src/operators/kernel/arm/sigmoid_kernel.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #ifdef SIGMOID_OP #include "../sigmoid_kernel.h" +#include "../central-arm-func/sigmoid_arm_func.h" #if __ARM_NEON #include "../../math/math_func_neon.h" #endif @@ -25,52 +26,6 @@ namespace operators { using framework::DDim; using framework::Tensor; -void sigmoid(const Tensor *X, Tensor *Y) { -#if __ARM_NEON - const float *input = X->data(); - float *output = Y->mutable_data(); - const DDim &dDim = X->dims(); - int axis_index = 1; - if (dDim.size() < 4) { - axis_index = 0; - } - DDim outer_ddim = - paddle_mobile::framework::slice_ddim(dDim, 0, axis_index + 1); - DDim inner_ddim = - paddle_mobile::framework::slice_ddim(dDim, axis_index + 1, dDim.size()); - int out_size = paddle_mobile::framework::product(outer_ddim); - int inner_size = paddle_mobile::framework::product(inner_ddim); - - DLOG << "outsize=" << out_size; - DLOG << "innersize=" << inner_size; - #pragma omp parallel for - for (int i = 0; i < out_size; ++i) { - const float *input_outer_ptr = input + i * inner_size; - float *output_outer_ptr = output + i * inner_size; - int nn = inner_size >> 2; - int remain = inner_size - (nn << 2); - float32x4_t _one = vdupq_n_f32(1.f); - for (; nn > 0; nn--) { - float32x4_t data = vld1q_f32(input_outer_ptr); - data = vnegq_f32(data); - data = exp_ps(data); - data = vaddq_f32(data, _one); - float32x4_t out_data = vrecpeq_f32(data); - out_data = vmulq_f32(vrecpsq_f32(data, out_data), out_data); - vst1q_f32(output_outer_ptr, out_data); - - input_outer_ptr += 4; - output_outer_ptr += 4; - } - for (; remain > 0; remain--) { - *output_outer_ptr = 1.f / (1.f + exp(-*input_outer_ptr)); - output_outer_ptr++; - input_outer_ptr++; - } - } -#endif -} - template <> bool SigmoidKernel::Init(SigmoidParam *param) { return true; @@ -78,11 +33,7 @@ bool SigmoidKernel::Init(SigmoidParam *param) { template <> void SigmoidKernel::Compute(const SigmoidParam ¶m) const { - const Tensor *in_x = param.InputX(); - Tensor *out = param.Out(); - auto x_dims = in_x->dims(); - out->Resize(x_dims); - sigmoid(in_x, out); + SigmoidCompute(param); } template class SigmoidKernel; diff --git a/src/operators/kernel/arm/softmax_kernel.cpp b/src/operators/kernel/arm/softmax_kernel.cpp index 29006d48dc00b650a725cd0a9cc3c37568e829a9..3ce763be38678319cfc23be83180450e5d3b209c 100644 --- a/src/operators/kernel/arm/softmax_kernel.cpp +++ b/src/operators/kernel/arm/softmax_kernel.cpp @@ -15,7 +15,8 @@ limitations under the License. */ #ifdef SOFTMAX_OP #include "../softmax_kernel.h" -#include "../../math/softmax.h" +#include "../central-arm-func/softmax_arm_func.h" +#include "operators/math/softmax.h" namespace paddle_mobile { namespace operators { @@ -26,11 +27,7 @@ bool SoftmaxKernel::Init(SoftmaxParam *param) { template <> void SoftmaxKernel::Compute(const SoftmaxParam ¶m) const { - const Tensor *in_x = param.InputX(); - Tensor *out = param.Out(); - auto x_dims = in_x->dims(); - out->Resize(x_dims); - math::SoftmaxFuntor()(in_x, out); + SoftmaxCompute(param); } template class SoftmaxKernel; diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..ed6dc46a90f2b6fa73555b3575f24103a34d1dda --- /dev/null +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -0,0 +1,138 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADD_OP +#pragma once + +#include +#include "operators/math/conv_func.h" +#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/im2col.h" +#include "operators/math/math_function.h" +#include "operators/math/vol2col.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { +void ConvAddBasic(const FusionConvAddParam ¶m) { + const Tensor *input = param.Input(); + Tensor filter = *param.Filter(); + Tensor bias = *param.Bias(); + int axis = param.Axis(); + Tensor *output = param.Output(); + math::expand_bias(bias, axis, output->dims()); + output->ShareDataWith(bias); + int groups = param.Groups(); + std::vector strides = param.Strides(); + std::vector paddings = param.Paddings(); + std::vector dilations = param.Dilations(); + + const int batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + + std::vector output_shape_vec(framework::vectorize(output->dims())); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, data_dim + 1); + + bool is_expand = + math::IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; + Tensor col_matrix; + if (is_expand) { + col.mutable_data(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output->dims()[1]) / groups; + + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; + + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { + col.ShareDataWith(in_slice); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(in_slice, dilations, strides, paddings, &col); + } + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + math::matmul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(1)); + } + } +} + +template +void ConvAddCompute(const FusionConvAddParam ¶m) { + if (param.Groups() == param.Input()->dims()[1] && + param.Input()->dims()[1] == param.Output()->dims()[1] && + param.Filter()->dims()[2] == param.Filter()->dims()[3] && + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), + param.Bias(), true); + } else if (param.Groups() == param.Input()->dims()[1] && + param.Input()->dims()[1] == param.Output()->dims()[1] && + param.Filter()->dims()[2] == param.Filter()->dims()[3] && + param.Filter()->dims()[2] == 3) { + math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), + param.Filter(), param.Bias(), param.Output(), true); + } else { + ConvAddBasic(param); + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 6accf1937da5343a33d9dd739c125836f080f181..33caded3afaaf125bac9108f2fafeda3d3c2049f 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -15,19 +15,21 @@ limitations under the License. */ #ifdef CONV_OP #pragma once -#include #include - +#include "operators/math/conv_func.h" +#include "operators/math/depthwise_conv_3x3.h" +#include "operators/math/im2col.h" +#include "operators/math/math_function.h" +#include "operators/math/vol2col.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { - inline void ConvBasic(const ConvParam ¶m) { const Tensor *input = param.Input(); Tensor filter = *param.Filter(); Tensor *output = param.Output(); - + output->mutable_data(); int groups = param.Groups(); std::vector strides = param.Strides(); std::vector paddings = param.Paddings(); @@ -111,20 +113,18 @@ inline void ConvBasic(const ConvParam ¶m) { template void ConvCompute(const ConvParam ¶m) { - Tensor Bias; - Bias.mutable_data({param.Groups()}); if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), - &Bias, false); + nullptr, false); } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + param.Filter()->dims()[2] == 3) { math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), &Bias, param.Output(), false); + param.Filter(), nullptr, param.Output(), false); } else { ConvBasic(param); } diff --git a/src/operators/kernel/central-arm-func/pool_arm_func.h b/src/operators/kernel/central-arm-func/pool_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..9eb8aceb1ab5dc1ee10f43e0632f35ef12722487 --- /dev/null +++ b/src/operators/kernel/central-arm-func/pool_arm_func.h @@ -0,0 +1,92 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef POOL_OP +#pragma once + +#include +#include +#include "operators/math/pooling.h" + +namespace paddle_mobile { +namespace operators { +using framework::Tensor; + +inline void PoolBasic(std::string pooling_type, std::vector ksize, + std::vector strides, std::vector paddings, + const Tensor *in_x, Tensor *out) { + if (pooling_type == "max") { + math::PoolFunctor, float> pool2d_forward; + math::MaxPool pool_process; + pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out); + + } else if (pooling_type == "avg") { + math::PoolFunctor, float> pool2d_forward; + math::AvgPool pool_process; + pool2d_forward(*in_x, ksize, strides, paddings, pool_process, out); + } +} +template +void PoolCompute(const PoolParam ¶m) { + const Tensor *in_x = param.Input(); + Tensor *out = param.Output(); + std::string pooling_type = param.PoolingType(); + + std::vector ksize = param.Ksize(); + + std::vector strides = param.Strides(); + + std::vector paddings = param.Paddings(); + if (ksize.size() != 2) { + LOG(paddle_mobile::LogLevel::kLOG_ERROR) + << "Pool op only supports 2D and 3D input."; + } + + if (param.isGlobalPooling()) { + for (size_t i = 0; i < ksize.size(); ++i) { + paddings[i] = 0; + ksize[i] = static_cast(in_x->dims()[i + 2]); + } + } else if (ksize[0] == 3 && ksize[0] == ksize[1]) { + if (pooling_type == "max") { + if (strides[0] == strides[1] && strides[0] == 1 && + paddings[0] == paddings[1] && paddings[1] == 1) { + math::Pool3x3Maxs1p1(in_x, out); + } else { + math::Pool3x3Max(strides, paddings, in_x, out); + } + } else if (pooling_type == "avg") { + if (strides[0] == strides[1] && strides[0] == 1 && + paddings[0] == paddings[1] && paddings[1] == 1) { + math::Pool3x3Avgs1p1(in_x, out); + } else { + math::Pool3x3Avg(strides, paddings, in_x, out); + } + } + + } else if (ksize[0] == 2 && ksize[0] == ksize[1]) { + if (pooling_type == "max") { + math::Pool2x2Max(strides, paddings, in_x, out); + } else if (pooling_type == "avg") { + math::Pool2x2Avg(strides, paddings, in_x, out); + } + + } else { + PoolBasic(pooling_type, ksize, strides, paddings, in_x, out); + } +} + +} // namespace operators +} // namespace paddle_mobile +#endif diff --git a/src/operators/kernel/central-arm-func/sigmoid_arm_func.h b/src/operators/kernel/central-arm-func/sigmoid_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..eb0e4ab7e4b4f18f8ede4d85b859e68f7d58bda2 --- /dev/null +++ b/src/operators/kernel/central-arm-func/sigmoid_arm_func.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef SIGMOID_OP +#pragma once + +#include "operators/op_param.h" +#if __ARM_NEON +#include +#include "operators/math/math_func_neon.h" +#endif + +namespace paddle_mobile { +namespace operators { +using framework::DDim; +void sigmoid(const Tensor *X, Tensor *Y) { +#if __ARM_NEON + const float *input = X->data(); + float *output = Y->mutable_data(); + const DDim &dDim = X->dims(); + int axis_index = 1; + if (dDim.size() < 4) { + axis_index = 0; + } + DDim outer_ddim = + paddle_mobile::framework::slice_ddim(dDim, 0, axis_index + 1); + DDim inner_ddim = + paddle_mobile::framework::slice_ddim(dDim, axis_index + 1, dDim.size()); + int out_size = paddle_mobile::framework::product(outer_ddim); + int inner_size = paddle_mobile::framework::product(inner_ddim); + + DLOG << "outsize=" << out_size; + DLOG << "innersize=" << inner_size; + #pragma omp parallel for + for (int i = 0; i < out_size; ++i) { + const float *input_outer_ptr = input + i * inner_size; + float *output_outer_ptr = output + i * inner_size; + int nn = inner_size >> 2; + int remain = inner_size - (nn << 2); + float32x4_t _one = vdupq_n_f32(1.f); + for (; nn > 0; nn--) { + float32x4_t data = vld1q_f32(input_outer_ptr); + data = vnegq_f32(data); + data = exp_ps(data); + data = vaddq_f32(data, _one); + float32x4_t out_data = vrecpeq_f32(data); + out_data = vmulq_f32(vrecpsq_f32(data, out_data), out_data); + vst1q_f32(output_outer_ptr, out_data); + + input_outer_ptr += 4; + output_outer_ptr += 4; + } + for (; remain > 0; remain--) { + *output_outer_ptr = 1.f / (1.f + exp(-*input_outer_ptr)); + output_outer_ptr++; + input_outer_ptr++; + } + } +#endif +} + +template +void SigmoidCompute(const SigmoidParam ¶m) { + const Tensor *in_x = param.InputX(); + Tensor *out = param.Out(); + auto x_dims = in_x->dims(); + out->Resize(x_dims); + sigmoid(in_x, out); +} +} // namespace operators +} // namespace paddle_mobile +#endif diff --git a/src/operators/kernel/central-arm-func/softmax_arm_func.h b/src/operators/kernel/central-arm-func/softmax_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..5a60bf88ae5d936567dc096c1f4bb31a73f0ef34 --- /dev/null +++ b/src/operators/kernel/central-arm-func/softmax_arm_func.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef SOFTMAX_OP +#pragma once +#include "../../math/softmax.h" +namespace paddle_mobile { +namespace operators { +template +void SoftmaxCompute(const SoftmaxParam ¶m) { + const Tensor *in_x = param.InputX(); + Tensor *out = param.Out(); + auto x_dims = in_x->dims(); + out->Resize(x_dims); + math::SoftmaxFuntor()(in_x, out); +} +} // namespace operators +} // namespace paddle_mobile +#endif diff --git a/src/operators/kernel/pool_kernel.h b/src/operators/kernel/pool_kernel.h index d666910b73e7a3cef2cc59d4ba32b826ae6d0876..fd9faa3d5a508084924e080f5c5ed7e7b454b5f2 100644 --- a/src/operators/kernel/pool_kernel.h +++ b/src/operators/kernel/pool_kernel.h @@ -17,7 +17,6 @@ limitations under the License. */ #pragma once #include "framework/operator.h" -#include "operators/math/pooling.h" #include "operators/op_param.h" namespace paddle_mobile { diff --git a/src/operators/kernel/softmax_kernel.h b/src/operators/kernel/softmax_kernel.h index 5a87d64dd9987d445b13a4fa9dc29a04e4ecc398..a500d9c81cce96b0f1db6d45981ad9aa02ea7c0b 100644 --- a/src/operators/kernel/softmax_kernel.h +++ b/src/operators/kernel/softmax_kernel.h @@ -23,8 +23,6 @@ namespace paddle_mobile { namespace operators { using framework::OpKernelBase; -void simoid(Tensor *X, Tensor *Y); - template class SoftmaxKernel : public OpKernelBase { public: diff --git a/src/operators/math/depthwise_conv_3x3.cpp b/src/operators/math/depthwise_conv_3x3.cpp index 984678e8730ea58d7dc647450dd098d265f0eb39..f23affb45107b0d2414c49843cdfbd70c953c95c 100644 --- a/src/operators/math/depthwise_conv_3x3.cpp +++ b/src/operators/math/depthwise_conv_3x3.cpp @@ -245,7 +245,10 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, const float *input_data = input->data(); const float *filter_data = filter->data(); float *output_data = output->data(); - const float *bias_data = bias->data(); + const float *bias_data; + if (if_bias) { + bias_data = bias->data(); + } const int h = static_cast(input->dims()[2]); const int w = static_cast(input->dims()[3]); diff --git a/src/operators/math/pool_3x3.cpp b/src/operators/math/pool_3x3.cpp index fb91528b473418849d9005a2c0a5a52d9d033e58..83d0bcb699f82b9c290080982ba6750a64d74e53 100644 --- a/src/operators/math/pool_3x3.cpp +++ b/src/operators/math/pool_3x3.cpp @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef POOL_OP -#include "operators/math/pool_3x3.h" -#include +#include "pool_3x3.h" #include "framework/tensor.h" +#if __ARM_NEON +#include +#endif // __ARM_NEON +#include namespace paddle_mobile { namespace operators { namespace math { diff --git a/src/operators/op_param.h b/src/operators/op_param.h index c0f0fbc8a9939bc4609e64359835a685dd4c67f9..892b08e6da0ce92df95e81dd9896df3ee8899fb9 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -195,8 +195,7 @@ class OpParam { class ConvParam : OpParam { public: ConvParam(const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - const framework::Scope &scope) { + const AttributeMap &attrs, const Scope &scope) { filter_ = FilterFrom(inputs, scope); input_ = InputFrom(inputs, scope); output_ = OutputFrom(outputs, scope); @@ -237,12 +236,11 @@ Print &operator<<(Print &printer, const ConvParam &conv_param); class ElementwiseAddParam : OpParam { public: ElementwiseAddParam(const VariableNameMap &inputs, - const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - const framework::Scope &scope) { - input_x_ = InputXFrom(inputs, scope); - input_y_ = InputYFrom(inputs, scope); - out_ = OutFrom(outputs, scope); + const VariableNameMap &outputs, const AttributeMap &attrs, + const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + input_y_ = InputYFrom(inputs, scope); + out_ = OutFrom(outputs, scope); axis_ = GetAttr("axis", attrs); } @@ -267,11 +265,10 @@ class ElementwiseAddParam : OpParam { class MulParam : OpParam { public: MulParam(const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - const framework::Scope &scope) { - input_x_ = InputXFrom(inputs, scope); - input_y_ = InputYFrom(inputs, scope); - out_ = OutFrom(outputs, scope); + const AttributeMap &attrs, const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + input_y_ = InputYFrom(inputs, scope); + out_ = OutFrom(outputs, scope); x_num_col_dims_ = GetAttr("x_num_col_dims", attrs); y_num_col_dims_ = GetAttr("y_num_col_dims", attrs); } @@ -299,10 +296,9 @@ class MulParam : OpParam { class ConcatParam : public OpParam { public: ConcatParam(const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - const framework::Scope &scope) { + const AttributeMap &attrs, const Scope &scope) { inputs_ = InputMultiFrom(inputs, scope); - out_ = OutFrom(outputs, scope); + out_ = OutFrom(outputs, scope); axis_ = GetAttr("axis", attrs); } @@ -323,11 +319,10 @@ class ConcatParam : public OpParam { class LrnParam : public OpParam { public: LrnParam(const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - const framework::Scope &scope) { - input_x_ = InputXFrom(inputs, scope); - out_ = OutFrom(outputs, scope); - mid_out_ = MidOutFrom(outputs, scope); + const AttributeMap &attrs, const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + out_ = OutFrom(outputs, scope); + mid_out_ = MidOutFrom(outputs, scope); n_ = GetAttr("n", attrs); alpha_ = GetAttr("alpha", attrs); beta_ = GetAttr("beta", attrs); @@ -367,14 +362,13 @@ class LrnParam : public OpParam { class BatchNormParam : OpParam { public: BatchNormParam(const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - const framework::Scope &scope) { - input_x_ = InputXFrom(inputs, scope); - output_y_ = OutputYFrom(outputs, scope); - input_bias_ = InputBiasFrom(inputs, scope); - input_mean_ = InputMeanFrom(inputs, scope); - input_scale_ = InputScaleFrom(inputs, scope); - input_variance_ = InputVarianceFrom(inputs, scope); + const AttributeMap &attrs, const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + output_y_ = OutputYFrom(outputs, scope); + input_bias_ = InputBiasFrom(inputs, scope); + input_mean_ = InputMeanFrom(inputs, scope); + input_scale_ = InputScaleFrom(inputs, scope); + input_variance_ = InputVarianceFrom(inputs, scope); epsilon_ = GetAttr("epsilon", attrs); momentum_ = GetAttr("momentum", attrs); is_test_ = GetAttr("is_test", attrs); @@ -418,11 +412,10 @@ class BatchNormParam : OpParam { class PoolParam : public OpParam { public: PoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - const framework::Scope &scope) { - input_ = InputXFrom(inputs, scope); + const AttributeMap &attrs, const Scope &scope) { + input_ = InputXFrom(inputs, scope); - output_ = OutFrom(outputs, scope); + output_ = OutFrom(outputs, scope); pooling_type_ = GetAttr("pooling_type", attrs); ksize_ = GetAttr>("ksize", attrs); strides_ = GetAttr>("strides", attrs); @@ -464,13 +457,11 @@ class PoolParam : public OpParam { class PriorBoxParam : public OpParam { public: PriorBoxParam(const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - const framework::Scope &scope) { - input_ = InputFrom(inputs, scope); - input_image_ = InputImageFrom(inputs, scope); - output_boxes_ = OutputBoxesFrom(outputs, scope); - output_variances_ = - OutputVariancesFrom(outputs, scope); + const AttributeMap &attrs, const Scope &scope) { + input_ = InputFrom(inputs, scope); + input_image_ = InputImageFrom(inputs, scope); + output_boxes_ = OutputBoxesFrom(outputs, scope); + output_variances_ = OutputVariancesFrom(outputs, scope); min_sizes_ = GetAttr>("min_sizes", attrs); max_sizes_ = GetAttr>("max_sizes", attrs); aspect_ratios_ = GetAttr>("aspect_ratios", attrs); @@ -528,13 +519,11 @@ class PriorBoxParam : public OpParam { class BoxCoderParam : public OpParam { public: BoxCoderParam(const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - const framework::Scope &scope) { - input_priorbox_ = InputPriorBoxFrom(inputs, scope); - input_priorboxvar_ = - InputPriorBoxVarFrom(inputs, scope); - input_targetbox_ = InputTargetBoxFrom(inputs, scope); - output_box_ = OutputBoxFrom(outputs, scope); + const AttributeMap &attrs, const Scope &scope) { + input_priorbox_ = InputPriorBoxFrom(inputs, scope); + input_priorboxvar_ = InputPriorBoxVarFrom(inputs, scope); + input_targetbox_ = InputTargetBoxFrom(inputs, scope); + output_box_ = OutputBoxFrom(outputs, scope); code_type_ = GetAttr("code_type", attrs); } const Tensor *InputPriorBox() const { return input_priorbox_; } @@ -560,10 +549,9 @@ class BoxCoderParam : public OpParam { class SoftmaxParam : public OpParam { public: SoftmaxParam(const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - const framework::Scope &scope) { - input_x_ = InputXFrom(inputs, scope); - out_ = OutFrom(outputs, scope); + const AttributeMap &attrs, const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + out_ = OutFrom(outputs, scope); } const Tensor *InputX() const { return input_x_; } Tensor *Out() const { return out_; } @@ -578,10 +566,9 @@ class SoftmaxParam : public OpParam { class SigmoidParam : public OpParam { public: SigmoidParam(const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - const framework::Scope &scope) { - input_x_ = InputXFrom(inputs, scope); - out_ = OutFrom(outputs, scope); + const AttributeMap &attrs, const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + out_ = OutFrom(outputs, scope); } const Tensor *InputX() const { return input_x_; } Tensor *Out() const { return out_; } @@ -643,9 +630,9 @@ class MultiClassNMSParam : public OpParam { class FeedParam : public OpParam { public: FeedParam(const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, framework::Scope &scope) { - input_x_ = InputXFrom(inputs, scope); - out_ = OutFrom(outputs, scope); + const AttributeMap &attrs, Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + out_ = OutFrom(outputs, scope); auto var = scope.Var("batch_size"); batch_size = var->GetValue(); } @@ -662,10 +649,9 @@ class FeedParam : public OpParam { class FetchParam : public OpParam { public: FetchParam(const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - const framework::Scope &scope) { - input_x_ = InputXFrom(inputs, scope); - out_ = OutFrom(outputs, scope); + const AttributeMap &attrs, const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + out_ = OutFrom(outputs, scope); } const Tensor *InputX() const { return input_x_; } Tensor *Out() const { return out_; } @@ -863,10 +849,10 @@ class FusionConvAddBNReluParam : public OpParam { paddings_ = GetAttr>("paddings", attrs); dilations_ = GetAttr>("dilations", attrs); groups = GetAttr("groups", attrs); - input_bias_ = InputBiasFrom(inputs, scope); - input_mean_ = InputMeanFrom(inputs, scope); - input_scale_ = InputScaleFrom(inputs, scope); - input_variance_ = InputVarianceFrom(inputs, scope); + input_bias_ = InputBiasFrom(inputs, scope); + input_mean_ = InputMeanFrom(inputs, scope); + input_scale_ = InputScaleFrom(inputs, scope); + input_variance_ = InputVarianceFrom(inputs, scope); epsilon_ = GetAttr("epsilon", attrs); momentum_ = GetAttr("momentum", attrs); is_test_ = GetAttr("is_test", attrs); diff --git a/test/net/test_googlenet.cpp b/test/net/test_googlenet.cpp index c6b99a2ed0d76a8e73c4393e67a679b435b9325f..1695995a8d60d20e0d6c5f8911c39a948426a82a 100644 --- a/test/net/test_googlenet.cpp +++ b/test/net/test_googlenet.cpp @@ -17,25 +17,25 @@ limitations under the License. */ #include "../test_include.h" int main() { - paddle_mobile::PaddleMobile paddle_mobile; + paddle_mobile::Loader loader; bool optimize = true; auto time1 = time(); // auto program = loader.Load(g_googlenet, optimize); - if (paddle_mobile.Load(g_googlenet_combine + "/model", - g_googlenet_combine + "/params", optimize)) { - auto time2 = time(); - DLOG << "load cost :" << time_diff(time1, time2) << "ms\n"; - std::vector input; - std::vector dims{1, 3, 224, 224}; - GetInput(g_test_image_1x3x224x224, &input, dims); - auto time3 = time(); + auto program = loader.Load(g_googlenet_combine + "/model", + g_googlenet_combine + "/params", optimize); + auto time2 = time(); + DLOG << "load cost :" << time_diff(time1, time2) << "ms\n"; + paddle_mobile::Executor executor(program, 1, optimize); + std::vector input; + std::vector dims{1, 3, 224, 224}; + GetInput(g_test_image_1x3x224x224, &input, dims); + auto time3 = time(); - for (int i = 0; i < 10; ++i) { - paddle_mobile.Predict(input, dims); - } - - auto time4 = time(); - DLOG << "predict cost :" << time_diff(time3, time4) << "ms\n"; + for (int i = 0; i < 10; ++i) { + executor.Predict(input, dims); } + + auto time4 = time(); + DLOG << "predict cost :" << time_diff(time3, time4) << "ms\n"; return 0; } diff --git a/tools/build.sh b/tools/build.sh index 80caa4011821270549598aa898b0d31b30b437e6..43ce4eb63661bb1f9aa660653771f1cdf2cfed0d 100755 --- a/tools/build.sh +++ b/tools/build.sh @@ -32,8 +32,8 @@ build_for_mac() { build_for_android() { #rm -rf "../build" - if [ -z "${ANDROID_NDK}" ]; then - echo "ANDROID_NDK not found!" + if [ -z "${NDK_ROOT}" ]; then + echo "NDK_ROOT not found!" exit -1 fi