提交 a8d072c7 编写于 作者: D dangqingqing

fix bug.

上级 9bc1a1a1
......@@ -41,8 +41,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
" which is a learnable parameter.");
AddInput("Ids",
"An input with type int32 or int64"
"contains the ids to be looked up in W.")
.NotInGradient();
"contains the ids to be looked up in W.");
AddOutput("Out", "The lookup results, which have the same type with W.");
AddComment(
"This operator is used to perform lookups on the parameter W,"
......@@ -56,7 +55,9 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &context) const override {
context.Output<Tensor>(0)->Resize(context.Input<Tensor>(0)->dims());
auto table = context.Input<Tensor>("W");
auto d_table = context.Output<Tensor>(framework::GradVarName("W"));
d_table->Resize(table->dims());
}
};
......
......@@ -23,7 +23,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T, int blockDimX, int blockDimY, int gridDimX>
__global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
__global__ void LookupTable(T* output, const T* table, const int32_t* ids,
const int N, const int K, const int D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * gridDimX;
......@@ -32,8 +32,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
int id = ids[idy];
PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N);
T* out = output + idy;
const T* tab = table + id;
T* out = output + idy * D;
const T* tab = table + id * D;
for (int i = idx; i < D; i += blockDimX) {
out[i] = tab[i];
}
......@@ -42,9 +42,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
}
template <typename T, int blockDimX, int blockDimY, int gridDimX>
__global__ void LookupTableGradKernel(T* table, const T* output,
const uint32_t* ids, const int N,
const int K, const int D) {
__global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
const int N, const int K, const int D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * gridDimX;
......@@ -52,10 +51,10 @@ __global__ void LookupTableGradKernel(T* table, const T* output,
int id = ids[idy];
PADDLE_ASSERT(id >= 0);
PADDLE_ASSERT(id < N);
const T* out = output + idy;
T* tab = table + id;
const T* out = output + idy * D;
T* tab = table + id * D;
for (int i = idx; i < D; i += blockDimX) {
paddle::platform::CudaAtomicAdd(tab + i, out[i]);
paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
}
idy += blockDimY * gridDimX;
}
......@@ -72,7 +71,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = product(ids_t->dims());
auto ids = ids_t->data<uint32_t>();
auto ids = ids_t->data<int32_t>();
auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace());
......@@ -83,7 +82,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
};
template <typename T>
class LookupTableGrad : public framework::OpKernel {
class LookupTableGradCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids");
......@@ -93,9 +92,9 @@ class LookupTableGrad : public framework::OpKernel {
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = product(ids_t->dims());
const uint32_t* ids = ids_t->data<uint32_t>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
const int32_t* ids = ids_t->data<int32_t>();
const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
......@@ -103,8 +102,8 @@ class LookupTableGrad : public framework::OpKernel {
device_context);
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTableGradKernel<T, 128, 8, 8><<<grids, threads>>>(d_table, d_output,
ids, N, K, D);
LookupTableGrad<T, 128, 8, 8><<<grids, threads>>>(d_table, d_output, ids, N,
K, D);
}
};
......@@ -113,4 +112,5 @@ class LookupTableGrad : public framework::OpKernel {
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGrad<float>);
REGISTER_OP_GPU_KERNEL(lookup_table_grad,
ops::LookupTableGradCUDAKernel<float>);
......@@ -32,7 +32,7 @@ class LookupTableKernel : public framework::OpKernel {
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
auto ids = ids_t->data<uint32_t>();
auto ids = ids_t->data<int32_t>();
auto table = table_t->data<T>();
auto output = output_t->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < product(ids_t->dims()); ++i) {
......@@ -53,9 +53,9 @@ class LookupTableGradKernel : public framework::OpKernel {
size_t N = d_table_t->dims()[0];
size_t D = d_table_t->dims()[1];
auto ids = ids_t->data<uint32_t>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto ids = ids_t->data<int32_t>();
const T* d_output = d_output_t->data<T>();
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
......
......@@ -10,7 +10,7 @@ class TestSigmoidOp(unittest.TestCase):
def setUp(self):
self.type = 'lookup_table'
table = np.random.random((17, 31)).astype('float32')
ids = np.random.randint(0, 17, 4)
ids = np.random.randint(0, 17, 4).astype('int32')
self.inputs = {'W': table, 'Ids': ids}
self.outputs = {'Out': table[ids]}
......@@ -19,10 +19,8 @@ class TestSigmoidGradOp(GradientChecker):
def test_grad(self):
op = create_op('lookup_table')
table = np.random.random((17, 31)).astype('float32')
ids = np.random.randint(0, 17, 4)
ids = np.random.randint(0, 17, 4).astype('int32')
inputs = {'W': table, 'Ids': ids}
# compare gradients between cpu and gpu
self.compare_grad(op, inputs)
# check gradients
self.check_grad(op, inputs, set('W'), 'Out')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册