diff --git a/README.md b/README.md index 818e6e31989d5256c82fd7b1e8ae2ccd18a09386..4322548185305c291838d7aee4b57d1826c08915 100644 --- a/README.md +++ b/README.md @@ -8,46 +8,19 @@ [![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)--> -欢迎来到 Paddle-Mobile GitHub 项目。 - -Paddle-Mobile是PaddlePaddle组织下的项目,是一个致力于嵌入式平台的深度学习的框架。Paddle-Mobile设计思想和PaddlePaddle的最新版fluid版本保持了高度一致,同时针对嵌入式做了大量优化。设计之初就对嵌入式的性能、体积、能耗、硬件平台覆盖等方面做了考虑。 - -## 简单搜索线上效果 - -如下gif是简单搜索app的线上主体检测应用效果 - -![ezgif-1-050a733dfb](http://otkwwi4x8.bkt.clouddn.com/2018-07-05-ezgif-1-050a733dfb.gif) - -## Demo目录 - -[点我](https://github.com/PaddlePaddle/paddle-mobile/tree/develop/demo) +欢迎来到 Paddle-Mobile GitHub 项目。Paddle-Mobile是PaddlePaddle组织下的项目,是一个致力于嵌入式平台的深度学习的框架。 ## Features -- **ARM CPU** - -- **Mali GPU** - -- **苹果设备的GPU Metal实现** - -- **FPGA** - - 目前已经支持 ZCU102 开发板。 - -- **灵活性** - - * paddle-mobile cpu版不依赖任何第三库, 可进行快速集成。 - * 使用泛型特化进行平台切换, 可灵活切换 cpu、gpu 和其他协处理器。 - * 可根据特定的常见网络, 进行编译特定的 op, 降低编译时间, 减小包大小。 - * 使用 docker 编译, 提供统一的编译环境。 - * 高可拓展性, 方便拓展其他协处理器, 提供高性能 arm 算子实现, 方便其他协处理器开发者集成开发。 - * 直接兼容 paddle-fluid 模型, 不需要额外的转换操作。 - -- **体积** - - paddle-mobile从设计之初就深入考虑到移动端的包体积的问题,cpu实现中没有外部依赖。在编译过程中,如果该网络不需要的op是完全不会被打入的。同时编译选项优化也为体积压缩提供了帮助。 - 除了二进制体积,我们对代码体积极力避免过大。整个仓库的代码体积也非常小。 +- 高性能支持ARM CPU +- 支持Mali GPU +- 支持Andreno GPU +- 支持苹果设备的GPU Metal实现 +- 支持ZU5、ZU9等FPGA开发板 +- 支持树莓派等arm-linux开发板 +## Demo目录 +[https://github.com/PaddlePaddle/paddle-mobile/tree/develop/demo](https://github.com/PaddlePaddle/paddle-mobile/tree/develop/demo) ## 文档 @@ -74,18 +47,22 @@ Paddle-Mobile是PaddlePaddle组织下的项目,是一个致力于嵌入式平 ### 1. 直接使用Paddle Fluid训练 该方式最为可靠,推荐方式 ### 2. caffe转为Paddle Fluid模型 -[链接](https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/caffe2fluid) +[https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/caffe2fluid](https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/caffe2fluid) ### 3. ONNX ONNX全称为“Open Neural Network Exchange”,即“开放的神经网络切换”。该项目的目的是让不同的神经网络开发框架做到互通互用。 除直接使用PaddlePaddle训练fluid版本的模型外,还可以通过onnx转换得到个别Paddle fluid模型。 -目前,百度也在做onnx支持工作。相关转换项目在这里:[paddle-onnx](https://github.com/PaddlePaddle/paddle-onnx)。 - -![](http://7xop3k.com1.z0.glb.clouddn.com/15311951836000.jpg) +目前,百度也在做onnx支持工作。相关转换项目在这里: +[https://github.com/PaddlePaddle/paddle-onnx](https://github.com/PaddlePaddle/paddle-onnx) ### 4. 部分测试模型和测试图片下载 -[下载链接](http://mms-graph.bj.bcebos.com/paddle-mobile%2FmodelsAndImages.zip) +[http://mms-graph.bj.bcebos.com/paddle-mobile%2FmodelsAndImages.zip](http://mms-graph.bj.bcebos.com/paddle-mobile%2FmodelsAndImages.zip) + + ## 问题解决 @@ -97,5 +74,3 @@ Paddle-Mobile 提供相对宽松的Apache-2.0开源协议 [Apache-2.0 license](L ## 旧版 Mobile-Deep-Learning 原MDL(Mobile-Deep-Learning)工程被迁移到了这里 [Mobile-Deep-Learning](https://github.com/allonli/mobile-deep-learning) - - diff --git a/src/common/types.cpp b/src/common/types.cpp index dc70a6ffe485e1ca420c29b6e6f17125451b4df1..b90fb70f2a81b365f049632cc7281a69ec58e18d 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -44,6 +44,7 @@ const char *G_OP_TYPE_RESHAPE2 = "reshape2"; const char *G_OP_TYPE_SIGMOID = "sigmoid"; const char *G_OP_TYPE_SOFTMAX = "softmax"; const char *G_OP_TYPE_TRANSPOSE = "transpose"; +const char *G_OP_TYPE_TRANSPOSE2 = "transpose2"; const char *G_OP_TYPE_SPLIT = "split"; const char *G_OP_TYPE_FEED = "feed"; const char *G_OP_TYPE_FETCH = "fetch"; @@ -91,6 +92,7 @@ std::unordered_map< {G_OP_TYPE_FEED, {{"X"}, {"Out"}}}, {G_OP_TYPE_FETCH, {{"X"}, {"Out"}}}, {G_OP_TYPE_TRANSPOSE, {{"X"}, {"Out"}}}, + {G_OP_TYPE_TRANSPOSE2, {{"X"}, {"Out", "XShape"}}}, {G_OP_TYPE_BOX_CODER, {{"PriorBox", "PriorBoxVar", "TargetBox"}, {"OutputBox"}}}, {G_OP_TYPE_FUSION_CONV_ADD_BN_RELU, {{"Input"}, {"Out"}}}, diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 552c7154149984cd8c42a1782c11ae5198b71586..982f1c0f3525afde8475866c0121343fafc9d5a0 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -115,6 +115,9 @@ LOAD_OP2(reshape2, CPU, MALI_GPU); #ifdef TRANSPOSE_OP LOAD_OP1(transpose, CPU); #endif +#ifdef TRANSPOSE2_OP +LOAD_OP1(transpose2, CPU); +#endif #ifdef PRIORBOX_OP LOAD_OP1(prior_box, CPU); #endif diff --git a/src/operators/kernel/arm/im2sequence_kernel.cpp b/src/operators/kernel/arm/im2sequence_kernel.cpp index 8295fd94a31db2ad1c10d32a8c639b067e422f45..cc6ae2ae8bc7cde9b365817ba9cafc19776da913 100644 --- a/src/operators/kernel/arm/im2sequence_kernel.cpp +++ b/src/operators/kernel/arm/im2sequence_kernel.cpp @@ -35,7 +35,7 @@ template <> void Im2SequenceKernel::Compute( const Im2SequenceParam ¶m) const { const Tensor *in_x = param.Input(); - Tensor *out = param.Output(); + framework::LoDTensor *out = param.Output(); out->mutable_data(); std::vector kernels = param.Kernels(); @@ -52,22 +52,31 @@ void Im2SequenceKernel::Compute( paddings[2], strides[0]); int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1], paddings[3], strides[1]); - const std::vector dilations({1, 1}); + out->mutable_data({batch_size * output_height * output_width, + img_channels * kernels[0] * kernels[1]}); + const std::vector dilations({1, 1}); // TODO: verify auto out_dims = out->dims(); out->Resize({batch_size, out->numel() / batch_size}); - for (int i = 0; i < batch_size; i++) { const Tensor src = in_x->Slice(i, i + 1).Resize({img_channels, img_height, img_width}); Tensor dst = out->Slice(i, i + 1).Resize( {output_height, output_width, img_channels, kernels[0], kernels[1]}); - math::Im2ColFunctor f; f(src, dilations, strides, paddings, &dst); } out->Resize(out_dims); + framework::LoD lod(1); + lod[0].reserve(batch_size + 1); + int offset = 0; + lod[0].push_back(offset); + for (int i = 0; i < batch_size; ++i) { + offset += output_height * output_width; + lod[0].push_back(offset); + } + out->set_lod(lod); } template class Im2SequenceKernel; diff --git a/src/operators/kernel/arm/transpose2_kernel.cpp b/src/operators/kernel/arm/transpose2_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..656d2768840a52f50c42d3797018aa9aec037783 --- /dev/null +++ b/src/operators/kernel/arm/transpose2_kernel.cpp @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef TRANSPOSE2_OP + +#include "operators/kernel/transpose2_kernel.h" +#include "operators/kernel/central-arm-func/transpose2_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool Transpose2Kernel::Init(Transpose2Param *param) { + return true; +} + +template <> +void Transpose2Kernel::Compute( + const Transpose2Param ¶m) const { + Transpose2Compute(param); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h index 42c01d2825e052a52e7021a1b2a97997fb9c915b..45d5dc76d1e95668638706a252cc24d7ff2dec40 100644 --- a/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h +++ b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h @@ -29,10 +29,9 @@ void FusionFcCompute(const FusionFcParam ¶m) { auto *input_z_data = input_z->data(); int axis = param.Axis(); Tensor *out = param.Out(); - auto *out_data = out->mutable_data(); // int m = out->dims()[0]; // int n = out->dims()[1]; - + auto *out_data = out->mutable_data(); const Tensor x_matrix = input_x->dims().size() > 2 ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) diff --git a/src/operators/kernel/central-arm-func/pool_arm_func.h b/src/operators/kernel/central-arm-func/pool_arm_func.h index 37479c22efe95b6506054cf3ded5855aa766c34c..941c237865707bce854aedba56029a4f5de9b2bf 100644 --- a/src/operators/kernel/central-arm-func/pool_arm_func.h +++ b/src/operators/kernel/central-arm-func/pool_arm_func.h @@ -83,6 +83,7 @@ void PoolCompute(const PoolParam ¶m) { #if __aarch64__ PoolBasic(pooling_type, ksize, strides, paddings, in_x, out); #else + /// todo: fix bug in Pool2x2 if (pooling_type == "max") { math::Pool2x2Maxs2p0(strides, paddings, in_x, out); } else if (pooling_type == "avg") { diff --git a/src/operators/kernel/central-arm-func/softmax_arm_func.h b/src/operators/kernel/central-arm-func/softmax_arm_func.h index d311d97984a7207df9075befe71a9806092966e1..a94c8299c514bc9e2937daf57b1a845d7be56b16 100644 --- a/src/operators/kernel/central-arm-func/softmax_arm_func.h +++ b/src/operators/kernel/central-arm-func/softmax_arm_func.h @@ -24,6 +24,7 @@ void SoftmaxCompute(const SoftmaxParam ¶m) { Tensor *out = param.Out(); auto x_dims = in_x->dims(); out->Resize(x_dims); + out->mutable_data(); math::SoftmaxFuntor()(in_x, out); } } // namespace operators diff --git a/src/operators/kernel/central-arm-func/transpose2_arm_func.h b/src/operators/kernel/central-arm-func/transpose2_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..dea90e863b20f19820d60d9cce67b6849d3c467b --- /dev/null +++ b/src/operators/kernel/central-arm-func/transpose2_arm_func.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef TRANSPOSE2_OP +#pragma once + +#include +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +void Transpose2Compute(const Transpose2Param& param) { + const auto* input_x = param.InputX(); + const auto input_x_dims = input_x->dims(); + auto* out = param.Out(); + const auto axis = param.Axis(); + const auto* input_x_data = input_x->data(); + auto* out_data = out->mutable_data(); + + size_t ndim = axis.size(); + std::vector xdim(ndim); + std::vector xstride(ndim); + std::vector xout(ndim); + for (int i = 0; i < ndim; i++) { + int j = ndim - 1 - i; + xdim[j] = input_x_dims[axis[i]]; + xstride[j] = 1; + for (int k = axis[i] + 1; k < ndim; k++) { + xstride[j] *= input_x_dims[k]; + } + xout[j] = xstride[j] * xdim[j]; + } + + auto numel = input_x->numel(); + size_t pind = 0; + std::vector ind(ndim); + for (int i = 0; i < numel; i++) { + out_data[i] = input_x_data[pind]; + ind[0]++; + pind += xstride[0]; + for (int j = 0; j < ndim - 1; j++) { + if (ind[j] == xdim[j]) { + ind[j + 1]++; + ind[j] = 0; + pind += xstride[j + 1]; + pind -= xout[j]; + } else { + break; + } + } + } +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/transpose2_kernel.h b/src/operators/kernel/transpose2_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..8ae75ea483ddb887d9c53b32228ff72b41c76097 --- /dev/null +++ b/src/operators/kernel/transpose2_kernel.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef TRANSPOSE2_OP + +#pragma once + +#include + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class Transpose2Kernel + : public framework::OpKernelBase> { + public: + void Compute(const Transpose2Param& param) const; + bool Init(Transpose2Param* param); +}; +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/math/depthwise_conv_3x3.cpp b/src/operators/math/depthwise_conv_3x3.cpp index b6cf28a9ca665a1496ee8032f87c013137deade8..fac3b95e27f4b95b395e84ad87cc3fd380b3c4dd 100644 --- a/src/operators/math/depthwise_conv_3x3.cpp +++ b/src/operators/math/depthwise_conv_3x3.cpp @@ -257,8 +257,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, const int h = static_cast(input->dims()[2]); const int w = static_cast(input->dims()[3]); - const int l = h; - + // const int l = h; const int batch_size = static_cast(input->dims()[0]); const int c = static_cast(input->dims()[1]); const int hxw = h * w; @@ -271,7 +270,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, vbias = vdupq_n_f32(bias_data[j]); } - int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 + int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0 float w00 = filter_data_tmp[0]; float w01 = filter_data_tmp[1]; float w02 = filter_data_tmp[2]; @@ -283,39 +282,38 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, float w22 = filter_data_tmp[8]; output_data[0] = w11 * input_data[0] + w12 * input_data[1] + - w21 * input_data[l] + w22 * input_data[l + 1]; - output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + - w20 * input_data[2 * l - 2] + - w21 * input_data[2 * l - 1]; - output_data[(l - 1) * l] = - w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + - w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1]; - output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] + - w01 * input_data[(l - 2) * (l + 1) + 1] + - w10 * input_data[l * l - 2] + - w11 * input_data[l * l - 1]; + w21 * input_data[w] + w22 * input_data[w + 1]; + output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w - 1] + + w20 * input_data[2 * w - 2] + + w21 * input_data[2 * w - 1]; + output_data[(h - 1) * w] = + w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] + + w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1]; + output_data[h * w - 1] = + w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] + + w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1]; if (if_bias) { output_data[0] += bias_data[j]; - output_data[l - 1] += bias_data[j]; - output_data[(l - 1) * l] += bias_data[j]; - output_data[l * l - 1] += bias_data[j]; + output_data[w - 1] += bias_data[j]; + output_data[(h - 1) * w] += bias_data[j]; + output_data[h * w - 1] += bias_data[j]; } - for (int i = 1; i < l - 1; ++i) { - output_data[i * l] = - w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + - w11 * input_data[i * l] + w12 * input_data[i * l + 1] + - w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1]; - - output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + - w01 * input_data[i * l + l - 1 - l] + - w10 * input_data[i * l + l - 1 - 1] + - w11 * input_data[i * l + l - 1] + - w20 * input_data[i * l + l - 1 + l - 1] + - w21 * input_data[i * l + l - 1 + l]; + for (int i = 1; i < h - 1; ++i) { + output_data[i * w] = + w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] + + w11 * input_data[i * w] + w12 * input_data[i * w + 1] + + w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1]; + + output_data[i * w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + + w01 * input_data[i * w + w - 1 - w] + + w10 * input_data[i * w + w - 1 - 1] + + w11 * input_data[i * w + w - 1] + + w20 * input_data[i * w + w - 1 + w - 1] + + w21 * input_data[i * w + w - 1 + w]; if (if_bias) { - output_data[i * l] += bias_data[j]; - output_data[i * l + l - 1] += bias_data[j]; + output_data[i * w] += bias_data[j]; + output_data[i * w + w - 1] += bias_data[j]; } } @@ -325,15 +323,15 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, out0; in0 = vld1q_f32(input_tmp); - in2 = vld1q_f32(input_tmp + l); - const float *input_tmp_end = input_tmp + (l - 2) * l; + in2 = vld1q_f32(input_tmp + w); + const float *input_tmp_end = input_tmp + (h - 2) * w; in4 = vld1q_f32(input_tmp_end); - in6 = vld1q_f32(input_tmp_end + l); - int c_mid = l_mid; + in6 = vld1q_f32(input_tmp_end + w); + int c_mid = w_mid; auto output_ptr = output_data + 1; for (; c_mid > 3; c_mid -= 4) { in1 = vld1q_f32(input_tmp + 4); - in3 = vld1q_f32(input_tmp + l + 4); + in3 = vld1q_f32(input_tmp + w + 4); tmp0 = vextq_f32(in0, in1, 1); tmp1 = vextq_f32(in0, in1, 2); @@ -352,7 +350,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, vst1q_f32(output_ptr, out0); in5 = vld1q_f32(input_tmp_end + 4); - in7 = vld1q_f32(input_tmp_end + l + 4); + in7 = vld1q_f32(input_tmp_end + w + 4); tmp0 = vextq_f32(in4, in5, 1); tmp1 = vextq_f32(in4, in5, 2); @@ -367,7 +365,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, out0 = vmlaq_n_f32(out0, tmp3, w12); out0 = vaddq_f32(out0, vbias); - vst1q_f32(output_ptr + (l - 1) * l, out0); + vst1q_f32(output_ptr + (h - 1) * w, out0); // can optimize to each 8 stride. input_tmp += 4; @@ -380,8 +378,8 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, } // top right pad - float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); + float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]); tmp0 = vextq_f32(in0, pad0, 1); tmp1 = vextq_f32(in0, pad0, 2); @@ -409,8 +407,8 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, } // bottom right pad - float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); - float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); + float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]); + float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]); tmp0 = vextq_f32(in4, pad2, 1); tmp1 = vextq_f32(in4, pad2, 2); @@ -427,28 +425,28 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, for (int i = 0; i < c_mid; ++i) { if (i == 0) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); + vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0); } if (i == 1) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); + vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1); } if (i == 2) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); + vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2); } } // mid - for (int i = 0; i < l - 2; ++i) { - auto output_ptr = output_data + (i + 1) * l + 1; - input_tmp = input_data + i * l; + for (int i = 0; i < h - 2; ++i) { + auto output_ptr = output_data + (i + 1) * w + 1; + input_tmp = input_data + i * w; auto in0_tmp = vld1q_f32(input_tmp); - auto in2_tmp = vld1q_f32(input_tmp + l); - auto in4_tmp = vld1q_f32(input_tmp + l + l); - c_mid = l_mid; + auto in2_tmp = vld1q_f32(input_tmp + w); + auto in4_tmp = vld1q_f32(input_tmp + w + w); + c_mid = w_mid; for (; c_mid > 3; c_mid -= 4) { auto in1_tmp = vld1q_f32(input_tmp + 4); - auto in3_tmp = vld1q_f32(input_tmp + l + 4); - auto in5_tmp = vld1q_f32(input_tmp + l + l + 4); + auto in3_tmp = vld1q_f32(input_tmp + w + 4); + auto in5_tmp = vld1q_f32(input_tmp + w + w + 4); tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); @@ -477,9 +475,9 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, in4_tmp = in5_tmp; } - float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); - float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); + float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]); + float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]); tmp0 = vextq_f32(in0_tmp, pad0, 1); tmp1 = vextq_f32(in0_tmp, pad0, 2); @@ -539,8 +537,9 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, const int hxw = input_height * input_width; - const int l = input_height; - + // const int l = input_height; + const int h = input_height; + const int w = input_width; float32x4_t vzero = vdupq_n_f32(0); for (int b = 0; b < batch_size; b++) { @@ -626,54 +625,53 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, } output_data[0] = w11 * input_data[0] + w12 * input_data[1] + - w21 * input_data[l] + w22 * input_data[l + 1]; - output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + - w20 * input_data[2 * l - 2] + - w21 * input_data[2 * l - 1]; - output_data[(l - 1) * l] = - w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + - w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1]; - output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] + - w01 * input_data[(l - 2) * (l + 1) + 1] + - w10 * input_data[l * l - 2] + - w11 * input_data[l * l - 1]; + w21 * input_data[w] + w22 * input_data[w + 1]; + output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w - 1] + + w20 * input_data[2 * w - 2] + + w21 * input_data[2 * w - 1]; + output_data[(h - 1) * w] = + w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] + + w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1]; + output_data[h * w - 1] = + w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] + + w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1]; output_data[0] = output_data[0] * newscale_data[c] + newbias_data[c]; - output_data[l - 1] = - output_data[l - 1] * newscale_data[c] + newbias_data[c]; - output_data[(l - 1) * l] = - output_data[(l - 1) * l] * newscale_data[c] + newbias_data[c]; - output_data[l * l - 1] = - output_data[l * l - 1] * newscale_data[c] + newbias_data[c]; + output_data[w - 1] = + output_data[w - 1] * newscale_data[c] + newbias_data[c]; + output_data[(h - 1) * w] = + output_data[(h - 1) * w] * newscale_data[c] + newbias_data[c]; + output_data[h * w - 1] = + output_data[h * w - 1] * newscale_data[c] + newbias_data[c]; if (if_relu) { output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; - output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l - 1]; - output_data[(l - 1) * l] = - output_data[(l - 1) * l] < 0 ? 0 : output_data[(l - 1) * l]; - output_data[l * l - 1] = - output_data[l * l - 1] < 0 ? 0 : output_data[l * l - 1]; + output_data[w - 1] = output_data[w - 1] < 0 ? 0 : output_data[w - 1]; + output_data[(h - 1) * w] = + output_data[(h - 1) * w] < 0 ? 0 : output_data[(h - 1) * w]; + output_data[h * w - 1] = + output_data[h * w - 1] < 0 ? 0 : output_data[h * w - 1]; } - for (int i = 1; i < l - 1; ++i) { - output_data[i * l] = - w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + - w11 * input_data[i * l] + w12 * input_data[i * l + 1] + - w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1]; - - output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + - w01 * input_data[i * l + l - 1 - l] + - w10 * input_data[i * l + l - 1 - 1] + - w11 * input_data[i * l + l - 1] + - w20 * input_data[i * l + l - 1 + l - 1] + - w21 * input_data[i * l + l - 1 + l]; - output_data[i * l] = - output_data[i * l] * newscale_data[c] + newbias_data[c]; - output_data[i * l + l - 1] = - output_data[i * l + l - 1] * newscale_data[c] + newbias_data[c]; + for (int i = 1; i < h - 1; ++i) { + output_data[i * w] = + w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] + + w11 * input_data[i * w] + w12 * input_data[i * w + 1] + + w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1]; + + output_data[i * w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + + w01 * input_data[i * w + w - 1 - w] + + w10 * input_data[i * w + w - 1 - 1] + + w11 * input_data[i * w + w - 1] + + w20 * input_data[i * w + w - 1 + w - 1] + + w21 * input_data[i * w + w - 1 + w]; + output_data[i * w] = + output_data[i * w] * newscale_data[c] + newbias_data[c]; + output_data[i * w + w - 1] = + output_data[i * w + w - 1] * newscale_data[c] + newbias_data[c]; if (if_relu) { - output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l]; - output_data[i * l + l - 1] = - output_data[i * l + l - 1] < 0 ? 0 : output_data[i * l + l - 1]; + output_data[i * w] = output_data[i * w] < 0 ? 0 : output_data[i * w]; + output_data[i * w + w - 1] = + output_data[i * w + w - 1] < 0 ? 0 : output_data[i * w + w - 1]; } } @@ -776,7 +774,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, const int h = static_cast(input->dims()[2]); const int w = static_cast(input->dims()[3]); - const int l = h; +// const int l = h; const int batch_size = static_cast(input->dims()[0]); const int c = static_cast(input->dims()[1]); @@ -792,7 +790,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, vnewbias = vdupq_n_f32(newbias_data[j]); vnewscale = vdupq_n_f32(newscale_data[j]); - int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 + int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0 float w00 = filter_data_tmp[0]; float w01 = filter_data_tmp[1]; float w02 = filter_data_tmp[2]; @@ -804,49 +802,49 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, float w22 = filter_data_tmp[8]; output_data[0] = w11 * input_data[0] + w12 * input_data[1] + - w21 * input_data[l] + w22 * input_data[l + 1]; - - output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - - 1] + w20 * input_data[2 * l - 2] + w21 * input_data[2 * l - 1]; - - output_data[(l - 1) * l] = - w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + - 1] + w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1]; - output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] + - w01 * input_data[(l - 2) * (l + 1) + 1] + - w10 * input_data[l * l - 2] + - w11 * input_data[l * l - 1]; + w21 * input_data[w] + w22 * input_data[w + 1]; + + output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w - + 1] + w20 * input_data[2 * w - 2] + w21 * input_data[2 * w - 1]; + + output_data[(h - 1) * w] = + w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + + 1] + w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1]; + output_data[h * w - 1] = w00 * input_data[h*w-w-2] + + w01 * input_data[h*w-w-1] + + w10 * input_data[h * w - 2] + + w11 * input_data[h * w - 1]; output_data[0] = output_data[0] * newscale_data[j] + - newbias_data[j]; output_data[l - 1] = output_data[l - 1] * - newscale_data[j] + newbias_data[j]; output_data[(l - 1) * l] = - output_data[(l - 1) * l] * newscale_data[j] + newbias_data[j]; - output_data[l * l - 1] = - output_data[l * l - 1] * newscale_data[j] + newbias_data[j]; + newbias_data[j]; output_data[w - 1] = output_data[w - 1] * + newscale_data[j] + newbias_data[j]; output_data[(h - 1) * w] = + output_data[(h - 1) * w] * newscale_data[j] + newbias_data[j]; + output_data[h * w - 1] = + output_data[h * w - 1] * newscale_data[j] + newbias_data[j]; if (if_relu) { output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; - output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l - - 1]; output_data[(l - 1) * l] = output_data[(l - 1) * l] < 0 ? 0 : - output_data[(l - 1) * l]; output_data[l * l - 1] = output_data[l * l - 1] - < 0 ? 0 : output_data[l * l - 1]; + output_data[w - 1] = output_data[w - 1] < 0 ? 0 : output_data[w - + 1]; output_data[(h - 1) * w] = output_data[(h - 1) * w] < 0 ? 0 : + output_data[(h - 1) * w]; output_data[h * w - 1] = output_data[h * w - 1] + < 0 ? 0 : output_data[h * w - 1]; } - for (int i = 1; i < l - 1; ++i) { - output_data[i * l] = - w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] - + w11 * input_data[i * l] + w12 * input_data[i * l + 1] + w21 * - input_data[i * l + l] + w22 * input_data[i * l + l + 1]; output_data[i * - l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + w01 * input_data[i - * l + l - 1 - l] + w10 * input_data[i * l + l - 1 - 1] + w11 * - input_data[i * l + l - 1] + w20 * input_data[i * l + l - 1 + l - 1] + w21 - * input_data[i * l + l - 1 + l]; output_data[i * l] = output_data[i * l] - * newscale_data[j] + newbias_data[j]; output_data[i * l + l - 1] = - output_data[i * l + l - 1] * newscale_data[j] + + for (int i = 1; i < h - 1; ++i) { + output_data[i * w] = + w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] + + w11 * input_data[i * w] + w12 * input_data[i * w + 1] + w21 * + input_data[i * w + w] + w22 * input_data[i * w + w + 1]; output_data[i * + w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + w01 * input_data[i + * w + w - 1 - w] + w10 * input_data[i * w + w - 1 - 1] + w11 * + input_data[i * w + w - 1] + w20 * input_data[i * w + w - 1 + w - 1] + w21 + * input_data[i * w + w - 1 + w]; output_data[i * w] = output_data[i * w] + * newscale_data[j] + newbias_data[j]; output_data[i * w + w - 1] = + output_data[i * w + w - 1] * newscale_data[j] + newbias_data[j]; if (if_relu) { - output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i - * l]; output_data[i * l + l - 1] = output_data[i * l + l - 1] < 0 ? 0 : - output_data[i * l + l - 1]; + output_data[i * w] = output_data[i * w] < 0 ? 0 : output_data[i + * w]; output_data[i * w + w - 1] = output_data[i * w + w - 1] < 0 ? 0 : + output_data[i * w + w - 1]; } } @@ -855,11 +853,11 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, out0; in0 = vld1q_f32(input_tmp); in2 = - vld1q_f32(input_tmp + l); const float *input_tmp_end = input_tmp + (l - - 2) * l; in4 = vld1q_f32(input_tmp_end); in6 = vld1q_f32(input_tmp_end + - l); int c_mid = l_mid; auto output_ptr = output_data + 1; for (; c_mid > + vld1q_f32(input_tmp + w); const float *input_tmp_end = input_tmp + (h - + 2) * w; in4 = vld1q_f32(input_tmp_end); in6 = vld1q_f32(input_tmp_end + + w); int c_mid = w_mid; auto output_ptr = output_data + 1; for (; c_mid > 3; c_mid -= 4) { in1 = vld1q_f32(input_tmp + 4); in3 = - vld1q_f32(input_tmp + l + 4); + vld1q_f32(input_tmp + w + 4); tmp0 = vextq_f32(in0, in1, 1); tmp1 = vextq_f32(in0, in1, 2); @@ -880,7 +878,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, vst1q_f32(output_ptr, out0); in5 = vld1q_f32(input_tmp_end + 4); - in7 = vld1q_f32(input_tmp_end + l + 4); + in7 = vld1q_f32(input_tmp_end + w + 4); tmp0 = vextq_f32(in4, in5, 1); tmp1 = vextq_f32(in4, in5, 2); @@ -897,7 +895,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, if (if_relu) { out0 = vmaxq_f32(out0, vzero); } - vst1q_f32(output_ptr + (l - 1) * l, out0); + vst1q_f32(output_ptr + (h - 1) * w, out0); // can optimize to each 8 stride. input_tmp += 4; @@ -910,8 +908,8 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, } // top right pad - float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); + float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]); tmp0 = vextq_f32(in0, pad0, 1); tmp1 = vextq_f32(in0, pad0, 2); @@ -941,8 +939,8 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, } // bottom right pad - float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); - float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); + float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]); + float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]); tmp0 = vextq_f32(in4, pad2, 1); tmp1 = vextq_f32(in4, pad2, 2); @@ -961,29 +959,29 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, } for (int i = 0; i < c_mid; ++i) { if (i == 0) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); + vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0); } if (i == 1) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); + vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1); } if (i == 2) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); + vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2); } } // mid - for (int i = 0; i < l - 2; ++i) { - auto output_ptr = output_data + (i + 1) * l + 1; - input_tmp = input_data + i * l; + for (int i = 0; i < h - 2; ++i) { + auto output_ptr = output_data + (i + 1) * w + 1; + input_tmp = input_data + i * w; auto in0_tmp = vld1q_f32(input_tmp); - auto in2_tmp = vld1q_f32(input_tmp + l); - auto in4_tmp = vld1q_f32(input_tmp + l + l); - c_mid = l_mid; + auto in2_tmp = vld1q_f32(input_tmp + w); + auto in4_tmp = vld1q_f32(input_tmp + w + w); + c_mid = w_mid; for (; c_mid > 3; c_mid -= 4) { auto in1_tmp = vld1q_f32(input_tmp + 4); - auto in3_tmp = vld1q_f32(input_tmp + l + 4); - auto in5_tmp = vld1q_f32(input_tmp + l + l + 4); + auto in3_tmp = vld1q_f32(input_tmp + w + 4); + auto in5_tmp = vld1q_f32(input_tmp + w + w + 4); tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); @@ -1014,9 +1012,9 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, in4_tmp = in5_tmp; } - float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); - float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); + float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]); + float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]); tmp0 = vextq_f32(in0_tmp, pad0, 1); tmp1 = vextq_f32(in0_tmp, pad0, 2); @@ -1060,6 +1058,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, #endif } +/// w!=h not fix void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, Tensor *output, const Tensor *new_scale, const Tensor *new_bias, bool if_relu) { @@ -1275,7 +1274,8 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, const int in_l = in_h; const int inhxw = in_h * in_w; const int outhxw = out_h * out_w; - const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0; + /// todo : fix if_pad when w != h + const int if_pad = in_w - 1 == (out_w - 1) * 2 ? 1 : 0; const int batch_size = static_cast(input->dims()[0]); const int c = static_cast(input->dims()[1]); const float *input_row_ptr; @@ -1381,9 +1381,9 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, if ((w4 != w_times)) { vst1q_f32(output_row_ptr, res3); } else { - if (out_l - 2 - w_times * 3 == 1) { + if (out_w - 2 - w_times * 3 == 1) { vst1q_lane_f32(output_row_ptr, res3, 0); - } else if (out_l - 2 - w_times * 3 == 2) { + } else if (out_w - 2 - w_times * 3 == 2) { vst1q_lane_f32(output_row_ptr, res3, 0); vst1q_lane_f32(output_row_ptr + 1, res3, 1); } @@ -1393,28 +1393,28 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, } output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 + - input_const[in_l] * w21 + - input_const[in_l + 1] * w22; + input_const[in_w] * w21 + + input_const[in_w + 1] * w22; - out2in_mid = (out_l - 1) * 2; - output_data_tmp[out_l - 1] = + out2in_mid = (out_w - 1) * 2; + output_data_tmp[out_w - 1] = w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + w20 * input_const[out2in_mid + in_w - 1] + w21 * input_const[out2in_mid + in_w] + (1 - if_pad) * (w12 * input_const[out2in_mid + 1] + w22 * input_const[out2in_mid + in_w + 1]); - out2in_mid = (out_l - 1) * 2 * in_w; + out2in_mid = (out_h - 1) * 2 * in_w; - output_data_tmp[out_l * (out_l - 1)] = + output_data_tmp[out_w * (out_h - 1)] = w01 * input_const[out2in_mid - in_w] + w02 * input_const[out2in_mid - in_w + 1] + w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + (1 - if_pad) * (w21 * input_const[out2in_mid + in_w] + w22 * input_const[out2in_mid + in_w + 1]); - out2in_mid = (out_l - 1) * 2 * in_w + (out_l - 1) * 2; + out2in_mid = (out_h - 1) * 2 * in_w + (out_w - 1) * 2; - output_data_tmp[out_l * out_l - 1] = + output_data_tmp[out_h * out_w - 1] = w00 * input_const[out2in_mid - in_w - 1] + w01 * input_const[out2in_mid - in_w] + w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + @@ -1425,21 +1425,21 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, w22 * input_const[out2in_mid + in_w + 1]); if (if_bias) { output_data_tmp[0] += bias_data[j]; - output_data_tmp[out_l - 1] += bias_data[j]; - output_data_tmp[out_l * (out_l - 1)] += bias_data[j]; - output_data_tmp[out_l * out_l - 1] += bias_data[j]; + output_data_tmp[out_w - 1] += bias_data[j]; + output_data_tmp[out_w * (out_h - 1)] += bias_data[j]; + output_data_tmp[out_h * out_w - 1] += bias_data[j]; } for (int i = 1; i < out_h - 1; i++) { out2in_mid = i * 2 * in_w; - output_data_tmp[i * out_l] = w01 * input_const[out2in_mid - in_w] + + output_data_tmp[i * out_w] = w01 * input_const[out2in_mid - in_w] + w02 * input_const[out2in_mid - in_w + 1] + w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + w21 * input_const[out2in_mid + in_w] + w22 * input_const[out2in_mid + in_w + 1]; - out2in_mid = i * 2 * in_w + (out_l - 1) * 2; - output_data_tmp[i * out_l + out_l - 1] = + out2in_mid = i * 2 * in_w + (out_w - 1) * 2; + output_data_tmp[i * out_w + out_w - 1] = w00 * input_const[out2in_mid - in_w - 1] + w01 * input_const[out2in_mid - in_w] + w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + @@ -1449,8 +1449,8 @@ void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, w12 * input_const[out2in_mid + 1] + w22 * input_const[out2in_mid + in_w + 1]); if (if_bias) { - output_data_tmp[i * out_l] += bias_data[j]; - output_data_tmp[i * out_l + out_l - 1] += bias_data[j]; + output_data_tmp[i * out_w] += bias_data[j]; + output_data_tmp[i * out_w + out_w - 1] += bias_data[j]; } } filter_data_tmp += 9; @@ -1657,11 +1657,12 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, const int in_w = static_cast(input->dims()[3]); const int out_h = static_cast(output->dims()[2]); const int out_w = static_cast(output->dims()[3]); - const int out_l = out_h; - const int in_l = in_h; + // const int out_l = out_h; + // const int in_l = in_h; const int inhxw = in_h * in_w; const int outhxw = out_h * out_w; - const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0; + /// todo : fix if_pad when w != h + const int if_pad = in_w - 1 == (out_w - 1) * 2 ? 1 : 0; const int batch_size = static_cast(input->dims()[0]); const int c = static_cast(input->dims()[1]); const int w_times = (out_w - 2) / 3; @@ -1775,9 +1776,9 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, vst1q_lane_f32(output_row_ptr + 1, res3, 1); vst1q_lane_f32(output_row_ptr + 2, res3, 2); } else { - if (out_l - 2 - w_times * 3 == 1) { + if (out_w - 2 - w_times * 3 == 1) { vst1q_lane_f32(output_row_ptr, res3, 0); - } else if (out_l - 2 - w_times * 3 == 2) { + } else if (out_w - 2 - w_times * 3 == 2) { vst1q_lane_f32(output_row_ptr, res3, 0); vst1q_lane_f32(output_row_ptr + 1, res3, 1); } @@ -1787,28 +1788,28 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, } output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 + - input_const[in_l] * w21 + - input_const[in_l + 1] * w22; + input_const[in_w] * w21 + + input_const[in_w + 1] * w22; - out2in_mid = (out_l - 1) * 2; - output_data_tmp[out_l - 1] = + out2in_mid = (out_w - 1) * 2; + output_data_tmp[out_w - 1] = w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + w20 * input_const[out2in_mid + in_w - 1] + w21 * input_const[out2in_mid + in_w] + (1 - if_pad) * (w12 * input_const[out2in_mid + 1] + w22 * input_const[out2in_mid + in_w + 1]); - out2in_mid = (out_l - 1) * 2 * in_w; + out2in_mid = (out_h - 1) * 2 * in_w; - output_data_tmp[out_l * (out_l - 1)] = + output_data_tmp[out_w * (out_h - 1)] = w01 * input_const[out2in_mid - in_w] + w02 * input_const[out2in_mid - in_w + 1] + w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + (1 - if_pad) * (w21 * input_const[out2in_mid + in_w] + w22 * input_const[out2in_mid + in_w + 1]); - out2in_mid = (out_l - 1) * 2 * in_w + (out_l - 1) * 2; + out2in_mid = (out_h - 1) * 2 * in_w + (out_w - 1) * 2; - output_data_tmp[out_l * out_l - 1] = + output_data_tmp[out_h * out_w - 1] = w00 * input_const[out2in_mid - in_w - 1] + w01 * input_const[out2in_mid - in_w] + w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + @@ -1819,38 +1820,38 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, w22 * input_const[out2in_mid + in_w + 1]); output_data_tmp[0] = output_data_tmp[0] * newscale_data[j] + newbias_data[j]; - output_data_tmp[out_l - 1] = - output_data_tmp[out_l - 1] * newscale_data[j] + newbias_data[j]; - output_data_tmp[out_l * (out_l - 1)] = - output_data_tmp[out_l * (out_l - 1)] * newscale_data[j] + + output_data_tmp[out_w - 1] = + output_data_tmp[out_w - 1] * newscale_data[j] + newbias_data[j]; + output_data_tmp[out_w * (out_h - 1)] = + output_data_tmp[out_w * (out_h - 1)] * newscale_data[j] + newbias_data[j]; - output_data_tmp[out_l * out_l - 1] = - output_data_tmp[out_l * out_l - 1] * newscale_data[j] + + output_data_tmp[out_h * out_w - 1] = + output_data_tmp[out_h * out_w - 1] * newscale_data[j] + newbias_data[j]; if (if_relu) { output_data_tmp[0] = output_data_tmp[0] < 0 ? 0 : output_data_tmp[0]; - output_data_tmp[out_l - 1] = - output_data_tmp[out_l - 1] < 0 ? 0 : output_data_tmp[out_l - 1]; - output_data_tmp[out_l * (out_l - 1)] = - output_data_tmp[out_l * (out_l - 1)] < 0 + output_data_tmp[out_w - 1] = + output_data_tmp[out_w - 1] < 0 ? 0 : output_data_tmp[out_w - 1]; + output_data_tmp[out_w * (out_h - 1)] = + output_data_tmp[out_w * (out_h - 1)] < 0 ? 0 - : output_data_tmp[out_l * (out_l - 1)]; - output_data_tmp[out_l * out_l - 1] = - output_data_tmp[out_l * out_l - 1] < 0 + : output_data_tmp[out_w * (out_h - 1)]; + output_data_tmp[out_h * out_w - 1] = + output_data_tmp[out_h * out_w - 1] < 0 ? 0 - : output_data_tmp[out_l * out_l - 1]; + : output_data_tmp[out_h * out_w - 1]; } for (int i = 1; i < out_h - 1; i++) { out2in_mid = i * 2 * in_w; - output_data_tmp[i * out_l] = w01 * input_const[out2in_mid - in_w] + + output_data_tmp[i * out_w] = w01 * input_const[out2in_mid - in_w] + w02 * input_const[out2in_mid - in_w + 1] + w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + w21 * input_const[out2in_mid + in_w] + w22 * input_const[out2in_mid + in_w + 1]; - out2in_mid = i * 2 * in_w + (out_l - 1) * 2; - output_data_tmp[i * out_l + out_l - 1] = + out2in_mid = i * 2 * in_w + (out_w - 1) * 2; + output_data_tmp[i * out_w + out_w - 1] = w00 * input_const[out2in_mid - in_w - 1] + w01 * input_const[out2in_mid - in_w] + w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + @@ -1859,18 +1860,18 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, (1 - if_pad) * (w02 * input_const[out2in_mid - in_w + 1] + w12 * input_const[out2in_mid + 1] + w22 * input_const[out2in_mid + in_w + 1]); - output_data_tmp[i * out_l] = - output_data_tmp[i * out_l] * newscale_data[j] + newbias_data[j]; - output_data_tmp[i * out_l + out_l - 1] = - output_data_tmp[i * out_l + out_l - 1] * newscale_data[j] + + output_data_tmp[i * out_w] = + output_data_tmp[i * out_w] * newscale_data[j] + newbias_data[j]; + output_data_tmp[i * out_w + out_w - 1] = + output_data_tmp[i * out_w + out_w - 1] * newscale_data[j] + newbias_data[j]; if (if_relu) { - output_data_tmp[i * out_l] = - output_data_tmp[i * out_l] < 0 ? 0 : output_data_tmp[i * out_l]; - output_data_tmp[i * out_l + out_l - 1] = - output_data_tmp[i * out_l + out_l - 1] < 0 + output_data_tmp[i * out_w] = + output_data_tmp[i * out_w] < 0 ? 0 : output_data_tmp[i * out_w]; + output_data_tmp[i * out_w + out_w - 1] = + output_data_tmp[i * out_w + out_w - 1] < 0 ? 0 - : output_data_tmp[i * out_l + out_l - 1]; + : output_data_tmp[i * out_w + out_w - 1]; } } } diff --git a/src/operators/math/im2col.cpp b/src/operators/math/im2col.cpp index 4c81e7fa3bd4e5ea36f04b453d4f84468745f919..47055ec4f24e5b5b226c1f084bb2253d2ebb77c7 100644 --- a/src/operators/math/im2col.cpp +++ b/src/operators/math/im2col.cpp @@ -53,7 +53,7 @@ void Im2ColFunctor::operator()( (((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0)); int fill = isize % 2; if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 && - dilation[0] == 1 && im_height > 2) { + dilation[0] == 1 && im_height > 2 && im_height == im_width) { for (int c = 0; c < im_channels; ++c) { int oosize = osize * osize; int nk4 = osize / 4; @@ -225,7 +225,7 @@ void Im2ColFunctor::operator()( im_data += isize * isize; } } else if (stride[0] == 2 && filter_height == 3 && pad1 && dilation[0] == 1 && - im_height > 2) { + im_height > 2 && im_height == im_width) { for (int c = 0; c < im_channels; ++c) { int oosize = osize * osize; int nk4 = osize / 4; @@ -605,7 +605,6 @@ class Im2ColFunctor { const T *im_data = im.data(); T *col_data = col->data(); - for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) { for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) { for (int channel = 0; channel < im_channels; ++channel) { @@ -617,7 +616,6 @@ class Im2ColFunctor { ++filter_col_idx) { int im_col_offset = col_col_idx * stride[1] + filter_col_idx - padding[1]; - int col_offset = ((((col_row_idx)*col_width + col_col_idx) * im_channels + channel) * @@ -625,7 +623,6 @@ class Im2ColFunctor { filter_row_idx) * filter_width + filter_col_idx; - int im_offset = (channel * im_height + im_row_offset) * im_width + im_col_offset; col_data[col_offset] = diff --git a/src/operators/math/pool_2x2.cpp b/src/operators/math/pool_2x2.cpp index 9dc3dbafed990de2f4057d98a2accdd8ce2fd7db..88bf866b73f6f06d28f6e1868031ae1a25b9b31c 100644 --- a/src/operators/math/pool_2x2.cpp +++ b/src/operators/math/pool_2x2.cpp @@ -58,7 +58,7 @@ void Pool2x2Maxs2p0(vector strides, vector paddings, const float *in_ptr1 = input_data + i * input_batch_stride + c * input_channel_stride + ph * input_width; const float *in_ptr2 = in_ptr1 + input_width; - if (ph + 1 >= input_height) { + if (ph != input_height && ph + 1 >= input_height) { in_ptr2 = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * input_width)); memset(static_cast(const_cast(in_ptr2)), -FLT_MAX, @@ -122,19 +122,30 @@ void Pool2x2Maxs2p0(vector strides, vector paddings, #endif if (_w2 != 0) { - in_ptr1 += 16 * w1 + 4 * w2; - in_ptr2 += 16 * w1 + 4 * w2; - out_ptr += 8 * w1 + 2 * w2; + in_ptr1 = input_data + i * input_batch_stride + + c * input_channel_stride + ph * input_width + 16 * w1 + + 4 * w2; + in_ptr2 = in_ptr1 + input_width; + out_ptr = output_data + i * output_batch_stride + + c * output_channel_stride + ph / 2 * output_width + 8 * w1 + + 2 * w2; if (_w2 == 1) { *out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2; } else if (_w2 == 2) { - float temp = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++; + float temp = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2; + in_ptr1++; + in_ptr2++; float temp1 = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2; *out_ptr = (temp > temp1) ? temp : temp1; } else if (_w2 == 3) { - float temp = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++; - float temp1 = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++; - *out_ptr++ = (temp > temp1) ? temp : temp1; + float temp = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2; + in_ptr1++; + in_ptr2++; + float temp1 = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2; + in_ptr1++; + in_ptr2++; + *out_ptr = (temp > temp1) ? temp : temp1; + out_ptr++; *out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2; } } @@ -173,7 +184,7 @@ void Pool2x2Avgs2p0(vector strides, vector paddings, int w2 = _w1 / 4; int _w2 = _w1 % 4; - float quarter = 1 / 4; + float quarter = 0.25; for (int i = 0; i < batch_size; ++i) { for (int c = 0; c < output_channels; ++c) { for (int ph = 0; ph < input_height; ph += 2) { @@ -250,25 +261,32 @@ void Pool2x2Avgs2p0(vector strides, vector paddings, #endif if (_w2 != 0) { - in_ptr1 += 16 * w1 + 4 * w2; - in_ptr2 += 16 * w1 + 4 * w2; - out_ptr += 8 * w1 + 2 * w2; + in_ptr1 = input_data + i * input_batch_stride + + c * input_channel_stride + ph * input_width + 16 * w1 + + 4 * w2; + in_ptr2 = in_ptr1 + input_width; + out_ptr = output_data + i * output_batch_stride + + c * output_channel_stride + ph / 2 * output_width + 8 * w1 + + 2 * w2; if (_w2 == 1) { *out_ptr = 0.5 * (*in_ptr1 + *in_ptr2); } else if (_w2 == 2) { float temp = 0; - temp += *in_ptr1++; - temp += *in_ptr2++; temp += *in_ptr1; temp += *in_ptr2; - *out_ptr = 0.5 * temp; + in_ptr1++; + in_ptr2++; + temp += *in_ptr1; + temp += *in_ptr2; + *out_ptr = 0.25 * temp; } else if (_w2 == 3) { float temp = 0; temp += *in_ptr1++; temp += *in_ptr2++; temp += *in_ptr1++; temp += *in_ptr2++; - *out_ptr++ = 0.5 * temp; + *out_ptr = 0.25 * temp; + out_ptr++; *out_ptr = 0.5 * (*in_ptr1 + *in_ptr2); } } diff --git a/src/operators/op_param.h b/src/operators/op_param.h index ecf125e8012e10daf08159dd34494b31e6415759..2c0075271a92cb66ef95603965dd18d0dd3c5faf 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1132,6 +1132,37 @@ class TransposeParam : public OpParam { }; #endif +#ifdef TRANSPOSE2_OP +template +class Transpose2Param : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + Transpose2Param(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + out_ = OutFrom(outputs, scope); + output_xshape_ = OutputXShapeFrom(outputs, scope); + axis_ = GetAttr>("axis", attrs); + } + + const RType *InputX() const { return input_x_; } + + RType *Out() const { return out_; } + + RType *OutputXShape() const { return output_xshape_; } + + const vector &Axis() const { return axis_; } + + private: + RType *input_x_; + RType *out_; + RType *output_xshape_; + vector axis_; +}; +#endif + #ifdef LOOKUP_OP template class LookupParam : public OpParam { @@ -2116,9 +2147,9 @@ class Im2SequenceParam : public OpParam { paddings_ = GetAttr>("paddings", attrs); } - const RType *Input() const { return input_x_; } + const GType *Input() const { return input_x_; } - RType *Output() const { return out_; } + GType *Output() const { return out_; } const vector &Kernels() const { return kernels_; } @@ -2127,8 +2158,8 @@ class Im2SequenceParam : public OpParam { const vector &Paddings() const { return paddings_; } private: - RType *input_x_; - RType *out_; + GType *input_x_; + GType *out_; vector kernels_; vector strides_; vector paddings_; diff --git a/src/operators/transpose2_op.cpp b/src/operators/transpose2_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..64d07991f60b4057e3d2841afa1bfe6483f31a88 --- /dev/null +++ b/src/operators/transpose2_op.cpp @@ -0,0 +1,64 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef TRANSPOSE2_OP + +#include + +#include "common/enforce.h" +#include "operators/transpose2_op.h" +namespace paddle_mobile { +namespace operators { + +template +void Transpose2Op::InferShape() const { + auto input_x_dims = this->param_.InputX()->dims(); + auto axis = this->param_.Axis(); + + size_t x_dims_size = input_x_dims.size(); + size_t axis_size = axis.size(); + + PADDLE_MOBILE_ENFORCE((x_dims_size == axis_size), + "input_dims must " + "be equal to the axis_size. ") + + std::vector count(axis_size, 0); + for (size_t i = 0; i < axis_size; i++) { + PADDLE_MOBILE_ENFORCE( + axis[i] < static_cast(axis_size) && ++count[axis[i]] == 1, + "Each element of Attribute axis should be a unique value " + "range from 0 to (dims - 1), " + "where the dims is the axis's size"); + } + framework::DDim out_dims(input_x_dims); + for (size_t i = 0; i < axis_size; i++) { + out_dims[i] = input_x_dims[axis[i]]; + } + this->param_.Out()->Resize(out_dims); + std::vector xshape_dims(input_x_dims.size() + 1, 0); + for (int i = 0; i < input_x_dims.size(); ++i) { + xshape_dims[i + 1] = input_x_dims[i]; + } + this->param_.OutputXShape()->Resize(framework::make_ddim(xshape_dims)); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(transpose2, ops::Transpose2Op); +#endif + +#endif // TRANSPOSE_OP diff --git a/src/operators/transpose2_op.h b/src/operators/transpose2_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f1339cc59e0c71a232eddd5dcef47f62994b80da --- /dev/null +++ b/src/operators/transpose2_op.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef TRANSPOSE2_OP + +#pragma once + +#include + +#include "framework/operator.h" +#include "operators/kernel/transpose2_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +using paddle_mobile::framework::Tensor; + +template +class Transpose2Op : public framework::OperatorWithKernel< + DeviceType, Transpose2Param, + operators::Transpose2Kernel> { + public: + Transpose2Op(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel< + DeviceType, Transpose2Param, + operators::Transpose2Kernel>(type, inputs, outputs, + attrs, scope) {} + + using framework::OperatorWithKernel< + DeviceType, Transpose2Param, + operators::Transpose2Kernel>::OperatorWithKernel; + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c534123952eb5c33173abddb4ca1700c57fd103a..2bd7169533f637add2a752feaceca8df132cb262 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -184,6 +184,10 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-transpose-op operators/test_transpose_op.cpp test_helper.h test_include.h) target_link_libraries(test-transpose-op paddle-mobile) + # gen test + ADD_EXECUTABLE(test-transpose2-op operators/test_transpose2_op.cpp test_helper.h test_include.h) + target_link_libraries(test-transpose2-op paddle-mobile) + # gen test ADD_EXECUTABLE(test-multiclassnms-op operators/test_multiclass_nms_op.cpp test_helper.h test_include.h) target_link_libraries(test-multiclassnms-op paddle-mobile) @@ -343,6 +347,10 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-multi-process net/test_multi_inference_predict.cpp test_helper.h test_include.h) target_link_libraries(test-multi-process paddle-mobile) + # gen test + ADD_EXECUTABLE(test-eng net/test_eng.cpp test_helper.h test_include.h) + target_link_libraries(test-eng paddle-mobile) + #add_library(test-lib-size SHARED common/test_lib_size.h common/test_lib_size.cpp) endif () diff --git a/test/net/test_eng.cpp b/test/net/test_eng.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b99a6c927a44ca4032b352731b3971b63cf26b4f --- /dev/null +++ b/test/net/test_eng.cpp @@ -0,0 +1,50 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "../test_helper.h" +#include "../test_include.h" + +int main() { +#ifdef PADDLE_MOBILE_CPU + paddle_mobile::PaddleMobile paddle_mobile; +#endif + // paddle_mobile.SetThreadNum(4); + auto time1 = time(); + if (paddle_mobile.Load(std::string(g_eng) + "/model", + std::string(g_eng) + "/params", true, false, 1, + true)) { + auto time2 = time(); + std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl; + std::vector dims{1, 1, 48, 400}; + LoDTensor input_tensor; + SetupTensor(&input_tensor, {1, 1, 48, 400}, static_cast(0), + static_cast(1)); + + std::vector input(input_tensor.data(), + input_tensor.data() + input_tensor.numel()); + // 预热十次 + for (int i = 0; i < 1; ++i) { + paddle_mobile.PredictLod(input_tensor); + } + auto time3 = time(); + for (int i = 0; i < 1; ++i) { + paddle_mobile.PredictLod(input_tensor); + } + auto time4 = time(); + std::cout << "predict cost :" << time_diff(time3, time4) << "ms" + << std::endl; + } + return 0; +} diff --git a/test/operators/test_batchnorm_op.cpp b/test/operators/test_batchnorm_op.cpp index 4ccad8c1512036c2400a09575b3775e75b26acce..5f064d27f3f3f9cca5428467557c9412f76735c7 100644 --- a/test/operators/test_batchnorm_op.cpp +++ b/test/operators/test_batchnorm_op.cpp @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - #include "../test_helper.h" #include "../test_include.h" #include "operators/batchnorm_op.h" diff --git a/test/operators/test_box_coder_op.cpp b/test/operators/test_box_coder_op.cpp index 92cba3995c866c67c00491ad5cc38fb094594ad3..aeef10be9650623767af4d2de8913ce53b1d2c59 100644 --- a/test/operators/test_box_coder_op.cpp +++ b/test/operators/test_box_coder_op.cpp @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once #include "../test_include.h" #include "operators/box_coder_op.h" diff --git a/test/operators/test_elementwise_sub_op.cpp b/test/operators/test_elementwise_sub_op.cpp index cfac83eff7a012d52d47f96e088bd8519603cadc..e27361b21c3146675ea856d02d70878e73e8912f 100644 --- a/test/operators/test_elementwise_sub_op.cpp +++ b/test/operators/test_elementwise_sub_op.cpp @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - #include "../test_helper.h" #include "../test_include.h" #include "operators/elementwise_sub_op.h" diff --git a/test/operators/test_fill_constant_op.cpp b/test/operators/test_fill_constant_op.cpp index b099217d1641eb221b3d0d86d780fb6ecfa929bd..99c65ed821c0a90691070b661a6967a11d4694f7 100644 --- a/test/operators/test_fill_constant_op.cpp +++ b/test/operators/test_fill_constant_op.cpp @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once #include "../test_include.h" #include "operators/fill_constant_op.h" diff --git a/test/operators/test_fusion_fc_op.cpp b/test/operators/test_fusion_fc_op.cpp index a23bde45cb74f0f75e655821b15e66b1cef4c081..aaa2d7b578dbda4c6919210eb4a2fb42ba243e53 100644 --- a/test/operators/test_fusion_fc_op.cpp +++ b/test/operators/test_fusion_fc_op.cpp @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - #include #include "../test_include.h" #include "operators/fusion_fc_op.h" diff --git a/test/operators/test_im2sequence_op.cpp b/test/operators/test_im2sequence_op.cpp index b45e437e12f95cd9f7050247fc03a152246d8122..6c69d1cc9d94ffd958251ee4ed783d6b5531c455 100644 --- a/test/operators/test_im2sequence_op.cpp +++ b/test/operators/test_im2sequence_op.cpp @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - #include "../test_helper.h" #include "../test_include.h" #include "operators/im2sequence_op.h" diff --git a/test/operators/test_multiclass_nms_op.cpp b/test/operators/test_multiclass_nms_op.cpp index d1b98d4965fd182ab1adc480279f38cea53974be..3447bbdd10b64d2c2f497bdb4d5af15958a9a95b 100644 --- a/test/operators/test_multiclass_nms_op.cpp +++ b/test/operators/test_multiclass_nms_op.cpp @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once #include "../test_include.h" #include "operators/multiclass_nms_op.h" @@ -31,14 +30,12 @@ class TestMultiClassNMSOp { const std::vector> blocks = to_predict_program_->Blocks(); - // DLOG << " **block size " << blocks.size(); for (auto block_desc : blocks) { std::vector> ops = block_desc->Ops(); - // DLOG << " ops " << ops.size(); for (auto op : ops) { if (op->Type() == "multiclass_nms" && op->Input("BBoxes")[0] == "box_coder_0.tmp_0") { - DLOG << " mul attr size: " << op->GetAttrMap().size(); + DLOG << " attr size: " << op->GetAttrMap().size(); DLOG << " inputs size: " << op->GetInputs().size(); DLOG << " outputs size: " << op->GetOutputs().size(); DLOG << " BBoxes is : " << op->Input("BBoxes")[0]; @@ -55,14 +52,6 @@ class TestMultiClassNMSOp { << op->GetAttrMap().at("nms_top_k").Get(); DLOG << " score_threshold : " << op->GetAttrMap().at("score_threshold").Get(); - // DLOG << " variances : " << - // op->GetAttrMap().at("variances").Get>(); - // DLOG << " aspect_ratios : " << - // op->GetAttrMap().at("aspect_ratios").Get>(); - // DLOG << " min_sizes : " << - // op->GetAttrMap().at("min_sizes").Get>(); - // DLOG << " max_sizes : " << - // op->GetAttrMap().at("max_sizes").Get>(); std::shared_ptr> priorbox = std::make_shared>( op->Type(), op->GetInputs(), op->GetOutputs(), @@ -88,16 +77,12 @@ class TestMultiClassNMSOp { auto *output_tensor = output->GetMutable(); output_tensor->mutable_data({1917, 6}); - // DLOG << typeid(output_tensor).name(); - // DLOG << "output_tensor dims: " << output_tensor->dims(); - std::shared_ptr out_tensor = std::make_shared(); out_tensor.reset(output_tensor); predict(t1, t2, 0); return out_tensor; - // return outvars_tensor; } private: diff --git a/test/operators/test_polygon_box_transform_op.cpp b/test/operators/test_polygon_box_transform_op.cpp index a71177ddbd8e4d8b0f204fd6ec9c948882499cbd..5b30ce1ebfd59db972953e16e4506fa2595b8f04 100644 --- a/test/operators/test_polygon_box_transform_op.cpp +++ b/test/operators/test_polygon_box_transform_op.cpp @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once #include "../test_include.h" #include "operators/polygon_box_transform_op.h" diff --git a/test/operators/test_prior_box_op.cpp b/test/operators/test_prior_box_op.cpp index 8c697a9a7982f05b71caa5bb5f4d12e50dc9d418..2c75d01df297030b4633829ac4b29f7592aaf5c4 100644 --- a/test/operators/test_prior_box_op.cpp +++ b/test/operators/test_prior_box_op.cpp @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once #include "../test_include.h" #include "operators/prior_box_op.h" diff --git a/test/operators/test_reshape2_op.cpp b/test/operators/test_reshape2_op.cpp index 564b8bcb4db8bdc2c97d4bbc9635262a8a28a6e4..42c348a6274592eb23332620131faa0784a71d28 100644 --- a/test/operators/test_reshape2_op.cpp +++ b/test/operators/test_reshape2_op.cpp @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once #include "../test_include.h" #include "operators/reshape2_op.h" diff --git a/test/operators/test_sum_op.cpp b/test/operators/test_sum_op.cpp index e51d1cff5e99c5d9c444db046e78eee6a03f9243..467529d8d3877fcb9ac5527daf5f037aea6d18fc 100644 --- a/test/operators/test_sum_op.cpp +++ b/test/operators/test_sum_op.cpp @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - #include "../test_helper.h" #include "../test_include.h" #include "operators/sum_op.h" diff --git a/test/operators/test_transpose2_op.cpp b/test/operators/test_transpose2_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b75a957cd5c1cd08dc09895e9e2448761e822274 --- /dev/null +++ b/test/operators/test_transpose2_op.cpp @@ -0,0 +1,143 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "../test_include.h" +#include "operators/transpose2_op.h" + +namespace paddle_mobile { +namespace framework { + +template +class TestTranspose2Op { + public: + explicit TestTranspose2Op(const Program p) : program_(p) { + if (use_optimize_) { + to_predict_program_ = program_.optimizeProgram; + } else { + to_predict_program_ = program_.originProgram; + } + const std::vector> blocks = + to_predict_program_->Blocks(); + for (auto block_desc : blocks) { + std::vector> ops = block_desc->Ops(); + for (auto op : ops) { + if (op->Type() == "transpose2") { + DLOG << " attr size: " << op->GetAttrMap().size(); + std::unordered_map attrs = op->GetAttrMap(); + for (std::unordered_map::iterator it = + attrs.begin(); + it != attrs.end(); ++it) { + DLOG << " " << it->first << " " << it->second; + } + + DLOG << " inputs size: " << op->GetInputs().size(); + VariableNameMap inputs = op->GetInputs(); + for (VariableNameMap::iterator it = inputs.begin(); + it != inputs.end(); ++it) { + DLOG << " " << it->first << " " << it->second; + } + + DLOG << " outputs size: " << op->GetOutputs().size(); + VariableNameMap outputs = op->GetOutputs(); + for (VariableNameMap::iterator it = outputs.begin(); + it != outputs.end(); ++it) { + DLOG << " " << it->first << " " << it->second; + } + + input_var_name = op->Input("X")[0]; + output_var_name = op->Output("Out")[0]; + std::shared_ptr> op_ptr = + std::make_shared>( + op->Type(), op->GetInputs(), op->GetOutputs(), + op->GetAttrMap(), program_.scope); + ops_of_block_[*block_desc.get()].push_back(op_ptr); + return; + } + } + } + } + + std::shared_ptr predict(const Tensor &t) { + auto scope = program_.scope; + Variable *input_feed_value = scope->Var(input_var_name); + auto tensor_input = input_feed_value->GetMutable(); + tensor_input->ShareDataWith(t); + + Variable *output = scope->Var(output_var_name); + auto *output_tensor = output->GetMutable(); + output_tensor->mutable_data({1, 2, 8}); + + std::shared_ptr out_tensor = std::make_shared(); + out_tensor.reset(output_tensor); + + predict(t, 0); + + return out_tensor; + } + + private: + const framework::Program program_; + std::shared_ptr to_predict_program_; + std::map>>> + ops_of_block_; + bool use_optimize_ = false; + string input_var_name; + string output_var_name; + + void predict(const Tensor &t, int block_id) { + std::shared_ptr to_predict_block = + to_predict_program_->Block(block_id); + for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) { + auto op = ops_of_block_[*to_predict_block.get()][j]; + op->Run(); + } + } +}; + +template class TestTranspose2Op; +} // namespace framework +} // namespace paddle_mobile + +int main() { + DLOG << "----------**********----------"; + DLOG << "begin to run Transpose2 Test"; + paddle_mobile::Loader loader; + auto program = loader.Load(std::string(g_ocr) + "/model", + std::string(g_ocr) + "/params"); + + paddle_mobile::framework::Tensor input; + SetupTensor(&input, {1, 8, 2}, static_cast(0), + static_cast(1)); + auto *input_ptr = input.data(); + for (int i = 0; i < 16; ++i) { + *(input_ptr + i) = i; + } + DLOG << "input : "; + for (int i = 0; i < input.numel(); ++i) { + DLOG << " index " << i << " : " << input_ptr[i]; + } + + paddle_mobile::framework::TestTranspose2Op + testTranspose2Op(program); + + auto output = testTranspose2Op.predict(input); + auto *output_ptr = output->data(); + + DLOG << "output : "; + for (int i = 0; i < output->numel(); ++i) { + DLOG << " index " << i << " : " << output_ptr[i]; + } + return 0; +} diff --git a/tools/op.cmake b/tools/op.cmake index 16973eb8875033a681f861eb46816563e05c9cef..2e1e311a2c96bac02257cfdce2d2fbebcd962dfb 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -205,6 +205,7 @@ if(NOT FOUND_MATCH) set(SIGMOID_OP ON) set(SOFTMAX_OP ON) set(TRANSPOSE_OP ON) + set(TRANSPOSE2_OP ON) set(FUSION_CONVADDBNRELU_OP ON) set(FUSION_CONVADDADDPRELU_OP ON) set(FUSION_DWCONVBNRELU_OP ON) @@ -251,6 +252,7 @@ endif() # option(SIGMOID_OP "" ON) # option(SOFTMAX_OP "" ON) # option(TRANSPOSE_OP "" ON) + # option(TRANSPOSE2_OP "" ON) # endif () if (BATCHNORM_OP) @@ -328,6 +330,9 @@ endif() if (TRANSPOSE_OP) add_definitions(-DTRANSPOSE_OP) endif() +if (TRANSPOSE2_OP) + add_definitions(-DTRANSPOSE2_OP) +endif() if (FUSION_CONVADDBNRELU_OP) add_definitions(-DFUSION_CONVADDBNRELU_OP) endif()