未验证 提交 b086e835 编写于 作者: Z zhupengyang 提交者: GitHub

enhance kernels of transformer decoder (#3110)

* enhance gather, lookup_table arm kernel uts

* enhance beam_search, beam_search_decoder, increment rigster
上级 e2a02f63
......@@ -123,5 +123,4 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_decode_bboxes_compute_arm SRCS decode_bboxes_compute_test.cc DEPS decode_bboxes_compute_arm)
lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm)
lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm)
lite_cc_test(test_lookup_table_compute_arm SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_arm)
endif()
......@@ -20,8 +20,6 @@ namespace lite {
namespace kernels {
namespace arm {
void BeamSearchCompute::PrepareForRun() {}
void BeamSearchCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::BeamSearchParam>();
......@@ -50,11 +48,17 @@ REGISTER_LITE_KERNEL(beam_search,
kNCHW,
paddle::lite::kernels::arm::BeamSearchCompute,
def)
.BindInput("pre_ids", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("pre_scores", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("ids", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("scores", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("selected_ids", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("selected_scores", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("parent_idx", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("pre_ids",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("pre_scores",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindInput("ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("scores",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("selected_ids",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("selected_scores",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("parent_idx",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize();
......@@ -25,10 +25,6 @@ namespace arm {
class BeamSearchCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::BeamSearchParam;
void PrepareForRun() override;
void Run() override;
~BeamSearchCompute() {}
......
......@@ -293,8 +293,12 @@ REGISTER_LITE_KERNEL(beam_search_decode,
kNCHW,
paddle::lite::kernels::arm::BeamSearchDecodeCompute,
def)
.BindInput("Ids", {LiteType::GetTensorListTy(TARGET(kARM))})
.BindInput("Scores", {LiteType::GetTensorListTy(TARGET(kARM))})
.BindOutput("SentenceIds", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("SentenceScores", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Ids",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Scores",
{LiteType::GetTensorListTy(TARGET(kARM), PRECISION(kFloat))})
.BindOutput("SentenceIds",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("SentenceScores",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
......@@ -20,8 +20,6 @@ namespace lite {
namespace kernels {
namespace arm {
void GatherCompute::PrepareForRun() {}
void GatherCompute::Run() {
auto& param = this->Param<operators::GatherParam>();
......@@ -49,7 +47,7 @@ void GatherCompute::Run() {
} // namespace paddle
REGISTER_LITE_KERNEL(
gather, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::GatherCompute, def)
gather, kARM, kAny, kNCHW, paddle::lite::kernels::arm::GatherCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
......
......@@ -22,12 +22,9 @@ namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::GatherParam;
void PrepareForRun() override;
class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public:
void Run() override;
~GatherCompute() {}
......
......@@ -20,8 +20,6 @@ namespace lite {
namespace kernels {
namespace arm {
void IncrementCompute::PrepareForRun() {}
void IncrementCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::IncrementParam>();
......@@ -52,10 +50,10 @@ void IncrementCompute::Run() {
REGISTER_LITE_KERNEL(increment,
kARM,
kFloat,
kAny,
kNCHW,
paddle::lite::kernels::arm::IncrementCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
......@@ -23,12 +23,8 @@ namespace lite {
namespace kernels {
namespace arm {
class IncrementCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class IncrementCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public:
using param_t = operators::IncrementParam;
void PrepareForRun() override;
void Run() override;
~IncrementCompute() {}
......
......@@ -28,10 +28,8 @@ namespace arm {
void LookupTableCompute::Run() {
auto& param = this->Param<param_t>();
// inputs
auto w = param.W;
auto ids = param.Ids;
// outputs
auto out = param.Out;
auto table_dim = w->dims();
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/lookup_table_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void lookup_table_compute_ref(const operators::LookupTableParam &param) {
auto *ids_t = param.Ids;
auto *output_t = param.Out;
int64_t padding_idx = param.padding_idx;
auto *ids = ids_t->data<int64_t>();
int64_t ids_numel = ids_t->dims().production();
auto *table_t = param.W;
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<float>();
auto *output = output_t->mutable_data<float>();
memset(output, 0, output_t->dims().production() * sizeof(float));
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != -1 && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(float));
} else {
CHECK_LT(ids[i], row_number);
CHECK_GE(ids[i], 0);
memcpy(output + i * row_width,
table + ids[i] * row_width,
row_width * sizeof(float));
}
}
}
TEST(lookup_table_arm, retrieve_op) {
auto lookup_table =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kAny)>(
"lookup_table");
ASSERT_FALSE(lookup_table.empty());
ASSERT_TRUE(lookup_table.front());
}
TEST(lookup_table_arm, init) {
LookupTableCompute lookup_table;
ASSERT_EQ(lookup_table.precision(), PRECISION(kAny));
ASSERT_EQ(lookup_table.target(), TARGET(kARM));
}
TEST(lookup_table_arm, compute) {
LookupTableCompute lookup_table;
operators::LookupTableParam param;
lite::Tensor w, ids, out, out_ref;
int64_t padding_idx = -1;
auto w_dim = DDim(std::vector<int64_t>({4, 5}));
auto ids_dim = DDim(std::vector<int64_t>({3, 2}));
auto out_dim = DDim(std::vector<int64_t>({3, 2, 5}));
w.Resize(w_dim);
ids.Resize(ids_dim);
out.Resize(out_dim);
out_ref.Resize(out_dim);
auto *w_data = w.mutable_data<float>();
auto *ids_data = ids.mutable_data<int64_t>();
auto *out_data = out.mutable_data<float>();
auto *out_ref_data = out_ref.mutable_data<float>();
int w_num = w_dim.production();
for (int i = 0; i < w_num; i++) {
w_data[i] = static_cast<float>(i + 1) / (w_num + 1);
}
int ids_num = ids_dim.production();
for (int i = 0; i < ids_num; i++) {
ids_data[i] = i % 4;
}
int out_num = out_dim.production();
param.W = &w;
param.Ids = &ids;
param.Out = &out;
lookup_table.SetParam(param);
lookup_table.Run();
param.Out = &out_ref;
lookup_table_compute_ref(param);
for (int i = 0; i < out_num; i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(lookup_table, kARM, kAny, kNCHW, def);
......@@ -54,8 +54,8 @@ void LookupTableCompute::Run() {
auto &param = this->Param<param_t>();
auto &ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
Tensor *w_t = param.W;
Tensor *ids_t = param.Ids;
const Tensor *w_t = param.W;
const Tensor *ids_t = param.Ids;
Tensor *out_t = param.Out;
int64_t padding_idx = param.padding_idx;
......
......@@ -40,10 +40,8 @@ bool BeamSearchDecodeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
param_.ids = scope->FindVar(ids)->GetMutable<std::vector<lite::Tensor>>();
param_.scores =
scope->FindVar(scores)->GetMutable<std::vector<lite::Tensor>>();
param_.sentence_ids =
scope->FindVar(sentence_ids)->GetMutable<lite::Tensor>();
param_.sentence_scores =
scope->FindVar(sentence_scores)->GetMutable<lite::Tensor>();
param_.sentence_ids = scope->FindMutableTensor(sentence_ids);
param_.sentence_scores = scope->FindMutableTensor(sentence_scores);
param_.beam_size = op_desc.GetAttr<int>("beam_size");
param_.end_id = op_desc.GetAttr<int>("end_id");
......
......@@ -33,21 +33,17 @@ bool BeamSearchOp::CheckShape() const {
bool BeamSearchOp::InferShape() const { return true; }
bool BeamSearchOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.pre_ids = scope->FindVar(opdesc.Input("pre_ids").front())
->GetMutable<lite::Tensor>();
param_.pre_scores = scope->FindVar(opdesc.Input("pre_scores").front())
->GetMutable<lite::Tensor>();
param_.ids =
scope->FindVar(opdesc.Input("ids").front())->GetMutable<lite::Tensor>();
param_.scores = scope->FindVar(opdesc.Input("scores").front())
->GetMutable<lite::Tensor>();
param_.selected_ids = scope->FindVar(opdesc.Output("selected_ids").front())
->GetMutable<lite::Tensor>();
param_.pre_ids = scope->FindTensor(opdesc.Input("pre_ids").front());
param_.pre_scores = scope->FindTensor(opdesc.Input("pre_scores").front());
param_.ids = scope->FindTensor(opdesc.Input("ids").front());
param_.scores = scope->FindTensor(opdesc.Input("scores").front());
param_.selected_ids =
scope->FindMutableTensor(opdesc.Output("selected_ids").front());
param_.selected_scores =
scope->FindVar(opdesc.Output("selected_scores").front())
->GetMutable<lite::Tensor>();
param_.parent_idx = scope->FindVar(opdesc.Output("parent_idx").front())
->GetMutable<lite::Tensor>();
scope->FindMutableTensor(opdesc.Output("selected_scores").front());
param_.parent_idx =
scope->FindMutableTensor(opdesc.Output("parent_idx").front());
CHECK(param_.pre_ids) << "id null";
CHECK(param_.pre_scores) << "pre score null";
CHECK(param_.ids) << "ids null";
......
......@@ -39,15 +39,12 @@ bool GatherOp::InferShape() const {
}
bool GatherOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.Index =
scope->FindVar(opdesc.Input("Index").front())->GetMutable<lite::Tensor>();
param_.X = scope->FindTensor(opdesc.Input("X").front());
param_.Index = scope->FindTensor(opdesc.Input("Index").front());
param_.Out = scope->FindMutableTensor(opdesc.Output("Out").front());
CHECK(param_.X) << "X is null";
CHECK(param_.Out) << "out is null";
CHECK(param_.Index) << "index is null";
CHECK(param_.Out) << "out is null";
return true;
}
......
......@@ -34,10 +34,8 @@ bool IncrementOp::InferShape() const {
}
bool IncrementOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.X = scope->FindMutableTensor(opdesc.Input("X").front());
param_.Out = scope->FindMutableTensor(opdesc.Output("Out").front());
CHECK(param_.X);
CHECK(param_.Out);
param_.step = opdesc.GetAttr<float>("step");
......
......@@ -55,9 +55,9 @@ bool LookupTableOpLite::AttachImpl(const cpp::OpDesc& op_desc,
auto ids = op_desc.Input("Ids").front();
auto out = op_desc.Output("Out").front();
param_.W = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.Ids = scope->FindVar(ids)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.W = scope->FindTensor(input);
param_.Ids = scope->FindTensor(ids);
param_.Out = scope->FindMutableTensor(out);
param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx");
......
......@@ -52,9 +52,9 @@ bool LookupTableV2OpLite::AttachImpl(const cpp::OpDesc &op_desc,
auto ids = op_desc.Input("Ids").front();
auto out = op_desc.Output("Out").front();
param_.W = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.Ids = scope->FindVar(ids)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.W = scope->FindTensor(input);
param_.Ids = scope->FindTensor(ids);
param_.Out = scope->FindMutableTensor(out);
param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx");
......
......@@ -655,8 +655,8 @@ struct BeamSearchDecodeParam {
/// ----------------------- LookupTable operators ----------------------f
struct LookupTableParam {
lite::Tensor* W{nullptr};
lite::Tensor* Ids{nullptr};
const lite::Tensor* W{nullptr};
const lite::Tensor* Ids{nullptr};
lite::Tensor* Out{nullptr};
int64_t padding_idx{-1};
};
......
......@@ -22,7 +22,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_
#lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_logical_xor_compute SRCS logical_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_topk_compute SRCS topk_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
......@@ -94,7 +94,9 @@ TEST(Gather, precision) {
LOG(INFO) << "test gather op";
float abs_error = 2e-5;
Place place;
#if defined(LITE_WITH_XPU)
#if defined(LITE_WITH_ARM)
place = {TARGET(kARM), PRECISION(kAny)};
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
......
......@@ -16,6 +16,7 @@
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
......@@ -58,36 +59,35 @@ class IncrementComputeTester : public arena::TestCase {
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(input_, dims_, data.data());
std::vector<float> din(dims_.production());
fill_data_rand(din.data(), -5.f, 5.f, dims_.production());
SetCommonTensor(input_, dims_, din.data());
}
};
void test_increment(Place place) {
void test_increment(Place place, float abs_error) {
DDimLite dims_0{{3, 5, 4, 4}};
DDimLite dims_1{{3, 5}};
for (auto dims : {dims_0, dims_1}) {
for (float step : {1, 2}) {
std::unique_ptr<arena::TestCase> tester(
new IncrementComputeTester(place, "def", step, dims));
arena::Arena arena(std::move(tester), place, 2e-5);
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
}
TEST(Increment, precision) {
// #ifdef LITE_WITH_X86
// Place place(TARGET(kX86));
// #endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_increment(place);
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_ARM)
place = {TARGET(kARM), PRECISION(kAny)};
#else
return;
#endif
test_increment(place, abs_error);
}
} // namespace lite
......
......@@ -111,7 +111,9 @@ TEST(LookupTable, precision) {
LOG(INFO) << "test lookup_table op";
float abs_error = 2e-5;
Place place;
#if defined(LITE_WITH_XPU)
#if defined(LITE_WITH_ARM)
place = {TARGET(kARM), PRECISION(kAny)};
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册