未验证 提交 7e20d7e1 编写于 作者: Z zhaoyang-star 提交者: GitHub

Merge branch 'develop' into enable_prifile_in_tiny_publish

......@@ -44,6 +44,8 @@ sh run_benchmark.sh
3. 自动执行另一个脚本`benchmark.sh`(多台手机连接USB,请在`benchmark.sh`脚本中对`adb`命令后加上测试手机的`serial number`);
4. 从手机下载benchmark结果`result_armv7.txt``result_armv8.txt`,到当前目录,并显示Benchmark结果。
> **注意:** 如果运行中遇到`Operation not permitted`的问题,请使用`sudo +sh run_benchmark.sh`给予授权,并尝试重新关闭/打开手机**USB调试**和**文件传输模式**,或者通过USB重新连接手机之后再次运行脚本。
## 二. 逐步Benchmark
### 1. 编译benchmark可执行文件
......
......@@ -3,7 +3,7 @@
**注意:本编译方法只适用于release/v2.6.0之后版本(包括 v2.6.0)**
安装了Android的编译环境,可以下载并编译 Paddle-Lite源码
如果您还没有配置好Andriod交叉编译环境,请先根据[环境准备](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html#id2)中的内容,根据您的开发环境安装编译Android预测库所需的编译环境。运行编译脚本之前,请先检查环变量`NDK_ROOT`指向正确的Andriod NDK安装路径,之后可以下载并编译 Paddle-Lite源码。
```shell
# 1. 下载Paddle-Lite源码 并切换到release分支
......@@ -14,6 +14,7 @@ cd Paddle-Lite && git checkout release/v2.3
./lite/tools/build_android.sh
```
> **提示:** 编译过程中,如果程序在下载第三方库时花费较多时间,请尝试删除Paddle-Lite下的`<lite-repo>/third-party`目录之后再次运行编译脚本,脚本会自动下载存储于百度云的第三方库代码包,节省从git repo下载第三方库代码的时间。
### 编译结果
......
......@@ -38,7 +38,7 @@
### 2.3 配置校准数据生成器
静态离线量化内部使用异步数据读取的方式读取校准数据,大家只需要根据模型的输入,配置读取数据的sample_generator。sample_generator是Python生成器,**必须每次返回单个样本数据**,会用作`DataLoader.set_sample_generator()`的数据源。
建议参考[异步数据读取文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/data_preparing/use_py_reader.html)和本文示例,学习如何配置校准数据生成器。
建议参考[异步数据读取文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/data_preparing/static_mode/use_py_reader.html)和本文示例,学习如何配置校准数据生成器。
### 2.4 调用静态离线量化
......
# 模型转换工具 X2Paddle
X2Paddle可以将caffe、tensorflow、onnx模型转换成Paddle支持的模型。
X2Paddle可以将caffe、tensorflow、onnx模型转换成Paddle支持的模型。目前支持版本为caffe 1.0;tensorflow 1.x,推荐1.4.0;ONNX 1.6.0,OpSet支持 9, 10, 11版本。
[X2Paddle](https://github.com/PaddlePaddle/X2Paddle)支持将Caffe/TensorFlow模型转换为PaddlePaddle模型。
支持的模型可参考**X2Paddle模型测试库:**
......
......@@ -23,7 +23,9 @@ add_kernel(print_compute_host Host extra SRCS print_compute.cc DEPS ${lite_kerne
add_kernel(while_compute_host Host extra SRCS while_compute.cc DEPS ${lite_kernel_deps} program)
add_kernel(conditional_block_compute_host Host extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} program)
add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps})
add_kernel(one_hot_compute_host Host extra SRCS one_hot_compute.cc DEPS ${lite_kernel_deps})
if(LITE_BUILD_EXTRA)
if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)
lite_cc_test(test_one_hot_compute_host SRCS one_hot_compute_test.cc DEPS one_hot_compute_host)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/one_hot_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
template <typename T>
void OneHotKernelFunctor(const Tensor* in,
Tensor* out,
int depth,
bool allow_out_of_range = false) {
auto* p_in_data = in->data<T>();
auto numel = in->numel();
auto* p_out_data = out->mutable_data<T>();
memset(p_out_data, 0, out->numel() * sizeof(T));
if (allow_out_of_range) {
for (int i = 0; i < numel; ++i) {
if (p_in_data[i] >= 0 && p_in_data[i] < depth) {
p_out_data[i * depth + static_cast<int>(p_in_data[i])] = 1.0;
}
}
} else {
for (int i = 0; i < numel; ++i) {
CHECK_GE(p_in_data[i], 0) << "Illegal index value, Input(input) value "
"should be at least 0, but received input ("
<< p_in_data[i] << ") less than 0";
CHECK_LE(p_in_data[i], depth)
<< "Illegal index value, Input(input) value should be less than "
"Input(depth), but received input ("
<< p_in_data[i] << ") not less than depth (" << depth << ")";
p_out_data[i * depth + static_cast<int>(p_in_data[i])] = 1.0;
}
}
}
void OneHotCompute::Run() {
auto& param = this->template Param<param_t>();
switch (param.dtype) {
case static_cast<int>(lite::core::FluidType::INT64):
OneHotKernelFunctor<int64_t>(
param.X, param.Out, param.depth, param.allow_out_of_range);
break;
case static_cast<int>(lite::core::FluidType::INT32):
OneHotKernelFunctor<int32_t>(
param.X, param.Out, param.depth, param.allow_out_of_range);
break;
case static_cast<int>(lite::core::FluidType::FP32):
OneHotKernelFunctor<float>(
param.X, param.Out, param.depth, param.allow_out_of_range);
break;
default:
LOG(ERROR) << "Unsupported data type for one_hot op:" << param.dtype;
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
one_hot, kHost, kAny, kAny, paddle::lite::kernels::host::OneHotCompute, def)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindInput("depth_tensor",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kAny),
DATALAYOUT(kAny))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class OneHotCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
using param_t = operators::OneHotParam;
void Run() override;
virtual ~OneHotCompute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <iostream>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/host/one_hot_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
/* note:
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this
operator:
X is a LoDTensor:
X.lod = [[0, 1, 4]]
X.shape = [4, 1]
X.data = [[1], [1], [3], [0]]
set depth = 4
Out is a LoDTensor:
Out.lod = [[0, 1, 4]]
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.]] */
TEST(one_hot, test) {
using T = float;
lite::Tensor x, out;
x.Resize({4, 1});
out.Resize({4, 4});
auto* x_data = x.mutable_data<T>();
x_data[0] = 1;
x_data[1] = 1;
x_data[2] = 3;
x_data[3] = 0;
auto* out_data = out.mutable_data<T>();
float out_ref[4][4] = {
{0, 1, 0, 0}, {0, 1, 0, 0}, {0, 0, 0, 1}, {1, 0, 0, 0}};
OneHotCompute one_hot;
operators::OneHotParam param;
param.X = &x;
param.Out = &out;
param.depth = 4;
// static_cast<int>(lite::core::FluidType::FP32) = 5;
param.dtype = 5;
one_hot.SetParam(param);
one_hot.PrepareForRun();
one_hot.Run();
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_data[i], out_ref[i], 1e-5);
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(one_hot, kHost, kAny, kAny, def);
......@@ -17,6 +17,9 @@ lite_cc_library(subgraph_bridge_batch_norm_op_huawei_ascend_npu SRCS batch_norm_
lite_cc_library(subgraph_bridge_softmax_op_huawei_ascend_npu SRCS softmax_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_dropout_op_huawei_ascend_npu SRCS dropout_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_fc_op_huawei_ascend_npu SRCS fc_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_reshape_op_huawei_ascend_npu SRCS reshape_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_transpose_op_huawei_ascend_npu SRCS transpose_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_flatten_op_huawei_ascend_npu SRCS flatten_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps})
set(huawei_ascend_npu_subgraph_bridges
subgraph_bridge_registry
......@@ -32,4 +35,7 @@ set(huawei_ascend_npu_subgraph_bridges
subgraph_bridge_softmax_op_huawei_ascend_npu
subgraph_bridge_dropout_op_huawei_ascend_npu
subgraph_bridge_fc_op_huawei_ascend_npu
subgraph_bridge_reshape_op_huawei_ascend_npu
subgraph_bridge_transpose_op_huawei_ascend_npu
subgraph_bridge_flatten_op_huawei_ascend_npu
CACHE INTERNAL "huawei_ascend_npu_subgraph_bridges")
......@@ -132,19 +132,22 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
return FAILED;
}
// Filter node
std::shared_ptr<Node> filter_node = nullptr;
// Check depthwise mode, and decide whether use DepthwiseConv2D Op
bool use_depthwise_conv = false;
bool is_depthwise_mode = (ic == groups && oc == groups);
if (is_depthwise_mode && dilations[0] == 1 && dilations[1] == 1) {
use_depthwise_conv = true;
// Change filter shape {oc, ic/groups = 1, kh, kw} => { K=1, oc, kh, hw}
filter->Resize({1L, oc, filter_dims[2], filter_dims[3]});
filter_node = graph->Add(
filter_name, *filter, {1L, oc, filter_dims[2], filter_dims[3]});
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] DepthwiseConv2D op is used.";
} else {
filter_node = graph->Add(filter_name, *filter);
}
// Filter node
auto filter_node = graph->Add(filter_name, *filter);
// Add bias node if exists bias
// Supports the bias nodes with the following dimensions
// 0: {oc} => 1D tensor of foramt ND
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto out = scope->FindMutableTensor(out_name);
auto out_dims = out->dims();
VLOG(3) << "output shape is: " << out_dims.repr();
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Const Shape node
auto shape_node =
graph->Add<int64_t>(x_name + "/shape", out_dims.Vectorize());
// Reshape node
auto reshaped_x_node = graph->Add<ge::op::Reshape>(out_name);
auto reshaped_x_op = reshaped_x_node->data<ge::op::Reshape>();
reshaped_x_op->set_input_x(*x_node->data());
reshaped_x_op->set_input_shape(*shape_node->data());
reshaped_x_op->set_attr_axis(0);
INPUT_UPDATE(reshaped_x_op, x, x_node);
INPUT_UPDATE(reshaped_x_op, shape, shape_node);
OUTPUT_UPDATE(reshaped_x_op, y, reshaped_x_node);
return SUCCESS;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
flatten,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::FlattenConverter);
REGISTER_SUBGRAPH_BRIDGE(
flatten2,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::FlattenConverter);
......@@ -42,3 +42,9 @@ USE_SUBGRAPH_BRIDGE(batch_norm, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(softmax, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(dropout, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(fc, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(reshape, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(reshape2, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(transpose, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(transpose2, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(flatten, kHuaweiAscendNPU);
USE_SUBGRAPH_BRIDGE(flatten2, kHuaweiAscendNPU);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/reshape_op.h"
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Shape Const node
if (op_info->HasInput("ShapeTensor")) {
LOG(WARNING) << "[HUAWEI_ASCEND_NPU] not support \"Shape\" from more than "
"one Tensor.";
return FAILED;
}
std::shared_ptr<Node> actual_shape_node = nullptr;
if (op_info->HasInput("Shape")) {
auto actual_shape_name = op_info->Input("Shape").front();
if (graph->Has(actual_shape_name)) {
actual_shape_node = graph->Get(actual_shape_name);
} else {
auto actual_shape = scope->FindMutableTensor(actual_shape_name);
auto actual_shape_dims = actual_shape->dims();
auto actual_shape_data = actual_shape->mutable_data<int>();
auto shape =
std::vector<int>(actual_shape_data,
actual_shape_data + actual_shape_dims.production());
auto out_shape = lite::operators::ValidateShape(shape, x_dims);
actual_shape_node =
graph->Add<int>(actual_shape_name,
std::vector<int>(out_shape.begin(), out_shape.end()));
}
} else if (op_info->HasAttr("shape")) {
auto shape = op_info->GetAttr<std::vector<int>>("shape");
auto out_shape = lite::operators::ValidateShape(shape, x_dims);
out_shape = CvtShape(out_shape);
actual_shape_node = graph->Add<int64_t>(
out_name + "/shape",
std::vector<int64_t>(out_shape.begin(), out_shape.end()));
}
// actual_shape_node should not be nullptr
CHECK(actual_shape_node);
// Reshape node
auto reshape_node = graph->Add<ge::op::Reshape>(out_name);
auto reshape_op = reshape_node->data<ge::op::Reshape>();
reshape_op->set_input_x(*x_node->data());
reshape_op->set_input_shape(*actual_shape_node->data());
INPUT_UPDATE(reshape_op, x, x_node);
INPUT_UPDATE(reshape_op, shape, actual_shape_node);
OUTPUT_UPDATE(reshape_op, y, reshape_node);
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
reshape,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ReshapeConverter);
REGISTER_SUBGRAPH_BRIDGE(
reshape2,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::ReshapeConverter);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/huawei_ascend_npu/bridges/graph.h"
#include "lite/kernels/huawei_ascend_npu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace huawei_ascend_npu {
int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
auto axis = op_info->GetAttr<std::vector<int>>("axis");
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Transpose node
auto transpose_node = graph->Add<ge::op::TransposeD>(out_name);
auto transpose_op = transpose_node->data<ge::op::TransposeD>();
transpose_op->set_input_x(*x_node->data());
transpose_op->set_attr_perm(
ge::Operator::OpListInt(axis.begin(), axis.end()));
INPUT_UPDATE(transpose_op, x, x_node);
OUTPUT_UPDATE(transpose_op, y, transpose_node);
return SUCCESS;
}
} // namespace huawei_ascend_npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(
transpose,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::TransposeConverter);
REGISTER_SUBGRAPH_BRIDGE(
transpose2,
kHuaweiAscendNPU,
paddle::lite::subgraph::huawei_ascend_npu::TransposeConverter);
......@@ -83,9 +83,9 @@ class VectorView {
operator std::vector<T>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<T> tmp;
tmp.reserve(size());
tmp.resize(size());
for (size_t i = 0; i < size(); ++i) {
tmp.push_back(cvec_->operator[](i));
tmp[i] = cvec_->operator[](i);
}
return tmp;
}
......
......@@ -30,13 +30,13 @@ class BlockDescView : public BlockDescAPI {
public:
explicit BlockDescView(proto::BlockDesc const* desc) : desc_(desc) {
CHECK(desc_);
vars_.reserve(VarsSize());
ops_.reserve(OpsSize());
vars_.resize(VarsSize());
ops_.resize(OpsSize());
for (size_t idx = 0; idx < VarsSize(); ++idx) {
vars_.push_back(VarDescView(desc_->vars()->Get(idx)));
vars_[idx] = VarDescView(desc_->vars()->Get(idx));
}
for (size_t idx = 0; idx < OpsSize(); ++idx) {
ops_.push_back(OpDescView(desc_->ops()->Get(idx)));
ops_[idx] = OpDescView(desc_->ops()->Get(idx));
}
}
......@@ -76,7 +76,7 @@ class BlockDescView : public BlockDescAPI {
return desc_->forward_block_idx();
}
BlockDescView() { NotImplemented(); }
BlockDescView() = default;
private:
proto::BlockDesc const* desc_; // not_own
......
......@@ -19,8 +19,8 @@ namespace lite {
namespace fbs {
template <>
std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str());
std::string OpDescView::GetAttr<std::string>(const char* name) const {
const auto& it = desc_->attrs()->LookupByKey(name);
if (!it->s()) {
return std::string();
}
......@@ -28,56 +28,48 @@ std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
}
template <>
std::string OpDescView::GetAttr<std::string>(size_t idx) const {
const auto& it = desc_->attrs()->Get(idx);
if (!it->s()) {
return std::string();
}
return it->s()->str();
std::string OpDescView::GetAttr<std::string>(const std::string& name) const {
return GetAttr<std::string>(name.c_str());
}
template <>
lite::VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(const std::string& name) const {
const auto& it = desc_->attrs()->LookupByKey(name.c_str());
OpDescView::GetAttr<std::vector<std::string>>(const char* name) const {
const auto& it = desc_->attrs()->LookupByKey(name);
CHECK(it) << "Attr " << name << "does not exist.";
return VectorView<std::string>(it->strings());
}
template <>
VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(size_t idx) const {
const auto& it = desc_->attrs()->Get(idx);
CHECK(it) << "Attr " << idx << "does not exist.";
return VectorView<std::string>(it->strings());
lite::VectorView<std::string, Flatbuffers>
OpDescView::GetAttr<std::vector<std::string>>(const std::string& name) const {
return GetAttr<std::vector<std::string>>(name.c_str());
}
#define GET_ATTR_IMPL(T, fb_f__) \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
const char* name) const { \
const auto& it = desc_->attrs()->LookupByKey(name); \
return it->fb_f__(); \
} \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
return it->fb_f__(); \
const std::string& name) const { \
return GetAttr<T>(name.c_str()); \
}
#define GET_ATTRS_IMPL(T, fb_f__) \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
const std::string& name) const { \
const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \
const char* name) const { \
const auto& it = desc_->attrs()->LookupByKey(name); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
} \
template <> \
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT OpDescView::GetAttr<T>( \
size_t idx) const { \
const auto& it = desc_->attrs()->Get(idx); \
return typename lite::OpDataTypeTrait<T, Flatbuffers>::RT(it->fb_f__()); \
const std::string& name) const { \
return GetAttr<T>(name.c_str()); \
}
GET_ATTR_IMPL(int32_t, i);
......
......@@ -36,57 +36,68 @@ class OpDescView : public OpDescAPI {
std::string Type() const override { return desc_->type()->str(); }
// Get the arguments of parameter called `param`
std::vector<std::string> Input(const std::string& param) const override {
const auto& var = desc_->inputs()->LookupByKey(param.c_str());
std::vector<std::string> Input(const char* param) const {
const auto& var = desc_->inputs()->LookupByKey(param);
std::vector<std::string> args_vec;
if (var->arguments()) {
args_vec.reserve(var->arguments()->size());
for (const auto& in : *var->arguments()) {
args_vec.push_back(in->str());
if (var && var->arguments()) {
args_vec.resize(var->arguments()->size());
for (size_t i = 0; i < var->arguments()->size(); ++i) {
args_vec[i] = (*var->arguments())[i]->str();
}
}
return args_vec;
}
std::vector<std::string> Input(const std::string& param) const override {
return Input(param.c_str());
}
std::vector<std::string> InputArgumentNames() const override {
const auto& vars = desc_->inputs();
std::vector<std::string> input_names_vec;
if (vars) {
input_names_vec.reserve(vars->size());
for (const auto& in : *vars) {
input_names_vec.push_back(in->parameter()->str());
input_names_vec.resize(vars->size());
for (size_t i = 0; i < vars->size(); ++i) {
input_names_vec[i] = (*vars)[i]->parameter()->str();
}
}
return input_names_vec;
}
std::vector<std::string> Output(const std::string& param) const override {
const auto& var = desc_->outputs()->LookupByKey(param.c_str());
std::vector<std::string> Output(const char* param) const {
const auto& var = desc_->outputs()->LookupByKey(param);
std::vector<std::string> args_vec;
if (var && var->arguments()) {
args_vec.reserve(var->arguments()->size());
for (const auto& out : *var->arguments()) {
args_vec.push_back(out->str());
args_vec.resize(var->arguments()->size());
for (size_t i = 0; i < var->arguments()->size(); ++i) {
args_vec[i] = (*var->arguments())[i]->str();
}
}
return args_vec;
}
std::vector<std::string> Output(const std::string& param) const override {
return Output(param.c_str());
}
std::vector<std::string> OutputArgumentNames() const override {
const auto& vars = desc_->outputs();
std::vector<std::string> output_names_vec;
if (vars) {
output_names_vec.reserve(vars->size());
for (const auto& out : *vars) {
output_names_vec.push_back(out->parameter()->str());
output_names_vec.resize(vars->size());
for (size_t i = 0; i < vars->size(); ++i) {
output_names_vec[i] = (*vars)[i]->parameter()->str();
}
}
return output_names_vec;
}
bool HasAttr(const char* name) const {
return desc_->attrs()->LookupByKey(name) != nullptr;
}
bool HasAttr(const std::string& name) const override {
return desc_->attrs()->LookupByKey(name.c_str()) != nullptr;
return HasAttr(name.c_str());
}
size_t AttrsSize() const { return desc_->attrs()->size(); }
......@@ -95,25 +106,23 @@ class OpDescView : public OpDescAPI {
return desc_->attrs()->Get(idx)->name()->str();
}
OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
const auto& attr = desc_->attrs()->LookupByKey(name.c_str());
OpDescAPI::AttrType GetAttrType(const char* name) const {
const auto& attr = desc_->attrs()->LookupByKey(name);
CHECK(attr) << "Can not find attr: " << name;
return ConvertAttrType(attr->type());
}
OpDescAPI::AttrType GetAttrType(size_t idx) const {
const auto& attr = desc_->attrs()->Get(idx);
CHECK(attr);
return ConvertAttrType(attr->type());
OpDescAPI::AttrType GetAttrType(const std::string& name) const override {
return GetAttrType(name.c_str());
}
std::vector<std::string> AttrNames() const override {
const auto& attrs = desc_->attrs();
std::vector<std::string> attr_names_vec;
if (attrs) {
attr_names_vec.reserve(attrs->size());
for (const auto& attr : *attrs) {
attr_names_vec.push_back(attr->name()->str());
attr_names_vec.resize(attrs->size());
for (size_t i = 0; i < attrs->size(); ++i) {
attr_names_vec[i] = (*attrs)[i]->name()->str();
}
}
return attr_names_vec;
......@@ -121,10 +130,11 @@ class OpDescView : public OpDescAPI {
template <typename T>
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(
const std::string& name) const;
const char* name) const;
template <typename T>
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(size_t idx) const;
typename lite::OpDataTypeTrait<T, Flatbuffers>::RT GetAttr(
const std::string& name) const;
private:
proto::OpDesc const* desc_;
......@@ -138,7 +148,7 @@ class OpDescView : public OpDescAPI {
// caused by different building options.
public:
OpDescView() { NotImplemented(); }
OpDescView() = default;
bool HasInput(const std::string& param) const {
return desc_->inputs()->LookupByKey(param.c_str()) != nullptr;
}
......
......@@ -42,9 +42,9 @@ class ParamDescView : public ParamDescReadAPI {
std::vector<int64_t> Dim() const override {
const auto& dims = tensor_desc_->dim();
std::vector<int64_t> dims_vec;
dims_vec.reserve(dims->size());
for (const auto& dim : *dims) {
dims_vec.push_back(dim);
dims_vec.resize(dims->size());
for (size_t i = 0; i < dims->size(); ++i) {
dims_vec[i] = dims->operator[](i);
}
return dims_vec;
}
......@@ -57,7 +57,7 @@ class ParamDescView : public ParamDescReadAPI {
size_t byte_size() const override { return tensor_desc_->data()->size(); }
ParamDescView() = delete;
ParamDescView() = default;
private:
proto::ParamDesc const* desc_;
......@@ -87,9 +87,9 @@ class CombinedParamsDescView : public CombinedParamsDescReadAPI {
void InitParams() {
desc_ = proto::GetCombinedParamsDesc(buf_.data());
size_t params_size = desc_->params()->size();
params_.reserve(params_size);
params_.resize(params_size);
for (size_t idx = 0; idx < params_size; ++idx) {
params_.push_back(ParamDescView(desc_->params()->Get(idx)));
params_[idx] = ParamDescView(desc_->params()->Get(idx));
}
}
......
......@@ -48,9 +48,9 @@ class ProgramDescView : public ProgramDescAPI {
void InitProgramDesc() {
desc_ = proto::GetProgramDesc(buf_.data());
blocks_.reserve(BlocksSize());
blocks_.resize(BlocksSize());
for (size_t idx = 0; idx < BlocksSize(); ++idx) {
blocks_.push_back(BlockDescView(desc_->blocks()->Get(idx)));
blocks_[idx] = BlockDescView(desc_->blocks()->Get(idx));
}
}
......
......@@ -42,9 +42,9 @@ class VarDescView : public VarDescAPI {
CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR);
const auto& dims = desc_->type()->lod_tensor()->tensor()->dims();
std::vector<int64_t> dims_vec;
dims_vec.reserve(dims->size());
for (const auto& dim : *dims) {
dims_vec.push_back(dim);
dims_vec.resize(dims->size());
for (size_t i = 0; i < dims->size(); ++i) {
dims_vec[i] = dims->operator[](i);
}
return dims_vec;
}
......@@ -66,7 +66,7 @@ class VarDescView : public VarDescAPI {
// caused by different building options.
public:
VarDescView() { NotImplemented(); }
VarDescView() = default;
void SetDataType(Type data_type) { NotImplemented(); }
void SetShape(const std::vector<int64_t>& dims) { NotImplemented(); }
......
......@@ -127,9 +127,9 @@ class VectorView<std::string, Flatbuffers> {
operator std::vector<std::string>() const {
VLOG(5) << "Copying elements out of VectorView will damage performance.";
std::vector<std::string> tmp;
tmp.reserve(size());
tmp.resize(size());
for (size_t i = 0; i < size(); ++i) {
tmp.push_back(cvec_->operator[](i)->str());
tmp[i] = cvec_->operator[](i)->str();
}
return tmp;
}
......
......@@ -147,6 +147,7 @@ add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS})
add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS})
add_operator(retinanet_detection_output_op extra SRCS retinanet_detection_output_op.cc DEPS ${op_DEPS})
add_operator(where_index_op extra SRCS where_index_op.cc DEPS ${op_DEPS})
add_operator(one_hot_op extra SRCS one_hot_op.cc DEPS ${op_DEPS})
# for content-dnn specific
add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc DEPS ${op_DEPS})
add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS})
......@@ -175,7 +176,7 @@ add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS})
add_operator(__xpu__resnet_cbam_op extra SRCS __xpu__resnet_cbam_op.cc DEPS ${op_DEPS})
add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc DEPS ${op_DEPS})
add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc DEPS ${op_DEPS})
lite_cc_test(test_one_hot_op SRCS one_hot_op_test.cc DEPS one_hot_op memory scope ${op_deps} one_hot_compute_host)
if (NOT LITE_WITH_X86)
lite_cc_test(test_fc_op SRCS fc_op_test.cc
DEPS fc_op memory
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/one_hot_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool OneHotOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool OneHotOp::InferShapeImpl() const {
auto out_dims = param_.X->dims();
CHECK_GE(out_dims.size(), 2);
int depth = param_.depth_tensor ? param_.depth
: param_.depth_tensor->data<int32_t>()[0];
out_dims[out_dims.size() - 1] = depth;
param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.X->lod());
return true;
}
bool OneHotOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto x = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();
param_.X = scope->FindVar(x)->GetMutable<Tensor>();
param_.Out = scope->FindMutableTensor(out);
if (op_desc.HasInput("depth_tensor") &&
!op_desc.Input("depth_tensor").empty()) {
auto depth_tensor = op_desc.Input("depth_tensor").front();
param_.depth_tensor = scope->FindVar(depth_tensor)->GetMutable<Tensor>();
}
if (op_desc.HasAttr("depth")) {
param_.depth = op_desc.GetAttr<int>("depth");
}
if (op_desc.HasAttr("allow_out_of_range")) {
param_.allow_out_of_range = op_desc.GetAttr<bool>("allow_out_of_range");
}
param_.dtype = op_desc.GetAttr<int>("dtype");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(one_hot, paddle::lite::operators::OneHotOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
/* note:
One Hot Operator. This operator creates the one-hot representations for input
index values. The following example will help to explain the function of this
operator:
X is a LoDTensor:
X.lod = [[0, 1, 4]]
X.shape = [4, 1]
X.data = [[1], [1], [3], [0]]
set depth = 4
Out is a LoDTensor:
Out.lod = [[0, 1, 4]]
Out.shape = [4, 4]
Out.data = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.]] */
class OneHotOp : public OpLite {
public:
OneHotOp() {}
explicit OneHotOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "one_hot"; }
#ifdef LITE_WITH_PROFILE
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) {
ch->input_shape = ch->DimToStr(param_.X->dims());
ch->output_shape = ch->DimToStr(param_.Out->dims());
ch->macs = param_.X->numel() * 1.f;
}
#endif
private:
mutable OneHotParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/one_hot_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(one_hot_op_lite, TestHost) {
// prepare variables
Scope scope;
auto* x = scope.Var("X")->GetMutable<Tensor>();
auto* depth_tensor = scope.Var("depth_tensor")->GetMutable<Tensor>();
auto* output = scope.Var("Out")->GetMutable<Tensor>();
depth_tensor->dims();
output->dims();
// set data
x->Resize(DDim(std::vector<int64_t>({4, 1})));
auto* x_data = x->mutable_data<int32_t>();
x_data[0] = 1;
x_data[1] = 1;
x_data[2] = 3;
x_data[3] = 0;
// prepare op desc
cpp::OpDesc desc;
desc.SetType("one_hot");
desc.SetInput("X", {"X"});
desc.SetInput("depth_tensor", {"depth_tensor"});
desc.SetOutput("Out", {"Out"});
desc.SetAttr("depth", static_cast<int>(4));
desc.SetAttr("dtype", static_cast<int>(1));
desc.SetAttr("allow_out_of_range", static_cast<bool>(0));
OneHotOp one_hot("one_hot");
one_hot.SetValidPlaces({Place{TARGET(kHost), PRECISION(kAny)}});
one_hot.Attach(desc, &scope);
auto kernels = one_hot.CreateKernels({Place{TARGET(kHost), PRECISION(kAny)}});
ASSERT_FALSE(kernels.empty());
}
} // namespace operators
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(one_hot, kHost, kAny, kAny, def);
......@@ -1824,6 +1824,15 @@ struct PrintParam : ParamBase {
bool is_forward{true};
};
struct OneHotParam : ParamBase {
const lite::Tensor* X{};
const lite::Tensor* depth_tensor{nullptr};
lite::Tensor* Out{};
int depth;
int dtype;
bool allow_out_of_range;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -14,7 +14,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT
lite_cc_test(test_kernel_argmax_compute SRCS argmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_axpy_compute SRCS axpy_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_group_norm_compute SRCS group_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......@@ -40,21 +40,21 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT
lite_cc_test(test_kernel_fill_constant_batch_size_like_compute SRCS fill_constant_batch_size_like_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_conv_compute SRCS sequence_conv_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_value_compute SRCS assign_value_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_box_clip_compute SRCS box_clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_sum_compute SRCS reduce_sum_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_prod_compute SRCS reduce_prod_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_conv_compute SRCS sequence_conv_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_value_compute SRCS assign_value_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_box_clip_compute SRCS box_clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_sum_compute SRCS reduce_sum_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_prod_compute SRCS reduce_prod_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_roi_align_compute SRCS roi_align_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......@@ -73,7 +73,7 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_elementwise_grad_compute SRCS elementwise_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_grad_compute SRCS mul_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sgd_compute SRCS sgd_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_pool_grad_compute SRCS sequence_pool_grad_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_pool_grad_compute SRCS sequence_pool_grad_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
endif()
......@@ -90,5 +90,6 @@ endif()
lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_as_compute SRCS expand_as_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_flatten_compute SRCS flatten_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_crf_decoding_compute SRCS crf_decoding_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class FlattenComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string op_type_ = "flatten";
std::string input_ = "x";
std::string output_ = "out";
std::string xshape_ = "xshape";
DDim dims_;
int axis_;
public:
FlattenComputeTester(const Place& place,
const std::string& alias,
DDim dims,
int axis)
: TestCase(place, alias), dims_(dims), axis_(axis) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
auto* x = scope->FindTensor(input_);
int64_t outer = 1, inner = 1;
for (size_t i = 0; i < dims_.size(); ++i) {
if (i < axis_) {
outer *= dims_[i];
} else {
inner *= dims_[i];
}
}
std::vector<int64_t> out_shape(2);
out_shape[0] = outer;
out_shape[1] = inner;
out->Resize(out_shape);
auto x_data = x->data<float>();
auto out_data = out->mutable_data<float>();
memcpy(out_data, x_data, sizeof(float) * dims_.production());
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(op_type_);
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_});
if (op_type_ == "flatten2") {
op_desc->SetOutput("XShape", {xshape_});
}
op_desc->SetAttr("axis", axis_);
}
void PrepareData() override {
std::vector<float> din(dims_.production());
fill_data_rand(din.data(), -1.f, 1.f, dims_.production());
SetCommonTensor(input_, dims_, din.data());
}
};
void TestFlatten(Place place, float abs_error) {
DDim dims{{2, 3, 4, 5}};
std::vector<int> axes{0, 1, 2, 3};
for (auto axis : axes) {
std::unique_ptr<arena::TestCase> tester(
new FlattenComputeTester(place, "def", dims, axis));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision({"xshape"});
}
}
TEST(flatten, precision) {
LOG(INFO) << "test flatten op";
Place place;
float abs_error = 1e-5;
#if defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#else
return;
#endif
TestFlatten(place, abs_error);
}
} // namespace lite
} // namespace paddle
......@@ -208,6 +208,9 @@ TEST(Reshape, precision) {
place = TARGET(kHost);
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
place = TARGET(kXPU);
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#else
return;
#endif
......
......@@ -169,6 +169,9 @@ TEST(Transpose, precision) {
#elif defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU)
place = TARGET(kHuaweiAscendNPU);
abs_error = 1e-2; // precision_mode default is force_fp16
#else
return;
#endif
......
......@@ -153,6 +153,10 @@ function make_tiny_publish_so {
prepare_thirdparty
fi
if [ "${WITH_STRIP}" == "ON" ]; then
WITH_EXTRA=ON
fi
local cmake_mutable_options="
-DLITE_BUILD_EXTRA=$WITH_EXTRA \
-DLITE_WITH_LOG=$WITH_LOG \
......@@ -206,6 +210,10 @@ function make_full_publish_so {
prepare_opencl_source_code $workspace $build_dir
fi
if [ "${WITH_STRIP}" == "ON" ]; then
WITH_EXTRA=ON
fi
local cmake_mutable_options="
-DLITE_BUILD_EXTRA=$WITH_EXTRA \
-DLITE_WITH_LOG=$WITH_LOG \
......
......@@ -49,6 +49,10 @@ function make_ios {
exit 1
fi
if [ "${WITH_STRIP}" == "ON" ]; then
WITH_EXTRA=ON
fi
build_dir=$workspace/build.ios.${os}.${arch}
if [ -d $build_dir ]
then
......@@ -61,7 +65,6 @@ function make_ios {
GEN_CODE_PATH_PREFIX=lite/gen_code
mkdir -p ./${GEN_CODE_PATH_PREFIX}
touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc
cmake $workspace \
-DWITH_LITE=ON \
-DLITE_WITH_ARM=ON \
......
......@@ -173,6 +173,9 @@ function make_tiny_publish_so {
if [ "${WITH_OPENCL}" = "ON" ]; then
prepare_opencl_source_code $workspace $build_dir
fi
if [ "${WITH_STRIP}" == "ON" ]; then
WITH_EXTRA=ON
fi
init_cmake_mutable_options
cmake $workspace \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册