diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index 5f70458a87f5d014a692cd35455a997bdcc15776..94d40890a765413e88a35a6ad995ca97ac84dcda 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -41,8 +41,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { " which is a learnable parameter."); AddInput("Ids", "An input with type int32 or int64" - "contains the ids to be looked up in W.") - .NotInGradient(); + "contains the ids to be looked up in W."); AddOutput("Out", "The lookup results, which have the same type with W."); AddComment( "This operator is used to perform lookups on the parameter W," @@ -56,7 +55,9 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &context) const override { - context.Output(0)->Resize(context.Input(0)->dims()); + auto table = context.Input("W"); + auto d_table = context.Output(framework::GradVarName("W")); + d_table->Resize(table->dims()); } }; diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu index 94b440e00e872e67cec9dab57034f088a26e5c0a..99678ef681627d93c35aae724d97812fc24a15c1 100644 --- a/paddle/operators/lookup_table_op.cu +++ b/paddle/operators/lookup_table_op.cu @@ -23,7 +23,7 @@ namespace operators { using Tensor = framework::Tensor; template -__global__ void LookupTable(T* output, const T* table, const uint32_t* ids, +__global__ void LookupTable(T* output, const T* table, const int32_t* ids, const int N, const int K, const int D) { int idx = threadIdx.x; int idy = blockIdx.x + threadIdx.y * gridDimX; @@ -32,8 +32,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids, int id = ids[idy]; PADDLE_ASSERT(id >= 0); PADDLE_ASSERT(id < N); - T* out = output + idy; - const T* tab = table + id; + T* out = output + idy * D; + const T* tab = table + id * D; for (int i = idx; i < D; i += blockDimX) { out[i] = tab[i]; } @@ -42,9 +42,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids, } template -__global__ void LookupTableGradKernel(T* table, const T* output, - const uint32_t* ids, const int N, - const int K, const int D) { +__global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids, + const int N, const int K, const int D) { int idx = threadIdx.x; int idy = blockIdx.x + threadIdx.y * gridDimX; @@ -52,10 +51,10 @@ __global__ void LookupTableGradKernel(T* table, const T* output, int id = ids[idy]; PADDLE_ASSERT(id >= 0); PADDLE_ASSERT(id < N); - const T* out = output + idy; - T* tab = table + id; + const T* out = output + idy * D; + T* tab = table + id * D; for (int i = idx; i < D; i += blockDimX) { - paddle::platform::CudaAtomicAdd(tab + i, out[i]); + paddle::platform::CudaAtomicAdd(&tab[i], out[i]); } idy += blockDimY * gridDimX; } @@ -72,7 +71,7 @@ class LookupTableCUDAKernel : public framework::OpKernel { size_t N = table_t->dims()[0]; size_t D = table_t->dims()[1]; size_t K = product(ids_t->dims()); - auto ids = ids_t->data(); + auto ids = ids_t->data(); auto table = table_t->data(); auto output = output_t->mutable_data(context.GetPlace()); @@ -83,7 +82,7 @@ class LookupTableCUDAKernel : public framework::OpKernel { }; template -class LookupTableGrad : public framework::OpKernel { +class LookupTableGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto ids_t = context.Input("Ids"); @@ -93,9 +92,9 @@ class LookupTableGrad : public framework::OpKernel { int N = d_table_t->dims()[0]; int D = d_table_t->dims()[1]; int K = product(ids_t->dims()); - const uint32_t* ids = ids_t->data(); - T* d_table = d_table_t->mutable_data(context.GetPlace()); + const int32_t* ids = ids_t->data(); const T* d_output = d_output_t->data(); + T* d_table = d_table_t->mutable_data(context.GetPlace()); auto* device_context = const_cast(context.device_context_); @@ -103,8 +102,8 @@ class LookupTableGrad : public framework::OpKernel { device_context); dim3 threads(128, 8); dim3 grids(8, 1); - LookupTableGradKernel<<>>(d_table, d_output, - ids, N, K, D); + LookupTableGrad<<>>(d_table, d_output, ids, N, + K, D); } }; @@ -113,4 +112,5 @@ class LookupTableGrad : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel); -REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGrad); +REGISTER_OP_GPU_KERNEL(lookup_table_grad, + ops::LookupTableGradCUDAKernel); diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h index 790ecab3c66ada68c48d3306a7565430b340f431..9254e03a1b7d11b3003fe07b784152ddfa05d8c7 100644 --- a/paddle/operators/lookup_table_op.h +++ b/paddle/operators/lookup_table_op.h @@ -32,7 +32,7 @@ class LookupTableKernel : public framework::OpKernel { size_t N = table_t->dims()[0]; size_t D = table_t->dims()[1]; - auto ids = ids_t->data(); + auto ids = ids_t->data(); auto table = table_t->data(); auto output = output_t->mutable_data(context.GetPlace()); for (size_t i = 0; i < product(ids_t->dims()); ++i) { @@ -53,9 +53,9 @@ class LookupTableGradKernel : public framework::OpKernel { size_t N = d_table_t->dims()[0]; size_t D = d_table_t->dims()[1]; - auto ids = ids_t->data(); - T* d_table = d_table_t->mutable_data(context.GetPlace()); + auto ids = ids_t->data(); const T* d_output = d_output_t->data(); + T* d_table = d_table_t->mutable_data(context.GetPlace()); auto* device_context = const_cast(context.device_context_); diff --git a/python/paddle/v2/framework/tests/test_lookup_table.py b/python/paddle/v2/framework/tests/test_lookup_table.py index 071069768bf754eff20a9ee2e67279a3a61a14fc..3056bf53e3d23cf004368bbbe9c1616d3a8efa58 100644 --- a/python/paddle/v2/framework/tests/test_lookup_table.py +++ b/python/paddle/v2/framework/tests/test_lookup_table.py @@ -10,7 +10,7 @@ class TestSigmoidOp(unittest.TestCase): def setUp(self): self.type = 'lookup_table' table = np.random.random((17, 31)).astype('float32') - ids = np.random.randint(0, 17, 4) + ids = np.random.randint(0, 17, 4).astype('int32') self.inputs = {'W': table, 'Ids': ids} self.outputs = {'Out': table[ids]} @@ -19,10 +19,8 @@ class TestSigmoidGradOp(GradientChecker): def test_grad(self): op = create_op('lookup_table') table = np.random.random((17, 31)).astype('float32') - ids = np.random.randint(0, 17, 4) + ids = np.random.randint(0, 17, 4).astype('int32') inputs = {'W': table, 'Ids': ids} - # compare gradients between cpu and gpu - self.compare_grad(op, inputs) # check gradients self.check_grad(op, inputs, set('W'), 'Out')