提交 24029728 编写于 作者: P Pei Yang 提交者: GitHub

Update lookup_table op on arm x86,add lookup_table_v2_op (#2405)

* update lookup_table arm x86, test=develop

* add lookup_table_v2_op for compatibility, test=develop
上级 3132ad03
......@@ -104,4 +104,5 @@ lite_cc_test(test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc
if(LITE_BUILD_EXTRA)
lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm)
lite_cc_test(test_lookup_table_compute_arm SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_arm)
endif()
......@@ -28,7 +28,6 @@ namespace arm {
void LookupTableCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
// inputs
auto w = param.W;
auto ids = param.Ids;
......@@ -37,7 +36,7 @@ void LookupTableCompute::Run() {
auto table_dim = w->dims();
int64_t ids_numel = ids->numel();
auto ids_data = ids->data<float>();
auto ids_data = ids->data<int64_t>();
int64_t row_number = table_dim[0];
int64_t row_width = table_dim[1];
......@@ -76,3 +75,14 @@ REGISTER_LITE_KERNEL(lookup_table,
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(lookup_table_v2,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::LookupTableCompute,
def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/lookup_table_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void lookup_table_compute_ref(const operators::LookupTableParam &param) {
auto *ids_t = param.Ids;
auto *output_t = param.Out;
int64_t padding_idx = param.padding_idx;
auto *ids = ids_t->data<int64_t>();
int64_t ids_numel = ids_t->dims().production();
auto *table_t = param.W;
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<float>();
auto *output = output_t->mutable_data<float>();
memset(output, 0, output_t->dims().production() * sizeof(float));
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != -1 && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(float));
} else {
CHECK_LT(ids[i], row_number);
CHECK_GE(ids[i], 0);
memcpy(output + i * row_width,
table + ids[i] * row_width,
row_width * sizeof(float));
}
}
}
TEST(lookup_table_arm, retrieve_op) {
auto lookup_table =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"lookup_table");
ASSERT_FALSE(lookup_table.empty());
ASSERT_TRUE(lookup_table.front());
}
TEST(lookup_table_arm, init) {
LookupTableCompute lookup_table;
ASSERT_EQ(lookup_table.precision(), PRECISION(kFloat));
ASSERT_EQ(lookup_table.target(), TARGET(kARM));
}
TEST(lookup_table_arm, compute) {
LookupTableCompute lookup_table;
operators::LookupTableParam param;
lite::Tensor w, ids, out, out_ref;
int64_t padding_idx = -1;
auto w_dim = DDim(std::vector<int64_t>({4, 5}));
auto ids_dim = DDim(std::vector<int64_t>({3, 2}));
auto out_dim = DDim(std::vector<int64_t>({3, 2, 5}));
w.Resize(w_dim);
ids.Resize(ids_dim);
out.Resize(out_dim);
out_ref.Resize(out_dim);
auto *w_data = w.mutable_data<float>();
auto *ids_data = ids.mutable_data<int64_t>();
auto *out_data = out.mutable_data<float>();
auto *out_ref_data = out_ref.mutable_data<float>();
int w_num = w_dim.production();
for (int i = 0; i < w_num; i++) {
w_data[i] = static_cast<float>(i + 1) / (w_num + 1);
}
int ids_num = ids_dim.production();
for (int i = 0; i < ids_num; i++) {
ids_data[i] = i % 4;
}
int out_num = out_dim.production();
param.W = &w;
param.Ids = &ids;
param.Out = &out;
lookup_table.SetParam(param);
lookup_table.Run();
param.Out = &out_ref;
lookup_table_compute_ref(param);
for (int i = 0; i < out_num; i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(lookup_table, kARM, kFloat, kNCHW, def);
......@@ -98,3 +98,14 @@ REGISTER_LITE_KERNEL(lookup_table,
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.Finalize();
REGISTER_LITE_KERNEL(lookup_table_v2,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::LookupTableCompute,
def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.Finalize();
......@@ -39,7 +39,7 @@ add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_ker
add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps})
add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_x86 X86 basic SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_x86 X86 extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps})
......@@ -71,4 +71,8 @@ lite_cc_test(test_cast_compute_x86 SRCS cast_compute_test.cc DEPS cast_compute_x
lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86)
lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86)
lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86)
if(LITE_BUILD_EXTRA)
lite_cc_test(test_lookup_table_compute_x86 SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_x86)
endif()
lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86)
......@@ -32,3 +32,13 @@ REGISTER_LITE_KERNEL(lookup_table,
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
REGISTER_LITE_KERNEL(lookup_table_v2,
kX86,
kInt64,
kNCHW,
paddle::lite::kernels::x86::LookupTableCompute<float>,
def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
......@@ -30,7 +30,6 @@ class LookupTableCompute : public KernelLite<TARGET(kX86), PRECISION(kInt64)> {
void Run() override {
auto &param = *param_.get_mutable<operators::LookupTableParam>();
// auto& context = context_->As<X86Context>();
auto *ids_t = param.Ids;
auto *output_t = param.Out;
int64_t padding_idx = param.padding_idx;
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/lookup_table_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(lookup_table_x86, compute) {
LookupTableCompute<float> lookup_table;
operators::LookupTableParam param;
lite::Tensor w, ids, out, out_ref;
int64_t padding_idx = -1;
int vocab_size = 40;
int emb_size = 50;
int ids_h = 30;
int ids_w = 20;
auto w_dim = DDim({vocab_size, emb_size});
auto ids_dim = DDim({ids_h, ids_w});
auto out_dim = DDim({ids_h, ids_w, emb_size});
w.Resize(w_dim);
ids.Resize(ids_dim);
out.Resize(out_dim);
out_ref.Resize(out_dim);
auto* w_data = w.mutable_data<float>();
auto* ids_data = ids.mutable_data<int64_t>();
auto* out_data = out.mutable_data<float>();
auto* out_ref_data = out_ref.mutable_data<float>();
int w_num = w_dim.production();
for (int i = 0; i < w_num; i++) {
w_data[i] = static_cast<float>(i + 1) / (w_num + 1);
}
int ids_num = ids_dim.production();
for (int i = 0; i < ids_num; i++) {
ids_data[i] = i % vocab_size;
}
int out_num = out_dim.production();
for (int i = 0; i < out_num; i++) {
out_ref_data[i] =
static_cast<float>((i % (vocab_size * emb_size)) + 1) / (w_num + 1);
}
param.W = &w;
param.Ids = &ids;
param.Out = &out;
param.padding_idx = padding_idx;
lookup_table.SetParam(param);
lookup_table.Run();
for (int i = 0; i < out_num; i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(lookup_table, kX86, kInt64, kNCHW, def);
......@@ -85,6 +85,7 @@ add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
add_operator(lookup_table_op extra SRCS lookup_table_op.cc DEPS ${op_DEPS})
add_operator(lookup_table_v2_op extra SRCS lookup_table_v2_op.cc DEPS ${op_DEPS})
add_operator(beam_search_decode_op extra SRCS beam_search_decode_op.cc DEPS ${op_DEPS})
add_operator(graph_op_lite extra SRCS graph_op.cc DEPS ${op_DEPS})
add_operator(logical_xor extra SRCS logical_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/lookup_table_v2_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool LookupTableV2OpLite::CheckShape() const {
CHECK_OR_FALSE(param_.W)
CHECK_OR_FALSE(param_.Ids)
CHECK_OR_FALSE(param_.Out)
auto table_dims = param_.W->dims();
CHECK_EQ_OR_FALSE(table_dims.size(), 2)
return true;
}
bool LookupTableV2OpLite::InferShape() const {
auto table_dims = param_.W->dims();
auto ids_dims = param_.Ids->dims();
std::vector<int64_t> out_dims;
for (int i = 0; i < ids_dims.size(); ++i) {
out_dims.push_back(ids_dims[i]);
}
out_dims.push_back(table_dims[1]);
param_.Out->Resize(lite::DDim{out_dims});
param_.Out->set_lod(param_.Ids->lod());
return true;
}
bool LookupTableV2OpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto input = op_desc.Input("W").front();
auto ids = op_desc.Input("Ids").front();
auto out = op_desc.Output("Out").front();
param_.W = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.Ids = scope->FindVar(ids)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(lookup_table_v2, paddle::lite::operators::LookupTableV2OpLite)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class LookupTableV2OpLite : public OpLite {
public:
LookupTableV2OpLite() {}
explicit LookupTableV2OpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "LookupTable"; }
private:
mutable LookupTableParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册