diff --git a/paddle/fluid/operators/lookup_table_dequant_op.cc b/paddle/fluid/operators/lookup_table_dequant_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b5ac9eb609c863851f5ab1acc02552b3d6c43f14 --- /dev/null +++ b/paddle/fluid/operators/lookup_table_dequant_op.cc @@ -0,0 +1,128 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/lookup_table_dequant_op.h" + +#include + +#include "paddle/fluid/framework/no_need_buffer_vars_inference.h" +#include "paddle/fluid/framework/var_type_inference.h" + +namespace paddle { +namespace operators { + +class LookupTableDequantOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("W"), true, + platform::errors::InvalidArgument( + "Input(W) of LookupTableDequantOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Ids"), true, + platform::errors::InvalidArgument( + "Input(Ids) of LookupTableDequantOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("Out"), true, + platform::errors::InvalidArgument( + "Output(Out) of LookupTableDequantOp should not be null.")); + + auto table_dims = ctx->GetInputDim("W"); + auto ids_dims = ctx->GetInputDim("Ids"); + int ids_rank = ids_dims.size(); + VLOG(5) << "ids rank is " << ids_rank << std::endl; + PADDLE_ENFORCE_EQ( + table_dims.size(), 2, + platform::errors::InvalidArgument( + "ShapeError: The dimensions of the 'lookup table' must be 2. " + "But received lookup table's dimensions = %d, " + "lookup table's shape = [%s].", + table_dims.size(), table_dims)); + PADDLE_ENFORCE_EQ( + ids_dims[ids_rank - 1], 1, + platform::errors::InvalidArgument( + "ShapeError: The last dimensions of the 'Ids' tensor must be 1. " + "But received Ids's last dimensions = %d, Ids's shape = [%s].", + ids_dims[ids_rank - 1], ids_dims)); + + auto output_dims = + framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1)); + PADDLE_ENFORCE_GE(table_dims[1], 2, + platform::errors::InvalidArgument( + "the second dim of table_dims should be " + "greater or equal to 2, but the actual shape " + "is [%s]", + table_dims)); + + output_dims.push_back((table_dims[1] - 2) * 4); + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); + + if (ctx->GetOutputsVarType("Out")[0] == + framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("Ids", /*->*/ "Out"); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class LookupTableDequantOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("W", + "(Tensor) The input represents embedding tensors, " + "This tensor is a quantized tensor"); + AddInput("Ids", + "An input with type int64 " + "contains the ids to be looked up in W. " + "The last dimension size must be 1."); + AddOutput("Out", "The lookup results, which have the same type as W."); + AddAttr("padding_idx", + "(int64, default -1) " + "If the value is -1, it makes no effect to lookup. " + "Otherwise the given value indicates padding the output " + "with zeros whenever lookup encounters it in Ids.") + .SetDefault(kNoPadding); + AddComment(R"DOC( +Lookup Table Dequant Operator. + +The `W` input is a quantized parameter for the sake of saving memories. +This operator first index embeddings with `Ids`, +then dequantizes them and contact them as output (`Out`). + +The input Ids can carry the LoD (Level of Details) information, +or not. And the output only shares the LoD information with input Ids. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + lookup_table_dequant, ops::LookupTableDequantOp, + ops::LookupTableDequantOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(lookup_table_dequant, + ops::LookupTableDequantKernel); diff --git a/paddle/fluid/operators/lookup_table_dequant_op.h b/paddle/fluid/operators/lookup_table_dequant_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d059d856212527d26647472c98519cb6da77b3da --- /dev/null +++ b/paddle/fluid/operators/lookup_table_dequant_op.h @@ -0,0 +1,109 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/framework/var_type_traits.h" +#include "paddle/fluid/operators/math/blas.h" + +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/fluid/operators/distributed/parameter_prefetch.h" +#endif + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = framework::SelectedRows; +using DDim = framework::DDim; + +template +void dequant(const unsigned char *in, T *out, float min, float max, + int emb_size, int pow_2_bits) { + float scale = (max - min) / pow_2_bits; + for (int i = 0; i < emb_size; ++i) { + T x = scale * static_cast(in[i]) + min; + out[i] = x; + } +} + +constexpr int64_t kNoPadding = -1; + +template +class LookupTableDequantKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *ids_t = context.Input("Ids"); // int tensor + auto *output_t = context.Output("Out"); // float tensor + auto *table_var = context.InputVar("W"); + + auto id_name = context.InputNames("Ids").front(); + auto embedding_name = context.InputNames("W").front(); + auto out_name = context.OutputNames("Out").front(); + + int64_t padding_idx = context.Attr("padding_idx"); + auto *ids = ids_t->data(); + int64_t ids_numel = ids_t->numel(); + + PADDLE_ENFORCE_GE( + table_var->Type(), framework::VarTypeTrait::kId, + platform::errors::InvalidArgument("lookup table must be LodTensor")); + auto *table_t = context.Input("W"); + int64_t row_number = table_t->dims()[0]; + int64_t quant_number = table_t->dims()[1]; + int64_t row_width = (quant_number - 2) * 4; + + auto *table = table_t->data(); + auto *output = output_t->mutable_data(context.GetPlace()); + int pow_2_bits = static_cast(pow(2, 8)); + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != kNoPadding && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_LT( + ids[i], row_number, + platform::errors::InvalidArgument( + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + row_number, ids[i])); + PADDLE_ENFORCE_GE( + ids[i], 0, + platform::errors::InvalidArgument( + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + row_number, ids[i])); + float min = *(table + ids[i] * quant_number); + float max = *(table + ids[i] * quant_number + 1); + int offset = ids[i] * quant_number + 2; + const unsigned char *tensor_buf = + reinterpret_cast(table + offset); + dequant(tensor_buf, output + i * row_width, min, max, row_width, + pow_2_bits); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_dequant_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_dequant_op.py new file mode 100644 index 0000000000000000000000000000000000000000..689b9992a6d9fee8898f22d144c713e4f2f7ea67 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_dequant_op.py @@ -0,0 +1,55 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest, skip_check_grad_ci +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.compat as cpt +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard +import struct + + +class TestLookupTableDequantOp(OpTest): + def setUp(self): + self.op_type = "lookup_table_dequant" + table = np.random.random((17, 32)).astype("float32") + ids = np.random.randint(0, 17, 4).astype("int64") + ids_expand = np.expand_dims(ids, axis=1) + self.inputs = {'W': table, 'Ids': ids_expand} + + # calculate output + output = [] + for id in ids: + tmp = [] + min, max = table[id][0], table[id][1] + for val in table[id][2:]: + tmp += [ + int(x) * (max - min) / pow(2, 8) + min + for x in bytearray(struct.pack("f", val)) + ] + output.append(tmp) + + self.outputs = {'Out': np.asarray(output, dtype="float32")} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main()