未验证 提交 b49d95b1 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] add topk, log op bridge (#3216)

上级 327c62e0
...@@ -44,5 +44,5 @@ REGISTER_LITE_KERNEL( ...@@ -44,5 +44,5 @@ REGISTER_LITE_KERNEL(
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Indices", .BindOutput("Indices",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize(); .Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/op_registry.h"
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def);
USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def);
...@@ -36,13 +36,12 @@ lite_cc_library(subgraph_bridge_split_op_npu SRCS split_op.cc DEPS ${npu_subgrap ...@@ -36,13 +36,12 @@ lite_cc_library(subgraph_bridge_split_op_npu SRCS split_op.cc DEPS ${npu_subgrap
lite_cc_library(subgraph_bridge_concat_op_npu SRCS concat_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_concat_op_npu SRCS concat_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_shuffle_channel_op_npu SRCS shuffle_channel_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_shuffle_channel_op_npu SRCS shuffle_channel_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_pad2d_op_npu SRCS pad2d_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_pad2d_op_npu SRCS pad2d_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_square_op_npu SRCS square_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_sqrt_op_npu SRCS sqrt_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_reduce_mean_op_npu SRCS reduce_mean_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_reduce_mean_op_npu SRCS reduce_mean_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_unsqueeze_op_npu SRCS unsqueeze_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_unsqueeze_op_npu SRCS unsqueeze_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_argmax_op_npu SRCS argmax_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_argmax_op_npu SRCS argmax_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_instance_norm_op_npu SRCS instance_norm_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_instance_norm_op_npu SRCS instance_norm_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_dropout_op_npu SRCS dropout_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_dropout_op_npu SRCS dropout_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_topk_op_npu SRCS topk_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_layer_norm_op_npu SRCS layer_norm_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_layer_norm_op_npu SRCS layer_norm_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_fill_constant_op_npu SRCS fill_constant_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_fill_constant_op_npu SRCS fill_constant_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_fill_constant_batch_size_like_op_npu SRCS fill_constant_batch_size_like_op.cc DEPS ${npu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_fill_constant_batch_size_like_op_npu SRCS fill_constant_batch_size_like_op.cc DEPS ${npu_subgraph_bridge_deps})
...@@ -72,13 +71,12 @@ set(npu_subgraph_bridges ...@@ -72,13 +71,12 @@ set(npu_subgraph_bridges
subgraph_bridge_concat_op_npu subgraph_bridge_concat_op_npu
subgraph_bridge_shuffle_channel_op_npu subgraph_bridge_shuffle_channel_op_npu
subgraph_bridge_pad2d_op_npu subgraph_bridge_pad2d_op_npu
subgraph_bridge_square_op_npu
subgraph_bridge_sqrt_op_npu
subgraph_bridge_reduce_mean_op_npu subgraph_bridge_reduce_mean_op_npu
subgraph_bridge_unsqueeze_op_npu subgraph_bridge_unsqueeze_op_npu
subgraph_bridge_argmax_op_npu subgraph_bridge_argmax_op_npu
subgraph_bridge_instance_norm_op_npu subgraph_bridge_instance_norm_op_npu
subgraph_bridge_dropout_op_npu subgraph_bridge_dropout_op_npu
subgraph_bridge_topk_op_npu
subgraph_bridge_layer_norm_op_npu subgraph_bridge_layer_norm_op_npu
subgraph_bridge_fill_constant_op_npu subgraph_bridge_fill_constant_op_npu
subgraph_bridge_fill_constant_batch_size_like_op_npu subgraph_bridge_fill_constant_batch_size_like_op_npu
......
...@@ -21,6 +21,7 @@ namespace lite { ...@@ -21,6 +21,7 @@ namespace lite {
namespace subgraph { namespace subgraph {
namespace npu { namespace npu {
template <typename ActType>
int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr); CHECK(ctx != nullptr);
CHECK(op != nullptr); CHECK(op != nullptr);
...@@ -30,6 +31,40 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -30,6 +31,40 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto scope = op->scope(); auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "..."; VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindTensor(x_name);
auto out_name = op_info->Output("Out").front();
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Act node
auto act_node = graph->template Add<ActType>(out_name);
auto act_op = act_node->template data<ActType>();
act_op->set_input_x(*x_node->data());
return SUCCESS;
}
template <>
int ActConverter<ge::op::Activation>(void* ctx,
OpLite* op,
KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes // Get input and output vars and op attributes
auto x_name = op_info->Input("X").front(); auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name); auto x = scope->FindMutableTensor(x_name);
...@@ -45,8 +80,8 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -45,8 +80,8 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} }
// Act node // Act node
auto act_node = graph->Add<ge::op::Activation>(out_name); auto act_node = graph->template Add<ge::op::Activation>(out_name);
auto act_op = act_node->data<ge::op::Activation>(); auto act_op = act_node->template data<ge::op::Activation>();
act_op->set_input_x(*x_node->data()); act_op->set_input_x(*x_node->data());
// TODO(hong19860320) set the coef value for act Ops, such as leaky_relu, // TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
// clipped_relu etc. // clipped_relu etc.
...@@ -74,27 +109,42 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -74,27 +109,42 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(sigmoid, REGISTER_SUBGRAPH_BRIDGE(
kNPU, sigmoid,
paddle::lite::subgraph::npu::ActConverter); kNPU,
REGISTER_SUBGRAPH_BRIDGE(relu, kNPU, paddle::lite::subgraph::npu::ActConverter); paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(tanh, kNPU, paddle::lite::subgraph::npu::ActConverter); REGISTER_SUBGRAPH_BRIDGE(
REGISTER_SUBGRAPH_BRIDGE(relu_clipped, relu, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
kNPU, REGISTER_SUBGRAPH_BRIDGE(
paddle::lite::subgraph::npu::ActConverter); tanh, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(relu6, REGISTER_SUBGRAPH_BRIDGE(
kNPU, relu_clipped,
paddle::lite::subgraph::npu::ActConverter); kNPU,
REGISTER_SUBGRAPH_BRIDGE(leaky_relu, paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
kNPU, REGISTER_SUBGRAPH_BRIDGE(
paddle::lite::subgraph::npu::ActConverter); relu6, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(abs, kNPU, paddle::lite::subgraph::npu::ActConverter); REGISTER_SUBGRAPH_BRIDGE(
REGISTER_SUBGRAPH_BRIDGE(softsign, leaky_relu,
kNPU, kNPU,
paddle::lite::subgraph::npu::ActConverter); paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(softplus, REGISTER_SUBGRAPH_BRIDGE(
kNPU, abs, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
paddle::lite::subgraph::npu::ActConverter); REGISTER_SUBGRAPH_BRIDGE(
REGISTER_SUBGRAPH_BRIDGE(hard_sigmoid, softsign,
kNPU, kNPU,
paddle::lite::subgraph::npu::ActConverter); paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(
softplus,
kNPU,
paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(
hard_sigmoid,
kNPU,
paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(
log, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Log>);
REGISTER_SUBGRAPH_BRIDGE(
square, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Square>);
REGISTER_SUBGRAPH_BRIDGE(
sqrt, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Sqrt>);
...@@ -21,6 +21,9 @@ USE_SUBGRAPH_BRIDGE(relu_clipped, kNPU); ...@@ -21,6 +21,9 @@ USE_SUBGRAPH_BRIDGE(relu_clipped, kNPU);
USE_SUBGRAPH_BRIDGE(leaky_relu, kNPU); USE_SUBGRAPH_BRIDGE(leaky_relu, kNPU);
USE_SUBGRAPH_BRIDGE(softsign, kNPU); USE_SUBGRAPH_BRIDGE(softsign, kNPU);
USE_SUBGRAPH_BRIDGE(hard_sigmoid, kNPU); USE_SUBGRAPH_BRIDGE(hard_sigmoid, kNPU);
USE_SUBGRAPH_BRIDGE(log, kNPU);
USE_SUBGRAPH_BRIDGE(sqrt, kNPU);
USE_SUBGRAPH_BRIDGE(square, kNPU);
USE_SUBGRAPH_BRIDGE(batch_norm, kNPU); USE_SUBGRAPH_BRIDGE(batch_norm, kNPU);
USE_SUBGRAPH_BRIDGE(less_than, kNPU); USE_SUBGRAPH_BRIDGE(less_than, kNPU);
...@@ -58,8 +61,7 @@ USE_SUBGRAPH_BRIDGE(scale, kNPU); ...@@ -58,8 +61,7 @@ USE_SUBGRAPH_BRIDGE(scale, kNPU);
USE_SUBGRAPH_BRIDGE(shuffle_channel, kNPU); USE_SUBGRAPH_BRIDGE(shuffle_channel, kNPU);
USE_SUBGRAPH_BRIDGE(softmax, kNPU); USE_SUBGRAPH_BRIDGE(softmax, kNPU);
USE_SUBGRAPH_BRIDGE(split, kNPU); USE_SUBGRAPH_BRIDGE(split, kNPU);
USE_SUBGRAPH_BRIDGE(sqrt, kNPU); // USE_SUBGRAPH_BRIDGE(top_k, kNPU);
USE_SUBGRAPH_BRIDGE(square, kNPU);
USE_SUBGRAPH_BRIDGE(transpose, kNPU); USE_SUBGRAPH_BRIDGE(transpose, kNPU);
USE_SUBGRAPH_BRIDGE(transpose2, kNPU); USE_SUBGRAPH_BRIDGE(transpose2, kNPU);
USE_SUBGRAPH_BRIDGE(unsqueeze, kNPU); USE_SUBGRAPH_BRIDGE(unsqueeze, kNPU);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cmath>
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/test_helper.h"
#include "lite/operators/activation_ops.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
template <typename dtype>
void sqrt_ref(const std::shared_ptr<operators::ActivationOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindTensor("x");
auto out = scope->FindMutableTensor("out_ref");
out->Resize(x->dims());
auto x_data = x->data<dtype>();
auto out_data = out->mutable_data<dtype>();
for (size_t i = 0; i < x->numel(); i++) {
out_data[i] = std::sqrtf(x_data[i]);
}
}
void test_sqrt(const std::vector<int64_t>& input_shape) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.NewTensor(x_var_name);
auto* out = scope.NewTensor(out_var_name);
auto* out_ref = scope.NewTensor(out_ref_var_name);
x->Resize(input_shape);
// initialize input&output data
FillTensor<float>(x, 0, 5);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("sqrt");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::ActivationOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
// execute reference implementation and save to output tensor
sqrt_ref<float>(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(NPUBridges, sqrt) {
test_sqrt({2});
test_sqrt({2, 3});
test_sqrt({1, 2, 3, 4});
test_sqrt({5, 6, 7, 8});
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(sqrt);
USE_NPU_BRIDGE(sqrt);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace npu {
int SquareConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Square node
auto square_node = graph->Add<ge::op::Square>(out_name);
auto square_op = square_node->data<ge::op::Square>();
square_op->set_input_x(*x_node->data());
return SUCCESS;
}
} // namespace npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(square,
kNPU,
paddle::lite::subgraph::npu::SquareConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/test_helper.h"
#include "lite/operators/activation_ops.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
template <typename dtype>
void square_ref(const std::shared_ptr<operators::ActivationOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindTensor("x");
auto out = scope->FindMutableTensor("out_ref");
out->Resize(x->dims());
auto x_data = x->data<dtype>();
auto out_data = out->mutable_data<dtype>();
for (size_t i = 0; i < x->numel(); i++) {
out_data[i] = x_data[i] * x_data[i];
}
}
void test_square(const std::vector<int64_t>& input_shape) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.NewTensor(x_var_name);
auto* out = scope.NewTensor(out_var_name);
auto* out_ref = scope.NewTensor(out_ref_var_name);
x->Resize(input_shape);
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("square");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::ActivationOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
// execute reference implementation and save to output tensor
square_ref<float>(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(NPUBridges, square) {
test_square({2});
test_square({2, 3});
test_square({1, 2, 3, 4});
test_square({5, 6, 7, 8});
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(square);
USE_NPU_BRIDGE(square);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -21,7 +21,7 @@ namespace lite { ...@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph { namespace subgraph {
namespace npu { namespace npu {
int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) { int TopkConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr); CHECK(ctx != nullptr);
CHECK(op != nullptr); CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx); auto graph = static_cast<Graph*>(ctx);
...@@ -32,10 +32,12 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -32,10 +32,12 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Get input and output vars and op attributes // Get input and output vars and op attributes
auto x_name = op_info->Input("X").front(); auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name); auto x = scope->FindTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front(); auto out_name = op_info->Output("Out").front();
int k = op_info->GetAttr<int>("k");
// X node // X node
std::shared_ptr<Node> x_node = nullptr; std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) { if (graph->Has(x_name)) {
...@@ -44,10 +46,16 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -44,10 +46,16 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) {
x_node = graph->Add(x_name, *x); x_node = graph->Add(x_name, *x);
} }
// Sqrt node // k node
auto sqrt_node = graph->Add<ge::op::Sqrt>(out_name); std::shared_ptr<Node> k_node = graph->Add<int>(out_name + "/k", k);
auto sqrt_op = sqrt_node->data<ge::op::Sqrt>();
sqrt_op->set_input_x(*x_node->data()); // topk node
auto topk_node = graph->Add<ge::op::TopK>(out_name);
auto topk_op = topk_node->data<ge::op::TopK>();
topk_op->set_input_x(*x_node->data());
topk_op->set_input_k(*k_node->data());
topk_op->set_attr_format(0);
return SUCCESS; return SUCCESS;
} }
...@@ -56,6 +64,6 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -56,6 +64,6 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(sqrt, REGISTER_SUBGRAPH_BRIDGE(top_k,
kNPU, kNPU,
paddle::lite::subgraph::npu::SqrtConverter); paddle::lite::subgraph::npu::TopkConverter);
...@@ -20,6 +20,8 @@ namespace operators { ...@@ -20,6 +20,8 @@ namespace operators {
bool TopkOp::CheckShape() const { bool TopkOp::CheckShape() const {
CHECK_OR_FALSE(param_.X); CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
CHECK_OR_FALSE(param_.Indices);
return true; return true;
} }
...@@ -28,26 +30,25 @@ bool TopkOp::InferShape() const { ...@@ -28,26 +30,25 @@ bool TopkOp::InferShape() const {
out_dims[out_dims.size() - 1] = param_.K; out_dims[out_dims.size() - 1] = param_.K;
auto out = param_.Out; auto out = param_.Out;
out->Resize(out_dims); out->Resize(out_dims);
auto out_lod = out->mutable_lod(); out->set_lod(param_.X->lod());
*out_lod = param_.X->lod();
auto ind = param_.Indices; auto indices = param_.Indices;
ind->Resize(out_dims); indices->Resize(out_dims);
auto ind_lod = out->mutable_lod(); indices->set_lod(param_.X->lod());
*ind_lod = param_.X->lod();
return true; return true;
} }
bool TopkOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { bool TopkOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
param_.X = scope->FindVar(x)->GetMutable<Tensor>(); param_.X = scope->FindTensor(x);
auto outputs0 = op_desc.Output("Out").front(); auto output0 = op_desc.Output("Out").front();
auto outputs1 = op_desc.Output("Indices").front(); auto output1 = op_desc.Output("Indices").front();
param_.Out = scope->FindVar(outputs0)->GetMutable<lite::Tensor>(); param_.Out = scope->FindMutableTensor(output0);
param_.Indices = scope->FindVar(outputs1)->GetMutable<lite::Tensor>(); param_.Indices = scope->FindMutableTensor(output1);
param_.K = op_desc.GetAttr<int>("k"); param_.K = op_desc.GetAttr<int>("k");
CHECK(param_.X);
CHECK_GE(param_.K, 1) << "topK param is not valid"; CHECK_GE(param_.K, 1) << "topK param is not valid";
return true; return true;
} }
......
...@@ -21,7 +21,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_ ...@@ -21,7 +21,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_
#lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_logical_xor_compute SRCS logical_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_logical_xor_compute SRCS logical_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
...@@ -271,25 +271,12 @@ TEST(Activation_relu, precision) { ...@@ -271,25 +271,12 @@ TEST(Activation_relu, precision) {
return; return;
#endif #endif
for (auto n : {1, 3}) { for (auto dims : std::vector<std::vector<int64_t>>{
for (auto c : {3, 6}) { {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto h : {9, 18}) { std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
for (auto w : {9, 18}) { place, "def", 0.01, 6., "all", 0., DDim(dims), "relu", RELU));
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester( arena::Arena arena(std::move(tester), place, abs_error);
place, arena.TestPrecision();
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"relu",
RELU));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
} }
} }
...@@ -306,26 +293,21 @@ TEST(Activation_leaky_relu, precision) { ...@@ -306,26 +293,21 @@ TEST(Activation_leaky_relu, precision) {
return; return;
#endif #endif
for (auto n : {1, 3}) { for (auto dims : std::vector<std::vector<int64_t>>{
for (auto c : {3, 6}) { {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto h : {9, 18}) { for (auto slope : {0.01, 0.1}) {
for (auto w : {9, 18}) { std::unique_ptr<arena::TestCase> tester(
for (auto slope : {0.01, 0.1}) { new ActivationComputeTester(place,
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester( "def",
place, slope,
"def", 6.,
slope, "all",
6., 0.,
"all", DDim(dims),
0., "leaky_relu",
DDim(std::vector<int64_t>({n, c, h, w})), LEAKY_RELU));
"leaky_relu", arena::Arena arena(std::move(tester), place, abs_error);
LEAKY_RELU)); arena.TestPrecision();
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
} }
} }
} }
...@@ -343,26 +325,21 @@ TEST(Activation_relu_clipped, precision) { ...@@ -343,26 +325,21 @@ TEST(Activation_relu_clipped, precision) {
return; return;
#endif #endif
for (auto n : {1, 3}) { for (auto dims : std::vector<std::vector<int64_t>>{
for (auto c : {3, 6}) { {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto h : {9, 18}) { for (auto coef : {0.5, 6.}) {
for (auto w : {9, 18}) { std::unique_ptr<arena::TestCase> tester(
for (auto coef : {0.5, 6.}) { new ActivationComputeTester(place,
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester( "def",
place, 0.01,
"def", coef,
0.01, "all",
coef, 0.,
"all", DDim(dims),
0., "relu_clipped",
DDim(std::vector<int64_t>({n, c, h, w})), RELU_CLIPPED));
"relu_clipped", arena::Arena arena(std::move(tester), place, abs_error);
RELU_CLIPPED)); arena.TestPrecision();
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
} }
} }
} }
...@@ -372,26 +349,12 @@ TEST(Activation_prelu, precision) { ...@@ -372,26 +349,12 @@ TEST(Activation_prelu, precision) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
Place place(TARGET(kARM)); Place place(TARGET(kARM));
for (auto n : {1, 3}) { for (auto dims : std::vector<std::vector<int64_t>>{{1, 3, 2, 4}}) {
for (auto c : {3, 6}) { for (auto mode : {"all", "channel", "element"}) {
for (auto h : {9, 18}) { std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
for (auto w : {9, 18}) { place, "def", 0.01, 6, mode, 0., DDim(dims), "prelu", PRELU));
for (auto mode : {"all", "channel", "element"}) { arena::Arena arena(std::move(tester), place, 2e-5);
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester( arena.TestPrecision();
place,
"def",
0.01,
6,
mode,
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"prelu",
PRELU));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
} }
} }
#endif #endif
...@@ -410,25 +373,12 @@ TEST(Activation_sigmoid, precision) { ...@@ -410,25 +373,12 @@ TEST(Activation_sigmoid, precision) {
return; return;
#endif #endif
for (auto n : {1, 3}) { for (auto dims : std::vector<std::vector<int64_t>>{
for (auto c : {3, 6}) { {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto h : {9, 18}) { std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
for (auto w : {9, 18}) { place, "def", 0.01, 6., "all", 0., DDim(dims), "sigmoid", SIGMOID));
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester( arena::Arena arena(std::move(tester), place, abs_error);
place, arena.TestPrecision();
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"sigmoid",
SIGMOID));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
} }
} }
...@@ -447,25 +397,12 @@ TEST(Activation_tanh, precision) { ...@@ -447,25 +397,12 @@ TEST(Activation_tanh, precision) {
return; return;
#endif #endif
for (auto n : {1, 3}) { for (auto dims : std::vector<std::vector<int64_t>>{
for (auto c : {3, 6}) { {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto h : {9, 18}) { std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
for (auto w : {9, 18}) { place, "def", 0.01, 6., "all", 0., DDim(dims), "tanh", TANH));
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester( arena::Arena arena(std::move(tester), place, abs_error);
place, arena.TestPrecision();
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"tanh",
TANH));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
} }
} }
...@@ -474,26 +411,13 @@ TEST(Activation_swish, precision) { ...@@ -474,26 +411,13 @@ TEST(Activation_swish, precision) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
Place place(TARGET(kARM)); Place place(TARGET(kARM));
for (auto n : {1, 3}) { for (auto dims : std::vector<std::vector<int64_t>>{
for (auto c : {3, 6}) { {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto h : {9, 18}) { for (auto coef : {0.01, 0.1}) {
for (auto w : {9, 18}) { std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
for (auto coef : {0.01, 0.1}) { place, "def", 0.01, 6, "all", coef, DDim(dims), "swish", SWISH));
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester( arena::Arena arena(std::move(tester), place, 2e-5);
place, arena.TestPrecision();
"def",
0.01,
6,
"all",
coef,
DDim(std::vector<int64_t>({n, c, h, w})),
"swish",
SWISH));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
} }
} }
#endif #endif
...@@ -504,26 +428,13 @@ TEST(Activation_relu6, precision) { ...@@ -504,26 +428,13 @@ TEST(Activation_relu6, precision) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
Place place(TARGET(kARM)); Place place(TARGET(kARM));
for (auto n : {1, 3}) { for (auto dims : std::vector<std::vector<int64_t>>{
for (auto c : {3, 6}) { {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto h : {9, 18}) { for (auto slope : {0.01, 0.1}) {
for (auto w : {9, 18}) { std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
for (auto slope : {0.01, 0.1}) { place, "def", 0.01, 6., "all", 0., DDim(dims), "relu6", RELU6));
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester( arena::Arena arena(std::move(tester), place, 2e-5);
place, arena.TestPrecision();
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"relu6",
RELU6));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
} }
} }
#endif #endif
...@@ -531,30 +442,24 @@ TEST(Activation_relu6, precision) { ...@@ -531,30 +442,24 @@ TEST(Activation_relu6, precision) {
TEST(Activation_log, precision) { TEST(Activation_log, precision) {
LOG(INFO) << "test log op"; LOG(INFO) << "test log op";
#ifdef LITE_WITH_ARM Place place;
Place place(TARGET(kARM)); float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#else
return;
#endif
for (auto n : {1, 3}) { for (auto dims : std::vector<std::vector<int64_t>>{
for (auto c : {3, 6}) { {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto h : {9, 18}) { std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
for (auto w : {9, 18}) { place, "def", 0.01, 6., "all", 0., DDim(dims), "log", LOG));
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester( arena::Arena arena(std::move(tester), place, abs_error);
place, arena.TestPrecision();
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"log",
LOG));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
} }
#endif
} }
TEST(Activation_exp, precision) { TEST(Activation_exp, precision) {
...@@ -562,25 +467,12 @@ TEST(Activation_exp, precision) { ...@@ -562,25 +467,12 @@ TEST(Activation_exp, precision) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
Place place(TARGET(kARM)); Place place(TARGET(kARM));
for (auto n : {1, 3}) { for (auto dims : std::vector<std::vector<int64_t>>{
for (auto c : {3, 6}) { {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto h : {9, 18}) { std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
for (auto w : {9, 18}) { place, "def", 0.01, 6., "all", 0., DDim(dims), "exp", EXP));
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester( arena::Arena arena(std::move(tester), place, 2e-5);
place, arena.TestPrecision();
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"exp",
EXP));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
} }
#endif #endif
} }
...@@ -589,26 +481,14 @@ TEST(Activation_floor, precision) { ...@@ -589,26 +481,14 @@ TEST(Activation_floor, precision) {
LOG(INFO) << "test floor op"; LOG(INFO) << "test floor op";
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
Place place(TARGET(kARM)); Place place(TARGET(kARM));
for (auto n : {1, 3}) { for (auto dims : std::vector<std::vector<int64_t>>{
for (auto c : {3, 6}) { {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto h : {9, 18}) { std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
for (auto w : {9, 18}) { place, "def", 0.01, 6., "all", 0., DDim(dims), "floor", FLOOR));
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester( arena::Arena arena(std::move(tester), place, 2e-5);
place, arena.TestPrecision();
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"floor",
FLOOR));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
} }
#endif #endif
} }
...@@ -616,54 +496,36 @@ TEST(Activation_rsqrt, precision) { ...@@ -616,54 +496,36 @@ TEST(Activation_rsqrt, precision) {
LOG(INFO) << "test rsqrt op"; LOG(INFO) << "test rsqrt op";
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
Place place(TARGET(kARM)); Place place(TARGET(kARM));
for (auto n : {2}) { for (auto dims : std::vector<std::vector<int64_t>>{
for (auto c : {2}) { {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto h : {2}) { std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
for (auto w : {2}) { place, "def", 0.01, 6., "all", 0., DDim(dims), "rsqrt", RSQRT));
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester( arena::Arena arena(std::move(tester), place, 2e-5);
place, arena.TestPrecision();
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"rsqrt",
RSQRT));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
} }
#endif #endif
} }
TEST(Activation_square, precision) { TEST(Activation_square, precision) {
LOG(INFO) << "test square op"; LOG(INFO) << "test square op";
#ifdef LITE_WITH_ARM Place place;
Place place(TARGET(kARM)); float abs_error = 2e-5;
for (auto n : {2}) { #if defined(LITE_WITH_NPU)
for (auto c : {2}) { place = TARGET(kNPU);
for (auto h : {2}) { abs_error = 1e-2; // Using fp16 in NPU
for (auto w : {2}) { #elif defined(LITE_WITH_ARM)
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester( place = TARGET(kARM);
place, #else
"def", return;
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"square",
SQUARE));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
#endif #endif
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "square", SQUARE));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
} }
TEST(Activation_gelu, precision) { TEST(Activation_gelu, precision) {
......
...@@ -16,102 +16,109 @@ ...@@ -16,102 +16,109 @@
#include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h" #include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
bool comp_func(std::pair<float, int> a, std::pair<float, int> b) {
template <typename T1, typename T2>
bool comp_func(std::pair<T1, T2> a, std::pair<T1, T2> b) {
return (a.first > b.first); return (a.first > b.first);
} }
template <typename T1, typename T2>
class TopkComputeTester : public arena::TestCase { class TopkComputeTester : public arena::TestCase {
protected: protected:
// common attributes for this op. // common attributes for this op.
std::string input_ = "x"; std::string x_ = "x";
std::string out_val_ = "out_val"; std::string out_ = "out";
std::string out_ind_ = "out_ind"; std::string indices_ = "indices";
int K_ = 1; DDim x_dims_{{3, 5, 4, 4}};
DDim dims_{{3, 5, 4, 4}}; int k_ = 1;
public: public:
TopkComputeTester(const Place& place, TopkComputeTester(const Place& place,
const std::string& alias, const std::string& alias,
int K, DDim x_dims,
DDim dims) int k = 1)
: TestCase(place, alias), K_(K), dims_(dims) {} : TestCase(place, alias), x_dims_(x_dims), k_(k) {}
void RunBaseline(Scope* scope) override { void RunBaseline(Scope* scope) override {
auto* out_val = scope->NewTensor(out_val_); auto* out_val = scope->NewTensor(out_);
auto* out_ind = scope->NewTensor(out_ind_); auto* out_ind = scope->NewTensor(indices_);
CHECK(out_val); DDim out_dims = x_dims_;
CHECK(out_ind); out_dims[out_dims.size() - 1] = k_;
DDim out_dims = dims_;
out_dims[out_dims.size() - 1] = K_;
out_val->Resize(out_dims); out_val->Resize(out_dims);
out_ind->Resize(out_dims); out_ind->Resize(out_dims);
auto* out_val_data = out_val->mutable_data<float>(); auto* out_val_data = out_val->mutable_data<T1>();
auto* out_ind_data = out_ind->mutable_data<int>(); auto* out_ind_data = out_ind->mutable_data<T2>();
auto* x = scope->FindTensor(input_); auto* x = scope->FindTensor(x_);
const auto* x_data = x->data<float>(); const auto* x_data = x->data<T1>();
int m = out_dims.production() / K_; int m = out_dims.production() / k_;
int n = dims_[dims_.size() - 1]; int n = x_dims_[x_dims_.size() - 1];
for (int i = 0; i < m; i++) { for (int i = 0; i < m; i++) {
const float* in_tmp = x_data + i * n; const T1* in_tmp = x_data + i * n;
float* out_val_tmp = out_val_data + i * K_; T1* out_val_tmp = out_val_data + i * k_;
int* out_ind_tmp = out_ind_data + i * K_; T2* out_ind_tmp = out_ind_data + i * k_;
std::vector<std::pair<float, int>> vec; std::vector<std::pair<T1, T2>> vec;
for (int j = 0; j < n; j++) { for (int j = 0; j < n; j++) {
vec.push_back(std::make_pair(in_tmp[j], j)); vec.push_back(std::make_pair(in_tmp[j], static_cast<T2>(j)));
} }
std::partial_sort(vec.begin(), vec.begin() + K_, vec.end(), comp_func); std::partial_sort(
for (int q = 0; q < K_; q++) { vec.begin(), vec.begin() + k_, vec.end(), comp_func<T1, T2>);
for (int q = 0; q < k_; q++) {
out_val_tmp[q] = vec[q].first; out_val_tmp[q] = vec[q].first;
out_ind_tmp[q] = vec[q].second; out_ind_tmp[q] = vec[q].second;
LOG(INFO) << "out:" << i << " " << q << " " << out_val_tmp[q] << " "
<< out_ind_tmp[q];
} }
} }
} }
void PrepareOpDesc(cpp::OpDesc* op_desc) { void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("topk"); op_desc->SetType("top_k");
op_desc->SetInput("X", {input_}); op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {out_val_, out_ind_}); op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("K", K_); op_desc->SetOutput("Indices", {indices_});
op_desc->SetAttr("k", k_);
} }
void PrepareData() override { void PrepareData() override {
std::vector<float> data(dims_.production()); std::vector<T1> dx(x_dims_.production());
fill_data_rand<T1>(dx.data(), -1, 1, x_dims_.production());
for (int i = 0; i < dims_.production(); i++) { SetCommonTensor(x_, x_dims_, dx.data());
data[i] = std::rand() * 1.0f / RAND_MAX;
}
SetCommonTensor(input_, dims_, data.data());
} }
}; };
void test_topk(Place place) { template <typename T1, typename T2>
DDimLite dims_0{{3, 5}}; void test_topk(Place place, float abs_error) {
DDimLite dims_1{{8}}; for (auto x_shape : std::vector<std::vector<int64_t>>{
for (int K : {1, 2}) { {2, 3, 4, 5}, {3, 4, 5}, {4, 5}, {5}}) {
for (auto dims : {dims_0, dims_1}) { for (int k : {2, 5}) {
std::unique_ptr<arena::TestCase> tester( std::unique_ptr<arena::TestCase> tester(
new TopkComputeTester(place, "def", K, dims)); new TopkComputeTester<T1, T2>(place, "def", DDim(x_shape), k));
arena::Arena arena(std::move(tester), place, 2e-5); arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision(); arena.TestPrecision();
} }
} }
} }
TEST(Topk, precision) { TEST(Topk, precision) {
// #ifdef LITE_WITH_X86 Place place;
// Place place(TARGET(kX86)); float abs_error = 2e-5;
// #endif #if defined(LITE_WITH_NPU)
#ifdef LITE_WITH_ARM place = TARGET(kNPU);
Place place(TARGET(kARM)); abs_error = 1e-3; // Using fp16 in NPU
test_topk(place); #elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#else
return;
#endif
#if defined(LITE_WITH_NPU)
test_topk<float, int>(place, abs_error);
#else
test_topk<float, int64_t>(place, abs_error);
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册