未验证 提交 9074a60c 编写于 作者: F fengjiayi 提交者: GitHub

Refine lookup_table_op (#5257)

1. Change some `auto` to `auto*`
2. Change `Tensor` to `LoDTensor`
上级 db3b9438
......@@ -43,7 +43,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("W")->type());
return framework::ToDataType(ctx.Input<LoDTensor>("W")->type());
}
};
......@@ -93,7 +93,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("W")->type());
return framework::ToDataType(ctx.Input<LoDTensor>("W")->type());
}
};
......
......@@ -61,16 +61,16 @@ template <typename T>
class LookupTableCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W");
auto ids_t = context.Input<Tensor>("Ids");
auto output_t = context.Output<Tensor>("Out");
auto* table_t = context.Input<LoDTensor>("W");
auto* ids_t = context.Input<LoDTensor>("Ids");
auto* output_t = context.Output<LoDTensor>("Out");
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto ids = ids_t->data<int64_t>();
auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace());
auto* ids = ids_t->data<int64_t>();
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8);
dim3 grids(8, 1);
......@@ -87,9 +87,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
bool is_sparse = context.Attr<bool>("is_sparse");
if (is_sparse) {
auto* ids = context.Input<Tensor>("Ids");
auto* table = context.Input<Tensor>("W");
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<LoDTensor>("W");
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto* ids_data = ids->data<int64_t>();
......@@ -119,9 +119,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
d_output->numel(), stream);
} else {
auto ids_t = context.Input<Tensor>("Ids");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
auto ids_t = context.Input<LoDTensor>("Ids");
auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
......
......@@ -19,22 +19,22 @@
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
template <typename T>
class LookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W"); // float tensor
auto ids_t = context.Input<Tensor>("Ids"); // int tensor
auto output_t = context.Output<Tensor>("Out"); // float tensor
auto* table_t = context.Input<LoDTensor>("W"); // float tensor
auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
int N = table_t->dims()[0];
int D = table_t->dims()[1];
auto ids = ids_t->data<int64_t>();
auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace());
auto* ids = ids_t->data<int64_t>();
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_t->numel(); ++i) {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
......@@ -49,9 +49,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
bool is_sparse = context.Attr<bool>("is_sparse");
if (is_sparse) {
auto* ids = context.Input<Tensor>("Ids");
auto* table = context.Input<Tensor>("W");
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<LoDTensor>("W");
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto* ids_data = ids->data<int64_t>();
......@@ -76,10 +76,10 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} else {
auto* ids = context.Input<Tensor>("Ids");
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<Tensor>(framework::GradVarName("W"));
auto* table = context.Input<Tensor>("W");
auto* ids = context.Input<LoDTensor>("Ids");
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
auto* table = context.Input<LoDTensor>("W");
auto* ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册