diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 8949602cab00c28d03424ad7cca2387765375b80..731df6e6629826016cafc386284a17f754f83ece 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -104,4 +104,5 @@ lite_cc_test(test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc if(LITE_BUILD_EXTRA) lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm) + lite_cc_test(test_lookup_table_compute_arm SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_arm) endif() diff --git a/lite/kernels/arm/lookup_table_compute.cc b/lite/kernels/arm/lookup_table_compute.cc index fa7e2c0c3ae4580f5d19e82f7c48c74db3058847..ba58b378f4dda22fd78ce76b80bdbca8d8f284a3 100644 --- a/lite/kernels/arm/lookup_table_compute.cc +++ b/lite/kernels/arm/lookup_table_compute.cc @@ -28,7 +28,6 @@ namespace arm { void LookupTableCompute::Run() { auto& param = this->Param(); - auto& ctx = this->ctx_->template As(); // inputs auto w = param.W; auto ids = param.Ids; @@ -37,7 +36,7 @@ void LookupTableCompute::Run() { auto table_dim = w->dims(); int64_t ids_numel = ids->numel(); - auto ids_data = ids->data(); + auto ids_data = ids->data(); int64_t row_number = table_dim[0]; int64_t row_width = table_dim[1]; @@ -76,3 +75,14 @@ REGISTER_LITE_KERNEL(lookup_table, .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); + +REGISTER_LITE_KERNEL(lookup_table_v2, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::LookupTableCompute, + def) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/lookup_table_compute_test.cc b/lite/kernels/arm/lookup_table_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..78748edf39c43c5451f8fa3c4d63bde7405c7078 --- /dev/null +++ b/lite/kernels/arm/lookup_table_compute_test.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/lookup_table_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void lookup_table_compute_ref(const operators::LookupTableParam ¶m) { + auto *ids_t = param.Ids; + auto *output_t = param.Out; + int64_t padding_idx = param.padding_idx; + auto *ids = ids_t->data(); + int64_t ids_numel = ids_t->dims().production(); + + auto *table_t = param.W; + int64_t row_number = table_t->dims()[0]; + int64_t row_width = table_t->dims()[1]; + + auto *table = table_t->data(); + auto *output = output_t->mutable_data(); + memset(output, 0, output_t->dims().production() * sizeof(float)); + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != -1 && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(float)); + } else { + CHECK_LT(ids[i], row_number); + CHECK_GE(ids[i], 0); + memcpy(output + i * row_width, + table + ids[i] * row_width, + row_width * sizeof(float)); + } + } +} + +TEST(lookup_table_arm, retrieve_op) { + auto lookup_table = + KernelRegistry::Global().Create( + "lookup_table"); + ASSERT_FALSE(lookup_table.empty()); + ASSERT_TRUE(lookup_table.front()); +} + +TEST(lookup_table_arm, init) { + LookupTableCompute lookup_table; + ASSERT_EQ(lookup_table.precision(), PRECISION(kFloat)); + ASSERT_EQ(lookup_table.target(), TARGET(kARM)); +} + +TEST(lookup_table_arm, compute) { + LookupTableCompute lookup_table; + operators::LookupTableParam param; + lite::Tensor w, ids, out, out_ref; + int64_t padding_idx = -1; + + auto w_dim = DDim(std::vector({4, 5})); + auto ids_dim = DDim(std::vector({3, 2})); + auto out_dim = DDim(std::vector({3, 2, 5})); + + w.Resize(w_dim); + ids.Resize(ids_dim); + out.Resize(out_dim); + out_ref.Resize(out_dim); + + auto *w_data = w.mutable_data(); + auto *ids_data = ids.mutable_data(); + auto *out_data = out.mutable_data(); + auto *out_ref_data = out_ref.mutable_data(); + + int w_num = w_dim.production(); + for (int i = 0; i < w_num; i++) { + w_data[i] = static_cast(i + 1) / (w_num + 1); + } + int ids_num = ids_dim.production(); + for (int i = 0; i < ids_num; i++) { + ids_data[i] = i % 4; + } + int out_num = out_dim.production(); + + param.W = &w; + param.Ids = &ids; + param.Out = &out; + lookup_table.SetParam(param); + lookup_table.Run(); + param.Out = &out_ref; + lookup_table_compute_ref(param); + for (int i = 0; i < out_num; i++) { + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(lookup_table, kARM, kFloat, kNCHW, def); diff --git a/lite/kernels/cuda/lookup_table_compute.cu b/lite/kernels/cuda/lookup_table_compute.cu index 34b6de0e105f8f6dbf070b4ad41a9e6c7d2a06c8..3c3bb952cac01a6d1e296085dc357b9b3a03773a 100644 --- a/lite/kernels/cuda/lookup_table_compute.cu +++ b/lite/kernels/cuda/lookup_table_compute.cu @@ -98,3 +98,14 @@ REGISTER_LITE_KERNEL(lookup_table, .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) .Finalize(); +REGISTER_LITE_KERNEL(lookup_table_v2, + kCUDA, + kFloat, + kNCHW, + paddle::lite::kernels::cuda::LookupTableCompute, + def) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) + .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))}) + .Finalize(); diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 5fa9e20a025784e082960e02c0e841e8d7311fe5..d40f34d726427564f069106f0e86b9482ba9f0e6 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -39,7 +39,7 @@ add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_ker add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps}) add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_kernel_deps}) -add_kernel(lookup_table_compute_x86 X86 basic SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(lookup_table_compute_x86 X86 extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps}) @@ -71,4 +71,8 @@ lite_cc_test(test_cast_compute_x86 SRCS cast_compute_test.cc DEPS cast_compute_x lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86) lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86) + +if(LITE_BUILD_EXTRA) + lite_cc_test(test_lookup_table_compute_x86 SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_x86) +endif() lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86) diff --git a/lite/kernels/x86/lookup_table_compute.cc b/lite/kernels/x86/lookup_table_compute.cc index 364593251e17453011bad5b2c1057fc25d54d7c8..856a07a94cada4702d47820605436cee6523a527 100644 --- a/lite/kernels/x86/lookup_table_compute.cc +++ b/lite/kernels/x86/lookup_table_compute.cc @@ -32,3 +32,13 @@ REGISTER_LITE_KERNEL(lookup_table, .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); +REGISTER_LITE_KERNEL(lookup_table_v2, + kX86, + kInt64, + kNCHW, + paddle::lite::kernels::x86::LookupTableCompute, + def) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/lookup_table_compute.h b/lite/kernels/x86/lookup_table_compute.h index e0d7752ca77c810700f57722c4186b4e02d6411f..019544850309f8db306857f5f2767b4baaad9bb0 100644 --- a/lite/kernels/x86/lookup_table_compute.h +++ b/lite/kernels/x86/lookup_table_compute.h @@ -30,7 +30,6 @@ class LookupTableCompute : public KernelLite { void Run() override { auto ¶m = *param_.get_mutable(); - // auto& context = context_->As(); auto *ids_t = param.Ids; auto *output_t = param.Out; int64_t padding_idx = param.padding_idx; diff --git a/lite/kernels/x86/lookup_table_compute_test.cc b/lite/kernels/x86/lookup_table_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..86b2d39186b10de6def72a217cd6c70773b59420 --- /dev/null +++ b/lite/kernels/x86/lookup_table_compute_test.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/lookup_table_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(lookup_table_x86, compute) { + LookupTableCompute lookup_table; + operators::LookupTableParam param; + lite::Tensor w, ids, out, out_ref; + int64_t padding_idx = -1; + + int vocab_size = 40; + int emb_size = 50; + int ids_h = 30; + int ids_w = 20; + + auto w_dim = DDim({vocab_size, emb_size}); + auto ids_dim = DDim({ids_h, ids_w}); + auto out_dim = DDim({ids_h, ids_w, emb_size}); + + w.Resize(w_dim); + ids.Resize(ids_dim); + out.Resize(out_dim); + out_ref.Resize(out_dim); + + auto* w_data = w.mutable_data(); + auto* ids_data = ids.mutable_data(); + auto* out_data = out.mutable_data(); + auto* out_ref_data = out_ref.mutable_data(); + + int w_num = w_dim.production(); + for (int i = 0; i < w_num; i++) { + w_data[i] = static_cast(i + 1) / (w_num + 1); + } + int ids_num = ids_dim.production(); + for (int i = 0; i < ids_num; i++) { + ids_data[i] = i % vocab_size; + } + int out_num = out_dim.production(); + for (int i = 0; i < out_num; i++) { + out_ref_data[i] = + static_cast((i % (vocab_size * emb_size)) + 1) / (w_num + 1); + } + + param.W = &w; + param.Ids = &ids; + param.Out = &out; + param.padding_idx = padding_idx; + lookup_table.SetParam(param); + lookup_table.Run(); + for (int i = 0; i < out_num; i++) { + EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(lookup_table, kX86, kInt64, kNCHW, def); diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index b56606ba3c8a16985d394a86594aa1079a6eb4ba..92a0eb856a5fc6ca490012385ad4fc7c04431d00 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -85,6 +85,7 @@ add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_ # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) add_operator(lookup_table_op extra SRCS lookup_table_op.cc DEPS ${op_DEPS}) +add_operator(lookup_table_v2_op extra SRCS lookup_table_v2_op.cc DEPS ${op_DEPS}) add_operator(beam_search_decode_op extra SRCS beam_search_decode_op.cc DEPS ${op_DEPS}) add_operator(graph_op_lite extra SRCS graph_op.cc DEPS ${op_DEPS}) add_operator(logical_xor extra SRCS logical_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/lookup_table_v2_op.cc b/lite/operators/lookup_table_v2_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c783695163b1d95964ac1a8a9d79d7167811261a --- /dev/null +++ b/lite/operators/lookup_table_v2_op.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/lookup_table_v2_op.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool LookupTableV2OpLite::CheckShape() const { + CHECK_OR_FALSE(param_.W) + CHECK_OR_FALSE(param_.Ids) + CHECK_OR_FALSE(param_.Out) + + auto table_dims = param_.W->dims(); + + CHECK_EQ_OR_FALSE(table_dims.size(), 2) + + return true; +} + +bool LookupTableV2OpLite::InferShape() const { + auto table_dims = param_.W->dims(); + auto ids_dims = param_.Ids->dims(); + + std::vector out_dims; + for (int i = 0; i < ids_dims.size(); ++i) { + out_dims.push_back(ids_dims[i]); + } + out_dims.push_back(table_dims[1]); + param_.Out->Resize(lite::DDim{out_dims}); + param_.Out->set_lod(param_.Ids->lod()); + return true; +} + +bool LookupTableV2OpLite::AttachImpl(const cpp::OpDesc &op_desc, + lite::Scope *scope) { + auto input = op_desc.Input("W").front(); + auto ids = op_desc.Input("Ids").front(); + auto out = op_desc.Output("Out").front(); + + param_.W = scope->FindVar(input)->GetMutable(); + param_.Ids = scope->FindVar(ids)->GetMutable(); + param_.Out = scope->FindVar(out)->GetMutable(); + + param_.padding_idx = op_desc.GetAttr("padding_idx"); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(lookup_table_v2, paddle::lite::operators::LookupTableV2OpLite) diff --git a/lite/operators/lookup_table_v2_op.h b/lite/operators/lookup_table_v2_op.h new file mode 100644 index 0000000000000000000000000000000000000000..dabff3f0cac75cb70cde6eb6e95df34dc36901fe --- /dev/null +++ b/lite/operators/lookup_table_v2_op.h @@ -0,0 +1,46 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class LookupTableV2OpLite : public OpLite { + public: + LookupTableV2OpLite() {} + explicit LookupTableV2OpLite(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "LookupTable"; } + + private: + mutable LookupTableParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle