未验证 提交 3e47eee9 编写于 作者: J Jiawei Wang 提交者: GitHub

fix stack grad gpu (#32781) (#32877)

上级 b60ab6b6
......@@ -96,9 +96,10 @@ class StackGPUKernel : public framework::OpKernel<T> {
};
template <typename T, typename IntType>
__global__ void UnStackCUDAKernel(const T* __restrict__ input, int pre_dim_size,
int split_dim_size, int suf_dim_size,
int num_split, T** output_ptrs) {
__global__ void UnStackHelperCUDAKernel(const T* __restrict__ input,
int pre_dim_size, int split_dim_size,
int suf_dim_size, int num_split,
T** output_ptrs) {
assert(blockDim.y == 1);
assert(blockDim.z == 1);
// In this case they are equal
......@@ -114,6 +115,9 @@ __global__ void UnStackCUDAKernel(const T* __restrict__ input, int pre_dim_size,
IntType k = offset % suf_dim_size;
T* output = output_ptrs[j / each_dim_size];
if (output == nullptr) {
return;
}
IntType output_ind = i * each_dim_size * suf_dim_size +
(j % each_dim_size) * suf_dim_size + k;
*(output + output_ind) = input[offset];
......@@ -142,6 +146,9 @@ class StackGradGPUKernel : public framework::OpKernel<T> {
std::vector<T*> outputs(n);
auto out_var_names = ctx.OutputNames(framework::GradVarName("X"));
for (size_t j = 0; j < dx.size(); ++j) {
if (dx[j] == nullptr) {
outputs[j] = nullptr;
}
if (out_var_names[j] != framework::kEmptyVarName &&
dx[j]->numel() != 0UL) {
T* ptr = dx[j]->mutable_data<T>(ctx.GetPlace());
......@@ -170,13 +177,13 @@ class StackGradGPUKernel : public framework::OpKernel<T> {
auto config = GetGpuLaunchConfig1D(dev_ctx, dy_pre * split_dim * dy_suf);
if (dy->numel() < std::numeric_limits<int32_t>::max()) {
UnStackCUDAKernel<
UnStackHelperCUDAKernel<
T, int32_t><<<config.block_per_grid.x, config.thread_per_block.x, 0,
dev_ctx.stream()>>>(
dy_data, dy_pre, split_dim, dy_suf, split_dim,
reinterpret_cast<T**>(tmp_out_data->ptr()));
} else {
UnStackCUDAKernel<
UnStackHelperCUDAKernel<
T, int64_t><<<config.block_per_grid.x, config.thread_per_block.x, 0,
dev_ctx.stream()>>>(
dy_data, dy_pre, split_dim, dy_suf, split_dim,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册