提交 a973b2fb 编写于 作者: xiebaiyuan's avatar xiebaiyuan

Merge remote-tracking branch 'upstream/develop' into develop

......@@ -8,46 +8,19 @@
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)-->
欢迎来到 Paddle-Mobile GitHub 项目。
Paddle-Mobile是PaddlePaddle组织下的项目,是一个致力于嵌入式平台的深度学习的框架。Paddle-Mobile设计思想和PaddlePaddle的最新版fluid版本保持了高度一致,同时针对嵌入式做了大量优化。设计之初就对嵌入式的性能、体积、能耗、硬件平台覆盖等方面做了考虑。
## 简单搜索线上效果
如下gif是简单搜索app的线上主体检测应用效果
![ezgif-1-050a733dfb](http://otkwwi4x8.bkt.clouddn.com/2018-07-05-ezgif-1-050a733dfb.gif)
## Demo目录
[点我](https://github.com/PaddlePaddle/paddle-mobile/tree/develop/demo)
欢迎来到 Paddle-Mobile GitHub 项目。Paddle-Mobile是PaddlePaddle组织下的项目,是一个致力于嵌入式平台的深度学习的框架。
## Features
- **ARM CPU**
- **Mali GPU**
- **苹果设备的GPU Metal实现**
- **FPGA**
目前已经支持 ZCU102 开发板。
- **灵活性**
* paddle-mobile cpu版不依赖任何第三库, 可进行快速集成。
* 使用泛型特化进行平台切换, 可灵活切换 cpu、gpu 和其他协处理器。
* 可根据特定的常见网络, 进行编译特定的 op, 降低编译时间, 减小包大小。
* 使用 docker 编译, 提供统一的编译环境。
* 高可拓展性, 方便拓展其他协处理器, 提供高性能 arm 算子实现, 方便其他协处理器开发者集成开发。
* 直接兼容 paddle-fluid 模型, 不需要额外的转换操作。
- **体积**
paddle-mobile从设计之初就深入考虑到移动端的包体积的问题,cpu实现中没有外部依赖。在编译过程中,如果该网络不需要的op是完全不会被打入的。同时编译选项优化也为体积压缩提供了帮助。
除了二进制体积,我们对代码体积极力避免过大。整个仓库的代码体积也非常小。
- 高性能支持ARM CPU
- 支持Mali GPU
- 支持Andreno GPU
- 支持苹果设备的GPU Metal实现
- 支持ZU5、ZU9等FPGA开发板
- 支持树莓派等arm-linux开发板
## Demo目录
[https://github.com/PaddlePaddle/paddle-mobile/tree/develop/demo](https://github.com/PaddlePaddle/paddle-mobile/tree/develop/demo)
## 文档
......@@ -74,18 +47,22 @@ Paddle-Mobile是PaddlePaddle组织下的项目,是一个致力于嵌入式平
### 1. 直接使用Paddle Fluid训练
该方式最为可靠,推荐方式
### 2. caffe转为Paddle Fluid模型
[链接](https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/caffe2fluid)
[https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/caffe2fluid](https://github.com/PaddlePaddle/models/tree/develop/fluid/image_classification/caffe2fluid)
### 3. ONNX
ONNX全称为“Open Neural Network Exchange”,即“开放的神经网络切换”。该项目的目的是让不同的神经网络开发框架做到互通互用。
除直接使用PaddlePaddle训练fluid版本的模型外,还可以通过onnx转换得到个别Paddle fluid模型。
目前,百度也在做onnx支持工作。相关转换项目在这里:[paddle-onnx](https://github.com/PaddlePaddle/paddle-onnx)
![](http://7xop3k.com1.z0.glb.clouddn.com/15311951836000.jpg)
目前,百度也在做onnx支持工作。相关转换项目在这里:
[https://github.com/PaddlePaddle/paddle-onnx](https://github.com/PaddlePaddle/paddle-onnx)
### 4. 部分测试模型和测试图片下载
[下载链接](http://mms-graph.bj.bcebos.com/paddle-mobile%2FmodelsAndImages.zip)
[http://mms-graph.bj.bcebos.com/paddle-mobile%2FmodelsAndImages.zip](http://mms-graph.bj.bcebos.com/paddle-mobile%2FmodelsAndImages.zip)
<!--## 简单搜索线上效果
如下gif是简单搜索app的线上主体检测应用效果
![ezgif-1-050a733dfb](http://otkwwi4x8.bkt.clouddn.com/2018-07-05-ezgif-1-050a733dfb.gif)-->
## 问题解决
......@@ -97,5 +74,3 @@ Paddle-Mobile 提供相对宽松的Apache-2.0开源协议 [Apache-2.0 license](L
## 旧版 Mobile-Deep-Learning
原MDL(Mobile-Deep-Learning)工程被迁移到了这里 [Mobile-Deep-Learning](https://github.com/allonli/mobile-deep-learning)
......@@ -44,6 +44,7 @@ const char *G_OP_TYPE_RESHAPE2 = "reshape2";
const char *G_OP_TYPE_SIGMOID = "sigmoid";
const char *G_OP_TYPE_SOFTMAX = "softmax";
const char *G_OP_TYPE_TRANSPOSE = "transpose";
const char *G_OP_TYPE_TRANSPOSE2 = "transpose2";
const char *G_OP_TYPE_SPLIT = "split";
const char *G_OP_TYPE_FEED = "feed";
const char *G_OP_TYPE_FETCH = "fetch";
......@@ -91,6 +92,7 @@ std::unordered_map<
{G_OP_TYPE_FEED, {{"X"}, {"Out"}}},
{G_OP_TYPE_FETCH, {{"X"}, {"Out"}}},
{G_OP_TYPE_TRANSPOSE, {{"X"}, {"Out"}}},
{G_OP_TYPE_TRANSPOSE2, {{"X"}, {"Out", "XShape"}}},
{G_OP_TYPE_BOX_CODER,
{{"PriorBox", "PriorBoxVar", "TargetBox"}, {"OutputBox"}}},
{G_OP_TYPE_FUSION_CONV_ADD_BN_RELU, {{"Input"}, {"Out"}}},
......
......@@ -115,6 +115,9 @@ LOAD_OP2(reshape2, CPU, MALI_GPU);
#ifdef TRANSPOSE_OP
LOAD_OP1(transpose, CPU);
#endif
#ifdef TRANSPOSE2_OP
LOAD_OP1(transpose2, CPU);
#endif
#ifdef PRIORBOX_OP
LOAD_OP1(prior_box, CPU);
#endif
......
......@@ -35,7 +35,7 @@ template <>
void Im2SequenceKernel<CPU, float>::Compute(
const Im2SequenceParam<CPU> &param) const {
const Tensor *in_x = param.Input();
Tensor *out = param.Output();
framework::LoDTensor *out = param.Output();
out->mutable_data<float>();
std::vector<int> kernels = param.Kernels();
......@@ -52,22 +52,31 @@ void Im2SequenceKernel<CPU, float>::Compute(
paddings[2], strides[0]);
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
paddings[3], strides[1]);
const std::vector<int> dilations({1, 1});
out->mutable_data<float>({batch_size * output_height * output_width,
img_channels * kernels[0] * kernels[1]});
const std::vector<int> dilations({1, 1});
// TODO: verify
auto out_dims = out->dims();
out->Resize({batch_size, out->numel() / batch_size});
for (int i = 0; i < batch_size; i++) {
const Tensor src =
in_x->Slice(i, i + 1).Resize({img_channels, img_height, img_width});
Tensor dst = out->Slice(i, i + 1).Resize(
{output_height, output_width, img_channels, kernels[0], kernels[1]});
math::Im2ColFunctor<math::ColFormat::kOCF, CPU, float> f;
f(src, dilations, strides, paddings, &dst);
}
out->Resize(out_dims);
framework::LoD lod(1);
lod[0].reserve(batch_size + 1);
int offset = 0;
lod[0].push_back(offset);
for (int i = 0; i < batch_size; ++i) {
offset += output_height * output_width;
lod[0].push_back(offset);
}
out->set_lod(lod);
}
template class Im2SequenceKernel<CPU, float>;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TRANSPOSE2_OP
#include "operators/kernel/transpose2_kernel.h"
#include "operators/kernel/central-arm-func/transpose2_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool Transpose2Kernel<CPU, float>::Init(Transpose2Param<CPU> *param) {
return true;
}
template <>
void Transpose2Kernel<CPU, float>::Compute(
const Transpose2Param<CPU> &param) const {
Transpose2Compute<float>(param);
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -29,10 +29,9 @@ void FusionFcCompute(const FusionFcParam<CPU> &param) {
auto *input_z_data = input_z->data<float>();
int axis = param.Axis();
Tensor *out = param.Out();
auto *out_data = out->mutable_data<float>();
// int m = out->dims()[0];
// int n = out->dims()[1];
auto *out_data = out->mutable_data<float>();
const Tensor x_matrix =
input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
......
......@@ -83,6 +83,7 @@ void PoolCompute(const PoolParam<CPU> &param) {
#if __aarch64__
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
#else
/// todo: fix bug in Pool2x2
if (pooling_type == "max") {
math::Pool2x2Maxs2p0(strides, paddings, in_x, out);
} else if (pooling_type == "avg") {
......
......@@ -24,6 +24,7 @@ void SoftmaxCompute(const SoftmaxParam<CPU> &param) {
Tensor *out = param.Out();
auto x_dims = in_x->dims();
out->Resize(x_dims);
out->mutable_data<float>();
math::SoftmaxFuntor<CPU, float>()(in_x, out);
}
} // namespace operators
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TRANSPOSE2_OP
#pragma once
#include <vector>
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename P>
void Transpose2Compute(const Transpose2Param<CPU>& param) {
const auto* input_x = param.InputX();
const auto input_x_dims = input_x->dims();
auto* out = param.Out();
const auto axis = param.Axis();
const auto* input_x_data = input_x->data<float>();
auto* out_data = out->mutable_data<float>();
size_t ndim = axis.size();
std::vector<int> xdim(ndim);
std::vector<int> xstride(ndim);
std::vector<int> xout(ndim);
for (int i = 0; i < ndim; i++) {
int j = ndim - 1 - i;
xdim[j] = input_x_dims[axis[i]];
xstride[j] = 1;
for (int k = axis[i] + 1; k < ndim; k++) {
xstride[j] *= input_x_dims[k];
}
xout[j] = xstride[j] * xdim[j];
}
auto numel = input_x->numel();
size_t pind = 0;
std::vector<int> ind(ndim);
for (int i = 0; i < numel; i++) {
out_data[i] = input_x_data[pind];
ind[0]++;
pind += xstride[0];
for (int j = 0; j < ndim - 1; j++) {
if (ind[j] == xdim[j]) {
ind[j + 1]++;
ind[j] = 0;
pind += xstride[j + 1];
pind -= xout[j];
} else {
break;
}
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TRANSPOSE2_OP
#pragma once
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class Transpose2Kernel
: public framework::OpKernelBase<DeviceType, Transpose2Param<DeviceType>> {
public:
void Compute(const Transpose2Param<DeviceType>& param) const;
bool Init(Transpose2Param<DeviceType>* param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -53,7 +53,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
(((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0));
int fill = isize % 2;
if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 &&
dilation[0] == 1 && im_height > 2) {
dilation[0] == 1 && im_height > 2 && im_height == im_width) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
......@@ -225,7 +225,7 @@ void Im2ColFunctor<ColFormat::kCFO, CPU, float>::operator()(
im_data += isize * isize;
}
} else if (stride[0] == 2 && filter_height == 3 && pad1 && dilation[0] == 1 &&
im_height > 2) {
im_height > 2 && im_height == im_width) {
for (int c = 0; c < im_channels; ++c) {
int oosize = osize * osize;
int nk4 = osize / 4;
......@@ -605,7 +605,6 @@ class Im2ColFunctor<ColFormat::kOCF, CPU, T> {
const T *im_data = im.data<T>();
T *col_data = col->data<T>();
for (int col_row_idx = 0; col_row_idx < col_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < col_width; ++col_col_idx) {
for (int channel = 0; channel < im_channels; ++channel) {
......@@ -617,7 +616,6 @@ class Im2ColFunctor<ColFormat::kOCF, CPU, T> {
++filter_col_idx) {
int im_col_offset =
col_col_idx * stride[1] + filter_col_idx - padding[1];
int col_offset =
((((col_row_idx)*col_width + col_col_idx) * im_channels +
channel) *
......@@ -625,7 +623,6 @@ class Im2ColFunctor<ColFormat::kOCF, CPU, T> {
filter_row_idx) *
filter_width +
filter_col_idx;
int im_offset = (channel * im_height + im_row_offset) * im_width +
im_col_offset;
col_data[col_offset] =
......
......@@ -58,7 +58,7 @@ void Pool2x2Maxs2p0(vector<int> strides, vector<int> paddings,
const float *in_ptr1 = input_data + i * input_batch_stride +
c * input_channel_stride + ph * input_width;
const float *in_ptr2 = in_ptr1 + input_width;
if (ph + 1 >= input_height) {
if (ph != input_height && ph + 1 >= input_height) {
in_ptr2 = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * input_width));
memset(static_cast<void *>(const_cast<float *>(in_ptr2)), -FLT_MAX,
......@@ -122,19 +122,30 @@ void Pool2x2Maxs2p0(vector<int> strides, vector<int> paddings,
#endif
if (_w2 != 0) {
in_ptr1 += 16 * w1 + 4 * w2;
in_ptr2 += 16 * w1 + 4 * w2;
out_ptr += 8 * w1 + 2 * w2;
in_ptr1 = input_data + i * input_batch_stride +
c * input_channel_stride + ph * input_width + 16 * w1 +
4 * w2;
in_ptr2 = in_ptr1 + input_width;
out_ptr = output_data + i * output_batch_stride +
c * output_channel_stride + ph / 2 * output_width + 8 * w1 +
2 * w2;
if (_w2 == 1) {
*out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
} else if (_w2 == 2) {
float temp = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++;
float temp = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
in_ptr1++;
in_ptr2++;
float temp1 = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
*out_ptr = (temp > temp1) ? temp : temp1;
} else if (_w2 == 3) {
float temp = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++;
float temp1 = (*in_ptr1++ > *in_ptr2++) ? *in_ptr1++ : *in_ptr2++;
*out_ptr++ = (temp > temp1) ? temp : temp1;
float temp = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
in_ptr1++;
in_ptr2++;
float temp1 = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
in_ptr1++;
in_ptr2++;
*out_ptr = (temp > temp1) ? temp : temp1;
out_ptr++;
*out_ptr = (*in_ptr1 > *in_ptr2) ? *in_ptr1 : *in_ptr2;
}
}
......@@ -173,7 +184,7 @@ void Pool2x2Avgs2p0(vector<int> strides, vector<int> paddings,
int w2 = _w1 / 4;
int _w2 = _w1 % 4;
float quarter = 1 / 4;
float quarter = 0.25;
for (int i = 0; i < batch_size; ++i) {
for (int c = 0; c < output_channels; ++c) {
for (int ph = 0; ph < input_height; ph += 2) {
......@@ -250,25 +261,32 @@ void Pool2x2Avgs2p0(vector<int> strides, vector<int> paddings,
#endif
if (_w2 != 0) {
in_ptr1 += 16 * w1 + 4 * w2;
in_ptr2 += 16 * w1 + 4 * w2;
out_ptr += 8 * w1 + 2 * w2;
in_ptr1 = input_data + i * input_batch_stride +
c * input_channel_stride + ph * input_width + 16 * w1 +
4 * w2;
in_ptr2 = in_ptr1 + input_width;
out_ptr = output_data + i * output_batch_stride +
c * output_channel_stride + ph / 2 * output_width + 8 * w1 +
2 * w2;
if (_w2 == 1) {
*out_ptr = 0.5 * (*in_ptr1 + *in_ptr2);
} else if (_w2 == 2) {
float temp = 0;
temp += *in_ptr1++;
temp += *in_ptr2++;
temp += *in_ptr1;
temp += *in_ptr2;
*out_ptr = 0.5 * temp;
in_ptr1++;
in_ptr2++;
temp += *in_ptr1;
temp += *in_ptr2;
*out_ptr = 0.25 * temp;
} else if (_w2 == 3) {
float temp = 0;
temp += *in_ptr1++;
temp += *in_ptr2++;
temp += *in_ptr1++;
temp += *in_ptr2++;
*out_ptr++ = 0.5 * temp;
*out_ptr = 0.25 * temp;
out_ptr++;
*out_ptr = 0.5 * (*in_ptr1 + *in_ptr2);
}
}
......
......@@ -1132,6 +1132,37 @@ class TransposeParam : public OpParam {
};
#endif
#ifdef TRANSPOSE2_OP
template <typename Dtype>
class Transpose2Param : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
Transpose2Param(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
output_xshape_ = OutputXShapeFrom<GType>(outputs, scope);
axis_ = GetAttr<vector<int>>("axis", attrs);
}
const RType *InputX() const { return input_x_; }
RType *Out() const { return out_; }
RType *OutputXShape() const { return output_xshape_; }
const vector<int> &Axis() const { return axis_; }
private:
RType *input_x_;
RType *out_;
RType *output_xshape_;
vector<int> axis_;
};
#endif
#ifdef LOOKUP_OP
template <typename Dtype>
class LookupParam : public OpParam {
......@@ -2116,9 +2147,9 @@ class Im2SequenceParam : public OpParam {
paddings_ = GetAttr<vector<int>>("paddings", attrs);
}
const RType *Input() const { return input_x_; }
const GType *Input() const { return input_x_; }
RType *Output() const { return out_; }
GType *Output() const { return out_; }
const vector<int> &Kernels() const { return kernels_; }
......@@ -2127,8 +2158,8 @@ class Im2SequenceParam : public OpParam {
const vector<int> &Paddings() const { return paddings_; }
private:
RType *input_x_;
RType *out_;
GType *input_x_;
GType *out_;
vector<int> kernels_;
vector<int> strides_;
vector<int> paddings_;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TRANSPOSE2_OP
#include <vector>
#include "common/enforce.h"
#include "operators/transpose2_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void Transpose2Op<Dtype, T>::InferShape() const {
auto input_x_dims = this->param_.InputX()->dims();
auto axis = this->param_.Axis();
size_t x_dims_size = input_x_dims.size();
size_t axis_size = axis.size();
PADDLE_MOBILE_ENFORCE((x_dims_size == axis_size),
"input_dims must "
"be equal to the axis_size. ")
std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis_size; i++) {
PADDLE_MOBILE_ENFORCE(
axis[i] < static_cast<int>(axis_size) && ++count[axis[i]] == 1,
"Each element of Attribute axis should be a unique value "
"range from 0 to (dims - 1), "
"where the dims is the axis's size");
}
framework::DDim out_dims(input_x_dims);
for (size_t i = 0; i < axis_size; i++) {
out_dims[i] = input_x_dims[axis[i]];
}
this->param_.Out()->Resize(out_dims);
std::vector<int64_t> xshape_dims(input_x_dims.size() + 1, 0);
for (int i = 0; i < input_x_dims.size(); ++i) {
xshape_dims[i + 1] = input_x_dims[i];
}
this->param_.OutputXShape()->Resize(framework::make_ddim(xshape_dims));
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(transpose2, ops::Transpose2Op);
#endif
#endif // TRANSPOSE_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TRANSPOSE2_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/transpose2_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class Transpose2Op : public framework::OperatorWithKernel<
DeviceType, Transpose2Param<DeviceType>,
operators::Transpose2Kernel<DeviceType, T>> {
public:
Transpose2Op(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, Transpose2Param<DeviceType>,
operators::Transpose2Kernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, Transpose2Param<DeviceType>,
operators::Transpose2Kernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -184,6 +184,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-transpose-op operators/test_transpose_op.cpp test_helper.h test_include.h)
target_link_libraries(test-transpose-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-transpose2-op operators/test_transpose2_op.cpp test_helper.h test_include.h)
target_link_libraries(test-transpose2-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-multiclassnms-op operators/test_multiclass_nms_op.cpp test_helper.h test_include.h)
target_link_libraries(test-multiclassnms-op paddle-mobile)
......@@ -343,6 +347,10 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-multi-process net/test_multi_inference_predict.cpp test_helper.h test_include.h)
target_link_libraries(test-multi-process paddle-mobile)
# gen test
ADD_EXECUTABLE(test-eng net/test_eng.cpp test_helper.h test_include.h)
target_link_libraries(test-eng paddle-mobile)
#add_library(test-lib-size SHARED common/test_lib_size.h common/test_lib_size.cpp)
endif ()
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
int main() {
#ifdef PADDLE_MOBILE_CPU
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
#endif
// paddle_mobile.SetThreadNum(4);
auto time1 = time();
if (paddle_mobile.Load(std::string(g_eng) + "/model",
std::string(g_eng) + "/params", true, false, 1,
true)) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl;
std::vector<int64_t> dims{1, 1, 48, 400};
LoDTensor input_tensor;
SetupTensor<float>(&input_tensor, {1, 1, 48, 400}, static_cast<float>(0),
static_cast<float>(1));
std::vector<float> input(input_tensor.data<float>(),
input_tensor.data<float>() + input_tensor.numel());
// 预热十次
for (int i = 0; i < 1; ++i) {
paddle_mobile.PredictLod(input_tensor);
}
auto time3 = time();
for (int i = 0; i < 1; ++i) {
paddle_mobile.PredictLod(input_tensor);
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) << "ms"
<< std::endl;
}
return 0;
}
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/batchnorm_op.h"
......
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_include.h"
#include "operators/box_coder_op.h"
......
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/elementwise_sub_op.h"
......
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_include.h"
#include "operators/fill_constant_op.h"
......
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <framework/program/program-optimize/program_optimize.h>
#include "../test_include.h"
#include "operators/fusion_fc_op.h"
......
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/im2sequence_op.h"
......
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_include.h"
#include "operators/multiclass_nms_op.h"
......@@ -31,14 +30,12 @@ class TestMultiClassNMSOp {
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
// DLOG << " **block size " << blocks.size();
for (auto block_desc : blocks) {
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
// DLOG << " ops " << ops.size();
for (auto op : ops) {
if (op->Type() == "multiclass_nms" &&
op->Input("BBoxes")[0] == "box_coder_0.tmp_0") {
DLOG << " mul attr size: " << op->GetAttrMap().size();
DLOG << " attr size: " << op->GetAttrMap().size();
DLOG << " inputs size: " << op->GetInputs().size();
DLOG << " outputs size: " << op->GetOutputs().size();
DLOG << " BBoxes is : " << op->Input("BBoxes")[0];
......@@ -55,14 +52,6 @@ class TestMultiClassNMSOp {
<< op->GetAttrMap().at("nms_top_k").Get<int>();
DLOG << " score_threshold : "
<< op->GetAttrMap().at("score_threshold").Get<float>();
// DLOG << " variances : " <<
// op->GetAttrMap().at("variances").Get<std::vector<float>>();
// DLOG << " aspect_ratios : " <<
// op->GetAttrMap().at("aspect_ratios").Get<std::vector<float>>();
// DLOG << " min_sizes : " <<
// op->GetAttrMap().at("min_sizes").Get<std::vector<float>>();
// DLOG << " max_sizes : " <<
// op->GetAttrMap().at("max_sizes").Get<std::vector<float>>();
std::shared_ptr<operators::MultiClassNMSOp<Dtype, float>> priorbox =
std::make_shared<operators::MultiClassNMSOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
......@@ -88,16 +77,12 @@ class TestMultiClassNMSOp {
auto *output_tensor = output->GetMutable<LoDTensor>();
output_tensor->mutable_data<float>({1917, 6});
// DLOG << typeid(output_tensor).name();
// DLOG << "output_tensor dims: " << output_tensor->dims();
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict(t1, t2, 0);
return out_tensor;
// return outvars_tensor;
}
private:
......
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_include.h"
#include "operators/polygon_box_transform_op.h"
......
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_include.h"
#include "operators/prior_box_op.h"
......
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_include.h"
#include "operators/reshape2_op.h"
......
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "../test_helper.h"
#include "../test_include.h"
#include "operators/sum_op.h"
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/transpose2_op.h"
namespace paddle_mobile {
namespace framework {
template <typename Dtype>
class TestTranspose2Op {
public:
explicit TestTranspose2Op(const Program<Dtype> p) : program_(p) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
to_predict_program_->Blocks();
for (auto block_desc : blocks) {
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
for (auto op : ops) {
if (op->Type() == "transpose2") {
DLOG << " attr size: " << op->GetAttrMap().size();
std::unordered_map<std::string, Attribute> attrs = op->GetAttrMap();
for (std::unordered_map<std::string, Attribute>::iterator it =
attrs.begin();
it != attrs.end(); ++it) {
DLOG << " " << it->first << " " << it->second;
}
DLOG << " inputs size: " << op->GetInputs().size();
VariableNameMap inputs = op->GetInputs();
for (VariableNameMap::iterator it = inputs.begin();
it != inputs.end(); ++it) {
DLOG << " " << it->first << " " << it->second;
}
DLOG << " outputs size: " << op->GetOutputs().size();
VariableNameMap outputs = op->GetOutputs();
for (VariableNameMap::iterator it = outputs.begin();
it != outputs.end(); ++it) {
DLOG << " " << it->first << " " << it->second;
}
input_var_name = op->Input("X")[0];
output_var_name = op->Output("Out")[0];
std::shared_ptr<operators::Transpose2Op<Dtype, float>> op_ptr =
std::make_shared<operators::Transpose2Op<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), program_.scope);
ops_of_block_[*block_desc.get()].push_back(op_ptr);
return;
}
}
}
}
std::shared_ptr<Tensor> predict(const Tensor &t) {
auto scope = program_.scope;
Variable *input_feed_value = scope->Var(input_var_name);
auto tensor_input = input_feed_value->GetMutable<LoDTensor>();
tensor_input->ShareDataWith(t);
Variable *output = scope->Var(output_var_name);
auto *output_tensor = output->GetMutable<LoDTensor>();
output_tensor->mutable_data<float>({1, 2, 8});
std::shared_ptr<Tensor> out_tensor = std::make_shared<LoDTensor>();
out_tensor.reset(output_tensor);
predict(t, 0);
return out_tensor;
}
private:
const framework::Program<Dtype> program_;
std::shared_ptr<ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
string input_var_name;
string output_var_name;
void predict(const Tensor &t, int block_id) {
std::shared_ptr<BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
for (int j = 0; j < ops_of_block_[*to_predict_block.get()].size(); ++j) {
auto op = ops_of_block_[*to_predict_block.get()][j];
op->Run();
}
}
};
template class TestTranspose2Op<CPU>;
} // namespace framework
} // namespace paddle_mobile
int main() {
DLOG << "----------**********----------";
DLOG << "begin to run Transpose2 Test";
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_ocr) + "/model",
std::string(g_ocr) + "/params");
paddle_mobile::framework::Tensor input;
SetupTensor<float>(&input, {1, 8, 2}, static_cast<float>(0),
static_cast<float>(1));
auto *input_ptr = input.data<float>();
for (int i = 0; i < 16; ++i) {
*(input_ptr + i) = i;
}
DLOG << "input : ";
for (int i = 0; i < input.numel(); ++i) {
DLOG << " index " << i << " : " << input_ptr[i];
}
paddle_mobile::framework::TestTranspose2Op<paddle_mobile::CPU>
testTranspose2Op(program);
auto output = testTranspose2Op.predict(input);
auto *output_ptr = output->data<float>();
DLOG << "output : ";
for (int i = 0; i < output->numel(); ++i) {
DLOG << " index " << i << " : " << output_ptr[i];
}
return 0;
}
......@@ -205,6 +205,7 @@ if(NOT FOUND_MATCH)
set(SIGMOID_OP ON)
set(SOFTMAX_OP ON)
set(TRANSPOSE_OP ON)
set(TRANSPOSE2_OP ON)
set(FUSION_CONVADDBNRELU_OP ON)
set(FUSION_CONVADDADDPRELU_OP ON)
set(FUSION_DWCONVBNRELU_OP ON)
......@@ -251,6 +252,7 @@ endif()
# option(SIGMOID_OP "" ON)
# option(SOFTMAX_OP "" ON)
# option(TRANSPOSE_OP "" ON)
# option(TRANSPOSE2_OP "" ON)
# endif ()
if (BATCHNORM_OP)
......@@ -328,6 +330,9 @@ endif()
if (TRANSPOSE_OP)
add_definitions(-DTRANSPOSE_OP)
endif()
if (TRANSPOSE2_OP)
add_definitions(-DTRANSPOSE2_OP)
endif()
if (FUSION_CONVADDBNRELU_OP)
add_definitions(-DFUSION_CONVADDBNRELU_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册