提交 b9397b26 编写于 作者: C chengduoZH

remove concat_rows

上级 f1c3ecb2
......@@ -34,9 +34,12 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto ids_dims = ctx->GetInputDim("Ids");
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
// lookup_table and concat_rows use the same InferShape, for lookup_table,
// ids_var_type should be LoDTensor, for concat_rows, it should be
// SelectedRows.
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W
// and it must be a column vector with rank = 2 while the 2nd dimension
// size must be 1, when Ids's type is SelectedRows, the rows of Ids
// contains the ids to be looked up in W;
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
......@@ -60,70 +63,41 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W",
"An input represents embedding tensors, "
"(Tensor) The input represents embedding tensors, "
"which is a learnable parameter.");
AddInput("Ids",
"An input with type int32 or int64 "
"contains the ids to be looked up in W. "
"Ids must be a column vector with rank = 2. "
"The 2nd dimension size must be 1.");
AddOutput("Out", "The lookup results, which have the same type as W.");
AddAttr<bool>("is_sparse",
"(boolean, default false) "
"Sparse update")
.SetDefault(false);
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(-1);
AddComment(R"DOC(
Lookup Table Operator.
This operator is used to perform lookups on the parameter W,
then concatenated into a dense tensor.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC");
}
};
class ConcatRowsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ConcatRowsOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W",
"(Tensor) The input tensor of concat_rows operator. "
"The rank of this tensor is 2.");
AddInput(
"Ids",
"(SelectedRows) The rows of Ids contains the index to be looked up "
"(Tensor or SelectedRows) Ids's type can be Tensor or "
"SelectedRows, when Ids's type is Tensor, this tensor contains "
"the ids to be looked up in W and it must be a column vector with "
"rank = 2 while the 2nd dimension size must be 1; when Ids's type is "
"SelectedRows, the rows of Ids contains the ids to be looked up "
"in W.");
AddOutput("Out",
"(SelectedRows or Tensor) The result of concatenating, which "
"have the same type as W.");
"(Tensor or SelectedRows) The lookup results, which have the "
"same type as W.");
AddAttr<bool>("is_sparse",
"(boolean, default true) This attribution is invalid, it's "
"only used by `Lookup Table Operator`.")
.SetDefault(true);
"(boolean, default false) "
"Sparse update.")
.SetDefault(false);
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.")
.SetDefault(-1);
AddComment(R"DOC(
ConcatRows Operator.
Lookup Table Operator.
This operator is used to perform lookups on the W(dense tensor) according to
rows contained by Idx(sparse tensor), then concatenates them into a sparse
tensor or dense tensor.
This operator is used to perform lookups on the parameter W,
then concatenated into a dense or sparse tensor.
The type of Ids(Input) is SelectedRows.
The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's
type is SelectedRows, the rows of Ids contains the ids to be looked up in W;
when Ids's type is Tensor, this tensor contains the ids to be looked up in W
and it must be a column vector with rank = 2 while the 2nd dimension size must be 1,
at this time, Ids can carry the LoD (Level of Details) information, or not, and
the output only shares the LoD information with input Ids.
)DOC");
}
......@@ -189,8 +163,3 @@ REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
ops::LookupTableKernel<double>);
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>,
ops::LookupTableGradKernel<double>);
// concat_rows is used by regularization and it doesn't have gradient operation.
REGISTER_OPERATOR(concat_rows, ops::LookupTableOp, ops::ConcatRowsOpMaker);
REGISTER_OP_CPU_KERNEL(concat_rows, ops::LookupTableKernel<float>,
ops::LookupTableKernel<double>);
......@@ -74,16 +74,16 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W");
auto* output_t = context.Output<Tensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto* ids_var = context.InputVar("Ids"); // int tensor
auto* ids_var = context.InputVar("Ids");
int64_t* ids;
int64_t K;
auto* output_t = context.Output<Tensor>("Out"); // float tensor;
// lookup_table and concat_rows use the same kernel, for lookup_table,
// ids_var_type should be LoDTensor, for concat_rows, ids_var_type and
// out_var_type should be SelectedRows.
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids");
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
......
......@@ -30,15 +30,16 @@ template <typename T>
class LookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
auto* ids_var = context.InputVar("Ids"); // int tensor
auto* table_t = context.Input<LoDTensor>("W");
auto* output_t = context.Output<Tensor>("Out");
auto* ids_var = context.InputVar("Ids");
int64_t* ids;
int64_t ids_numel;
auto* output_t = context.Output<Tensor>("Out");
// lookup_table and concat_rows use the same kernel, for lookup_table,
// ids_var_type should be LoDTensor, for concat_rows, ids_var_type and
// out_var_type should be SelectedRows.
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids");
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test import OpTest
class TestConcatRowsOp(OpTest):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Grad Variable
height = 10
rows = [0, 4, 4, 7]
row_numel = 12
ids_selected_rows = scope.var('Ids').get_selected_rows()
ids_selected_rows.set_height(height)
ids_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32")
ids_tensor = ids_selected_rows.get_tensor()
ids_tensor.set(np_array, place)
# create and initialize W Variable
W = scope.var('W').get_tensor()
W_array = np.full((height, row_numel), 1.0).astype("float32")
for i in range(height):
W_array[i] *= i
W.set(W_array, place)
Out = scope.var('Out').get_selected_rows()
Out_array = np.full((len(rows), row_numel), -1.0).astype("float32")
Out.set_height(height)
Out.set_rows(rows)
Out_tensor = Out.get_tensor()
Out_tensor.set(Out_array, place)
# create and run concat_rows_op operator
concat_rows_op = Operator(
"concat_rows",
W='W',
Ids='Ids',
Out='Out',
attrs={'is_sparse': True})
concat_rows_op.run(scope, place)
# get and compare result
result_array = np.array(Out_tensor)
for idx, row in enumerate(rows):
assert (row == result_array[idx]).all()
def test_concat_rows(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
if __name__ == "__main__":
unittest.main()
......@@ -14,6 +14,8 @@
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test import OpTest
......@@ -47,5 +49,52 @@ class TestLookupTableOpWithPadding(TestLookupTableOp):
pass
# Testing look_up_table when Ids's type is SelectedRows.
class TestLookupTableIdsIsSelectedRows(OpTest):
def check_with_place(self, place):
scope = core.Scope()
height = 10
rows = [0, 4, 4, 7]
row_numel = 12
ids_selected_rows = scope.var('Ids').get_selected_rows()
ids_selected_rows.set_height(height)
ids_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32")
ids_tensor = ids_selected_rows.get_tensor()
ids_tensor.set(np_array, place)
W = scope.var('W').get_tensor()
W_array = np.full((height, row_numel), 1.0).astype("float32")
for i in range(height):
W_array[i] *= i
W.set(W_array, place)
Out = scope.var('Out').get_selected_rows()
Out_array = np.full((len(rows), row_numel), -1.0).astype("float32")
Out.set_height(height)
Out.set_rows(rows)
Out_tensor = Out.get_tensor()
Out_tensor.set(Out_array, place)
# create and run concat_rows_op operator
concat_rows_op = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
concat_rows_op.run(scope, place)
# get and compare result
result_array = np.array(Out_tensor)
for idx, row in enumerate(rows):
assert (row == result_array[idx]).all()
def test_concat_rows(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册