diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index 60d5d67202542a9eea2713247958afcf4000e008..49078844354c8597ea489a986826edbdbcc9eb62 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -88,6 +88,7 @@ add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_k add_kernel(gru_compute_arm ARM extra SRCS gru_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(beam_search_decode_compute_arm ARM extra SRCS beam_search_decode_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(lookup_table_compute_arm ARM extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(lookup_table_dequant_compute_arm ARM extra SRCS lookup_table_dequant_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(logical_compute_arm ARM extra SRCS logical_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(less_than_arm ARM extra SRCS compare_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/lookup_table_dequant_compute.cc b/lite/kernels/arm/lookup_table_dequant_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..6127e4279b14c40af1cde14b267581426f9ffaa1 --- /dev/null +++ b/lite/kernels/arm/lookup_table_dequant_compute.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/lookup_table_dequant_compute.h" +#include +#include +#include "lite/api/paddle_place.h" +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void dequant(const unsigned char *in, + float *out, + float min, + float max, + int emb_size, + int pow_2_bits) { + float scale = (max - min) / pow_2_bits; + for (int i = 0; i < emb_size; ++i) { + float x = scale * static_cast(in[i]) + min; + out[i] = x; + } +} + +void LookupTableDequantCompute::Run() { + auto ¶m = this->Param(); + // inputs + auto w = param.W; + auto ids = param.Ids; + // outputs + auto out = param.Out; + + auto table_dim = w->dims(); + int64_t ids_numel = ids->numel(); + auto ids_data = ids->data(); + + int64_t row_number = table_dim[0]; + int64_t quant_number = table_dim[1]; + int64_t row_width = (quant_number - 2) * 4; + + auto table_data = w->data(); + auto dout = out->mutable_data(); + int pow_2_bits = static_cast(pow(2, 8)); + + for (int64_t i = 0; i < ids_numel; ++i) { + int ids_int = ids_data[i]; + if (param.padding_idx != -1 && ids_data[i] == param.padding_idx) { + memset(dout + i * row_width, 0, row_width * sizeof(float)); + } else { + CHECK_LT(ids_data[i], row_number) + << "look uptable ids[i] < row_number check failed"; + CHECK_GE(ids_data[i], 0) << "lookuptable ids[i] >= 0 check failed"; + float min = *(table_data + ids_data[i] * quant_number); + float max = *(table_data + ids_data[i] * quant_number + 1); + int offset = ids_data[i] * quant_number + 2; + const unsigned char *tensor_buf = + reinterpret_cast(table_data + offset); + dequant( + tensor_buf, dout + i * row_width, min, max, row_width, pow_2_bits); + + // memcpy(dout + i * row_width, + // table_data + ids_int * row_width, + // row_width * sizeof(float)); + } + } + *(out->mutable_lod()) = ids->lod(); +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(lookup_table_dequant, + kARM, + kAny, + kNCHW, + paddle::lite::kernels::arm::LookupTableDequantCompute, + def) + .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/lookup_table_dequant_compute.h b/lite/kernels/arm/lookup_table_dequant_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..e1a41dcdf20fb287d3cc2022832cbb9a25e93eb4 --- /dev/null +++ b/lite/kernels/arm/lookup_table_dequant_compute.h @@ -0,0 +1,39 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class LookupTableDequantCompute + : public KernelLite { + public: + using param_t = operators::LookupTableDequantParam; + + LookupTableDequantCompute() = default; + + void Run() override; + + virtual ~LookupTableDequantCompute() = default; +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index d4fc74e2c5bb0cc9e23e97a73ef514be7e9b6af6..512473c7d52d0f43b01611048012a8c54e2b7244 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -109,6 +109,7 @@ add_operator(distribute_fpn_proposals_op_lite extra SRCS distribute_fpn_proposal # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) add_operator(lookup_table_op extra SRCS lookup_table_op.cc DEPS ${op_DEPS}) +add_operator(lookup_table_dequant_op extra SRCS lookup_table_dequant_op.cc DEPS ${op_DEPS}) add_operator(lookup_table_v2_op extra SRCS lookup_table_v2_op.cc DEPS ${op_DEPS}) add_operator(beam_search_decode_op extra SRCS beam_search_decode_op.cc DEPS ${op_DEPS}) add_operator(logical_xor extra SRCS logical_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/lookup_table_dequant_op.cc b/lite/operators/lookup_table_dequant_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b81043bfbfeed356e3d67065686057adfadcb25f --- /dev/null +++ b/lite/operators/lookup_table_dequant_op.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/lookup_table_dequant_op.h" +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool LookupTableDequantOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.W) + CHECK_OR_FALSE(param_.Ids) + CHECK_OR_FALSE(param_.Out) + + const auto& table_dims = param_.W->dims(); + const auto& ids_dims = param_.Ids->dims(); + + int ids_rank = ids_dims.size(); + + CHECK_EQ_OR_FALSE(table_dims.size(), 2); + CHECK_EQ_OR_FALSE(ids_dims[ids_rank - 1], 1); + CHECK_GT_OR_FALSE(table_dims[1], 2); + return true; +} + +bool LookupTableDequantOpLite::InferShape() const { + const auto& table_dims = param_.W->dims(); + const auto& ids_dims = param_.Ids->dims(); + + auto out_dims = ids_dims; + int ids_rank = ids_dims.size(); + out_dims[ids_rank - 1] = (table_dims[1] - 2) * 4; + + param_.Out->Resize(out_dims); + param_.Out->set_lod(param_.Ids->lod()); + return true; +} + +bool LookupTableDequantOpLite::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { + auto input = op_desc.Input("W").front(); + auto ids = op_desc.Input("Ids").front(); + auto out = op_desc.Output("Out").front(); + + param_.W = scope->FindVar(input)->GetMutable(); + param_.Ids = scope->FindVar(ids)->GetMutable(); + param_.Out = scope->FindVar(out)->GetMutable(); + + param_.padding_idx = op_desc.GetAttr("padding_idx"); + + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(lookup_table_dequant, + paddle::lite::operators::LookupTableDequantOpLite) diff --git a/lite/operators/lookup_table_dequant_op.h b/lite/operators/lookup_table_dequant_op.h new file mode 100644 index 0000000000000000000000000000000000000000..3a9683d5ca0d87365cb240b91dccab07cf26ca71 --- /dev/null +++ b/lite/operators/lookup_table_dequant_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class LookupTableDequantOpLite : public OpLite { + public: + LookupTableDequantOpLite() {} + explicit LookupTableDequantOpLite(const std::string &op_type) + : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "LookupTableDequant"; } + + private: + mutable LookupTableDequantParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 25bbcc7687dd3000301f95c7ab365d2157f196dd..0ca7e3d2a832283af1582ae28626a9c4b3936005 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -661,6 +661,13 @@ struct LookupTableParam { int64_t padding_idx{-1}; }; +struct LookupTableDequantParam { + lite::Tensor* W{nullptr}; + lite::Tensor* Ids{nullptr}; + lite::Tensor* Out{nullptr}; + int64_t padding_idx{-1}; +}; + struct Im2SequenceParam { const lite::Tensor* X{}; const lite::Tensor* Y{}; diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 54e85e14a0276fdc5543d7369bb4c7f044e87974..6d63b7054176fae6d2c88cf2e330fce2c6f7eb6f 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -58,6 +58,7 @@ if(LITE_BUILD_EXTRA) lite_cc_test(test_kernel_search_aligned_mat_mul_compute SRCS search_aligned_mat_mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_search_seq_fc_compute SRCS search_seq_fc_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_lookup_table_compute SRCS lookup_table_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_lookup_table_dequant_compute SRCS lookup_table_dequant_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_gather_compute SRCS gather_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) endif() lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/lookup_table_dequant_compute_test.cc b/lite/tests/kernels/lookup_table_dequant_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..30f551e8b7e29ef373145bdb3496c1885907531a --- /dev/null +++ b/lite/tests/kernels/lookup_table_dequant_compute_test.cc @@ -0,0 +1,150 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" +#include "lite/tests/utils/fill_data.h" + +namespace paddle { +namespace lite { + +void dequant(const unsigned char* in, + float* out, + float min, + float max, + int emb_size, + int pow_2_bits) { + float scale = (max - min) / pow_2_bits; + for (int i = 0; i < emb_size; ++i) { + float x = scale * static_cast(in[i]) + min; + out[i] = x; + } +} + +class LookupTableDequantComputeTest : public arena::TestCase { + protected: + // common attributes for this op. + std::string op_type_ = "lookup_table_dequant"; + std::string ids_ = "ids"; + std::string w_ = "w"; + std::string out_ = "out"; + DDim ids_dims_{{2, 1}}; + DDim w_dims_{{8, 4}}; + int64_t padding_idx_ = -1; + + public: + LookupTableDequantComputeTest(const Place& place, + const std::string& alias, + const DDim& ids_dims, + const DDim& w_dims, + int64_t padding_idx) + : TestCase(place, alias), + ids_dims_(ids_dims), + w_dims_(w_dims), + padding_idx_(padding_idx) {} + + void RunBaseline(Scope* scope) override { + auto ids = scope->FindTensor(ids_); + auto w = scope->FindTensor(w_); + auto ids_dims = ids->dims(); + auto w_dims = w->dims(); + + auto out = scope->NewTensor(out_); + CHECK(out); + + int ids_rank = ids_dims.size(); + CHECK_EQ(ids_dims[ids_rank - 1], 1); + CHECK_EQ(w_dims.size(), 2); + + std::vector out_dims; + for (int i = 0; i < ids_rank - 1; ++i) { + out_dims.push_back(ids_dims[i]); + } + out_dims.push_back((w_dims[1] - 2) * 4); + out->Resize(out_dims); + out->set_lod(ids->lod()); + + auto ids_data = ids->data(); + auto ids_size = ids_dims.production(); + auto w_data = w->data(); + auto w_rows = w_dims[0]; + auto quant_number = w_dims[1]; + auto w_cols = (quant_number - 2) * 4; + auto out_data = out->mutable_data(); + int pow_2_bits = static_cast(pow(2, 8)); + + for (int64_t i = 0; i < ids_size; i++) { + auto id = ids_data[i]; + if (padding_idx_ != -1 && id == padding_idx_) { + memset(out_data + i * w_cols, 0, w_cols * sizeof(float)); + } else { + CHECK_LT(id, w_rows) << "lookup_table ids[i] expected < " << w_rows + << " but got " << id; + CHECK_GE(id, 0) << "lookup_table ids[i] expected >= 0 but got " << id; + float min = *(w_data + ids_data[i] * quant_number); + float max = *(w_data + ids_data[i] * quant_number + 1); + int offset = ids_data[i] * quant_number + 2; + const unsigned char* tensor_buf = + reinterpret_cast(w_data + offset); + dequant( + tensor_buf, out_data + i * w_cols, min, max, w_cols, pow_2_bits); + } + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType(op_type_); + op_desc->SetInput("Ids", {ids_}); + op_desc->SetInput("W", {w_}); + op_desc->SetOutput("Out", {out_}); + op_desc->SetAttr("padding_idx", padding_idx_); + } + + void PrepareData() override { + std::vector ids(ids_dims_.production()); + fill_data_rand( + ids.data(), 0, w_dims_[0] - 1, ids_dims_.production()); + + std::vector w(w_dims_.production()); + fill_data_rand(w.data(), -1.f, 1.f, w_dims_.production()); + + SetCommonTensor(ids_, ids_dims_, ids.data()); + SetCommonTensor(w_, w_dims_, w.data()); + } +}; + +TEST(LookupTableDequant, precision) { +#ifdef LITE_WITH_ARM + float abs_error = 2e-5; + Place place = {TARGET(kARM), PRECISION(kAny)}; + for (auto ids_dims : + std::vector>{{5, 2, 3, 1}, {2, 3, 1}, {3, 1}}) { + for (auto w_dims : + std::vector>{{4, 3}, {6, 8}, {12, 15}}) { + for (auto padding_idx : std::vector{-1}) { + std::unique_ptr tester( + new LookupTableDequantComputeTest( + place, "def", DDim(ids_dims), DDim(w_dims), padding_idx)); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } + } + } +#endif +} + +} // namespace lite +} // namespace paddle