diff --git a/paddle/phi/kernels/funcs/sparse/convolution.h b/paddle/phi/kernels/funcs/sparse/convolution.h index aa7a5e996f17f47ac94e50b0de8ceda8f62732db..e6f3a573088b28738c7186ca9a71ff355714ee53 100644 --- a/paddle/phi/kernels/funcs/sparse/convolution.h +++ b/paddle/phi/kernels/funcs/sparse/convolution.h @@ -163,17 +163,20 @@ inline void SubmPreProcess(const Context& dev_ctx, DenseTensor* kernel_grad, DenseTensor* x_grad) { auto blas = phi::funcs::GetBlas(dev_ctx); - T* d_kernel_ptr = kernel_grad->data(); - blas.GEMM(CblasTrans, - CblasNoTrans, - x.non_zero_elements().dims()[1], - out_grad.dims()[1], - x.non_zero_elements().dims()[0], - static_cast(1), - x.non_zero_elements().data(), - out_grad.data(), - static_cast(0), - d_kernel_ptr + half_kernel_size * in_channels * out_channels); + const bool is_params_freezing = kernel_grad == nullptr; + if (!is_params_freezing) { + T* d_kernel_ptr = kernel_grad->data(); + blas.GEMM(CblasTrans, + CblasNoTrans, + x.non_zero_elements().dims()[1], + out_grad.dims()[1], + x.non_zero_elements().dims()[0], + static_cast(1), + x.non_zero_elements().data(), + out_grad.data(), + static_cast(0), + d_kernel_ptr + half_kernel_size * in_channels * out_channels); + } // call gemm: d_x = out_grad * transpose(kernel) // (n, out_channels) * (out_channels, in_channels) diff --git a/paddle/phi/kernels/sparse/batch_norm_grad_kernel.cc b/paddle/phi/kernels/sparse/batch_norm_grad_kernel.cc index ff3173ec0a1013171a9f460a9e22702cf1f583ca..73af07da806e05def96169f8e37d4aa6f81aef0c 100644 --- a/paddle/phi/kernels/sparse/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/batch_norm_grad_kernel.cc @@ -42,8 +42,19 @@ void BatchNormCooGradKernel(const Context& dev_ctx, DenseTensor* scale_grad, DenseTensor* bias_grad) { EmptyLikeCooKernel(dev_ctx, x, x_grad); - *scale_grad = phi::EmptyLike(dev_ctx, scale); - *bias_grad = phi::EmptyLike(dev_ctx, bias); + + // TODO(umiswing): add check for parameter freezing automatically + PADDLE_ENFORCE_EQ((scale_grad == nullptr && bias_grad == nullptr) || + (scale_grad != nullptr && bias_grad != nullptr), + true, + phi::errors::InvalidArgument( + "Weight and bias's stop_gradient of BatchNorm must be " + "True or False at the same time.")); + + if (scale_grad && bias_grad) { + *scale_grad = phi::EmptyLike(dev_ctx, scale); + *bias_grad = phi::EmptyLike(dev_ctx, bias); + } phi::BatchNormGradKernel(dev_ctx, x.values(), scale, diff --git a/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu index cd2472b453cba0515902288bd36f508bb73bff4a..d4076e5ef0b5d635815386d19fceedc83766fc3a 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu @@ -56,6 +56,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, const std::string& key, SparseCooTensor* x_grad, DenseTensor* kernel_grad) { + const bool is_params_freezing = kernel_grad == nullptr; const auto& kernel_dims = kernel.dims(); const bool is2D = kernel_dims.size() == 4 ? true : false; const int kernel_size = @@ -79,10 +80,13 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, T* in_features_ptr = in_features.data(); T* d_x_features_ptr = d_x_features.data(); T* out_grad_features_ptr = out_grad_features.data(); - *kernel_grad = phi::EmptyLike(dev_ctx, kernel); - T* d_kernel_ptr = kernel_grad->data(); - phi::backends::gpu::GpuMemsetAsync( - d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel(), dev_ctx.stream()); + T* d_kernel_ptr = nullptr; + if (!is_params_freezing) { + *kernel_grad = phi::EmptyLike(dev_ctx, kernel); + d_kernel_ptr = kernel_grad->data(); + phi::backends::gpu::GpuMemsetAsync( + d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel(), dev_ctx.stream()); + } int half_kernel_size = kernel_size / 2; auto blas = phi::funcs::GetBlas(dev_ctx); @@ -184,6 +188,8 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, } #endif const T* kernel_ptr = kernel.data(); + T* tmp_d_x_ptr = nullptr; + T* tmp_d_kernel_ptr = nullptr; for (int i = 0; i < kernel_size; i++) { if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) { continue; @@ -195,8 +201,10 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; T* tmp_out_grad_ptr = out_grad_features_ptr + offsets[i] * out_channels; const T* tmp_kernel_ptr = kernel_ptr + i * in_channels * out_channels; - T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels; - T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels; + tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels; + if (!is_params_freezing) { + tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels; + } #if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS if (cutlass) { @@ -204,26 +212,28 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, const IntT* scatter_x_indices = rulebook_ptr + offsets[i]; const IntT* gather_out_indices = rulebook_ptr + rulebook_len + offsets[i]; const size_t key = autotune::GenKey(M / features_num_range, N, K); - // call gemm: d_kernel = transpose(x) * out_grad - // (in_channels, n) * (n, out_channels) - static cutlass::device_memory::allocation workspace( - workspace_size); - GatherGemmScatterDriver<80, true, false>( - dev_ctx, - key, - x.values().data(), - out_grad.values().data(), - tmp_d_kernel_ptr, - tmp_d_kernel_ptr, - in_channels, - out_channels, - counter_ptr[i], - gather_x_indices, - gather_out_indices, - static_cast(nullptr), - static_cast(1.0), - static_cast(0.0), - &workspace); + if (!is_params_freezing) { + // call gemm: d_kernel = transpose(x) * out_grad + // (in_channels, n) * (n, out_channels) + static cutlass::device_memory::allocation workspace( + workspace_size); + GatherGemmScatterDriver<80, true, false>( + dev_ctx, + key, + x.values().data(), + out_grad.values().data(), + tmp_d_kernel_ptr, + tmp_d_kernel_ptr, + in_channels, + out_channels, + counter_ptr[i], + gather_x_indices, + gather_out_indices, + static_cast(nullptr), + static_cast(1.0), + static_cast(0.0), + &workspace); + } // call gemm: d_x = out_grad * transpose(kernel) // (n, out_channels) * (out_channels, in_channels) GatherGemmScatterDriver<80, false, true>( @@ -244,18 +254,20 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, nullptr); } else { #endif - // call gemm: d_kernel = transpose(x) * out_grad - // (in_channels, n) * (n, out_channels) - blas.GEMM(CblasTrans, - CblasNoTrans, - K, - N, - M, - static_cast(1), - tmp_in_ptr, - tmp_out_grad_ptr, - static_cast(0), - tmp_d_kernel_ptr); + if (!is_params_freezing) { + // call gemm: d_kernel = transpose(x) * out_grad + // (in_channels, n) * (n, out_channels) + blas.GEMM(CblasTrans, + CblasNoTrans, + K, + N, + M, + static_cast(1), + tmp_in_ptr, + tmp_out_grad_ptr, + static_cast(0), + tmp_d_kernel_ptr); + } // call gemm: d_x = out_grad * transpose(kernel) // (n, out_channels) * (out_channels, in_channels)