未验证 提交 788c600e 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #8932 from chengduoZH/feature/add_concat_rows

Enhance look_up_table op 
......@@ -33,8 +33,16 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids");
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W
// and it must be a column vector with rank = 2 while the 2nd dimension
// size must be 1, when Ids's type is SelectedRows, the rows of Ids
// contains the ids to be looked up in W;
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
}
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
ctx->ShareLoD("Ids", /*->*/ "Out");
......@@ -54,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W",
"An input represents embedding tensors, "
"(Tensor) The input represents embedding tensors, "
"which is a learnable parameter.");
AddInput("Ids",
"An input with type int32 or int64 "
"contains the ids to be looked up in W. "
"Ids must be a column vector with rank = 2. "
"The 2nd dimension size must be 1.");
AddOutput("Out", "The lookup results, which have the same type as W.");
AddInput(
"Ids",
"(Tensor or SelectedRows) Ids's type can be Tensor or "
"SelectedRows, when Ids's type is Tensor, this tensor contains "
"the ids to be looked up in W and it must be a column vector with "
"rank = 2 while the 2nd dimension size must be 1; when Ids's type is "
"SelectedRows, the rows of Ids contains the ids to be looked up "
"in W.");
AddOutput("Out",
"(Tensor or SelectedRows) The lookup results, which have the "
"same type as W.");
AddAttr<bool>("is_sparse",
"(boolean, default false) "
"Sparse update")
"Sparse update.")
.SetDefault(false);
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
......@@ -76,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
Lookup Table Operator.
This operator is used to perform lookups on the parameter W,
then concatenated into a dense tensor.
then concatenated into a dense or sparse tensor.
The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's
type is SelectedRows, the rows of Ids contains the ids to be looked up in W;
when Ids's type is Tensor, this tensor contains the ids to be looked up in W
and it must be a column vector with rank = 2 while the 2nd dimension size must be 1,
at this time, Ids can carry the LoD (Level of Details) information, or not, and
the output only shares the LoD information with input Ids.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC");
}
......
......@@ -74,14 +74,32 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W");
auto* ids_t = context.Input<LoDTensor>("Ids");
auto* output_t = context.Output<LoDTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto* ids_var = context.InputVar("Ids");
Tensor* output_t = context.Output<Tensor>("Out");
int64_t* ids;
int64_t K;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<framework::LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids");
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
K = ids_t->numel();
} else if (ids_var->IsType<framework::SelectedRows>()) {
auto* ids_t = context.Input<framework::SelectedRows>("Ids");
ids = const_cast<int64_t*>(ids_t->rows().CUDAData(context.GetPlace()));
K = ids_t->rows().size();
output_t->Resize({K, table_t->dims()[1]});
} else {
PADDLE_THROW("Unsupported Variable Type of Ids");
}
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto* ids = ids_t->data<int64_t>();
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
......
......@@ -22,6 +22,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
......@@ -29,25 +30,45 @@ template <typename T>
class LookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
auto* table_t = context.Input<LoDTensor>("W");
auto* ids_var = context.InputVar("Ids");
Tensor* output_t = context.Output<Tensor>("Out");
int64_t* ids;
int64_t ids_numel;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids");
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
ids_numel = ids_t->numel();
} else if (ids_var->IsType<SelectedRows>()) {
auto* ids_t = context.Input<SelectedRows>("Ids");
ids = const_cast<int64_t*>(ids_t->rows().data());
ids_numel = ids_t->rows().size();
output_t->Resize({ids_numel, table_t->dims()[1]});
} else {
PADDLE_THROW("Unsupported Variable Type of Ids");
}
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int N = table_t->dims()[0];
int D = table_t->dims()[1];
auto* ids = ids_t->data<int64_t>();
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
if (padding_idx == -1) {
for (int64_t i = 0; i < ids_t->numel(); ++i) {
for (int64_t i = 0; i < ids_numel; ++i) {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
}
} else {
for (int64_t i = 0; i < ids_t->numel(); ++i) {
for (int64_t i = 0; i < ids_numel; ++i) {
if (ids[i] == padding_idx) {
memset(output + i * D, 0, D * sizeof(T));
} else {
......
......@@ -15,6 +15,8 @@
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class TestLookupTableOp(OpTest):
......@@ -47,5 +49,52 @@ class TestLookupTableOpWithPadding(TestLookupTableOp):
pass
class TestLookupTableIdsIsSelectedRows(OpTest):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Variable
height = 10
rows = [0, 4, 4, 7]
row_numel = 12
# create and initialize W Variable
W = scope.var('W').get_tensor()
W_array = np.full((height, row_numel), 1.0).astype("float32")
for i in range(height):
W_array[i] *= i
W.set(W_array, place)
# create and initialize Ids Variable
ids_selected_rows = scope.var('Ids').get_selected_rows()
ids_selected_rows.set_height(len(rows))
ids_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32")
ids_tensor = ids_selected_rows.get_tensor()
ids_tensor.set(np_array, place)
# create Out Variable
Out = scope.var('Out').get_selected_rows()
# create and run lookup_table operator
concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
concat_rows_op.run(scope, place)
# get result from Out
Out_tensor = Out.get_tensor()
result_array = np.array(Out_tensor)
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(rows):
assert (row == result_array[idx]).all()
def test_concat_rows(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册