diff --git a/docs/benchmark/benchmark_tools.md b/docs/benchmark/benchmark_tools.md index 96a67931c91f1323508bdd4d2fda6d3a55bbb307..301f7ebd4a1597edeb131bd6418ed6b6ef731dce 100644 --- a/docs/benchmark/benchmark_tools.md +++ b/docs/benchmark/benchmark_tools.md @@ -44,6 +44,8 @@ sh run_benchmark.sh 3. 自动执行另一个脚本`benchmark.sh`(多台手机连接USB,请在`benchmark.sh`脚本中对`adb`命令后加上测试手机的`serial number`); 4. 从手机下载benchmark结果`result_armv7.txt`和`result_armv8.txt`,到当前目录,并显示Benchmark结果。 +> **注意:** 如果运行中遇到`Operation not permitted`的问题,请使用`sudo +sh run_benchmark.sh`给予授权,并尝试重新关闭/打开手机**USB调试**和**文件传输模式**,或者通过USB重新连接手机之后再次运行脚本。 + ## 二. 逐步Benchmark ### 1. 编译benchmark可执行文件 diff --git a/docs/user_guides/Compile/Android.md b/docs/user_guides/Compile/Android.md index 5ff0525f2eec8ef5fe6e49835b6a92447799b46c..bd920338e527fe3dcade37e51f628dfbb9777a09 100644 --- a/docs/user_guides/Compile/Android.md +++ b/docs/user_guides/Compile/Android.md @@ -3,7 +3,7 @@ **注意:本编译方法只适用于release/v2.6.0之后版本(包括 v2.6.0)** -安装了Android的编译环境,可以下载并编译 Paddle-Lite源码 +如果您还没有配置好Andriod交叉编译环境,请先根据[环境准备](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html#id2)中的内容,根据您的开发环境安装编译Android预测库所需的编译环境。运行编译脚本之前,请先检查环变量`NDK_ROOT`指向正确的Andriod NDK安装路径,之后可以下载并编译 Paddle-Lite源码。 ```shell # 1. 下载Paddle-Lite源码 并切换到release分支 @@ -14,6 +14,7 @@ cd Paddle-Lite && git checkout release/v2.3 ./lite/tools/build_android.sh ``` +> **提示:** 编译过程中,如果程序在下载第三方库时花费较多时间,请尝试删除Paddle-Lite下的`/third-party`目录之后再次运行编译脚本,脚本会自动下载存储于百度云的第三方库代码包,节省从git repo下载第三方库代码的时间。 ### 编译结果 diff --git a/docs/user_guides/post_quant_with_data.md b/docs/user_guides/post_quant_with_data.md index a861a9e95aa2dc79573d79037695d4864bb3a7ba..f49faa97aed16545031c6041ec25a9e00bb92b36 100644 --- a/docs/user_guides/post_quant_with_data.md +++ b/docs/user_guides/post_quant_with_data.md @@ -38,7 +38,7 @@ ### 2.3 配置校准数据生成器 静态离线量化内部使用异步数据读取的方式读取校准数据,大家只需要根据模型的输入,配置读取数据的sample_generator。sample_generator是Python生成器,**必须每次返回单个样本数据**,会用作`DataLoader.set_sample_generator()`的数据源。 -建议参考[异步数据读取文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/data_preparing/use_py_reader.html)和本文示例,学习如何配置校准数据生成器。 +建议参考[异步数据读取文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/data_preparing/static_mode/use_py_reader.html)和本文示例,学习如何配置校准数据生成器。 ### 2.4 调用静态离线量化 diff --git a/docs/user_guides/x2paddle.md b/docs/user_guides/x2paddle.md index e007caa437552627f785fc6e2ccff06dda54c244..1bc9ade65248d95715d10c035ee939e2c2ef14b9 100644 --- a/docs/user_guides/x2paddle.md +++ b/docs/user_guides/x2paddle.md @@ -1,6 +1,6 @@ # 模型转换工具 X2Paddle -X2Paddle可以将caffe、tensorflow、onnx模型转换成Paddle支持的模型。 +X2Paddle可以将caffe、tensorflow、onnx模型转换成Paddle支持的模型。目前支持版本为caffe 1.0;tensorflow 1.x,推荐1.4.0;ONNX 1.6.0,OpSet支持 9, 10, 11版本。 [X2Paddle](https://github.com/PaddlePaddle/X2Paddle)支持将Caffe/TensorFlow模型转换为PaddlePaddle模型。 支持的模型可参考**X2Paddle模型测试库:** diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index 4b082a92e9119deef74ceea3889730159ffdaf9d..1c0914de731c215df9c64712d1fc8cfbc7ce08dd 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -23,7 +23,9 @@ add_kernel(print_compute_host Host extra SRCS print_compute.cc DEPS ${lite_kerne add_kernel(while_compute_host Host extra SRCS while_compute.cc DEPS ${lite_kernel_deps} program) add_kernel(conditional_block_compute_host Host extra SRCS conditional_block_compute.cc DEPS ${lite_kernel_deps} program) add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(one_hot_compute_host Host extra SRCS one_hot_compute.cc DEPS ${lite_kernel_deps}) -if(LITE_BUILD_EXTRA) +if(LITE_BUILD_EXTRA AND LITE_WITH_x86) lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host) + lite_cc_test(test_one_hot_compute_host SRCS one_hot_compute_test.cc DEPS one_hot_compute_host) endif() diff --git a/lite/kernels/host/one_hot_compute.cc b/lite/kernels/host/one_hot_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..6880de39ae6b76c93fd3331964ae64c2e8d7b745 --- /dev/null +++ b/lite/kernels/host/one_hot_compute.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/host/one_hot_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +template +void OneHotKernelFunctor(const Tensor* in, + Tensor* out, + int depth, + bool allow_out_of_range = false) { + auto* p_in_data = in->data(); + auto numel = in->numel(); + auto* p_out_data = out->mutable_data(); + memset(p_out_data, 0, out->numel() * sizeof(T)); + if (allow_out_of_range) { + for (int i = 0; i < numel; ++i) { + if (p_in_data[i] >= 0 && p_in_data[i] < depth) { + p_out_data[i * depth + static_cast(p_in_data[i])] = 1.0; + } + } + } else { + for (int i = 0; i < numel; ++i) { + CHECK_GE(p_in_data[i], 0) << "Illegal index value, Input(input) value " + "should be at least 0, but received input (" + << p_in_data[i] << ") less than 0"; + CHECK_LE(p_in_data[i], depth) + << "Illegal index value, Input(input) value should be less than " + "Input(depth), but received input (" + << p_in_data[i] << ") not less than depth (" << depth << ")"; + p_out_data[i * depth + static_cast(p_in_data[i])] = 1.0; + } + } +} + +void OneHotCompute::Run() { + auto& param = this->template Param(); + switch (param.dtype) { + case static_cast(lite::core::FluidType::INT64): + OneHotKernelFunctor( + param.X, param.Out, param.depth, param.allow_out_of_range); + break; + case static_cast(lite::core::FluidType::INT32): + OneHotKernelFunctor( + param.X, param.Out, param.depth, param.allow_out_of_range); + break; + case static_cast(lite::core::FluidType::FP32): + OneHotKernelFunctor( + param.X, param.Out, param.depth, param.allow_out_of_range); + break; + default: + LOG(ERROR) << "Unsupported data type for one_hot op:" << param.dtype; + } +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL( + one_hot, kHost, kAny, kAny, paddle::lite::kernels::host::OneHotCompute, def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindInput("depth_tensor", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kHost), + PRECISION(kAny), + DATALAYOUT(kAny))}) + .Finalize(); diff --git a/lite/kernels/host/one_hot_compute.h b/lite/kernels/host/one_hot_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..6c94d900a729f6deaef5307a628166376ea83ecc --- /dev/null +++ b/lite/kernels/host/one_hot_compute.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class OneHotCompute + : public KernelLite { + public: + using param_t = operators::OneHotParam; + + void Run() override; + + virtual ~OneHotCompute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/host/one_hot_compute_test.cc b/lite/kernels/host/one_hot_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d767b358518c0402905b185f7acac5799eaa849c --- /dev/null +++ b/lite/kernels/host/one_hot_compute_test.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include +#include + +#include "lite/core/op_registry.h" +#include "lite/kernels/host/one_hot_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +/* note: +One Hot Operator. This operator creates the one-hot representations for input +index values. The following example will help to explain the function of this +operator: +X is a LoDTensor: + X.lod = [[0, 1, 4]] + X.shape = [4, 1] + X.data = [[1], [1], [3], [0]] +set depth = 4 +Out is a LoDTensor: + Out.lod = [[0, 1, 4]] + Out.shape = [4, 4] + Out.data = [[0., 1., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.]] */ +TEST(one_hot, test) { + using T = float; + + lite::Tensor x, out; + x.Resize({4, 1}); + out.Resize({4, 4}); + + auto* x_data = x.mutable_data(); + x_data[0] = 1; + x_data[1] = 1; + x_data[2] = 3; + x_data[3] = 0; + auto* out_data = out.mutable_data(); + float out_ref[4][4] = { + {0, 1, 0, 0}, {0, 1, 0, 0}, {0, 0, 0, 1}, {1, 0, 0, 0}}; + + OneHotCompute one_hot; + operators::OneHotParam param; + + param.X = &x; + param.Out = &out; + param.depth = 4; + // static_cast(lite::core::FluidType::FP32) = 5; + param.dtype = 5; + + one_hot.SetParam(param); + one_hot.PrepareForRun(); + + one_hot.Run(); + + for (int i = 0; i < out.numel(); i++) { + EXPECT_NEAR(out_data[i], out_ref[i], 1e-5); + } +} + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(one_hot, kHost, kAny, kAny, def); diff --git a/lite/kernels/huawei_ascend_npu/bridges/CMakeLists.txt b/lite/kernels/huawei_ascend_npu/bridges/CMakeLists.txt index 11e87833bbc4dd9b97a173a19baa4a1a98f9e96e..86705e95d4b4c7b6c8aa4c1965d2d989a9137fee 100644 --- a/lite/kernels/huawei_ascend_npu/bridges/CMakeLists.txt +++ b/lite/kernels/huawei_ascend_npu/bridges/CMakeLists.txt @@ -17,6 +17,9 @@ lite_cc_library(subgraph_bridge_batch_norm_op_huawei_ascend_npu SRCS batch_norm_ lite_cc_library(subgraph_bridge_softmax_op_huawei_ascend_npu SRCS softmax_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_dropout_op_huawei_ascend_npu SRCS dropout_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_fc_op_huawei_ascend_npu SRCS fc_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_reshape_op_huawei_ascend_npu SRCS reshape_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_transpose_op_huawei_ascend_npu SRCS transpose_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_flatten_op_huawei_ascend_npu SRCS flatten_op.cc DEPS ${huawei_ascend_npu_subgraph_bridge_deps}) set(huawei_ascend_npu_subgraph_bridges subgraph_bridge_registry @@ -32,4 +35,7 @@ set(huawei_ascend_npu_subgraph_bridges subgraph_bridge_softmax_op_huawei_ascend_npu subgraph_bridge_dropout_op_huawei_ascend_npu subgraph_bridge_fc_op_huawei_ascend_npu + subgraph_bridge_reshape_op_huawei_ascend_npu + subgraph_bridge_transpose_op_huawei_ascend_npu + subgraph_bridge_flatten_op_huawei_ascend_npu CACHE INTERNAL "huawei_ascend_npu_subgraph_bridges") diff --git a/lite/kernels/huawei_ascend_npu/bridges/conv_op.cc b/lite/kernels/huawei_ascend_npu/bridges/conv_op.cc index 23e6fd0211cb21d38e9b176140ff4e45220c9db7..1b3eb143b1a59b754f0d1d0b3c32c9b16aca0154 100644 --- a/lite/kernels/huawei_ascend_npu/bridges/conv_op.cc +++ b/lite/kernels/huawei_ascend_npu/bridges/conv_op.cc @@ -132,19 +132,22 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { return FAILED; } + // Filter node + std::shared_ptr filter_node = nullptr; + // Check depthwise mode, and decide whether use DepthwiseConv2D Op bool use_depthwise_conv = false; bool is_depthwise_mode = (ic == groups && oc == groups); if (is_depthwise_mode && dilations[0] == 1 && dilations[1] == 1) { use_depthwise_conv = true; // Change filter shape {oc, ic/groups = 1, kh, kw} => { K=1, oc, kh, hw} - filter->Resize({1L, oc, filter_dims[2], filter_dims[3]}); + filter_node = graph->Add( + filter_name, *filter, {1L, oc, filter_dims[2], filter_dims[3]}); LOG(WARNING) << "[HUAWEI_ASCEND_NPU] DepthwiseConv2D op is used."; + } else { + filter_node = graph->Add(filter_name, *filter); } - // Filter node - auto filter_node = graph->Add(filter_name, *filter); - // Add bias node if exists bias // Supports the bias nodes with the following dimensions // 0: {oc} => 1D tensor of foramt ND diff --git a/lite/kernels/huawei_ascend_npu/bridges/flatten_op.cc b/lite/kernels/huawei_ascend_npu/bridges/flatten_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e43915df5052919e98e81f8d0b8248088c9f07d --- /dev/null +++ b/lite/kernels/huawei_ascend_npu/bridges/flatten_op.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/huawei_ascend_npu/bridges/graph.h" +#include "lite/kernels/huawei_ascend_npu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace huawei_ascend_npu { + +int FlattenConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "..."; + + // Get input and output vars and op attributes + auto x_name = op_info->Input("X").front(); + auto x = scope->FindMutableTensor(x_name); + auto x_dims = x->dims(); + + auto out_name = op_info->Output("Out").front(); + auto out = scope->FindMutableTensor(out_name); + auto out_dims = out->dims(); + + VLOG(3) << "output shape is: " << out_dims.repr(); + + // X node + std::shared_ptr x_node = nullptr; + if (graph->Has(x_name)) { + x_node = graph->Get(x_name); + } else { + x_node = graph->Add(x_name, *x); + } + + // Const Shape node + auto shape_node = + graph->Add(x_name + "/shape", out_dims.Vectorize()); + // Reshape node + auto reshaped_x_node = graph->Add(out_name); + auto reshaped_x_op = reshaped_x_node->data(); + reshaped_x_op->set_input_x(*x_node->data()); + reshaped_x_op->set_input_shape(*shape_node->data()); + reshaped_x_op->set_attr_axis(0); + INPUT_UPDATE(reshaped_x_op, x, x_node); + INPUT_UPDATE(reshaped_x_op, shape, shape_node); + OUTPUT_UPDATE(reshaped_x_op, y, reshaped_x_node); + + return SUCCESS; +} + +} // namespace huawei_ascend_npu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE( + flatten, + kHuaweiAscendNPU, + paddle::lite::subgraph::huawei_ascend_npu::FlattenConverter); +REGISTER_SUBGRAPH_BRIDGE( + flatten2, + kHuaweiAscendNPU, + paddle::lite::subgraph::huawei_ascend_npu::FlattenConverter); diff --git a/lite/kernels/huawei_ascend_npu/bridges/paddle_use_bridges.h b/lite/kernels/huawei_ascend_npu/bridges/paddle_use_bridges.h index ec6069beb5d6d238dd0018ef4cb4d4b1d0cf9658..f7cfe39468bc34c6d93e3d97d4b270e77cc29a33 100644 --- a/lite/kernels/huawei_ascend_npu/bridges/paddle_use_bridges.h +++ b/lite/kernels/huawei_ascend_npu/bridges/paddle_use_bridges.h @@ -42,3 +42,9 @@ USE_SUBGRAPH_BRIDGE(batch_norm, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(softmax, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(dropout, kHuaweiAscendNPU); USE_SUBGRAPH_BRIDGE(fc, kHuaweiAscendNPU); +USE_SUBGRAPH_BRIDGE(reshape, kHuaweiAscendNPU); +USE_SUBGRAPH_BRIDGE(reshape2, kHuaweiAscendNPU); +USE_SUBGRAPH_BRIDGE(transpose, kHuaweiAscendNPU); +USE_SUBGRAPH_BRIDGE(transpose2, kHuaweiAscendNPU); +USE_SUBGRAPH_BRIDGE(flatten, kHuaweiAscendNPU); +USE_SUBGRAPH_BRIDGE(flatten2, kHuaweiAscendNPU); diff --git a/lite/kernels/huawei_ascend_npu/bridges/reshape_op.cc b/lite/kernels/huawei_ascend_npu/bridges/reshape_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7fc48d4bdfba4875fc7b32329dbea72de7a3dfec --- /dev/null +++ b/lite/kernels/huawei_ascend_npu/bridges/reshape_op.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/reshape_op.h" +#include "lite/kernels/huawei_ascend_npu/bridges/graph.h" +#include "lite/kernels/huawei_ascend_npu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace huawei_ascend_npu { + +int ReshapeConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "..."; + + // Get input and output vars and op attributes + auto x_name = op_info->Input("X").front(); + auto x = scope->FindMutableTensor(x_name); + auto x_dims = x->dims(); + + auto out_name = op_info->Output("Out").front(); + + // X node + std::shared_ptr x_node = nullptr; + if (graph->Has(x_name)) { + x_node = graph->Get(x_name); + } else { + x_node = graph->Add(x_name, *x); + } + + // Shape Const node + if (op_info->HasInput("ShapeTensor")) { + LOG(WARNING) << "[HUAWEI_ASCEND_NPU] not support \"Shape\" from more than " + "one Tensor."; + return FAILED; + } + + std::shared_ptr actual_shape_node = nullptr; + if (op_info->HasInput("Shape")) { + auto actual_shape_name = op_info->Input("Shape").front(); + if (graph->Has(actual_shape_name)) { + actual_shape_node = graph->Get(actual_shape_name); + } else { + auto actual_shape = scope->FindMutableTensor(actual_shape_name); + auto actual_shape_dims = actual_shape->dims(); + auto actual_shape_data = actual_shape->mutable_data(); + auto shape = + std::vector(actual_shape_data, + actual_shape_data + actual_shape_dims.production()); + auto out_shape = lite::operators::ValidateShape(shape, x_dims); + actual_shape_node = + graph->Add(actual_shape_name, + std::vector(out_shape.begin(), out_shape.end())); + } + } else if (op_info->HasAttr("shape")) { + auto shape = op_info->GetAttr>("shape"); + auto out_shape = lite::operators::ValidateShape(shape, x_dims); + out_shape = CvtShape(out_shape); + actual_shape_node = graph->Add( + out_name + "/shape", + std::vector(out_shape.begin(), out_shape.end())); + } + // actual_shape_node should not be nullptr + CHECK(actual_shape_node); + + // Reshape node + auto reshape_node = graph->Add(out_name); + auto reshape_op = reshape_node->data(); + reshape_op->set_input_x(*x_node->data()); + reshape_op->set_input_shape(*actual_shape_node->data()); + INPUT_UPDATE(reshape_op, x, x_node); + INPUT_UPDATE(reshape_op, shape, actual_shape_node); + OUTPUT_UPDATE(reshape_op, y, reshape_node); + + return REBUILD_WHEN_SHAPE_CHANGED; +} + +} // namespace huawei_ascend_npu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE( + reshape, + kHuaweiAscendNPU, + paddle::lite::subgraph::huawei_ascend_npu::ReshapeConverter); +REGISTER_SUBGRAPH_BRIDGE( + reshape2, + kHuaweiAscendNPU, + paddle::lite::subgraph::huawei_ascend_npu::ReshapeConverter); diff --git a/lite/kernels/huawei_ascend_npu/bridges/transpose_op.cc b/lite/kernels/huawei_ascend_npu/bridges/transpose_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e90106e2da6f9f894141824784cd24a41f295d41 --- /dev/null +++ b/lite/kernels/huawei_ascend_npu/bridges/transpose_op.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/huawei_ascend_npu/bridges/graph.h" +#include "lite/kernels/huawei_ascend_npu/bridges/utility.h" +#include "lite/kernels/npu/bridges/registry.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace huawei_ascend_npu { + +int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[HUAWEI_ASCEND_NPU] Converting " + op_type + "..."; + + // Get input and output vars and op attributes + auto x_name = op_info->Input("X").front(); + auto x = scope->FindMutableTensor(x_name); + auto x_dims = x->dims(); + + auto out_name = op_info->Output("Out").front(); + + auto axis = op_info->GetAttr>("axis"); + + // X node + std::shared_ptr x_node = nullptr; + if (graph->Has(x_name)) { + x_node = graph->Get(x_name); + } else { + x_node = graph->Add(x_name, *x); + } + + // Transpose node + auto transpose_node = graph->Add(out_name); + auto transpose_op = transpose_node->data(); + transpose_op->set_input_x(*x_node->data()); + transpose_op->set_attr_perm( + ge::Operator::OpListInt(axis.begin(), axis.end())); + INPUT_UPDATE(transpose_op, x, x_node); + OUTPUT_UPDATE(transpose_op, y, transpose_node); + + return SUCCESS; +} + +} // namespace huawei_ascend_npu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE( + transpose, + kHuaweiAscendNPU, + paddle::lite::subgraph::huawei_ascend_npu::TransposeConverter); +REGISTER_SUBGRAPH_BRIDGE( + transpose2, + kHuaweiAscendNPU, + paddle::lite::subgraph::huawei_ascend_npu::TransposeConverter); diff --git a/lite/model_parser/base/vector_view.h b/lite/model_parser/base/vector_view.h index e4149d9c5acae83472904a86c47659355972855e..45a27a59ce0a53331d4dd76d3a6d96fc6b301cbf 100644 --- a/lite/model_parser/base/vector_view.h +++ b/lite/model_parser/base/vector_view.h @@ -83,9 +83,9 @@ class VectorView { operator std::vector() const { VLOG(5) << "Copying elements out of VectorView will damage performance."; std::vector tmp; - tmp.reserve(size()); + tmp.resize(size()); for (size_t i = 0; i < size(); ++i) { - tmp.push_back(cvec_->operator[](i)); + tmp[i] = cvec_->operator[](i); } return tmp; } diff --git a/lite/model_parser/flatbuffers/block_desc.h b/lite/model_parser/flatbuffers/block_desc.h index ed7932847ea253b83f95136a5f02a9a39e2107e0..21ac9a1f4e9795ef6f3433ee2e704453264a3044 100644 --- a/lite/model_parser/flatbuffers/block_desc.h +++ b/lite/model_parser/flatbuffers/block_desc.h @@ -30,13 +30,13 @@ class BlockDescView : public BlockDescAPI { public: explicit BlockDescView(proto::BlockDesc const* desc) : desc_(desc) { CHECK(desc_); - vars_.reserve(VarsSize()); - ops_.reserve(OpsSize()); + vars_.resize(VarsSize()); + ops_.resize(OpsSize()); for (size_t idx = 0; idx < VarsSize(); ++idx) { - vars_.push_back(VarDescView(desc_->vars()->Get(idx))); + vars_[idx] = VarDescView(desc_->vars()->Get(idx)); } for (size_t idx = 0; idx < OpsSize(); ++idx) { - ops_.push_back(OpDescView(desc_->ops()->Get(idx))); + ops_[idx] = OpDescView(desc_->ops()->Get(idx)); } } @@ -76,7 +76,7 @@ class BlockDescView : public BlockDescAPI { return desc_->forward_block_idx(); } - BlockDescView() { NotImplemented(); } + BlockDescView() = default; private: proto::BlockDesc const* desc_; // not_own diff --git a/lite/model_parser/flatbuffers/op_desc.cc b/lite/model_parser/flatbuffers/op_desc.cc index 7e40b8c4537836e1c586bd433fddfa6b1ecdbae2..fcf6f84972423e9ce640dea217bae51f687d3fd9 100644 --- a/lite/model_parser/flatbuffers/op_desc.cc +++ b/lite/model_parser/flatbuffers/op_desc.cc @@ -19,8 +19,8 @@ namespace lite { namespace fbs { template <> -std::string OpDescView::GetAttr(const std::string& name) const { - const auto& it = desc_->attrs()->LookupByKey(name.c_str()); +std::string OpDescView::GetAttr(const char* name) const { + const auto& it = desc_->attrs()->LookupByKey(name); if (!it->s()) { return std::string(); } @@ -28,56 +28,48 @@ std::string OpDescView::GetAttr(const std::string& name) const { } template <> -std::string OpDescView::GetAttr(size_t idx) const { - const auto& it = desc_->attrs()->Get(idx); - if (!it->s()) { - return std::string(); - } - return it->s()->str(); +std::string OpDescView::GetAttr(const std::string& name) const { + return GetAttr(name.c_str()); } template <> lite::VectorView -OpDescView::GetAttr>(const std::string& name) const { - const auto& it = desc_->attrs()->LookupByKey(name.c_str()); +OpDescView::GetAttr>(const char* name) const { + const auto& it = desc_->attrs()->LookupByKey(name); CHECK(it) << "Attr " << name << "does not exist."; return VectorView(it->strings()); } template <> -VectorView -OpDescView::GetAttr>(size_t idx) const { - const auto& it = desc_->attrs()->Get(idx); - CHECK(it) << "Attr " << idx << "does not exist."; - return VectorView(it->strings()); +lite::VectorView +OpDescView::GetAttr>(const std::string& name) const { + return GetAttr>(name.c_str()); } #define GET_ATTR_IMPL(T, fb_f__) \ template <> \ typename lite::OpDataTypeTrait::RT OpDescView::GetAttr( \ - const std::string& name) const { \ - const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \ + const char* name) const { \ + const auto& it = desc_->attrs()->LookupByKey(name); \ return it->fb_f__(); \ } \ template <> \ typename lite::OpDataTypeTrait::RT OpDescView::GetAttr( \ - size_t idx) const { \ - const auto& it = desc_->attrs()->Get(idx); \ - return it->fb_f__(); \ + const std::string& name) const { \ + return GetAttr(name.c_str()); \ } #define GET_ATTRS_IMPL(T, fb_f__) \ template <> \ typename lite::OpDataTypeTrait::RT OpDescView::GetAttr( \ - const std::string& name) const { \ - const auto& it = desc_->attrs()->LookupByKey(name.c_str()); \ + const char* name) const { \ + const auto& it = desc_->attrs()->LookupByKey(name); \ return typename lite::OpDataTypeTrait::RT(it->fb_f__()); \ } \ template <> \ typename lite::OpDataTypeTrait::RT OpDescView::GetAttr( \ - size_t idx) const { \ - const auto& it = desc_->attrs()->Get(idx); \ - return typename lite::OpDataTypeTrait::RT(it->fb_f__()); \ + const std::string& name) const { \ + return GetAttr(name.c_str()); \ } GET_ATTR_IMPL(int32_t, i); diff --git a/lite/model_parser/flatbuffers/op_desc.h b/lite/model_parser/flatbuffers/op_desc.h index 94d99be4616a8ee89f7a1d0999526f4306a30df3..0fc1237f8a8f14e9d5edc8e7fb18fd213735ed3b 100644 --- a/lite/model_parser/flatbuffers/op_desc.h +++ b/lite/model_parser/flatbuffers/op_desc.h @@ -36,57 +36,68 @@ class OpDescView : public OpDescAPI { std::string Type() const override { return desc_->type()->str(); } - // Get the arguments of parameter called `param` - std::vector Input(const std::string& param) const override { - const auto& var = desc_->inputs()->LookupByKey(param.c_str()); + std::vector Input(const char* param) const { + const auto& var = desc_->inputs()->LookupByKey(param); std::vector args_vec; - if (var->arguments()) { - args_vec.reserve(var->arguments()->size()); - for (const auto& in : *var->arguments()) { - args_vec.push_back(in->str()); + if (var && var->arguments()) { + args_vec.resize(var->arguments()->size()); + for (size_t i = 0; i < var->arguments()->size(); ++i) { + args_vec[i] = (*var->arguments())[i]->str(); } } return args_vec; } + std::vector Input(const std::string& param) const override { + return Input(param.c_str()); + } + std::vector InputArgumentNames() const override { const auto& vars = desc_->inputs(); std::vector input_names_vec; if (vars) { - input_names_vec.reserve(vars->size()); - for (const auto& in : *vars) { - input_names_vec.push_back(in->parameter()->str()); + input_names_vec.resize(vars->size()); + for (size_t i = 0; i < vars->size(); ++i) { + input_names_vec[i] = (*vars)[i]->parameter()->str(); } } return input_names_vec; } - std::vector Output(const std::string& param) const override { - const auto& var = desc_->outputs()->LookupByKey(param.c_str()); + std::vector Output(const char* param) const { + const auto& var = desc_->outputs()->LookupByKey(param); std::vector args_vec; if (var && var->arguments()) { - args_vec.reserve(var->arguments()->size()); - for (const auto& out : *var->arguments()) { - args_vec.push_back(out->str()); + args_vec.resize(var->arguments()->size()); + for (size_t i = 0; i < var->arguments()->size(); ++i) { + args_vec[i] = (*var->arguments())[i]->str(); } } return args_vec; } + std::vector Output(const std::string& param) const override { + return Output(param.c_str()); + } + std::vector OutputArgumentNames() const override { const auto& vars = desc_->outputs(); std::vector output_names_vec; if (vars) { - output_names_vec.reserve(vars->size()); - for (const auto& out : *vars) { - output_names_vec.push_back(out->parameter()->str()); + output_names_vec.resize(vars->size()); + for (size_t i = 0; i < vars->size(); ++i) { + output_names_vec[i] = (*vars)[i]->parameter()->str(); } } return output_names_vec; } + bool HasAttr(const char* name) const { + return desc_->attrs()->LookupByKey(name) != nullptr; + } + bool HasAttr(const std::string& name) const override { - return desc_->attrs()->LookupByKey(name.c_str()) != nullptr; + return HasAttr(name.c_str()); } size_t AttrsSize() const { return desc_->attrs()->size(); } @@ -95,25 +106,23 @@ class OpDescView : public OpDescAPI { return desc_->attrs()->Get(idx)->name()->str(); } - OpDescAPI::AttrType GetAttrType(const std::string& name) const override { - const auto& attr = desc_->attrs()->LookupByKey(name.c_str()); + OpDescAPI::AttrType GetAttrType(const char* name) const { + const auto& attr = desc_->attrs()->LookupByKey(name); CHECK(attr) << "Can not find attr: " << name; return ConvertAttrType(attr->type()); } - OpDescAPI::AttrType GetAttrType(size_t idx) const { - const auto& attr = desc_->attrs()->Get(idx); - CHECK(attr); - return ConvertAttrType(attr->type()); + OpDescAPI::AttrType GetAttrType(const std::string& name) const override { + return GetAttrType(name.c_str()); } std::vector AttrNames() const override { const auto& attrs = desc_->attrs(); std::vector attr_names_vec; if (attrs) { - attr_names_vec.reserve(attrs->size()); - for (const auto& attr : *attrs) { - attr_names_vec.push_back(attr->name()->str()); + attr_names_vec.resize(attrs->size()); + for (size_t i = 0; i < attrs->size(); ++i) { + attr_names_vec[i] = (*attrs)[i]->name()->str(); } } return attr_names_vec; @@ -121,10 +130,11 @@ class OpDescView : public OpDescAPI { template typename lite::OpDataTypeTrait::RT GetAttr( - const std::string& name) const; + const char* name) const; template - typename lite::OpDataTypeTrait::RT GetAttr(size_t idx) const; + typename lite::OpDataTypeTrait::RT GetAttr( + const std::string& name) const; private: proto::OpDesc const* desc_; @@ -138,7 +148,7 @@ class OpDescView : public OpDescAPI { // caused by different building options. public: - OpDescView() { NotImplemented(); } + OpDescView() = default; bool HasInput(const std::string& param) const { return desc_->inputs()->LookupByKey(param.c_str()) != nullptr; } diff --git a/lite/model_parser/flatbuffers/param_desc.h b/lite/model_parser/flatbuffers/param_desc.h index c6fb00126792045c246da017168c1bd49e109729..0ed0bb5631b92d9c73f771c9ddf5d2624eb567cf 100644 --- a/lite/model_parser/flatbuffers/param_desc.h +++ b/lite/model_parser/flatbuffers/param_desc.h @@ -42,9 +42,9 @@ class ParamDescView : public ParamDescReadAPI { std::vector Dim() const override { const auto& dims = tensor_desc_->dim(); std::vector dims_vec; - dims_vec.reserve(dims->size()); - for (const auto& dim : *dims) { - dims_vec.push_back(dim); + dims_vec.resize(dims->size()); + for (size_t i = 0; i < dims->size(); ++i) { + dims_vec[i] = dims->operator[](i); } return dims_vec; } @@ -57,7 +57,7 @@ class ParamDescView : public ParamDescReadAPI { size_t byte_size() const override { return tensor_desc_->data()->size(); } - ParamDescView() = delete; + ParamDescView() = default; private: proto::ParamDesc const* desc_; @@ -87,9 +87,9 @@ class CombinedParamsDescView : public CombinedParamsDescReadAPI { void InitParams() { desc_ = proto::GetCombinedParamsDesc(buf_.data()); size_t params_size = desc_->params()->size(); - params_.reserve(params_size); + params_.resize(params_size); for (size_t idx = 0; idx < params_size; ++idx) { - params_.push_back(ParamDescView(desc_->params()->Get(idx))); + params_[idx] = ParamDescView(desc_->params()->Get(idx)); } } diff --git a/lite/model_parser/flatbuffers/program_desc.h b/lite/model_parser/flatbuffers/program_desc.h index 3432fbc154f38e133f77f72972fc10343e54e15f..57894c32f9905c5ccbbd722f58738c5f95bb31bf 100644 --- a/lite/model_parser/flatbuffers/program_desc.h +++ b/lite/model_parser/flatbuffers/program_desc.h @@ -48,9 +48,9 @@ class ProgramDescView : public ProgramDescAPI { void InitProgramDesc() { desc_ = proto::GetProgramDesc(buf_.data()); - blocks_.reserve(BlocksSize()); + blocks_.resize(BlocksSize()); for (size_t idx = 0; idx < BlocksSize(); ++idx) { - blocks_.push_back(BlockDescView(desc_->blocks()->Get(idx))); + blocks_[idx] = BlockDescView(desc_->blocks()->Get(idx)); } } diff --git a/lite/model_parser/flatbuffers/var_desc.h b/lite/model_parser/flatbuffers/var_desc.h index a538258b0754d03374d1cd136b62c8ae6f22a54e..d29b1dc27366537d6640deebf3536aa856282e15 100644 --- a/lite/model_parser/flatbuffers/var_desc.h +++ b/lite/model_parser/flatbuffers/var_desc.h @@ -42,9 +42,9 @@ class VarDescView : public VarDescAPI { CHECK(GetType() == VarDescAPI::Type::LOD_TENSOR); const auto& dims = desc_->type()->lod_tensor()->tensor()->dims(); std::vector dims_vec; - dims_vec.reserve(dims->size()); - for (const auto& dim : *dims) { - dims_vec.push_back(dim); + dims_vec.resize(dims->size()); + for (size_t i = 0; i < dims->size(); ++i) { + dims_vec[i] = dims->operator[](i); } return dims_vec; } @@ -66,7 +66,7 @@ class VarDescView : public VarDescAPI { // caused by different building options. public: - VarDescView() { NotImplemented(); } + VarDescView() = default; void SetDataType(Type data_type) { NotImplemented(); } void SetShape(const std::vector& dims) { NotImplemented(); } diff --git a/lite/model_parser/flatbuffers/vector_view.h b/lite/model_parser/flatbuffers/vector_view.h index bb1331823a2dce79d2b3a6784f1f2d5b5864281d..0c9dc306de03b8611af491fdfe8758c9ca09bde0 100644 --- a/lite/model_parser/flatbuffers/vector_view.h +++ b/lite/model_parser/flatbuffers/vector_view.h @@ -127,9 +127,9 @@ class VectorView { operator std::vector() const { VLOG(5) << "Copying elements out of VectorView will damage performance."; std::vector tmp; - tmp.reserve(size()); + tmp.resize(size()); for (size_t i = 0; i < size(); ++i) { - tmp.push_back(cvec_->operator[](i)->str()); + tmp[i] = cvec_->operator[](i)->str(); } return tmp; } diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index c8d8f2133be8a420f452ef2827257b3224b7d9ea..f84ce5cf6784e3beb1706de9c4dc8de8f3ad4541 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -147,6 +147,7 @@ add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS}) add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS}) add_operator(retinanet_detection_output_op extra SRCS retinanet_detection_output_op.cc DEPS ${op_DEPS}) add_operator(where_index_op extra SRCS where_index_op.cc DEPS ${op_DEPS}) +add_operator(one_hot_op extra SRCS one_hot_op.cc DEPS ${op_DEPS}) # for content-dnn specific add_operator(search_aligned_mat_mul_op extra SRCS search_aligned_mat_mul_op.cc DEPS ${op_DEPS}) add_operator(search_seq_fc_op extra SRCS search_seq_fc_op.cc DEPS ${op_DEPS}) @@ -175,7 +176,7 @@ add_operator(__xpu__fc_op extra SRCS __xpu__fc_op.cc DEPS ${op_DEPS}) add_operator(__xpu__resnet_cbam_op extra SRCS __xpu__resnet_cbam_op.cc DEPS ${op_DEPS}) add_operator(__xpu__search_attention_op extra SRCS __xpu__search_attention_op.cc DEPS ${op_DEPS}) add_operator(__xpu__mmdnn_op extra SRCS __xpu__mmdnn_op.cc DEPS ${op_DEPS}) - +lite_cc_test(test_one_hot_op SRCS one_hot_op_test.cc DEPS one_hot_op memory scope ${op_deps} one_hot_compute_host) if (NOT LITE_WITH_X86) lite_cc_test(test_fc_op SRCS fc_op_test.cc DEPS fc_op memory diff --git a/lite/operators/one_hot_op.cc b/lite/operators/one_hot_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..88b939a0de3b68567ddad1e00a539a955c011425 --- /dev/null +++ b/lite/operators/one_hot_op.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/one_hot_op.h" +#include "lite/core/op_registry.h" +namespace paddle { +namespace lite { +namespace operators { + +bool OneHotOp::CheckShape() const { + CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + return true; +} + +bool OneHotOp::InferShapeImpl() const { + auto out_dims = param_.X->dims(); + CHECK_GE(out_dims.size(), 2); + int depth = param_.depth_tensor ? param_.depth + : param_.depth_tensor->data()[0]; + out_dims[out_dims.size() - 1] = depth; + param_.Out->Resize(out_dims); + param_.Out->set_lod(param_.X->lod()); + return true; +} + +bool OneHotOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { + auto x = op_desc.Input("X").front(); + auto out = op_desc.Output("Out").front(); + param_.X = scope->FindVar(x)->GetMutable(); + param_.Out = scope->FindMutableTensor(out); + + if (op_desc.HasInput("depth_tensor") && + !op_desc.Input("depth_tensor").empty()) { + auto depth_tensor = op_desc.Input("depth_tensor").front(); + param_.depth_tensor = scope->FindVar(depth_tensor)->GetMutable(); + } + + if (op_desc.HasAttr("depth")) { + param_.depth = op_desc.GetAttr("depth"); + } + if (op_desc.HasAttr("allow_out_of_range")) { + param_.allow_out_of_range = op_desc.GetAttr("allow_out_of_range"); + } + param_.dtype = op_desc.GetAttr("dtype"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(one_hot, paddle::lite::operators::OneHotOp); diff --git a/lite/operators/one_hot_op.h b/lite/operators/one_hot_op.h new file mode 100644 index 0000000000000000000000000000000000000000..bd0aefc33080b532761509e77032c39417758a59 --- /dev/null +++ b/lite/operators/one_hot_op.h @@ -0,0 +1,71 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { +/* note: +One Hot Operator. This operator creates the one-hot representations for input +index values. The following example will help to explain the function of this +operator: +X is a LoDTensor: + X.lod = [[0, 1, 4]] + X.shape = [4, 1] + X.data = [[1], [1], [3], [0]] +set depth = 4 +Out is a LoDTensor: + Out.lod = [[0, 1, 4]] + Out.shape = [4, 4] + Out.data = [[0., 1., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.]] */ + +class OneHotOp : public OpLite { + public: + OneHotOp() {} + explicit OneHotOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + + std::string DebugString() const override { return "one_hot"; } + +#ifdef LITE_WITH_PROFILE + void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) { + ch->input_shape = ch->DimToStr(param_.X->dims()); + ch->output_shape = ch->DimToStr(param_.Out->dims()); + ch->macs = param_.X->numel() * 1.f; + } +#endif + + private: + mutable OneHotParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/one_hot_op_test.cc b/lite/operators/one_hot_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5daa8378868a21ef4b466cf8d2c82f681d123545 --- /dev/null +++ b/lite/operators/one_hot_op_test.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/one_hot_op.h" +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +TEST(one_hot_op_lite, TestHost) { + // prepare variables + Scope scope; + auto* x = scope.Var("X")->GetMutable(); + auto* depth_tensor = scope.Var("depth_tensor")->GetMutable(); + auto* output = scope.Var("Out")->GetMutable(); + depth_tensor->dims(); + output->dims(); + + // set data + x->Resize(DDim(std::vector({4, 1}))); + auto* x_data = x->mutable_data(); + x_data[0] = 1; + x_data[1] = 1; + x_data[2] = 3; + x_data[3] = 0; + + // prepare op desc + cpp::OpDesc desc; + desc.SetType("one_hot"); + desc.SetInput("X", {"X"}); + desc.SetInput("depth_tensor", {"depth_tensor"}); + desc.SetOutput("Out", {"Out"}); + desc.SetAttr("depth", static_cast(4)); + desc.SetAttr("dtype", static_cast(1)); + desc.SetAttr("allow_out_of_range", static_cast(0)); + OneHotOp one_hot("one_hot"); + one_hot.SetValidPlaces({Place{TARGET(kHost), PRECISION(kAny)}}); + one_hot.Attach(desc, &scope); + auto kernels = one_hot.CreateKernels({Place{TARGET(kHost), PRECISION(kAny)}}); + ASSERT_FALSE(kernels.empty()); +} + +} // namespace operators +} // namespace lite +} // namespace paddle +USE_LITE_KERNEL(one_hot, kHost, kAny, kAny, def); diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 1f37aaa5e76e21b85a24994a363674fee85e14d8..dffa15188431460fbf8a41f76f98258f213f6493 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1824,6 +1824,15 @@ struct PrintParam : ParamBase { bool is_forward{true}; }; +struct OneHotParam : ParamBase { + const lite::Tensor* X{}; + const lite::Tensor* depth_tensor{nullptr}; + lite::Tensor* Out{}; + int depth; + int dtype; + bool allow_out_of_range; +}; + } // namespace operators } // namespace lite } // namespace paddle diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 621b67d430565540bd0c90d2e0f78367ed903b6d..9b16db16327bf67221866e623edbcbaa6aadb389 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -14,7 +14,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT lite_cc_test(test_kernel_argmax_compute SRCS argmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_axpy_compute SRCS axpy_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_group_norm_compute SRCS group_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) @@ -40,21 +40,21 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT lite_cc_test(test_kernel_fill_constant_batch_size_like_compute SRCS fill_constant_batch_size_like_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) if(LITE_BUILD_EXTRA) - lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - #lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_sequence_conv_compute SRCS sequence_conv_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_assign_value_compute SRCS assign_value_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_box_clip_compute SRCS box_clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_reduce_sum_compute SRCS reduce_sum_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_reduce_prod_compute SRCS reduce_prod_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + #lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_sequence_conv_compute SRCS sequence_conv_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_assign_value_compute SRCS assign_value_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_box_clip_compute SRCS box_clip_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_reduce_sum_compute SRCS reduce_sum_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_reduce_prod_compute SRCS reduce_prod_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_roi_align_compute SRCS roi_align_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) @@ -73,7 +73,7 @@ if(LITE_BUILD_EXTRA) lite_cc_test(test_kernel_elementwise_grad_compute SRCS elementwise_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_mul_grad_compute SRCS mul_grad_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sgd_compute SRCS sgd_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_sequence_pool_grad_compute SRCS sequence_pool_grad_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_sequence_pool_grad_compute SRCS sequence_pool_grad_compute_test.cc DEPS ${bm_kernels} arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) endif() endif() @@ -90,5 +90,6 @@ endif() lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_expand_as_compute SRCS expand_as_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_flatten_compute SRCS flatten_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_crf_decoding_compute SRCS crf_decoding_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) endif() diff --git a/lite/tests/kernels/flatten_compute_test.cc b/lite/tests/kernels/flatten_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cecf98fdb437ff28b71bcf719f690d39d65daff9 --- /dev/null +++ b/lite/tests/kernels/flatten_compute_test.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" +#include "lite/tests/utils/fill_data.h" + +namespace paddle { +namespace lite { + +class FlattenComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string op_type_ = "flatten"; + std::string input_ = "x"; + std::string output_ = "out"; + std::string xshape_ = "xshape"; + DDim dims_; + int axis_; + + public: + FlattenComputeTester(const Place& place, + const std::string& alias, + DDim dims, + int axis) + : TestCase(place, alias), dims_(dims), axis_(axis) {} + + void RunBaseline(Scope* scope) override { + auto* out = scope->NewTensor(output_); + CHECK(out); + + auto* x = scope->FindTensor(input_); + + int64_t outer = 1, inner = 1; + for (size_t i = 0; i < dims_.size(); ++i) { + if (i < axis_) { + outer *= dims_[i]; + } else { + inner *= dims_[i]; + } + } + std::vector out_shape(2); + out_shape[0] = outer; + out_shape[1] = inner; + out->Resize(out_shape); + + auto x_data = x->data(); + auto out_data = out->mutable_data(); + memcpy(out_data, x_data, sizeof(float) * dims_.production()); + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType(op_type_); + op_desc->SetInput("X", {input_}); + op_desc->SetOutput("Out", {output_}); + if (op_type_ == "flatten2") { + op_desc->SetOutput("XShape", {xshape_}); + } + op_desc->SetAttr("axis", axis_); + } + + void PrepareData() override { + std::vector din(dims_.production()); + fill_data_rand(din.data(), -1.f, 1.f, dims_.production()); + SetCommonTensor(input_, dims_, din.data()); + } +}; + +void TestFlatten(Place place, float abs_error) { + DDim dims{{2, 3, 4, 5}}; + std::vector axes{0, 1, 2, 3}; + for (auto axis : axes) { + std::unique_ptr tester( + new FlattenComputeTester(place, "def", dims, axis)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision({"xshape"}); + } +} + +TEST(flatten, precision) { + LOG(INFO) << "test flatten op"; + Place place; + float abs_error = 1e-5; +#if defined(LITE_WITH_HUAWEI_ASCEND_NPU) + place = TARGET(kHuaweiAscendNPU); + abs_error = 1e-2; // precision_mode default is force_fp16 +#else + return; +#endif + + TestFlatten(place, abs_error); +} + +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/reshape_compute_test.cc b/lite/tests/kernels/reshape_compute_test.cc index f3fcc0bad5418624c86897bafc52dbf3a7ec0d8e..5e7cd953e740c952bde12c6d13a76b9ad9c344c3 100644 --- a/lite/tests/kernels/reshape_compute_test.cc +++ b/lite/tests/kernels/reshape_compute_test.cc @@ -208,6 +208,9 @@ TEST(Reshape, precision) { place = TARGET(kHost); #elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) place = TARGET(kXPU); +#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU) + place = TARGET(kHuaweiAscendNPU); + abs_error = 1e-2; // precision_mode default is force_fp16 #else return; #endif diff --git a/lite/tests/kernels/transpose_compute_test.cc b/lite/tests/kernels/transpose_compute_test.cc index 933e9f8ec5fc7b1d9b510c71f57fda309a5477dc..22c73e73c1422e0886c0f63ea726cbea6ba03e84 100644 --- a/lite/tests/kernels/transpose_compute_test.cc +++ b/lite/tests/kernels/transpose_compute_test.cc @@ -169,6 +169,9 @@ TEST(Transpose, precision) { #elif defined(LITE_WITH_NPU) place = TARGET(kNPU); abs_error = 1e-2; // Using fp16 in NPU +#elif defined(LITE_WITH_HUAWEI_ASCEND_NPU) + place = TARGET(kHuaweiAscendNPU); + abs_error = 1e-2; // precision_mode default is force_fp16 #else return; #endif diff --git a/lite/tools/build_android.sh b/lite/tools/build_android.sh index 13092593a86e1486448236aab1361b706b9a0e61..32e1a31ee3b589911866919e7c4ad2e9d02748cf 100755 --- a/lite/tools/build_android.sh +++ b/lite/tools/build_android.sh @@ -153,6 +153,10 @@ function make_tiny_publish_so { prepare_thirdparty fi + if [ "${WITH_STRIP}" == "ON" ]; then + WITH_EXTRA=ON + fi + local cmake_mutable_options=" -DLITE_BUILD_EXTRA=$WITH_EXTRA \ -DLITE_WITH_LOG=$WITH_LOG \ @@ -206,6 +210,10 @@ function make_full_publish_so { prepare_opencl_source_code $workspace $build_dir fi + if [ "${WITH_STRIP}" == "ON" ]; then + WITH_EXTRA=ON + fi + local cmake_mutable_options=" -DLITE_BUILD_EXTRA=$WITH_EXTRA \ -DLITE_WITH_LOG=$WITH_LOG \ diff --git a/lite/tools/build_ios.sh b/lite/tools/build_ios.sh index 4eea073a058ba9e1e821e9f0746687baa0c38d5f..f4232d0d2c2c5120ccc352f0d4b5f956137bef75 100755 --- a/lite/tools/build_ios.sh +++ b/lite/tools/build_ios.sh @@ -49,6 +49,10 @@ function make_ios { exit 1 fi + if [ "${WITH_STRIP}" == "ON" ]; then + WITH_EXTRA=ON + fi + build_dir=$workspace/build.ios.${os}.${arch} if [ -d $build_dir ] then @@ -61,7 +65,6 @@ function make_ios { GEN_CODE_PATH_PREFIX=lite/gen_code mkdir -p ./${GEN_CODE_PATH_PREFIX} touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc - cmake $workspace \ -DWITH_LITE=ON \ -DLITE_WITH_ARM=ON \ diff --git a/lite/tools/build_linux.sh b/lite/tools/build_linux.sh index 208f6aa31a150920c8a766859d4237cb745deb99..36306ad5807006492733cebe81e5ebb18f84751f 100755 --- a/lite/tools/build_linux.sh +++ b/lite/tools/build_linux.sh @@ -173,6 +173,9 @@ function make_tiny_publish_so { if [ "${WITH_OPENCL}" = "ON" ]; then prepare_opencl_source_code $workspace $build_dir fi + if [ "${WITH_STRIP}" == "ON" ]; then + WITH_EXTRA=ON + fi init_cmake_mutable_options cmake $workspace \