未验证 提交 5ba9dfc1 编写于 作者: M mapingshuo 提交者: GitHub

add lookup_table_dequant_op (#22900)

add lookup_table_dequant_op
上级 a020a257
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lookup_table_dequant_op.h"
#include <memory>
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle {
namespace operators {
class LookupTableDequantOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("W"), true,
platform::errors::InvalidArgument(
"Input(W) of LookupTableDequantOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Ids"), true,
platform::errors::InvalidArgument(
"Input(Ids) of LookupTableDequantOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of LookupTableDequantOp should not be null."));
auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids");
int ids_rank = ids_dims.size();
VLOG(5) << "ids rank is " << ids_rank << std::endl;
PADDLE_ENFORCE_EQ(
table_dims.size(), 2,
platform::errors::InvalidArgument(
"ShapeError: The dimensions of the 'lookup table' must be 2. "
"But received lookup table's dimensions = %d, "
"lookup table's shape = [%s].",
table_dims.size(), table_dims));
PADDLE_ENFORCE_EQ(
ids_dims[ids_rank - 1], 1,
platform::errors::InvalidArgument(
"ShapeError: The last dimensions of the 'Ids' tensor must be 1. "
"But received Ids's last dimensions = %d, Ids's shape = [%s].",
ids_dims[ids_rank - 1], ids_dims));
auto output_dims =
framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1));
PADDLE_ENFORCE_GE(table_dims[1], 2,
platform::errors::InvalidArgument(
"the second dim of table_dims should be "
"greater or equal to 2, but the actual shape "
"is [%s]",
table_dims));
output_dims.push_back((table_dims[1] - 2) * 4);
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
if (ctx->GetOutputsVarType("Out")[0] ==
framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("Ids", /*->*/ "Out");
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class LookupTableDequantOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("W",
"(Tensor) The input represents embedding tensors, "
"This tensor is a quantized tensor");
AddInput("Ids",
"An input with type int64 "
"contains the ids to be looked up in W. "
"The last dimension size must be 1.");
AddOutput("Out", "The lookup results, which have the same type as W.");
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(kNoPadding);
AddComment(R"DOC(
Lookup Table Dequant Operator.
The `W` input is a quantized parameter for the sake of saving memories.
This operator first index embeddings with `Ids`,
then dequantizes them and contact them as output (`Out`).
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
lookup_table_dequant, ops::LookupTableDequantOp,
ops::LookupTableDequantOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(lookup_table_dequant,
ops::LookupTableDequantKernel<float>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
template <typename T>
void dequant(const unsigned char *in, T *out, float min, float max,
int emb_size, int pow_2_bits) {
float scale = (max - min) / pow_2_bits;
for (int i = 0; i < emb_size; ++i) {
T x = scale * static_cast<int>(in[i]) + min;
out[i] = x;
}
}
constexpr int64_t kNoPadding = -1;
template <typename T>
class LookupTableDequantKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W");
auto id_name = context.InputNames("Ids").front();
auto embedding_name = context.InputNames("W").front();
auto out_name = context.OutputNames("Out").front();
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto *ids = ids_t->data<int64_t>();
int64_t ids_numel = ids_t->numel();
PADDLE_ENFORCE_GE(
table_var->Type(), framework::VarTypeTrait<LoDTensor>::kId,
platform::errors::InvalidArgument("lookup table must be LodTensor"));
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t quant_number = table_t->dims()[1];
int64_t row_width = (quant_number - 2) * 4;
auto *table = table_t->data<float>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
int pow_2_bits = static_cast<int>(pow(2, 8));
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(
ids[i], row_number,
platform::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number, ids[i]));
PADDLE_ENFORCE_GE(
ids[i], 0,
platform::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number, ids[i]));
float min = *(table + ids[i] * quant_number);
float max = *(table + ids[i] * quant_number + 1);
int offset = ids[i] * quant_number + 2;
const unsigned char *tensor_buf =
reinterpret_cast<const unsigned char *>(table + offset);
dequant(tensor_buf, output + i * row_width, min, max, row_width,
pow_2_bits);
}
}
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.compat as cpt
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import struct
class TestLookupTableDequantOp(OpTest):
def setUp(self):
self.op_type = "lookup_table_dequant"
table = np.random.random((17, 32)).astype("float32")
ids = np.random.randint(0, 17, 4).astype("int64")
ids_expand = np.expand_dims(ids, axis=1)
self.inputs = {'W': table, 'Ids': ids_expand}
# calculate output
output = []
for id in ids:
tmp = []
min, max = table[id][0], table[id][1]
for val in table[id][2:]:
tmp += [
int(x) * (max - min) / pow(2, 8) + min
for x in bytearray(struct.pack("f", val))
]
output.append(tmp)
self.outputs = {'Out': np.asarray(output, dtype="float32")}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册