提交 1509ce66 编写于 作者: C chengduoZH

enhancement look_up_table

上级 0d49b921
......@@ -33,8 +33,13 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto table_dims = ctx->GetInputDim("W");
auto ids_dims = ctx->GetInputDim("Ids");
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
// Maybe near future we will add concat_rows op.
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
}
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
ctx->ShareLoD("Ids", /*->*/ "Out");
......
......@@ -74,14 +74,34 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W");
auto* ids_t = context.Input<LoDTensor>("Ids");
auto* output_t = context.Output<LoDTensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
auto* ids_var = context.InputVar("Ids"); // int tensor
int64_t* ids;
int64_t K;
framework::Tensor* output_t;
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
// Maybe near future we will add concat_rows op.
if (ids_var->IsType<framework::LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids");
output_t = context.Output<LoDTensor>("Out"); // float tensor
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
K = ids_t->numel();
} else if (ids_var->IsType<framework::SelectedRows>()) {
auto* ids_t = context.Input<framework::SelectedRows>("Ids");
output_t = const_cast<framework::Tensor*>(
&(context.Output<framework::SelectedRows>("Out")
->value())); // float tensor
ids = const_cast<int64_t*>(ids_t->rows().CUDAData(context.GetPlace()));
K = ids_t->rows().size();
output_t->Resize({K, table_t->dims()[1]});
} else {
PADDLE_THROW("Unsupported Variable Type of Ids");
}
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto* ids = ids_t->data<int64_t>();
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
......
......@@ -22,6 +22,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
......@@ -29,25 +30,46 @@ template <typename T>
class LookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
auto* ids_var = context.InputVar("Ids"); // int tensor
int64_t* ids;
int64_t ids_numel;
Tensor* output_t;
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
// Maybe near future we will add concat_rows op.
if (ids_var->IsType<LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids");
output_t = context.Output<LoDTensor>("Out");
ids = const_cast<int64_t*>(ids_t->data<int64_t>());
ids_numel = ids_t->numel();
} else if (ids_var->IsType<SelectedRows>()) {
auto* ids_t = context.Input<SelectedRows>("Ids");
output_t =
const_cast<Tensor*>(&(context.Output<SelectedRows>("Out")->value()));
ids = const_cast<int64_t*>(ids_t->rows().data());
ids_numel = ids_t->rows().size();
output_t->Resize({ids_numel, table_t->dims()[1]});
} else {
PADDLE_THROW("Unsupported Variable Type of Ids");
}
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int N = table_t->dims()[0];
int D = table_t->dims()[1];
auto* ids = ids_t->data<int64_t>();
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
if (padding_idx == -1) {
for (int64_t i = 0; i < ids_t->numel(); ++i) {
for (int64_t i = 0; i < ids_numel; ++i) {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
}
} else {
for (int64_t i = 0; i < ids_t->numel(); ++i) {
for (int64_t i = 0; i < ids_numel; ++i) {
if (ids[i] == padding_idx) {
memset(output + i * D, 0, D * sizeof(T));
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册