提交 92e2207e 编写于 作者: C chengduoZH

refine doc

上级 ff09b21c
......@@ -34,8 +34,11 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto ids_dims = ctx->GetInputDim("Ids");
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
// Maybe near future we will add concat_rows op.
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W
// and it must be a column vector with rank = 2 while the 2nd dimension
// size must be 1, when Ids's type is SelectedRows, the rows of Ids
// contains the ids to be looked up in W;
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
......@@ -59,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W",
"An input represents embedding tensors, "
"(Tensor) The input represents embedding tensors, "
"which is a learnable parameter.");
AddInput("Ids",
"An input with type int32 or int64 "
"contains the ids to be looked up in W. "
"Ids must be a column vector with rank = 2. "
"The 2nd dimension size must be 1.");
AddOutput("Out", "The lookup results, which have the same type as W.");
AddInput(
"Ids",
"(Tensor or SelectedRows) Ids's type can be Tensor or "
"SelectedRows, when Ids's type is Tensor, this tensor contains "
"the ids to be looked up in W and it must be a column vector with "
"rank = 2 while the 2nd dimension size must be 1; when Ids's type is "
"SelectedRows, the rows of Ids contains the ids to be looked up "
"in W.");
AddOutput("Out",
"(Tensor or SelectedRows) The lookup results, which have the "
"same type as W.");
AddAttr<bool>("is_sparse",
"(boolean, default false) "
"Sparse update")
"Sparse update.")
.SetDefault(false);
AddAttr<int64_t>("padding_idx",
"(int64, default -1) "
......@@ -81,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
Lookup Table Operator.
This operator is used to perform lookups on the parameter W,
then concatenated into a dense tensor.
then concatenated into a dense or sparse tensor.
The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's
type is SelectedRows, the rows of Ids contains the ids to be looked up in W;
when Ids's type is Tensor, this tensor contains the ids to be looked up in W
and it must be a column vector with rank = 2 while the 2nd dimension size must be 1,
at this time, Ids can carry the LoD (Level of Details) information, or not, and
the output only shares the LoD information with input Ids.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC");
}
......
......@@ -75,22 +75,22 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto* ids_var = context.InputVar("Ids"); // int tensor
auto* ids_var = context.InputVar("Ids");
Tensor* output_t = context.Output<Tensor>("Out");
int64_t* ids;
int64_t K;
framework::Tensor* output_t;
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
// Maybe near future we will add concat_rows op.
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<framework::LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids");
output_t = context.Output<LoDTensor>("Out"); // float tensor
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
K = ids_t->numel();
} else if (ids_var->IsType<framework::SelectedRows>()) {
auto* ids_t = context.Input<framework::SelectedRows>("Ids");
output_t = context.Output<SelectedRows>("Out")->mutable_value();
ids = const_cast<int64_t*>(ids_t->rows().CUDAData(context.GetPlace()));
K = ids_t->rows().size();
output_t->Resize({K, table_t->dims()[1]});
......
......@@ -30,23 +30,23 @@ template <typename T>
class LookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
auto* ids_var = context.InputVar("Ids"); // int tensor
auto* table_t = context.Input<LoDTensor>("W");
auto* ids_var = context.InputVar("Ids");
Tensor* output_t = context.Output<Tensor>("Out");
int64_t* ids;
int64_t ids_numel;
Tensor* output_t;
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
// Maybe near future we will add concat_rows op.
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids");
output_t = context.Output<LoDTensor>("Out");
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
ids_numel = ids_t->numel();
} else if (ids_var->IsType<SelectedRows>()) {
auto* ids_t = context.Input<SelectedRows>("Ids");
output_t = context.Output<SelectedRows>("Out")->mutable_value();
ids = const_cast<int64_t*>(ids_t->rows().data());
ids_numel = ids_t->rows().size();
output_t->Resize({ids_numel, table_t->dims()[1]});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册