未验证 提交 b086e835 编写于 作者: Z zhupengyang 提交者: GitHub

enhance kernels of transformer decoder (#3110)

* enhance gather, lookup_table arm kernel uts

* enhance beam_search, beam_search_decoder, increment rigster
上级 e2a02f63
...@@ -123,5 +123,4 @@ if(LITE_BUILD_EXTRA) ...@@ -123,5 +123,4 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_decode_bboxes_compute_arm SRCS decode_bboxes_compute_test.cc DEPS decode_bboxes_compute_arm) lite_cc_test(test_decode_bboxes_compute_arm SRCS decode_bboxes_compute_test.cc DEPS decode_bboxes_compute_arm)
lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm) lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm)
lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm) lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm)
lite_cc_test(test_lookup_table_compute_arm SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_arm)
endif() endif()
...@@ -20,8 +20,6 @@ namespace lite { ...@@ -20,8 +20,6 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
void BeamSearchCompute::PrepareForRun() {}
void BeamSearchCompute::Run() { void BeamSearchCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::BeamSearchParam>(); auto& param = this->Param<operators::BeamSearchParam>();
...@@ -50,11 +48,17 @@ REGISTER_LITE_KERNEL(beam_search, ...@@ -50,11 +48,17 @@ REGISTER_LITE_KERNEL(beam_search,
kNCHW, kNCHW,
paddle::lite::kernels::arm::BeamSearchCompute, paddle::lite::kernels::arm::BeamSearchCompute,
def) def)
.BindInput("pre_ids", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("pre_ids",
.BindInput("pre_scores", {LiteType::GetTensorTy(TARGET(kARM))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("ids", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("pre_scores",
.BindInput("scores", {LiteType::GetTensorTy(TARGET(kARM))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("selected_ids", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("selected_scores", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("scores",
.BindOutput("parent_idx", {LiteType::GetTensorTy(TARGET(kARM))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("selected_ids",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("selected_scores",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("parent_idx",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize(); .Finalize();
...@@ -25,10 +25,6 @@ namespace arm { ...@@ -25,10 +25,6 @@ namespace arm {
class BeamSearchCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { class BeamSearchCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public: public:
using param_t = operators::BeamSearchParam;
void PrepareForRun() override;
void Run() override; void Run() override;
~BeamSearchCompute() {} ~BeamSearchCompute() {}
......
...@@ -293,8 +293,12 @@ REGISTER_LITE_KERNEL(beam_search_decode, ...@@ -293,8 +293,12 @@ REGISTER_LITE_KERNEL(beam_search_decode,
kNCHW, kNCHW,
paddle::lite::kernels::arm::BeamSearchDecodeCompute, paddle::lite::kernels::arm::BeamSearchDecodeCompute,
def) def)
.BindInput("Ids", {LiteType::GetTensorListTy(TARGET(kARM))}) .BindInput("Ids",
.BindInput("Scores", {LiteType::GetTensorListTy(TARGET(kARM))}) {LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("SentenceIds", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Scores",
.BindOutput("SentenceScores", {LiteType::GetTensorTy(TARGET(kARM))}) {LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("SentenceIds",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("SentenceScores",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize(); .Finalize();
...@@ -20,8 +20,6 @@ namespace lite { ...@@ -20,8 +20,6 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
void GatherCompute::PrepareForRun() {}
void GatherCompute::Run() { void GatherCompute::Run() {
auto& param = this->Param<operators::GatherParam>(); auto& param = this->Param<operators::GatherParam>();
...@@ -49,7 +47,7 @@ void GatherCompute::Run() { ...@@ -49,7 +47,7 @@ void GatherCompute::Run() {
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
gather, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::GatherCompute, def) gather, kARM, kAny, kNCHW, paddle::lite::kernels::arm::GatherCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Index", .BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
......
...@@ -22,12 +22,9 @@ namespace paddle { ...@@ -22,12 +22,9 @@ namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::GatherParam;
void PrepareForRun() override;
class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public:
void Run() override; void Run() override;
~GatherCompute() {} ~GatherCompute() {}
......
...@@ -20,8 +20,6 @@ namespace lite { ...@@ -20,8 +20,6 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
void IncrementCompute::PrepareForRun() {}
void IncrementCompute::Run() { void IncrementCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::IncrementParam>(); auto& param = this->Param<operators::IncrementParam>();
...@@ -52,10 +50,10 @@ void IncrementCompute::Run() { ...@@ -52,10 +50,10 @@ void IncrementCompute::Run() {
REGISTER_LITE_KERNEL(increment, REGISTER_LITE_KERNEL(increment,
kARM, kARM,
kFloat, kAny,
kNCHW, kNCHW,
paddle::lite::kernels::arm::IncrementCompute, paddle::lite::kernels::arm::IncrementCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize(); .Finalize();
...@@ -23,12 +23,8 @@ namespace lite { ...@@ -23,12 +23,8 @@ namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
class IncrementCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { class IncrementCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public: public:
using param_t = operators::IncrementParam;
void PrepareForRun() override;
void Run() override; void Run() override;
~IncrementCompute() {} ~IncrementCompute() {}
......
...@@ -28,10 +28,8 @@ namespace arm { ...@@ -28,10 +28,8 @@ namespace arm {
void LookupTableCompute::Run() { void LookupTableCompute::Run() {
auto& param = this->Param<param_t>(); auto& param = this->Param<param_t>();
// inputs
auto w = param.W; auto w = param.W;
auto ids = param.Ids; auto ids = param.Ids;
// outputs
auto out = param.Out; auto out = param.Out;
auto table_dim = w->dims(); auto table_dim = w->dims();
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/lookup_table_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void lookup_table_compute_ref(const operators::LookupTableParam &param) {
auto *ids_t = param.Ids;
auto *output_t = param.Out;
int64_t padding_idx = param.padding_idx;
auto *ids = ids_t->data<int64_t>();
int64_t ids_numel = ids_t->dims().production();
auto *table_t = param.W;
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<float>();
auto *output = output_t->mutable_data<float>();
memset(output, 0, output_t->dims().production() * sizeof(float));
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != -1 && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(float));
} else {
CHECK_LT(ids[i], row_number);
CHECK_GE(ids[i], 0);
memcpy(output + i * row_width,
table + ids[i] * row_width,
row_width * sizeof(float));
}
}
}
TEST(lookup_table_arm, retrieve_op) {
auto lookup_table =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kAny)>(
"lookup_table");
ASSERT_FALSE(lookup_table.empty());
ASSERT_TRUE(lookup_table.front());
}
TEST(lookup_table_arm, init) {
LookupTableCompute lookup_table;
ASSERT_EQ(lookup_table.precision(), PRECISION(kAny));
ASSERT_EQ(lookup_table.target(), TARGET(kARM));
}
TEST(lookup_table_arm, compute) {
LookupTableCompute lookup_table;
operators::LookupTableParam param;
lite::Tensor w, ids, out, out_ref;
int64_t padding_idx = -1;
auto w_dim = DDim(std::vector<int64_t>({4, 5}));
auto ids_dim = DDim(std::vector<int64_t>({3, 2}));
auto out_dim = DDim(std::vector<int64_t>({3, 2, 5}));
w.Resize(w_dim);
ids.Resize(ids_dim);
out.Resize(out_dim);
out_ref.Resize(out_dim);
auto *w_data = w.mutable_data<float>();
auto *ids_data = ids.mutable_data<int64_t>();
auto *out_data = out.mutable_data<float>();
auto *out_ref_data = out_ref.mutable_data<float>();
int w_num = w_dim.production();
for (int i = 0; i < w_num; i++) {
w_data[i] = static_cast<float>(i + 1) / (w_num + 1);
}
int ids_num = ids_dim.production();
for (int i = 0; i < ids_num; i++) {
ids_data[i] = i % 4;
}
int out_num = out_dim.production();
param.W = &w;
param.Ids = &ids;
param.Out = &out;
lookup_table.SetParam(param);
lookup_table.Run();
param.Out = &out_ref;
lookup_table_compute_ref(param);
for (int i = 0; i < out_num; i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(lookup_table, kARM, kAny, kNCHW, def);
...@@ -54,8 +54,8 @@ void LookupTableCompute::Run() { ...@@ -54,8 +54,8 @@ void LookupTableCompute::Run() {
auto &param = this->Param<param_t>(); auto &param = this->Param<param_t>();
auto &ctx = this->ctx_->template As<CUDAContext>(); auto &ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream(); auto stream = ctx.exec_stream();
Tensor *w_t = param.W; const Tensor *w_t = param.W;
Tensor *ids_t = param.Ids; const Tensor *ids_t = param.Ids;
Tensor *out_t = param.Out; Tensor *out_t = param.Out;
int64_t padding_idx = param.padding_idx; int64_t padding_idx = param.padding_idx;
......
...@@ -40,10 +40,8 @@ bool BeamSearchDecodeOpLite::AttachImpl(const cpp::OpDesc &op_desc, ...@@ -40,10 +40,8 @@ bool BeamSearchDecodeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
param_.ids = scope->FindVar(ids)->GetMutable<std::vector<lite::Tensor>>(); param_.ids = scope->FindVar(ids)->GetMutable<std::vector<lite::Tensor>>();
param_.scores = param_.scores =
scope->FindVar(scores)->GetMutable<std::vector<lite::Tensor>>(); scope->FindVar(scores)->GetMutable<std::vector<lite::Tensor>>();
param_.sentence_ids = param_.sentence_ids = scope->FindMutableTensor(sentence_ids);
scope->FindVar(sentence_ids)->GetMutable<lite::Tensor>(); param_.sentence_scores = scope->FindMutableTensor(sentence_scores);
param_.sentence_scores =
scope->FindVar(sentence_scores)->GetMutable<lite::Tensor>();
param_.beam_size = op_desc.GetAttr<int>("beam_size"); param_.beam_size = op_desc.GetAttr<int>("beam_size");
param_.end_id = op_desc.GetAttr<int>("end_id"); param_.end_id = op_desc.GetAttr<int>("end_id");
......
...@@ -33,21 +33,17 @@ bool BeamSearchOp::CheckShape() const { ...@@ -33,21 +33,17 @@ bool BeamSearchOp::CheckShape() const {
bool BeamSearchOp::InferShape() const { return true; } bool BeamSearchOp::InferShape() const { return true; }
bool BeamSearchOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool BeamSearchOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.pre_ids = scope->FindVar(opdesc.Input("pre_ids").front()) param_.pre_ids = scope->FindTensor(opdesc.Input("pre_ids").front());
->GetMutable<lite::Tensor>(); param_.pre_scores = scope->FindTensor(opdesc.Input("pre_scores").front());
param_.pre_scores = scope->FindVar(opdesc.Input("pre_scores").front()) param_.ids = scope->FindTensor(opdesc.Input("ids").front());
->GetMutable<lite::Tensor>(); param_.scores = scope->FindTensor(opdesc.Input("scores").front());
param_.ids = param_.selected_ids =
scope->FindVar(opdesc.Input("ids").front())->GetMutable<lite::Tensor>(); scope->FindMutableTensor(opdesc.Output("selected_ids").front());
param_.scores = scope->FindVar(opdesc.Input("scores").front())
->GetMutable<lite::Tensor>();
param_.selected_ids = scope->FindVar(opdesc.Output("selected_ids").front())
->GetMutable<lite::Tensor>();
param_.selected_scores = param_.selected_scores =
scope->FindVar(opdesc.Output("selected_scores").front()) scope->FindMutableTensor(opdesc.Output("selected_scores").front());
->GetMutable<lite::Tensor>(); param_.parent_idx =
param_.parent_idx = scope->FindVar(opdesc.Output("parent_idx").front()) scope->FindMutableTensor(opdesc.Output("parent_idx").front());
->GetMutable<lite::Tensor>();
CHECK(param_.pre_ids) << "id null"; CHECK(param_.pre_ids) << "id null";
CHECK(param_.pre_scores) << "pre score null"; CHECK(param_.pre_scores) << "pre score null";
CHECK(param_.ids) << "ids null"; CHECK(param_.ids) << "ids null";
......
...@@ -39,15 +39,12 @@ bool GatherOp::InferShape() const { ...@@ -39,15 +39,12 @@ bool GatherOp::InferShape() const {
} }
bool GatherOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool GatherOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X = param_.X = scope->FindTensor(opdesc.Input("X").front());
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>(); param_.Index = scope->FindTensor(opdesc.Input("Index").front());
param_.Out = param_.Out = scope->FindMutableTensor(opdesc.Output("Out").front());
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.Index =
scope->FindVar(opdesc.Input("Index").front())->GetMutable<lite::Tensor>();
CHECK(param_.X) << "X is null"; CHECK(param_.X) << "X is null";
CHECK(param_.Out) << "out is null";
CHECK(param_.Index) << "index is null"; CHECK(param_.Index) << "index is null";
CHECK(param_.Out) << "out is null";
return true; return true;
} }
......
...@@ -34,10 +34,8 @@ bool IncrementOp::InferShape() const { ...@@ -34,10 +34,8 @@ bool IncrementOp::InferShape() const {
} }
bool IncrementOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool IncrementOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X = param_.X = scope->FindMutableTensor(opdesc.Input("X").front());
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>(); param_.Out = scope->FindMutableTensor(opdesc.Output("Out").front());
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
CHECK(param_.X); CHECK(param_.X);
CHECK(param_.Out); CHECK(param_.Out);
param_.step = opdesc.GetAttr<float>("step"); param_.step = opdesc.GetAttr<float>("step");
......
...@@ -55,9 +55,9 @@ bool LookupTableOpLite::AttachImpl(const cpp::OpDesc& op_desc, ...@@ -55,9 +55,9 @@ bool LookupTableOpLite::AttachImpl(const cpp::OpDesc& op_desc,
auto ids = op_desc.Input("Ids").front(); auto ids = op_desc.Input("Ids").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
param_.W = scope->FindVar(input)->GetMutable<lite::Tensor>(); param_.W = scope->FindTensor(input);
param_.Ids = scope->FindVar(ids)->GetMutable<lite::Tensor>(); param_.Ids = scope->FindTensor(ids);
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>(); param_.Out = scope->FindMutableTensor(out);
param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx"); param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx");
......
...@@ -52,9 +52,9 @@ bool LookupTableV2OpLite::AttachImpl(const cpp::OpDesc &op_desc, ...@@ -52,9 +52,9 @@ bool LookupTableV2OpLite::AttachImpl(const cpp::OpDesc &op_desc,
auto ids = op_desc.Input("Ids").front(); auto ids = op_desc.Input("Ids").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
param_.W = scope->FindVar(input)->GetMutable<lite::Tensor>(); param_.W = scope->FindTensor(input);
param_.Ids = scope->FindVar(ids)->GetMutable<lite::Tensor>(); param_.Ids = scope->FindTensor(ids);
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>(); param_.Out = scope->FindMutableTensor(out);
param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx"); param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx");
......
...@@ -655,8 +655,8 @@ struct BeamSearchDecodeParam { ...@@ -655,8 +655,8 @@ struct BeamSearchDecodeParam {
/// ----------------------- LookupTable operators ----------------------f /// ----------------------- LookupTable operators ----------------------f
struct LookupTableParam { struct LookupTableParam {
lite::Tensor* W{nullptr}; const lite::Tensor* W{nullptr};
lite::Tensor* Ids{nullptr}; const lite::Tensor* Ids{nullptr};
lite::Tensor* Out{nullptr}; lite::Tensor* Out{nullptr};
int64_t padding_idx{-1}; int64_t padding_idx{-1};
}; };
......
...@@ -22,7 +22,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_ ...@@ -22,7 +22,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_
#lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_logical_xor_compute SRCS logical_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_logical_xor_compute SRCS logical_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
...@@ -94,7 +94,9 @@ TEST(Gather, precision) { ...@@ -94,7 +94,9 @@ TEST(Gather, precision) {
LOG(INFO) << "test gather op"; LOG(INFO) << "test gather op";
float abs_error = 2e-5; float abs_error = 2e-5;
Place place; Place place;
#if defined(LITE_WITH_XPU) #if defined(LITE_WITH_ARM)
place = {TARGET(kARM), PRECISION(kAny)};
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU); place = TARGET(kXPU);
#else #else
return; return;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h" #include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -58,36 +59,35 @@ class IncrementComputeTester : public arena::TestCase { ...@@ -58,36 +59,35 @@ class IncrementComputeTester : public arena::TestCase {
} }
void PrepareData() override { void PrepareData() override {
std::vector<float> data(dims_.production()); std::vector<float> din(dims_.production());
fill_data_rand(din.data(), -5.f, 5.f, dims_.production());
for (int i = 0; i < dims_.production(); i++) { SetCommonTensor(input_, dims_, din.data());
data[i] = i * 1.1;
}
SetCommonTensor(input_, dims_, data.data());
} }
}; };
void test_increment(Place place) {
void test_increment(Place place, float abs_error) {
DDimLite dims_0{{3, 5, 4, 4}}; DDimLite dims_0{{3, 5, 4, 4}};
DDimLite dims_1{{3, 5}}; DDimLite dims_1{{3, 5}};
for (auto dims : {dims_0, dims_1}) { for (auto dims : {dims_0, dims_1}) {
for (float step : {1, 2}) { for (float step : {1, 2}) {
std::unique_ptr<arena::TestCase> tester( std::unique_ptr<arena::TestCase> tester(
new IncrementComputeTester(place, "def", step, dims)); new IncrementComputeTester(place, "def", step, dims));
arena::Arena arena(std::move(tester), place, 2e-5); arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision(); arena.TestPrecision();
} }
} }
} }
TEST(Increment, precision) { TEST(Increment, precision) {
// #ifdef LITE_WITH_X86 Place place;
// Place place(TARGET(kX86)); float abs_error = 2e-5;
// #endif #if defined(LITE_WITH_ARM)
#ifdef LITE_WITH_ARM place = {TARGET(kARM), PRECISION(kAny)};
Place place(TARGET(kARM)); #else
test_increment(place); return;
#endif #endif
test_increment(place, abs_error);
} }
} // namespace lite } // namespace lite
......
...@@ -111,7 +111,9 @@ TEST(LookupTable, precision) { ...@@ -111,7 +111,9 @@ TEST(LookupTable, precision) {
LOG(INFO) << "test lookup_table op"; LOG(INFO) << "test lookup_table op";
float abs_error = 2e-5; float abs_error = 2e-5;
Place place; Place place;
#if defined(LITE_WITH_XPU) #if defined(LITE_WITH_ARM)
place = {TARGET(kARM), PRECISION(kAny)};
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU); place = TARGET(kXPU);
#else #else
return; return;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册