diff --git a/lite/kernels/arm/topk_compute.cc b/lite/kernels/arm/topk_compute.cc index 55c667c3067e74bf3938ececf5e9290da0a7d49b..c55bf2aa7861071770c4993800b5a2536d27511f 100644 --- a/lite/kernels/arm/topk_compute.cc +++ b/lite/kernels/arm/topk_compute.cc @@ -44,5 +44,5 @@ REGISTER_LITE_KERNEL( .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Indices", - {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) + {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) .Finalize(); diff --git a/lite/kernels/host/use_kernels.h b/lite/kernels/host/use_kernels.h deleted file mode 100644 index b5bab46a7191fc6732ea515b22e175141b87dc48..0000000000000000000000000000000000000000 --- a/lite/kernels/host/use_kernels.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "lite/core/op_registry.h" - -USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); -USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); -USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def); -USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def); diff --git a/lite/kernels/npu/bridges/CMakeLists.txt b/lite/kernels/npu/bridges/CMakeLists.txt index 8d62b7630e96416320f2b211b222dafb7338654c..bcf6ba63eb820ee187dd26b2722686a768f78c98 100644 --- a/lite/kernels/npu/bridges/CMakeLists.txt +++ b/lite/kernels/npu/bridges/CMakeLists.txt @@ -36,13 +36,12 @@ lite_cc_library(subgraph_bridge_split_op_npu SRCS split_op.cc DEPS ${npu_subgrap lite_cc_library(subgraph_bridge_concat_op_npu SRCS concat_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_shuffle_channel_op_npu SRCS shuffle_channel_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_pad2d_op_npu SRCS pad2d_op.cc DEPS ${npu_subgraph_bridge_deps}) -lite_cc_library(subgraph_bridge_square_op_npu SRCS square_op.cc DEPS ${npu_subgraph_bridge_deps}) -lite_cc_library(subgraph_bridge_sqrt_op_npu SRCS sqrt_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_reduce_mean_op_npu SRCS reduce_mean_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_unsqueeze_op_npu SRCS unsqueeze_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_argmax_op_npu SRCS argmax_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_instance_norm_op_npu SRCS instance_norm_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_dropout_op_npu SRCS dropout_op.cc DEPS ${npu_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_topk_op_npu SRCS topk_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_layer_norm_op_npu SRCS layer_norm_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_fill_constant_op_npu SRCS fill_constant_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_fill_constant_batch_size_like_op_npu SRCS fill_constant_batch_size_like_op.cc DEPS ${npu_subgraph_bridge_deps}) @@ -72,13 +71,12 @@ set(npu_subgraph_bridges subgraph_bridge_concat_op_npu subgraph_bridge_shuffle_channel_op_npu subgraph_bridge_pad2d_op_npu - subgraph_bridge_square_op_npu - subgraph_bridge_sqrt_op_npu subgraph_bridge_reduce_mean_op_npu subgraph_bridge_unsqueeze_op_npu subgraph_bridge_argmax_op_npu subgraph_bridge_instance_norm_op_npu subgraph_bridge_dropout_op_npu + subgraph_bridge_topk_op_npu subgraph_bridge_layer_norm_op_npu subgraph_bridge_fill_constant_op_npu subgraph_bridge_fill_constant_batch_size_like_op_npu diff --git a/lite/kernels/npu/bridges/act_op.cc b/lite/kernels/npu/bridges/act_op.cc index f3fd75f2d6728df2490a1ecf5c4e43c3ad08ab6c..db9a652b6c1b4055e09a70e1f407b1027fd1b1e8 100644 --- a/lite/kernels/npu/bridges/act_op.cc +++ b/lite/kernels/npu/bridges/act_op.cc @@ -21,6 +21,7 @@ namespace lite { namespace subgraph { namespace npu { +template int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { CHECK(ctx != nullptr); CHECK(op != nullptr); @@ -30,6 +31,40 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { auto scope = op->scope(); VLOG(3) << "[NPU] Converting " + op_type + "..."; + // Get input and output vars and op attributes + auto x_name = op_info->Input("X").front(); + auto x = scope->FindTensor(x_name); + + auto out_name = op_info->Output("Out").front(); + + // X node + std::shared_ptr x_node = nullptr; + if (graph->Has(x_name)) { + x_node = graph->Get(x_name); + } else { + x_node = graph->Add(x_name, *x); + } + + // Act node + auto act_node = graph->template Add(out_name); + auto act_op = act_node->template data(); + act_op->set_input_x(*x_node->data()); + + return SUCCESS; +} + +template <> +int ActConverter(void* ctx, + OpLite* op, + KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[NPU] Converting " + op_type + "..."; + // Get input and output vars and op attributes auto x_name = op_info->Input("X").front(); auto x = scope->FindMutableTensor(x_name); @@ -45,8 +80,8 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { } // Act node - auto act_node = graph->Add(out_name); - auto act_op = act_node->data(); + auto act_node = graph->template Add(out_name); + auto act_op = act_node->template data(); act_op->set_input_x(*x_node->data()); // TODO(hong19860320) set the coef value for act Ops, such as leaky_relu, // clipped_relu etc. @@ -74,27 +109,42 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { } // namespace lite } // namespace paddle -REGISTER_SUBGRAPH_BRIDGE(sigmoid, - kNPU, - paddle::lite::subgraph::npu::ActConverter); -REGISTER_SUBGRAPH_BRIDGE(relu, kNPU, paddle::lite::subgraph::npu::ActConverter); -REGISTER_SUBGRAPH_BRIDGE(tanh, kNPU, paddle::lite::subgraph::npu::ActConverter); -REGISTER_SUBGRAPH_BRIDGE(relu_clipped, - kNPU, - paddle::lite::subgraph::npu::ActConverter); -REGISTER_SUBGRAPH_BRIDGE(relu6, - kNPU, - paddle::lite::subgraph::npu::ActConverter); -REGISTER_SUBGRAPH_BRIDGE(leaky_relu, - kNPU, - paddle::lite::subgraph::npu::ActConverter); -REGISTER_SUBGRAPH_BRIDGE(abs, kNPU, paddle::lite::subgraph::npu::ActConverter); -REGISTER_SUBGRAPH_BRIDGE(softsign, - kNPU, - paddle::lite::subgraph::npu::ActConverter); -REGISTER_SUBGRAPH_BRIDGE(softplus, - kNPU, - paddle::lite::subgraph::npu::ActConverter); -REGISTER_SUBGRAPH_BRIDGE(hard_sigmoid, - kNPU, - paddle::lite::subgraph::npu::ActConverter); +REGISTER_SUBGRAPH_BRIDGE( + sigmoid, + kNPU, + paddle::lite::subgraph::npu::ActConverter); +REGISTER_SUBGRAPH_BRIDGE( + relu, kNPU, paddle::lite::subgraph::npu::ActConverter); +REGISTER_SUBGRAPH_BRIDGE( + tanh, kNPU, paddle::lite::subgraph::npu::ActConverter); +REGISTER_SUBGRAPH_BRIDGE( + relu_clipped, + kNPU, + paddle::lite::subgraph::npu::ActConverter); +REGISTER_SUBGRAPH_BRIDGE( + relu6, kNPU, paddle::lite::subgraph::npu::ActConverter); +REGISTER_SUBGRAPH_BRIDGE( + leaky_relu, + kNPU, + paddle::lite::subgraph::npu::ActConverter); +REGISTER_SUBGRAPH_BRIDGE( + abs, kNPU, paddle::lite::subgraph::npu::ActConverter); +REGISTER_SUBGRAPH_BRIDGE( + softsign, + kNPU, + paddle::lite::subgraph::npu::ActConverter); +REGISTER_SUBGRAPH_BRIDGE( + softplus, + kNPU, + paddle::lite::subgraph::npu::ActConverter); +REGISTER_SUBGRAPH_BRIDGE( + hard_sigmoid, + kNPU, + paddle::lite::subgraph::npu::ActConverter); + +REGISTER_SUBGRAPH_BRIDGE( + log, kNPU, paddle::lite::subgraph::npu::ActConverter); +REGISTER_SUBGRAPH_BRIDGE( + square, kNPU, paddle::lite::subgraph::npu::ActConverter); +REGISTER_SUBGRAPH_BRIDGE( + sqrt, kNPU, paddle::lite::subgraph::npu::ActConverter); diff --git a/lite/kernels/npu/bridges/paddle_use_bridges.h b/lite/kernels/npu/bridges/paddle_use_bridges.h index 49d1116fce75c3501ee1892f49b8c3d33a58e309..6c406302212640ec41d0701f530c0c1f32229539 100644 --- a/lite/kernels/npu/bridges/paddle_use_bridges.h +++ b/lite/kernels/npu/bridges/paddle_use_bridges.h @@ -21,6 +21,9 @@ USE_SUBGRAPH_BRIDGE(relu_clipped, kNPU); USE_SUBGRAPH_BRIDGE(leaky_relu, kNPU); USE_SUBGRAPH_BRIDGE(softsign, kNPU); USE_SUBGRAPH_BRIDGE(hard_sigmoid, kNPU); +USE_SUBGRAPH_BRIDGE(log, kNPU); +USE_SUBGRAPH_BRIDGE(sqrt, kNPU); +USE_SUBGRAPH_BRIDGE(square, kNPU); USE_SUBGRAPH_BRIDGE(batch_norm, kNPU); USE_SUBGRAPH_BRIDGE(less_than, kNPU); @@ -58,8 +61,7 @@ USE_SUBGRAPH_BRIDGE(scale, kNPU); USE_SUBGRAPH_BRIDGE(shuffle_channel, kNPU); USE_SUBGRAPH_BRIDGE(softmax, kNPU); USE_SUBGRAPH_BRIDGE(split, kNPU); -USE_SUBGRAPH_BRIDGE(sqrt, kNPU); -USE_SUBGRAPH_BRIDGE(square, kNPU); +// USE_SUBGRAPH_BRIDGE(top_k, kNPU); USE_SUBGRAPH_BRIDGE(transpose, kNPU); USE_SUBGRAPH_BRIDGE(transpose2, kNPU); USE_SUBGRAPH_BRIDGE(unsqueeze, kNPU); diff --git a/lite/kernels/npu/bridges/sqrt_op_test.cc b/lite/kernels/npu/bridges/sqrt_op_test.cc deleted file mode 100644 index 015d61685b2d99c3df55269442d61b4a137a2ca3..0000000000000000000000000000000000000000 --- a/lite/kernels/npu/bridges/sqrt_op_test.cc +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include "lite/core/op_registry.h" -#include "lite/kernels/npu/bridges/registry.h" -#include "lite/kernels/npu/bridges/test_helper.h" -#include "lite/operators/activation_ops.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace npu { -namespace bridges { - -template -void sqrt_ref(const std::shared_ptr op) { - Scope* scope = op->scope(); - const OpInfo* op_info = op->op_info(); - - auto x = scope->FindTensor("x"); - auto out = scope->FindMutableTensor("out_ref"); - out->Resize(x->dims()); - auto x_data = x->data(); - auto out_data = out->mutable_data(); - - for (size_t i = 0; i < x->numel(); i++) { - out_data[i] = std::sqrtf(x_data[i]); - } -} - -void test_sqrt(const std::vector& input_shape) { - // prepare input&output variables - Scope scope; - std::string x_var_name = "x"; - std::string out_var_name = "out"; - std::string out_ref_var_name = "out_ref"; - auto* x = scope.NewTensor(x_var_name); - auto* out = scope.NewTensor(out_var_name); - auto* out_ref = scope.NewTensor(out_ref_var_name); - x->Resize(input_shape); - - // initialize input&output data - FillTensor(x, 0, 5); - - // initialize op desc - cpp::OpDesc opdesc; - opdesc.SetType("sqrt"); - opdesc.SetInput("X", {x_var_name}); - opdesc.SetOutput("Out", {out_var_name}); - - // create and convert op to NPU model, then run it on NPU - auto op = CreateOp(opdesc, &scope); - LauchOp(op, {x_var_name}, {out_var_name}); - - // execute reference implementation and save to output tensor - sqrt_ref(op); - - // compare results - auto* out_data = out->mutable_data(); - auto* out_ref_data = out_ref->mutable_data(); - for (int i = 0; i < out->dims().production(); i++) { - EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2); - } -} - -TEST(NPUBridges, sqrt) { - test_sqrt({2}); - test_sqrt({2, 3}); - test_sqrt({1, 2, 3, 4}); - test_sqrt({5, 6, 7, 8}); -} - -} // namespace bridges -} // namespace npu -} // namespace kernels -} // namespace lite -} // namespace paddle - -USE_LITE_OP(sqrt); -USE_NPU_BRIDGE(sqrt); diff --git a/lite/kernels/npu/bridges/square_op.cc b/lite/kernels/npu/bridges/square_op.cc deleted file mode 100644 index a25d255de6f8e7e2af12514ad075e83df57f2a7e..0000000000000000000000000000000000000000 --- a/lite/kernels/npu/bridges/square_op.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "lite/kernels/npu/bridges/graph.h" -#include "lite/kernels/npu/bridges/registry.h" -#include "lite/kernels/npu/bridges/utility.h" - -namespace paddle { -namespace lite { -namespace subgraph { -namespace npu { - -int SquareConverter(void* ctx, OpLite* op, KernelBase* kernel) { - CHECK(ctx != nullptr); - CHECK(op != nullptr); - auto graph = static_cast(ctx); - auto op_info = op->op_info(); - auto op_type = op_info->Type(); - auto scope = op->scope(); - VLOG(3) << "[NPU] Converting " + op_type + "..."; - - // Get input and output vars and op attributes - auto x_name = op_info->Input("X").front(); - auto x = scope->FindMutableTensor(x_name); - auto x_dims = x->dims(); - auto out_name = op_info->Output("Out").front(); - - // X node - std::shared_ptr x_node = nullptr; - if (graph->Has(x_name)) { - x_node = graph->Get(x_name); - } else { - x_node = graph->Add(x_name, *x); - } - - // Square node - auto square_node = graph->Add(out_name); - auto square_op = square_node->data(); - square_op->set_input_x(*x_node->data()); - return SUCCESS; -} - -} // namespace npu -} // namespace subgraph -} // namespace lite -} // namespace paddle - -REGISTER_SUBGRAPH_BRIDGE(square, - kNPU, - paddle::lite::subgraph::npu::SquareConverter); diff --git a/lite/kernels/npu/bridges/square_op_test.cc b/lite/kernels/npu/bridges/square_op_test.cc deleted file mode 100644 index d715c11430096a0b6503fbe6047a40c3c29ba8f5..0000000000000000000000000000000000000000 --- a/lite/kernels/npu/bridges/square_op_test.cc +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "lite/core/op_registry.h" -#include "lite/kernels/npu/bridges/registry.h" -#include "lite/kernels/npu/bridges/test_helper.h" -#include "lite/operators/activation_ops.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace npu { -namespace bridges { - -template -void square_ref(const std::shared_ptr op) { - Scope* scope = op->scope(); - const OpInfo* op_info = op->op_info(); - - auto x = scope->FindTensor("x"); - auto out = scope->FindMutableTensor("out_ref"); - out->Resize(x->dims()); - auto x_data = x->data(); - auto out_data = out->mutable_data(); - - for (size_t i = 0; i < x->numel(); i++) { - out_data[i] = x_data[i] * x_data[i]; - } -} - -void test_square(const std::vector& input_shape) { - // prepare input&output variables - Scope scope; - std::string x_var_name = "x"; - std::string out_var_name = "out"; - std::string out_ref_var_name = "out_ref"; - auto* x = scope.NewTensor(x_var_name); - auto* out = scope.NewTensor(out_var_name); - auto* out_ref = scope.NewTensor(out_ref_var_name); - x->Resize(input_shape); - - // initialize input&output data - FillTensor(x); - - // initialize op desc - cpp::OpDesc opdesc; - opdesc.SetType("square"); - opdesc.SetInput("X", {x_var_name}); - opdesc.SetOutput("Out", {out_var_name}); - - // create and convert op to NPU model, then run it on NPU - auto op = CreateOp(opdesc, &scope); - LauchOp(op, {x_var_name}, {out_var_name}); - - // execute reference implementation and save to output tensor - square_ref(op); - - // compare results - auto* out_data = out->mutable_data(); - auto* out_ref_data = out_ref->mutable_data(); - for (int i = 0; i < out->dims().production(); i++) { - EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2); - } -} - -TEST(NPUBridges, square) { - test_square({2}); - test_square({2, 3}); - test_square({1, 2, 3, 4}); - test_square({5, 6, 7, 8}); -} - -} // namespace bridges -} // namespace npu -} // namespace kernels -} // namespace lite -} // namespace paddle - -USE_LITE_OP(square); -USE_NPU_BRIDGE(square); diff --git a/lite/kernels/npu/bridges/sqrt_op.cc b/lite/kernels/npu/bridges/topk_op.cc similarity index 70% rename from lite/kernels/npu/bridges/sqrt_op.cc rename to lite/kernels/npu/bridges/topk_op.cc index 85fe7bd8c83f4262e81fb615fb8a78978a71a2e9..1cc662e054d3c70a21c49ce00bd8f2e836e64883 100644 --- a/lite/kernels/npu/bridges/sqrt_op.cc +++ b/lite/kernels/npu/bridges/topk_op.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ namespace lite { namespace subgraph { namespace npu { -int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) { +int TopkConverter(void* ctx, OpLite* op, KernelBase* kernel) { CHECK(ctx != nullptr); CHECK(op != nullptr); auto graph = static_cast(ctx); @@ -32,10 +32,12 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) { // Get input and output vars and op attributes auto x_name = op_info->Input("X").front(); - auto x = scope->FindMutableTensor(x_name); - auto x_dims = x->dims(); + auto x = scope->FindTensor(x_name); + auto out_name = op_info->Output("Out").front(); + int k = op_info->GetAttr("k"); + // X node std::shared_ptr x_node = nullptr; if (graph->Has(x_name)) { @@ -44,10 +46,16 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) { x_node = graph->Add(x_name, *x); } - // Sqrt node - auto sqrt_node = graph->Add(out_name); - auto sqrt_op = sqrt_node->data(); - sqrt_op->set_input_x(*x_node->data()); + // k node + std::shared_ptr k_node = graph->Add(out_name + "/k", k); + + // topk node + auto topk_node = graph->Add(out_name); + auto topk_op = topk_node->data(); + topk_op->set_input_x(*x_node->data()); + topk_op->set_input_k(*k_node->data()); + topk_op->set_attr_format(0); + return SUCCESS; } @@ -56,6 +64,6 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) { } // namespace lite } // namespace paddle -REGISTER_SUBGRAPH_BRIDGE(sqrt, +REGISTER_SUBGRAPH_BRIDGE(top_k, kNPU, - paddle::lite::subgraph::npu::SqrtConverter); + paddle::lite::subgraph::npu::TopkConverter); diff --git a/lite/operators/topk_op.cc b/lite/operators/topk_op.cc index a15c3c7e41f9b53d3f8996b405a50c5e4005b1dd..fbfb825544870dfaf3e18d1595f2824970b7352b 100644 --- a/lite/operators/topk_op.cc +++ b/lite/operators/topk_op.cc @@ -20,6 +20,8 @@ namespace operators { bool TopkOp::CheckShape() const { CHECK_OR_FALSE(param_.X); + CHECK_OR_FALSE(param_.Out); + CHECK_OR_FALSE(param_.Indices); return true; } @@ -28,26 +30,25 @@ bool TopkOp::InferShape() const { out_dims[out_dims.size() - 1] = param_.K; auto out = param_.Out; out->Resize(out_dims); - auto out_lod = out->mutable_lod(); - *out_lod = param_.X->lod(); - auto ind = param_.Indices; - ind->Resize(out_dims); - auto ind_lod = out->mutable_lod(); - *ind_lod = param_.X->lod(); + out->set_lod(param_.X->lod()); + + auto indices = param_.Indices; + indices->Resize(out_dims); + indices->set_lod(param_.X->lod()); + return true; } bool TopkOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { auto x = op_desc.Input("X").front(); - param_.X = scope->FindVar(x)->GetMutable(); + param_.X = scope->FindTensor(x); - auto outputs0 = op_desc.Output("Out").front(); - auto outputs1 = op_desc.Output("Indices").front(); - param_.Out = scope->FindVar(outputs0)->GetMutable(); - param_.Indices = scope->FindVar(outputs1)->GetMutable(); + auto output0 = op_desc.Output("Out").front(); + auto output1 = op_desc.Output("Indices").front(); + param_.Out = scope->FindMutableTensor(output0); + param_.Indices = scope->FindMutableTensor(output1); param_.K = op_desc.GetAttr("k"); - CHECK(param_.X); CHECK_GE(param_.K, 1) << "topK param is not valid"; return true; } diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 4ecab783a1ae672ca06b19b5913c25f8642209b1..eae674a53cbd38dab3de886fcf5821395cba27bc 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -21,7 +21,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_ #lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_logical_xor_compute SRCS logical_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - #lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/activation_compute_test.cc b/lite/tests/kernels/activation_compute_test.cc index 50b38369d4fa36abe92a1ed572c50f61c67c5d56..afbf194976c6e524c05e95f9273748ed70b96277 100644 --- a/lite/tests/kernels/activation_compute_test.cc +++ b/lite/tests/kernels/activation_compute_test.cc @@ -271,25 +271,12 @@ TEST(Activation_relu, precision) { return; #endif - for (auto n : {1, 3}) { - for (auto c : {3, 6}) { - for (auto h : {9, 18}) { - for (auto w : {9, 18}) { - std::unique_ptr tester(new ActivationComputeTester( - place, - "def", - 0.01, - 6., - "all", - 0., - DDim(std::vector({n, c, h, w})), - "relu", - RELU)); - arena::Arena arena(std::move(tester), place, abs_error); - arena.TestPrecision(); - } - } - } + for (auto dims : std::vector>{ + {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { + std::unique_ptr tester(new ActivationComputeTester( + place, "def", 0.01, 6., "all", 0., DDim(dims), "relu", RELU)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); } } @@ -306,26 +293,21 @@ TEST(Activation_leaky_relu, precision) { return; #endif - for (auto n : {1, 3}) { - for (auto c : {3, 6}) { - for (auto h : {9, 18}) { - for (auto w : {9, 18}) { - for (auto slope : {0.01, 0.1}) { - std::unique_ptr tester(new ActivationComputeTester( - place, - "def", - slope, - 6., - "all", - 0., - DDim(std::vector({n, c, h, w})), - "leaky_relu", - LEAKY_RELU)); - arena::Arena arena(std::move(tester), place, abs_error); - arena.TestPrecision(); - } - } - } + for (auto dims : std::vector>{ + {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { + for (auto slope : {0.01, 0.1}) { + std::unique_ptr tester( + new ActivationComputeTester(place, + "def", + slope, + 6., + "all", + 0., + DDim(dims), + "leaky_relu", + LEAKY_RELU)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); } } } @@ -343,26 +325,21 @@ TEST(Activation_relu_clipped, precision) { return; #endif - for (auto n : {1, 3}) { - for (auto c : {3, 6}) { - for (auto h : {9, 18}) { - for (auto w : {9, 18}) { - for (auto coef : {0.5, 6.}) { - std::unique_ptr tester(new ActivationComputeTester( - place, - "def", - 0.01, - coef, - "all", - 0., - DDim(std::vector({n, c, h, w})), - "relu_clipped", - RELU_CLIPPED)); - arena::Arena arena(std::move(tester), place, abs_error); - arena.TestPrecision(); - } - } - } + for (auto dims : std::vector>{ + {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { + for (auto coef : {0.5, 6.}) { + std::unique_ptr tester( + new ActivationComputeTester(place, + "def", + 0.01, + coef, + "all", + 0., + DDim(dims), + "relu_clipped", + RELU_CLIPPED)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); } } } @@ -372,26 +349,12 @@ TEST(Activation_prelu, precision) { #ifdef LITE_WITH_ARM Place place(TARGET(kARM)); - for (auto n : {1, 3}) { - for (auto c : {3, 6}) { - for (auto h : {9, 18}) { - for (auto w : {9, 18}) { - for (auto mode : {"all", "channel", "element"}) { - std::unique_ptr tester(new ActivationComputeTester( - place, - "def", - 0.01, - 6, - mode, - 0., - DDim(std::vector({n, c, h, w})), - "prelu", - PRELU)); - arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); - } - } - } + for (auto dims : std::vector>{{1, 3, 2, 4}}) { + for (auto mode : {"all", "channel", "element"}) { + std::unique_ptr tester(new ActivationComputeTester( + place, "def", 0.01, 6, mode, 0., DDim(dims), "prelu", PRELU)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); } } #endif @@ -410,25 +373,12 @@ TEST(Activation_sigmoid, precision) { return; #endif - for (auto n : {1, 3}) { - for (auto c : {3, 6}) { - for (auto h : {9, 18}) { - for (auto w : {9, 18}) { - std::unique_ptr tester(new ActivationComputeTester( - place, - "def", - 0.01, - 6., - "all", - 0., - DDim(std::vector({n, c, h, w})), - "sigmoid", - SIGMOID)); - arena::Arena arena(std::move(tester), place, abs_error); - arena.TestPrecision(); - } - } - } + for (auto dims : std::vector>{ + {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { + std::unique_ptr tester(new ActivationComputeTester( + place, "def", 0.01, 6., "all", 0., DDim(dims), "sigmoid", SIGMOID)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); } } @@ -447,25 +397,12 @@ TEST(Activation_tanh, precision) { return; #endif - for (auto n : {1, 3}) { - for (auto c : {3, 6}) { - for (auto h : {9, 18}) { - for (auto w : {9, 18}) { - std::unique_ptr tester(new ActivationComputeTester( - place, - "def", - 0.01, - 6., - "all", - 0., - DDim(std::vector({n, c, h, w})), - "tanh", - TANH)); - arena::Arena arena(std::move(tester), place, abs_error); - arena.TestPrecision(); - } - } - } + for (auto dims : std::vector>{ + {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { + std::unique_ptr tester(new ActivationComputeTester( + place, "def", 0.01, 6., "all", 0., DDim(dims), "tanh", TANH)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); } } @@ -474,26 +411,13 @@ TEST(Activation_swish, precision) { #ifdef LITE_WITH_ARM Place place(TARGET(kARM)); - for (auto n : {1, 3}) { - for (auto c : {3, 6}) { - for (auto h : {9, 18}) { - for (auto w : {9, 18}) { - for (auto coef : {0.01, 0.1}) { - std::unique_ptr tester(new ActivationComputeTester( - place, - "def", - 0.01, - 6, - "all", - coef, - DDim(std::vector({n, c, h, w})), - "swish", - SWISH)); - arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); - } - } - } + for (auto dims : std::vector>{ + {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { + for (auto coef : {0.01, 0.1}) { + std::unique_ptr tester(new ActivationComputeTester( + place, "def", 0.01, 6, "all", coef, DDim(dims), "swish", SWISH)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); } } #endif @@ -504,26 +428,13 @@ TEST(Activation_relu6, precision) { #ifdef LITE_WITH_ARM Place place(TARGET(kARM)); - for (auto n : {1, 3}) { - for (auto c : {3, 6}) { - for (auto h : {9, 18}) { - for (auto w : {9, 18}) { - for (auto slope : {0.01, 0.1}) { - std::unique_ptr tester(new ActivationComputeTester( - place, - "def", - 0.01, - 6., - "all", - 0., - DDim(std::vector({n, c, h, w})), - "relu6", - RELU6)); - arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); - } - } - } + for (auto dims : std::vector>{ + {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { + for (auto slope : {0.01, 0.1}) { + std::unique_ptr tester(new ActivationComputeTester( + place, "def", 0.01, 6., "all", 0., DDim(dims), "relu6", RELU6)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); } } #endif @@ -531,30 +442,24 @@ TEST(Activation_relu6, precision) { TEST(Activation_log, precision) { LOG(INFO) << "test log op"; -#ifdef LITE_WITH_ARM - Place place(TARGET(kARM)); + Place place; + float abs_error = 2e-5; +#if defined(LITE_WITH_NPU) + place = TARGET(kNPU); + abs_error = 1e-2; // Using fp16 in NPU +#elif defined(LITE_WITH_ARM) + place = TARGET(kARM); +#else + return; +#endif - for (auto n : {1, 3}) { - for (auto c : {3, 6}) { - for (auto h : {9, 18}) { - for (auto w : {9, 18}) { - std::unique_ptr tester(new ActivationComputeTester( - place, - "def", - 0.01, - 6., - "all", - 0., - DDim(std::vector({n, c, h, w})), - "log", - LOG)); - arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); - } - } - } + for (auto dims : std::vector>{ + {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { + std::unique_ptr tester(new ActivationComputeTester( + place, "def", 0.01, 6., "all", 0., DDim(dims), "log", LOG)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); } -#endif } TEST(Activation_exp, precision) { @@ -562,25 +467,12 @@ TEST(Activation_exp, precision) { #ifdef LITE_WITH_ARM Place place(TARGET(kARM)); - for (auto n : {1, 3}) { - for (auto c : {3, 6}) { - for (auto h : {9, 18}) { - for (auto w : {9, 18}) { - std::unique_ptr tester(new ActivationComputeTester( - place, - "def", - 0.01, - 6., - "all", - 0., - DDim(std::vector({n, c, h, w})), - "exp", - EXP)); - arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); - } - } - } + for (auto dims : std::vector>{ + {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { + std::unique_ptr tester(new ActivationComputeTester( + place, "def", 0.01, 6., "all", 0., DDim(dims), "exp", EXP)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); } #endif } @@ -589,26 +481,14 @@ TEST(Activation_floor, precision) { LOG(INFO) << "test floor op"; #ifdef LITE_WITH_ARM Place place(TARGET(kARM)); - for (auto n : {1, 3}) { - for (auto c : {3, 6}) { - for (auto h : {9, 18}) { - for (auto w : {9, 18}) { - std::unique_ptr tester(new ActivationComputeTester( - place, - "def", - 0.01, - 6., - "all", - 0., - DDim(std::vector({n, c, h, w})), - "floor", - FLOOR)); - arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); - } - } - } + for (auto dims : std::vector>{ + {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { + std::unique_ptr tester(new ActivationComputeTester( + place, "def", 0.01, 6., "all", 0., DDim(dims), "floor", FLOOR)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); } + #endif } @@ -616,54 +496,36 @@ TEST(Activation_rsqrt, precision) { LOG(INFO) << "test rsqrt op"; #ifdef LITE_WITH_ARM Place place(TARGET(kARM)); - for (auto n : {2}) { - for (auto c : {2}) { - for (auto h : {2}) { - for (auto w : {2}) { - std::unique_ptr tester(new ActivationComputeTester( - place, - "def", - 0.01, - 6., - "all", - 0., - DDim(std::vector({n, c, h, w})), - "rsqrt", - RSQRT)); - arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); - } - } - } + for (auto dims : std::vector>{ + {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { + std::unique_ptr tester(new ActivationComputeTester( + place, "def", 0.01, 6., "all", 0., DDim(dims), "rsqrt", RSQRT)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); } #endif } TEST(Activation_square, precision) { LOG(INFO) << "test square op"; -#ifdef LITE_WITH_ARM - Place place(TARGET(kARM)); - for (auto n : {2}) { - for (auto c : {2}) { - for (auto h : {2}) { - for (auto w : {2}) { - std::unique_ptr tester(new ActivationComputeTester( - place, - "def", - 0.01, - 6., - "all", - 0., - DDim(std::vector({n, c, h, w})), - "square", - SQUARE)); - arena::Arena arena(std::move(tester), place, 2e-5); - arena.TestPrecision(); - } - } - } - } + Place place; + float abs_error = 2e-5; +#if defined(LITE_WITH_NPU) + place = TARGET(kNPU); + abs_error = 1e-2; // Using fp16 in NPU +#elif defined(LITE_WITH_ARM) + place = TARGET(kARM); +#else + return; #endif + + for (auto dims : std::vector>{ + {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { + std::unique_ptr tester(new ActivationComputeTester( + place, "def", 0.01, 6., "all", 0., DDim(dims), "square", SQUARE)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } } TEST(Activation_gelu, precision) { diff --git a/lite/tests/kernels/topk_compute_test.cc b/lite/tests/kernels/topk_compute_test.cc index 3c5540e48f3ba63508c511051265810bb9cf234b..699dd000fd49080e7b722754c6c515fb2b77a40c 100644 --- a/lite/tests/kernels/topk_compute_test.cc +++ b/lite/tests/kernels/topk_compute_test.cc @@ -16,102 +16,109 @@ #include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_ops.h" #include "lite/core/arena/framework.h" +#include "lite/tests/utils/fill_data.h" namespace paddle { namespace lite { -bool comp_func(std::pair a, std::pair b) { + +template +bool comp_func(std::pair a, std::pair b) { return (a.first > b.first); } +template class TopkComputeTester : public arena::TestCase { protected: // common attributes for this op. - std::string input_ = "x"; - std::string out_val_ = "out_val"; - std::string out_ind_ = "out_ind"; - int K_ = 1; - DDim dims_{{3, 5, 4, 4}}; + std::string x_ = "x"; + std::string out_ = "out"; + std::string indices_ = "indices"; + DDim x_dims_{{3, 5, 4, 4}}; + int k_ = 1; public: TopkComputeTester(const Place& place, const std::string& alias, - int K, - DDim dims) - : TestCase(place, alias), K_(K), dims_(dims) {} + DDim x_dims, + int k = 1) + : TestCase(place, alias), x_dims_(x_dims), k_(k) {} void RunBaseline(Scope* scope) override { - auto* out_val = scope->NewTensor(out_val_); - auto* out_ind = scope->NewTensor(out_ind_); - CHECK(out_val); - CHECK(out_ind); - DDim out_dims = dims_; - out_dims[out_dims.size() - 1] = K_; + auto* out_val = scope->NewTensor(out_); + auto* out_ind = scope->NewTensor(indices_); + DDim out_dims = x_dims_; + out_dims[out_dims.size() - 1] = k_; out_val->Resize(out_dims); out_ind->Resize(out_dims); - auto* out_val_data = out_val->mutable_data(); - auto* out_ind_data = out_ind->mutable_data(); + auto* out_val_data = out_val->mutable_data(); + auto* out_ind_data = out_ind->mutable_data(); - auto* x = scope->FindTensor(input_); - const auto* x_data = x->data(); - int m = out_dims.production() / K_; - int n = dims_[dims_.size() - 1]; + auto* x = scope->FindTensor(x_); + const auto* x_data = x->data(); + int m = out_dims.production() / k_; + int n = x_dims_[x_dims_.size() - 1]; for (int i = 0; i < m; i++) { - const float* in_tmp = x_data + i * n; - float* out_val_tmp = out_val_data + i * K_; - int* out_ind_tmp = out_ind_data + i * K_; - std::vector> vec; + const T1* in_tmp = x_data + i * n; + T1* out_val_tmp = out_val_data + i * k_; + T2* out_ind_tmp = out_ind_data + i * k_; + std::vector> vec; for (int j = 0; j < n; j++) { - vec.push_back(std::make_pair(in_tmp[j], j)); + vec.push_back(std::make_pair(in_tmp[j], static_cast(j))); } - std::partial_sort(vec.begin(), vec.begin() + K_, vec.end(), comp_func); - for (int q = 0; q < K_; q++) { + std::partial_sort( + vec.begin(), vec.begin() + k_, vec.end(), comp_func); + for (int q = 0; q < k_; q++) { out_val_tmp[q] = vec[q].first; out_ind_tmp[q] = vec[q].second; - LOG(INFO) << "out:" << i << " " << q << " " << out_val_tmp[q] << " " - << out_ind_tmp[q]; } } } void PrepareOpDesc(cpp::OpDesc* op_desc) { - op_desc->SetType("topk"); - op_desc->SetInput("X", {input_}); - op_desc->SetOutput("Out", {out_val_, out_ind_}); - op_desc->SetAttr("K", K_); + op_desc->SetType("top_k"); + op_desc->SetInput("X", {x_}); + op_desc->SetOutput("Out", {out_}); + op_desc->SetOutput("Indices", {indices_}); + op_desc->SetAttr("k", k_); } void PrepareData() override { - std::vector data(dims_.production()); - - for (int i = 0; i < dims_.production(); i++) { - data[i] = std::rand() * 1.0f / RAND_MAX; - } - - SetCommonTensor(input_, dims_, data.data()); + std::vector dx(x_dims_.production()); + fill_data_rand(dx.data(), -1, 1, x_dims_.production()); + SetCommonTensor(x_, x_dims_, dx.data()); } }; -void test_topk(Place place) { - DDimLite dims_0{{3, 5}}; - DDimLite dims_1{{8}}; - for (int K : {1, 2}) { - for (auto dims : {dims_0, dims_1}) { +template +void test_topk(Place place, float abs_error) { + for (auto x_shape : std::vector>{ + {2, 3, 4, 5}, {3, 4, 5}, {4, 5}, {5}}) { + for (int k : {2, 5}) { std::unique_ptr tester( - new TopkComputeTester(place, "def", K, dims)); - arena::Arena arena(std::move(tester), place, 2e-5); + new TopkComputeTester(place, "def", DDim(x_shape), k)); + arena::Arena arena(std::move(tester), place, abs_error); arena.TestPrecision(); } } } TEST(Topk, precision) { -// #ifdef LITE_WITH_X86 -// Place place(TARGET(kX86)); -// #endif -#ifdef LITE_WITH_ARM - Place place(TARGET(kARM)); - test_topk(place); + Place place; + float abs_error = 2e-5; +#if defined(LITE_WITH_NPU) + place = TARGET(kNPU); + abs_error = 1e-3; // Using fp16 in NPU +#elif defined(LITE_WITH_ARM) + place = TARGET(kARM); +#else + return; +#endif + +#if defined(LITE_WITH_NPU) + test_topk(place, abs_error); +#else + test_topk(place, abs_error); #endif }