From 3e47eee948e4e538149ee8b7eed27c223d08ad3b Mon Sep 17 00:00:00 2001 From: Jiawei Wang Date: Thu, 13 May 2021 10:26:35 +0800 Subject: [PATCH] fix stack grad gpu (#32781) (#32877) --- paddle/fluid/operators/stack_op.cu | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/stack_op.cu b/paddle/fluid/operators/stack_op.cu index 4800f5f9eb5..9e5e45f4d22 100644 --- a/paddle/fluid/operators/stack_op.cu +++ b/paddle/fluid/operators/stack_op.cu @@ -96,9 +96,10 @@ class StackGPUKernel : public framework::OpKernel { }; template -__global__ void UnStackCUDAKernel(const T* __restrict__ input, int pre_dim_size, - int split_dim_size, int suf_dim_size, - int num_split, T** output_ptrs) { +__global__ void UnStackHelperCUDAKernel(const T* __restrict__ input, + int pre_dim_size, int split_dim_size, + int suf_dim_size, int num_split, + T** output_ptrs) { assert(blockDim.y == 1); assert(blockDim.z == 1); // In this case they are equal @@ -114,6 +115,9 @@ __global__ void UnStackCUDAKernel(const T* __restrict__ input, int pre_dim_size, IntType k = offset % suf_dim_size; T* output = output_ptrs[j / each_dim_size]; + if (output == nullptr) { + return; + } IntType output_ind = i * each_dim_size * suf_dim_size + (j % each_dim_size) * suf_dim_size + k; *(output + output_ind) = input[offset]; @@ -142,6 +146,9 @@ class StackGradGPUKernel : public framework::OpKernel { std::vector outputs(n); auto out_var_names = ctx.OutputNames(framework::GradVarName("X")); for (size_t j = 0; j < dx.size(); ++j) { + if (dx[j] == nullptr) { + outputs[j] = nullptr; + } if (out_var_names[j] != framework::kEmptyVarName && dx[j]->numel() != 0UL) { T* ptr = dx[j]->mutable_data(ctx.GetPlace()); @@ -170,13 +177,13 @@ class StackGradGPUKernel : public framework::OpKernel { auto config = GetGpuLaunchConfig1D(dev_ctx, dy_pre * split_dim * dy_suf); if (dy->numel() < std::numeric_limits::max()) { - UnStackCUDAKernel< + UnStackHelperCUDAKernel< T, int32_t><<>>( dy_data, dy_pre, split_dim, dy_suf, split_dim, reinterpret_cast(tmp_out_data->ptr())); } else { - UnStackCUDAKernel< + UnStackHelperCUDAKernel< T, int64_t><<>>( dy_data, dy_pre, split_dim, dy_suf, split_dim, -- GitLab