提交 a8d072c7 编写于 作者: D dangqingqing

fix bug.

上级 9bc1a1a1
...@@ -41,8 +41,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -41,8 +41,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
" which is a learnable parameter."); " which is a learnable parameter.");
AddInput("Ids", AddInput("Ids",
"An input with type int32 or int64" "An input with type int32 or int64"
"contains the ids to be looked up in W.") "contains the ids to be looked up in W.");
.NotInGradient();
AddOutput("Out", "The lookup results, which have the same type with W."); AddOutput("Out", "The lookup results, which have the same type with W.");
AddComment( AddComment(
"This operator is used to perform lookups on the parameter W," "This operator is used to perform lookups on the parameter W,"
...@@ -56,7 +55,9 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -56,7 +55,9 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &context) const override { void InferShape(const framework::InferShapeContext &context) const override {
context.Output<Tensor>(0)->Resize(context.Input<Tensor>(0)->dims()); auto table = context.Input<Tensor>("W");
auto d_table = context.Output<Tensor>(framework::GradVarName("W"));
d_table->Resize(table->dims());
} }
}; };
......
...@@ -23,7 +23,7 @@ namespace operators { ...@@ -23,7 +23,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, int blockDimX, int blockDimY, int gridDimX> template <typename T, int blockDimX, int blockDimY, int gridDimX>
__global__ void LookupTable(T* output, const T* table, const uint32_t* ids, __global__ void LookupTable(T* output, const T* table, const int32_t* ids,
const int N, const int K, const int D) { const int N, const int K, const int D) {
int idx = threadIdx.x; int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * gridDimX; int idy = blockIdx.x + threadIdx.y * gridDimX;
...@@ -32,8 +32,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids, ...@@ -32,8 +32,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
int id = ids[idy]; int id = ids[idy];
PADDLE_ASSERT(id >= 0); PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N); PADDLE_ASSERT(id < N);
T* out = output + idy; T* out = output + idy * D;
const T* tab = table + id; const T* tab = table + id * D;
for (int i = idx; i < D; i += blockDimX) { for (int i = idx; i < D; i += blockDimX) {
out[i] = tab[i]; out[i] = tab[i];
} }
...@@ -42,9 +42,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids, ...@@ -42,9 +42,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
} }
template <typename T, int blockDimX, int blockDimY, int gridDimX> template <typename T, int blockDimX, int blockDimY, int gridDimX>
__global__ void LookupTableGradKernel(T* table, const T* output, __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
const uint32_t* ids, const int N, const int N, const int K, const int D) {
const int K, const int D) {
int idx = threadIdx.x; int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * gridDimX; int idy = blockIdx.x + threadIdx.y * gridDimX;
...@@ -52,10 +51,10 @@ __global__ void LookupTableGradKernel(T* table, const T* output, ...@@ -52,10 +51,10 @@ __global__ void LookupTableGradKernel(T* table, const T* output,
int id = ids[idy]; int id = ids[idy];
PADDLE_ASSERT(id >= 0); PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N); PADDLE_ASSERT(id < N);
const T* out = output + idy; const T* out = output + idy * D;
T* tab = table + id; T* tab = table + id * D;
for (int i = idx; i < D; i += blockDimX) { for (int i = idx; i < D; i += blockDimX) {
paddle::platform::CudaAtomicAdd(tab + i, out[i]); paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
} }
idy += blockDimY * gridDimX; idy += blockDimY * gridDimX;
} }
...@@ -72,7 +71,7 @@ class LookupTableCUDAKernel : public framework::OpKernel { ...@@ -72,7 +71,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
size_t N = table_t->dims()[0]; size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1]; size_t D = table_t->dims()[1];
size_t K = product(ids_t->dims()); size_t K = product(ids_t->dims());
auto ids = ids_t->data<uint32_t>(); auto ids = ids_t->data<int32_t>();
auto table = table_t->data<T>(); auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace()); auto output = output_t->mutable_data<T>(context.GetPlace());
...@@ -83,7 +82,7 @@ class LookupTableCUDAKernel : public framework::OpKernel { ...@@ -83,7 +82,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
}; };
template <typename T> template <typename T>
class LookupTableGrad : public framework::OpKernel { class LookupTableGradCUDAKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids"); auto ids_t = context.Input<Tensor>("Ids");
...@@ -93,9 +92,9 @@ class LookupTableGrad : public framework::OpKernel { ...@@ -93,9 +92,9 @@ class LookupTableGrad : public framework::OpKernel {
int N = d_table_t->dims()[0]; int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1]; int D = d_table_t->dims()[1];
int K = product(ids_t->dims()); int K = product(ids_t->dims());
const uint32_t* ids = ids_t->data<uint32_t>(); const int32_t* ids = ids_t->data<int32_t>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
const T* d_output = d_output_t->data<T>(); const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto* device_context = auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_); const_cast<platform::DeviceContext*>(context.device_context_);
...@@ -103,8 +102,8 @@ class LookupTableGrad : public framework::OpKernel { ...@@ -103,8 +102,8 @@ class LookupTableGrad : public framework::OpKernel {
device_context); device_context);
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(8, 1); dim3 grids(8, 1);
LookupTableGradKernel<T, 128, 8, 8><<<grids, threads>>>(d_table, d_output, LookupTableGrad<T, 128, 8, 8><<<grids, threads>>>(d_table, d_output, ids, N,
ids, N, K, D); K, D);
} }
}; };
...@@ -113,4 +112,5 @@ class LookupTableGrad : public framework::OpKernel { ...@@ -113,4 +112,5 @@ class LookupTableGrad : public framework::OpKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>); REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGrad<float>); REGISTER_OP_GPU_KERNEL(lookup_table_grad,
ops::LookupTableGradCUDAKernel<float>);
...@@ -32,7 +32,7 @@ class LookupTableKernel : public framework::OpKernel { ...@@ -32,7 +32,7 @@ class LookupTableKernel : public framework::OpKernel {
size_t N = table_t->dims()[0]; size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1]; size_t D = table_t->dims()[1];
auto ids = ids_t->data<uint32_t>(); auto ids = ids_t->data<int32_t>();
auto table = table_t->data<T>(); auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace()); auto output = output_t->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < product(ids_t->dims()); ++i) { for (size_t i = 0; i < product(ids_t->dims()); ++i) {
...@@ -53,9 +53,9 @@ class LookupTableGradKernel : public framework::OpKernel { ...@@ -53,9 +53,9 @@ class LookupTableGradKernel : public framework::OpKernel {
size_t N = d_table_t->dims()[0]; size_t N = d_table_t->dims()[0];
size_t D = d_table_t->dims()[1]; size_t D = d_table_t->dims()[1];
auto ids = ids_t->data<uint32_t>(); auto ids = ids_t->data<int32_t>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
const T* d_output = d_output_t->data<T>(); const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto* device_context = auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_); const_cast<platform::DeviceContext*>(context.device_context_);
......
...@@ -10,7 +10,7 @@ class TestSigmoidOp(unittest.TestCase): ...@@ -10,7 +10,7 @@ class TestSigmoidOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = 'lookup_table' self.type = 'lookup_table'
table = np.random.random((17, 31)).astype('float32') table = np.random.random((17, 31)).astype('float32')
ids = np.random.randint(0, 17, 4) ids = np.random.randint(0, 17, 4).astype('int32')
self.inputs = {'W': table, 'Ids': ids} self.inputs = {'W': table, 'Ids': ids}
self.outputs = {'Out': table[ids]} self.outputs = {'Out': table[ids]}
...@@ -19,10 +19,8 @@ class TestSigmoidGradOp(GradientChecker): ...@@ -19,10 +19,8 @@ class TestSigmoidGradOp(GradientChecker):
def test_grad(self): def test_grad(self):
op = create_op('lookup_table') op = create_op('lookup_table')
table = np.random.random((17, 31)).astype('float32') table = np.random.random((17, 31)).astype('float32')
ids = np.random.randint(0, 17, 4) ids = np.random.randint(0, 17, 4).astype('int32')
inputs = {'W': table, 'Ids': ids} inputs = {'W': table, 'Ids': ids}
# compare gradients between cpu and gpu
self.compare_grad(op, inputs)
# check gradients # check gradients
self.check_grad(op, inputs, set('W'), 'Out') self.check_grad(op, inputs, set('W'), 'Out')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册