未验证 提交 9074a60c 编写于 作者: F fengjiayi 提交者: GitHub

Refine lookup_table_op (#5257)

1. Change some `auto` to `auto*`
2. Change `Tensor` to `LoDTensor`
上级 db3b9438
...@@ -43,7 +43,7 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -43,7 +43,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
protected: protected:
framework::DataType IndicateDataType( framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("W")->type()); return framework::ToDataType(ctx.Input<LoDTensor>("W")->type());
} }
}; };
...@@ -93,7 +93,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -93,7 +93,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::DataType IndicateDataType( framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("W")->type()); return framework::ToDataType(ctx.Input<LoDTensor>("W")->type());
} }
}; };
......
...@@ -61,16 +61,16 @@ template <typename T> ...@@ -61,16 +61,16 @@ template <typename T>
class LookupTableCUDAKernel : public framework::OpKernel<T> { class LookupTableCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W"); auto* table_t = context.Input<LoDTensor>("W");
auto ids_t = context.Input<Tensor>("Ids"); auto* ids_t = context.Input<LoDTensor>("Ids");
auto output_t = context.Output<Tensor>("Out"); auto* output_t = context.Output<LoDTensor>("Out");
size_t N = table_t->dims()[0]; size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1]; size_t D = table_t->dims()[1];
size_t K = ids_t->numel(); size_t K = ids_t->numel();
auto ids = ids_t->data<int64_t>(); auto* ids = ids_t->data<int64_t>();
auto table = table_t->data<T>(); auto* table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace()); auto* output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(8, 1); dim3 grids(8, 1);
...@@ -87,9 +87,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -87,9 +87,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context.Attr<bool>("is_sparse");
if (is_sparse) { if (is_sparse) {
auto* ids = context.Input<Tensor>("Ids"); auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<Tensor>("W"); auto* table = context.Input<LoDTensor>("W");
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W")); auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto* ids_data = ids->data<int64_t>(); auto* ids_data = ids->data<int64_t>();
...@@ -119,9 +119,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -119,9 +119,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
d_output->numel(), stream); d_output->numel(), stream);
} else { } else {
auto ids_t = context.Input<Tensor>("Ids"); auto ids_t = context.Input<LoDTensor>("Ids");
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out")); auto d_output_t = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W")); auto d_table_t = context.Output<LoDTensor>(framework::GradVarName("W"));
int N = d_table_t->dims()[0]; int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1]; int D = d_table_t->dims()[1];
......
...@@ -19,22 +19,22 @@ ...@@ -19,22 +19,22 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
template <typename T> template <typename T>
class LookupTableKernel : public framework::OpKernel<T> { class LookupTableKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W"); // float tensor auto* table_t = context.Input<LoDTensor>("W"); // float tensor
auto ids_t = context.Input<Tensor>("Ids"); // int tensor auto* ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto output_t = context.Output<Tensor>("Out"); // float tensor auto* output_t = context.Output<LoDTensor>("Out"); // float tensor
int N = table_t->dims()[0]; int N = table_t->dims()[0];
int D = table_t->dims()[1]; int D = table_t->dims()[1];
auto ids = ids_t->data<int64_t>(); auto* ids = ids_t->data<int64_t>();
auto table = table_t->data<T>(); auto* table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace()); auto* output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_t->numel(); ++i) { for (int64_t i = 0; i < ids_t->numel(); ++i) {
PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0); PADDLE_ENFORCE_GE(ids[i], 0);
...@@ -49,9 +49,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -49,9 +49,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context.Attr<bool>("is_sparse");
if (is_sparse) { if (is_sparse) {
auto* ids = context.Input<Tensor>("Ids"); auto* ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<Tensor>("W"); auto* table = context.Input<LoDTensor>("W");
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W")); auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto* ids_data = ids->data<int64_t>(); auto* ids_data = ids->data<int64_t>();
...@@ -76,10 +76,10 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -76,10 +76,10 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims()); PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} else { } else {
auto* ids = context.Input<Tensor>("Ids"); auto* ids = context.Input<LoDTensor>("Ids");
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<Tensor>(framework::GradVarName("W")); auto* d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
auto* table = context.Input<Tensor>("W"); auto* table = context.Input<LoDTensor>("W");
auto* ids_data = ids->data<int64_t>(); auto* ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims(); auto ids_dim = ids->dims();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册