未验证 提交 e25b75b6 编写于 作者: H huangxu96 提交者: GitHub

fix a bug which will casue cuda address error when the input size is very large (#41824)

As the title
上级 ea0a164b
...@@ -119,7 +119,7 @@ struct gpu_gather_scatter_functor { ...@@ -119,7 +119,7 @@ struct gpu_gather_scatter_functor {
is_scatter_like ? self_dims[dim] : src_dims[dim]; is_scatter_like ? self_dims[dim] : src_dims[dim];
int64_t inner_dim_size = 1; int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1; int64_t outer_dim_size = 1;
for (int64_t i = 0; i < index_dims.size(); ++i) { for (int64_t i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i]; inner_dim_size *= index_dims[i];
} }
...@@ -127,11 +127,8 @@ struct gpu_gather_scatter_functor { ...@@ -127,11 +127,8 @@ struct gpu_gather_scatter_functor {
outer_dim_size *= index_dims[i]; outer_dim_size *= index_dims[i];
} }
int64_t slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
int block = 512; int block = 512;
int64_t n = slice_size * index_size; int64_t n = inner_dim_size * select_dim_size * outer_dim_size;
int64_t grid = (n + block - 1) / block; int64_t grid = (n + block - 1) / block;
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream(); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
...@@ -215,11 +212,8 @@ void gpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index, ...@@ -215,11 +212,8 @@ void gpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index,
outer_dim_size *= index_dims[i]; outer_dim_size *= index_dims[i];
} }
int64_t slice_size = 1;
for (int i = 1; i < grad_dims.size(); ++i) slice_size *= grad_dims[i];
int block = 512; int block = 512;
int64_t n = slice_size * index_size; int64_t n = inner_dim_size * select_dim_size * outer_dim_size;
int64_t grid = (n + block - 1) / block; int64_t grid = (n + block - 1) / block;
auto stream = auto stream =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream(); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册