未验证 提交 95fa383d 编写于 作者: D donproc 提交者: GitHub

optimize embedding cuda kernel lookup_table_v2,test=develop (#25587)

上级 72064172
......@@ -105,17 +105,17 @@ class LookupTableV2CUDAKernel : public framework::OpKernel<T> {
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(128, 8);
dim3 grids(8, 1);
dim3 threads(256, 4);
dim3 grids(80, 1);
if (padding_idx == -1)
LookupTableV2<
T, 128, 8, 8,
T, 256, 4, 80,
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
else
LookupTableV2<
T, 128, 8, 8,
T, 256, 4, 80,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册