未验证 提交 e7980adf 编写于 作者: H huangxu96 提交者: GitHub

[Cherry-Pick] take along axis bug fix (#41863)

This PR is the cherry-pick of #41824

This PR fixes a bug that will cause the Cuda address error. The reason for this bug is that the grid number of the Cuda Kernel had been wrongly set.
上级 623f8308
......@@ -119,7 +119,7 @@ struct gpu_gather_scatter_functor {
is_scatter_like ? self_dims[dim] : src_dims[dim];
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
for (int64_t i = 0; i < index_dims.size(); ++i) {
for (int64_t i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i];
}
......@@ -127,11 +127,8 @@ struct gpu_gather_scatter_functor {
outer_dim_size *= index_dims[i];
}
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
int block = 512;
int64_t n = slice_size * index_size;
int64_t n = inner_dim_size * select_dim_size * outer_dim_size;
int64_t grid = (n + block - 1) / block;
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
......@@ -215,11 +212,8 @@ void gpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
outer_dim_size *= index_dims[i];
}
int64_t slice_size = 1;
for (int i = 1; i < grad_dims.size(); ++i) slice_size *= grad_dims[i];
int block = 512;
int64_t n = slice_size * index_size;
int64_t n = inner_dim_size * select_dim_size * outer_dim_size;
int64_t grid = (n + block - 1) / block;
auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册