提交 00360e7e 编写于 作者: T typhoonzero

update

上级 579c92ab
......@@ -74,8 +74,9 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTable<T, 128, 8,
8><<<grids, threads, 0, context.device_context().stream()>>>(
LookupTable<
T, 128, 8,
8><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D);
}
};
......@@ -135,7 +136,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
dim3 grids(8, 1);
LookupTableGrad<
T, 128, 8,
8><<<grids, threads, 0, context.device_context().stream()>>>(
8><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
d_table, d_output, ids, N, K, D);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册