提交 92e2207e 编写于 作者: C chengduoZH

refine doc

上级 ff09b21c
...@@ -34,8 +34,11 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -34,8 +34,11 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputDim("Ids");
auto ids_var_type = ctx->GetInputsVarType("Ids").front(); auto ids_var_type = ctx->GetInputsVarType("Ids").front();
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// Maybe near future we will add concat_rows op. // is LoDTensor, this tensor contains the ids to be looked up in W
// and it must be a column vector with rank = 2 while the 2nd dimension
// size must be 1, when Ids's type is SelectedRows, the rows of Ids
// contains the ids to be looked up in W;
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) { if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2); PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1); PADDLE_ENFORCE_EQ(ids_dims[1], 1);
...@@ -59,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -59,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker) LookupTableOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("W", AddInput("W",
"An input represents embedding tensors, " "(Tensor) The input represents embedding tensors, "
"which is a learnable parameter."); "which is a learnable parameter.");
AddInput("Ids", AddInput(
"An input with type int32 or int64 " "Ids",
"contains the ids to be looked up in W. " "(Tensor or SelectedRows) Ids's type can be Tensor or "
"Ids must be a column vector with rank = 2. " "SelectedRows, when Ids's type is Tensor, this tensor contains "
"The 2nd dimension size must be 1."); "the ids to be looked up in W and it must be a column vector with "
AddOutput("Out", "The lookup results, which have the same type as W."); "rank = 2 while the 2nd dimension size must be 1; when Ids's type is "
"SelectedRows, the rows of Ids contains the ids to be looked up "
"in W.");
AddOutput("Out",
"(Tensor or SelectedRows) The lookup results, which have the "
"same type as W.");
AddAttr<bool>("is_sparse", AddAttr<bool>("is_sparse",
"(boolean, default false) " "(boolean, default false) "
"Sparse update") "Sparse update.")
.SetDefault(false); .SetDefault(false);
AddAttr<int64_t>("padding_idx", AddAttr<int64_t>("padding_idx",
"(int64, default -1) " "(int64, default -1) "
...@@ -81,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -81,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
Lookup Table Operator. Lookup Table Operator.
This operator is used to perform lookups on the parameter W, This operator is used to perform lookups on the parameter W,
then concatenated into a dense tensor. then concatenated into a dense or sparse tensor.
The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's
type is SelectedRows, the rows of Ids contains the ids to be looked up in W;
when Ids's type is Tensor, this tensor contains the ids to be looked up in W
and it must be a column vector with rank = 2 while the 2nd dimension size must be 1,
at this time, Ids can carry the LoD (Level of Details) information, or not, and
the output only shares the LoD information with input Ids.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC"); )DOC");
} }
......
...@@ -75,22 +75,22 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> { ...@@ -75,22 +75,22 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W"); auto* table_t = context.Input<LoDTensor>("W");
int64_t padding_idx = context.Attr<int64_t>("padding_idx"); int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto* ids_var = context.InputVar("Ids"); // int tensor auto* ids_var = context.InputVar("Ids");
Tensor* output_t = context.Output<Tensor>("Out");
int64_t* ids; int64_t* ids;
int64_t K; int64_t K;
framework::Tensor* output_t;
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// Maybe near future we will add concat_rows op. // is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<framework::LoDTensor>()) { if (ids_var->IsType<framework::LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids"); auto* ids_t = context.Input<LoDTensor>("Ids");
output_t = context.Output<LoDTensor>("Out"); // float tensor
ids = const_cast<int64_t*>(ids_t->data<int64_t>()); ids = const_cast<int64_t*>(ids_t->data<int64_t>());
K = ids_t->numel(); K = ids_t->numel();
} else if (ids_var->IsType<framework::SelectedRows>()) { } else if (ids_var->IsType<framework::SelectedRows>()) {
auto* ids_t = context.Input<framework::SelectedRows>("Ids"); auto* ids_t = context.Input<framework::SelectedRows>("Ids");
output_t = context.Output<SelectedRows>("Out")->mutable_value();
ids = const_cast<int64_t*>(ids_t->rows().CUDAData(context.GetPlace())); ids = const_cast<int64_t*>(ids_t->rows().CUDAData(context.GetPlace()));
K = ids_t->rows().size(); K = ids_t->rows().size();
output_t->Resize({K, table_t->dims()[1]}); output_t->Resize({K, table_t->dims()[1]});
......
...@@ -30,23 +30,23 @@ template <typename T> ...@@ -30,23 +30,23 @@ template <typename T>
class LookupTableKernel : public framework::OpKernel<T> { class LookupTableKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W"); // float tensor auto* table_t = context.Input<LoDTensor>("W");
auto* ids_var = context.InputVar("Ids"); // int tensor auto* ids_var = context.InputVar("Ids");
Tensor* output_t = context.Output<Tensor>("Out");
int64_t* ids; int64_t* ids;
int64_t ids_numel; int64_t ids_numel;
Tensor* output_t;
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows. // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// Maybe near future we will add concat_rows op. // is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if (ids_var->IsType<LoDTensor>()) { if (ids_var->IsType<LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids"); auto* ids_t = context.Input<LoDTensor>("Ids");
output_t = context.Output<LoDTensor>("Out");
ids = const_cast<int64_t*>(ids_t->data<int64_t>()); ids = const_cast<int64_t*>(ids_t->data<int64_t>());
ids_numel = ids_t->numel(); ids_numel = ids_t->numel();
} else if (ids_var->IsType<SelectedRows>()) { } else if (ids_var->IsType<SelectedRows>()) {
auto* ids_t = context.Input<SelectedRows>("Ids"); auto* ids_t = context.Input<SelectedRows>("Ids");
output_t = context.Output<SelectedRows>("Out")->mutable_value();
ids = const_cast<int64_t*>(ids_t->rows().data()); ids = const_cast<int64_t*>(ids_t->rows().data());
ids_numel = ids_t->rows().size(); ids_numel = ids_t->rows().size();
output_t->Resize({ids_numel, table_t->dims()[1]}); output_t->Resize({ids_numel, table_t->dims()[1]});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册