未验证 提交 b49d95b1 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] add topk, log op bridge (#3216)

上级 327c62e0
......@@ -44,5 +44,5 @@ REGISTER_LITE_KERNEL(
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Indices",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/op_registry.h"
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def);
USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def);
......@@ -36,13 +36,12 @@ lite_cc_library(subgraph_bridge_split_op_npu SRCS split_op.cc DEPS ${npu_subgrap
lite_cc_library(subgraph_bridge_concat_op_npu SRCS concat_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_shuffle_channel_op_npu SRCS shuffle_channel_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_pad2d_op_npu SRCS pad2d_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_square_op_npu SRCS square_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_sqrt_op_npu SRCS sqrt_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_reduce_mean_op_npu SRCS reduce_mean_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_unsqueeze_op_npu SRCS unsqueeze_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_argmax_op_npu SRCS argmax_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_instance_norm_op_npu SRCS instance_norm_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_dropout_op_npu SRCS dropout_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_topk_op_npu SRCS topk_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_layer_norm_op_npu SRCS layer_norm_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_fill_constant_op_npu SRCS fill_constant_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_fill_constant_batch_size_like_op_npu SRCS fill_constant_batch_size_like_op.cc DEPS ${npu_subgraph_bridge_deps})
......@@ -72,13 +71,12 @@ set(npu_subgraph_bridges
subgraph_bridge_concat_op_npu
subgraph_bridge_shuffle_channel_op_npu
subgraph_bridge_pad2d_op_npu
subgraph_bridge_square_op_npu
subgraph_bridge_sqrt_op_npu
subgraph_bridge_reduce_mean_op_npu
subgraph_bridge_unsqueeze_op_npu
subgraph_bridge_argmax_op_npu
subgraph_bridge_instance_norm_op_npu
subgraph_bridge_dropout_op_npu
subgraph_bridge_topk_op_npu
subgraph_bridge_layer_norm_op_npu
subgraph_bridge_fill_constant_op_npu
subgraph_bridge_fill_constant_batch_size_like_op_npu
......
......@@ -21,6 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
template <typename ActType>
int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
......@@ -30,6 +31,40 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindTensor(x_name);
auto out_name = op_info->Output("Out").front();
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Act node
auto act_node = graph->template Add<ActType>(out_name);
auto act_op = act_node->template data<ActType>();
act_op->set_input_x(*x_node->data());
return SUCCESS;
}
template <>
int ActConverter<ge::op::Activation>(void* ctx,
OpLite* op,
KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
......@@ -45,8 +80,8 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
}
// Act node
auto act_node = graph->Add<ge::op::Activation>(out_name);
auto act_op = act_node->data<ge::op::Activation>();
auto act_node = graph->template Add<ge::op::Activation>(out_name);
auto act_op = act_node->template data<ge::op::Activation>();
act_op->set_input_x(*x_node->data());
// TODO(hong19860320) set the coef value for act Ops, such as leaky_relu,
// clipped_relu etc.
......@@ -74,27 +109,42 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(sigmoid,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(relu, kNPU, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(tanh, kNPU, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(relu_clipped,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(relu6,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(leaky_relu,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(abs, kNPU, paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(softsign,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(softplus,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(hard_sigmoid,
kNPU,
paddle::lite::subgraph::npu::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(
sigmoid,
kNPU,
paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(
relu, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(
tanh, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(
relu_clipped,
kNPU,
paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(
relu6, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(
leaky_relu,
kNPU,
paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(
abs, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(
softsign,
kNPU,
paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(
softplus,
kNPU,
paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(
hard_sigmoid,
kNPU,
paddle::lite::subgraph::npu::ActConverter<ge::op::Activation>);
REGISTER_SUBGRAPH_BRIDGE(
log, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Log>);
REGISTER_SUBGRAPH_BRIDGE(
square, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Square>);
REGISTER_SUBGRAPH_BRIDGE(
sqrt, kNPU, paddle::lite::subgraph::npu::ActConverter<ge::op::Sqrt>);
......@@ -21,6 +21,9 @@ USE_SUBGRAPH_BRIDGE(relu_clipped, kNPU);
USE_SUBGRAPH_BRIDGE(leaky_relu, kNPU);
USE_SUBGRAPH_BRIDGE(softsign, kNPU);
USE_SUBGRAPH_BRIDGE(hard_sigmoid, kNPU);
USE_SUBGRAPH_BRIDGE(log, kNPU);
USE_SUBGRAPH_BRIDGE(sqrt, kNPU);
USE_SUBGRAPH_BRIDGE(square, kNPU);
USE_SUBGRAPH_BRIDGE(batch_norm, kNPU);
USE_SUBGRAPH_BRIDGE(less_than, kNPU);
......@@ -58,8 +61,7 @@ USE_SUBGRAPH_BRIDGE(scale, kNPU);
USE_SUBGRAPH_BRIDGE(shuffle_channel, kNPU);
USE_SUBGRAPH_BRIDGE(softmax, kNPU);
USE_SUBGRAPH_BRIDGE(split, kNPU);
USE_SUBGRAPH_BRIDGE(sqrt, kNPU);
USE_SUBGRAPH_BRIDGE(square, kNPU);
// USE_SUBGRAPH_BRIDGE(top_k, kNPU);
USE_SUBGRAPH_BRIDGE(transpose, kNPU);
USE_SUBGRAPH_BRIDGE(transpose2, kNPU);
USE_SUBGRAPH_BRIDGE(unsqueeze, kNPU);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cmath>
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/test_helper.h"
#include "lite/operators/activation_ops.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
template <typename dtype>
void sqrt_ref(const std::shared_ptr<operators::ActivationOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindTensor("x");
auto out = scope->FindMutableTensor("out_ref");
out->Resize(x->dims());
auto x_data = x->data<dtype>();
auto out_data = out->mutable_data<dtype>();
for (size_t i = 0; i < x->numel(); i++) {
out_data[i] = std::sqrtf(x_data[i]);
}
}
void test_sqrt(const std::vector<int64_t>& input_shape) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.NewTensor(x_var_name);
auto* out = scope.NewTensor(out_var_name);
auto* out_ref = scope.NewTensor(out_ref_var_name);
x->Resize(input_shape);
// initialize input&output data
FillTensor<float>(x, 0, 5);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("sqrt");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::ActivationOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
// execute reference implementation and save to output tensor
sqrt_ref<float>(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(NPUBridges, sqrt) {
test_sqrt({2});
test_sqrt({2, 3});
test_sqrt({1, 2, 3, 4});
test_sqrt({5, 6, 7, 8});
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(sqrt);
USE_NPU_BRIDGE(sqrt);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace npu {
int SquareConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto out_name = op_info->Output("Out").front();
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Square node
auto square_node = graph->Add<ge::op::Square>(out_name);
auto square_op = square_node->data<ge::op::Square>();
square_op->set_input_x(*x_node->data());
return SUCCESS;
}
} // namespace npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(square,
kNPU,
paddle::lite::subgraph::npu::SquareConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/test_helper.h"
#include "lite/operators/activation_ops.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
template <typename dtype>
void square_ref(const std::shared_ptr<operators::ActivationOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindTensor("x");
auto out = scope->FindMutableTensor("out_ref");
out->Resize(x->dims());
auto x_data = x->data<dtype>();
auto out_data = out->mutable_data<dtype>();
for (size_t i = 0; i < x->numel(); i++) {
out_data[i] = x_data[i] * x_data[i];
}
}
void test_square(const std::vector<int64_t>& input_shape) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.NewTensor(x_var_name);
auto* out = scope.NewTensor(out_var_name);
auto* out_ref = scope.NewTensor(out_ref_var_name);
x->Resize(input_shape);
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("square");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::ActivationOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
// execute reference implementation and save to output tensor
square_ref<float>(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(NPUBridges, square) {
test_square({2});
test_square({2, 3});
test_square({1, 2, 3, 4});
test_square({5, 6, 7, 8});
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(square);
USE_NPU_BRIDGE(square);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -21,7 +21,7 @@ namespace lite {
namespace subgraph {
namespace npu {
int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) {
int TopkConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......@@ -32,10 +32,12 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto x = scope->FindTensor(x_name);
auto out_name = op_info->Output("Out").front();
int k = op_info->GetAttr<int>("k");
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
......@@ -44,10 +46,16 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) {
x_node = graph->Add(x_name, *x);
}
// Sqrt node
auto sqrt_node = graph->Add<ge::op::Sqrt>(out_name);
auto sqrt_op = sqrt_node->data<ge::op::Sqrt>();
sqrt_op->set_input_x(*x_node->data());
// k node
std::shared_ptr<Node> k_node = graph->Add<int>(out_name + "/k", k);
// topk node
auto topk_node = graph->Add<ge::op::TopK>(out_name);
auto topk_op = topk_node->data<ge::op::TopK>();
topk_op->set_input_x(*x_node->data());
topk_op->set_input_k(*k_node->data());
topk_op->set_attr_format(0);
return SUCCESS;
}
......@@ -56,6 +64,6 @@ int SqrtConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(sqrt,
REGISTER_SUBGRAPH_BRIDGE(top_k,
kNPU,
paddle::lite::subgraph::npu::SqrtConverter);
paddle::lite::subgraph::npu::TopkConverter);
......@@ -20,6 +20,8 @@ namespace operators {
bool TopkOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
CHECK_OR_FALSE(param_.Indices);
return true;
}
......@@ -28,26 +30,25 @@ bool TopkOp::InferShape() const {
out_dims[out_dims.size() - 1] = param_.K;
auto out = param_.Out;
out->Resize(out_dims);
auto out_lod = out->mutable_lod();
*out_lod = param_.X->lod();
auto ind = param_.Indices;
ind->Resize(out_dims);
auto ind_lod = out->mutable_lod();
*ind_lod = param_.X->lod();
out->set_lod(param_.X->lod());
auto indices = param_.Indices;
indices->Resize(out_dims);
indices->set_lod(param_.X->lod());
return true;
}
bool TopkOp::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto x = op_desc.Input("X").front();
param_.X = scope->FindVar(x)->GetMutable<Tensor>();
param_.X = scope->FindTensor(x);
auto outputs0 = op_desc.Output("Out").front();
auto outputs1 = op_desc.Output("Indices").front();
param_.Out = scope->FindVar(outputs0)->GetMutable<lite::Tensor>();
param_.Indices = scope->FindVar(outputs1)->GetMutable<lite::Tensor>();
auto output0 = op_desc.Output("Out").front();
auto output1 = op_desc.Output("Indices").front();
param_.Out = scope->FindMutableTensor(output0);
param_.Indices = scope->FindMutableTensor(output1);
param_.K = op_desc.GetAttr<int>("k");
CHECK(param_.X);
CHECK_GE(param_.K, 1) << "topK param is not valid";
return true;
}
......
......@@ -21,7 +21,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_
#lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_logical_xor_compute SRCS logical_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
......@@ -271,25 +271,12 @@ TEST(Activation_relu, precision) {
return;
#endif
for (auto n : {1, 3}) {
for (auto c : {3, 6}) {
for (auto h : {9, 18}) {
for (auto w : {9, 18}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"relu",
RELU));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "relu", RELU));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
......@@ -306,26 +293,21 @@ TEST(Activation_leaky_relu, precision) {
return;
#endif
for (auto n : {1, 3}) {
for (auto c : {3, 6}) {
for (auto h : {9, 18}) {
for (auto w : {9, 18}) {
for (auto slope : {0.01, 0.1}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
slope,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"leaky_relu",
LEAKY_RELU));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto slope : {0.01, 0.1}) {
std::unique_ptr<arena::TestCase> tester(
new ActivationComputeTester(place,
"def",
slope,
6.,
"all",
0.,
DDim(dims),
"leaky_relu",
LEAKY_RELU));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
......@@ -343,26 +325,21 @@ TEST(Activation_relu_clipped, precision) {
return;
#endif
for (auto n : {1, 3}) {
for (auto c : {3, 6}) {
for (auto h : {9, 18}) {
for (auto w : {9, 18}) {
for (auto coef : {0.5, 6.}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
coef,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"relu_clipped",
RELU_CLIPPED));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto coef : {0.5, 6.}) {
std::unique_ptr<arena::TestCase> tester(
new ActivationComputeTester(place,
"def",
0.01,
coef,
"all",
0.,
DDim(dims),
"relu_clipped",
RELU_CLIPPED));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
......@@ -372,26 +349,12 @@ TEST(Activation_prelu, precision) {
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (auto n : {1, 3}) {
for (auto c : {3, 6}) {
for (auto h : {9, 18}) {
for (auto w : {9, 18}) {
for (auto mode : {"all", "channel", "element"}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6,
mode,
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"prelu",
PRELU));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
for (auto dims : std::vector<std::vector<int64_t>>{{1, 3, 2, 4}}) {
for (auto mode : {"all", "channel", "element"}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6, mode, 0., DDim(dims), "prelu", PRELU));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
#endif
......@@ -410,25 +373,12 @@ TEST(Activation_sigmoid, precision) {
return;
#endif
for (auto n : {1, 3}) {
for (auto c : {3, 6}) {
for (auto h : {9, 18}) {
for (auto w : {9, 18}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"sigmoid",
SIGMOID));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "sigmoid", SIGMOID));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
......@@ -447,25 +397,12 @@ TEST(Activation_tanh, precision) {
return;
#endif
for (auto n : {1, 3}) {
for (auto c : {3, 6}) {
for (auto h : {9, 18}) {
for (auto w : {9, 18}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"tanh",
TANH));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "tanh", TANH));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
......@@ -474,26 +411,13 @@ TEST(Activation_swish, precision) {
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (auto n : {1, 3}) {
for (auto c : {3, 6}) {
for (auto h : {9, 18}) {
for (auto w : {9, 18}) {
for (auto coef : {0.01, 0.1}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6,
"all",
coef,
DDim(std::vector<int64_t>({n, c, h, w})),
"swish",
SWISH));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto coef : {0.01, 0.1}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6, "all", coef, DDim(dims), "swish", SWISH));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
#endif
......@@ -504,26 +428,13 @@ TEST(Activation_relu6, precision) {
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (auto n : {1, 3}) {
for (auto c : {3, 6}) {
for (auto h : {9, 18}) {
for (auto w : {9, 18}) {
for (auto slope : {0.01, 0.1}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"relu6",
RELU6));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
for (auto slope : {0.01, 0.1}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "relu6", RELU6));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
#endif
......@@ -531,30 +442,24 @@ TEST(Activation_relu6, precision) {
TEST(Activation_log, precision) {
LOG(INFO) << "test log op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#else
return;
#endif
for (auto n : {1, 3}) {
for (auto c : {3, 6}) {
for (auto h : {9, 18}) {
for (auto w : {9, 18}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"log",
LOG));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "log", LOG));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
#endif
}
TEST(Activation_exp, precision) {
......@@ -562,25 +467,12 @@ TEST(Activation_exp, precision) {
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (auto n : {1, 3}) {
for (auto c : {3, 6}) {
for (auto h : {9, 18}) {
for (auto w : {9, 18}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"exp",
EXP));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "exp", EXP));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
#endif
}
......@@ -589,26 +481,14 @@ TEST(Activation_floor, precision) {
LOG(INFO) << "test floor op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (auto n : {1, 3}) {
for (auto c : {3, 6}) {
for (auto h : {9, 18}) {
for (auto w : {9, 18}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"floor",
FLOOR));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "floor", FLOOR));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
#endif
}
......@@ -616,54 +496,36 @@ TEST(Activation_rsqrt, precision) {
LOG(INFO) << "test rsqrt op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (auto n : {2}) {
for (auto c : {2}) {
for (auto h : {2}) {
for (auto w : {2}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"rsqrt",
RSQRT));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "rsqrt", RSQRT));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
#endif
}
TEST(Activation_square, precision) {
LOG(INFO) << "test square op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (auto n : {2}) {
for (auto c : {2}) {
for (auto h : {2}) {
for (auto w : {2}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"square",
SQUARE));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#else
return;
#endif
for (auto dims : std::vector<std::vector<int64_t>>{
{1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place, "def", 0.01, 6., "all", 0., DDim(dims), "square", SQUARE));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
TEST(Activation_gelu, precision) {
......
......@@ -16,102 +16,109 @@
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
bool comp_func(std::pair<float, int> a, std::pair<float, int> b) {
template <typename T1, typename T2>
bool comp_func(std::pair<T1, T2> a, std::pair<T1, T2> b) {
return (a.first > b.first);
}
template <typename T1, typename T2>
class TopkComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "x";
std::string out_val_ = "out_val";
std::string out_ind_ = "out_ind";
int K_ = 1;
DDim dims_{{3, 5, 4, 4}};
std::string x_ = "x";
std::string out_ = "out";
std::string indices_ = "indices";
DDim x_dims_{{3, 5, 4, 4}};
int k_ = 1;
public:
TopkComputeTester(const Place& place,
const std::string& alias,
int K,
DDim dims)
: TestCase(place, alias), K_(K), dims_(dims) {}
DDim x_dims,
int k = 1)
: TestCase(place, alias), x_dims_(x_dims), k_(k) {}
void RunBaseline(Scope* scope) override {
auto* out_val = scope->NewTensor(out_val_);
auto* out_ind = scope->NewTensor(out_ind_);
CHECK(out_val);
CHECK(out_ind);
DDim out_dims = dims_;
out_dims[out_dims.size() - 1] = K_;
auto* out_val = scope->NewTensor(out_);
auto* out_ind = scope->NewTensor(indices_);
DDim out_dims = x_dims_;
out_dims[out_dims.size() - 1] = k_;
out_val->Resize(out_dims);
out_ind->Resize(out_dims);
auto* out_val_data = out_val->mutable_data<float>();
auto* out_ind_data = out_ind->mutable_data<int>();
auto* out_val_data = out_val->mutable_data<T1>();
auto* out_ind_data = out_ind->mutable_data<T2>();
auto* x = scope->FindTensor(input_);
const auto* x_data = x->data<float>();
int m = out_dims.production() / K_;
int n = dims_[dims_.size() - 1];
auto* x = scope->FindTensor(x_);
const auto* x_data = x->data<T1>();
int m = out_dims.production() / k_;
int n = x_dims_[x_dims_.size() - 1];
for (int i = 0; i < m; i++) {
const float* in_tmp = x_data + i * n;
float* out_val_tmp = out_val_data + i * K_;
int* out_ind_tmp = out_ind_data + i * K_;
std::vector<std::pair<float, int>> vec;
const T1* in_tmp = x_data + i * n;
T1* out_val_tmp = out_val_data + i * k_;
T2* out_ind_tmp = out_ind_data + i * k_;
std::vector<std::pair<T1, T2>> vec;
for (int j = 0; j < n; j++) {
vec.push_back(std::make_pair(in_tmp[j], j));
vec.push_back(std::make_pair(in_tmp[j], static_cast<T2>(j)));
}
std::partial_sort(vec.begin(), vec.begin() + K_, vec.end(), comp_func);
for (int q = 0; q < K_; q++) {
std::partial_sort(
vec.begin(), vec.begin() + k_, vec.end(), comp_func<T1, T2>);
for (int q = 0; q < k_; q++) {
out_val_tmp[q] = vec[q].first;
out_ind_tmp[q] = vec[q].second;
LOG(INFO) << "out:" << i << " " << q << " " << out_val_tmp[q] << " "
<< out_ind_tmp[q];
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("topk");
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {out_val_, out_ind_});
op_desc->SetAttr("K", K_);
op_desc->SetType("top_k");
op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {out_});
op_desc->SetOutput("Indices", {indices_});
op_desc->SetAttr("k", k_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = std::rand() * 1.0f / RAND_MAX;
}
SetCommonTensor(input_, dims_, data.data());
std::vector<T1> dx(x_dims_.production());
fill_data_rand<T1>(dx.data(), -1, 1, x_dims_.production());
SetCommonTensor(x_, x_dims_, dx.data());
}
};
void test_topk(Place place) {
DDimLite dims_0{{3, 5}};
DDimLite dims_1{{8}};
for (int K : {1, 2}) {
for (auto dims : {dims_0, dims_1}) {
template <typename T1, typename T2>
void test_topk(Place place, float abs_error) {
for (auto x_shape : std::vector<std::vector<int64_t>>{
{2, 3, 4, 5}, {3, 4, 5}, {4, 5}, {5}}) {
for (int k : {2, 5}) {
std::unique_ptr<arena::TestCase> tester(
new TopkComputeTester(place, "def", K, dims));
arena::Arena arena(std::move(tester), place, 2e-5);
new TopkComputeTester<T1, T2>(place, "def", DDim(x_shape), k));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
TEST(Topk, precision) {
// #ifdef LITE_WITH_X86
// Place place(TARGET(kX86));
// #endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_topk(place);
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-3; // Using fp16 in NPU
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#else
return;
#endif
#if defined(LITE_WITH_NPU)
test_topk<float, int>(place, abs_error);
#else
test_topk<float, int64_t>(place, abs_error);
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册