未验证 提交 8bdb336d 编写于 作者: U umiswing 提交者: GitHub

[Sparse] Fix bugs in parameter freezing (#56154)

* Add enforce for sparse_bn.

* Add enforce for sp conv.
上级 e9c0fe03
...@@ -163,17 +163,20 @@ inline void SubmPreProcess(const Context& dev_ctx, ...@@ -163,17 +163,20 @@ inline void SubmPreProcess(const Context& dev_ctx,
DenseTensor* kernel_grad, DenseTensor* kernel_grad,
DenseTensor* x_grad) { DenseTensor* x_grad) {
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx); auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
T* d_kernel_ptr = kernel_grad->data<T>(); const bool is_params_freezing = kernel_grad == nullptr;
blas.GEMM(CblasTrans, if (!is_params_freezing) {
CblasNoTrans, T* d_kernel_ptr = kernel_grad->data<T>();
x.non_zero_elements().dims()[1], blas.GEMM(CblasTrans,
out_grad.dims()[1], CblasNoTrans,
x.non_zero_elements().dims()[0], x.non_zero_elements().dims()[1],
static_cast<T>(1), out_grad.dims()[1],
x.non_zero_elements().data<T>(), x.non_zero_elements().dims()[0],
out_grad.data<T>(), static_cast<T>(1),
static_cast<T>(0), x.non_zero_elements().data<T>(),
d_kernel_ptr + half_kernel_size * in_channels * out_channels); out_grad.data<T>(),
static_cast<T>(0),
d_kernel_ptr + half_kernel_size * in_channels * out_channels);
}
// call gemm: d_x = out_grad * transpose(kernel) // call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels) // (n, out_channels) * (out_channels, in_channels)
......
...@@ -42,8 +42,19 @@ void BatchNormCooGradKernel(const Context& dev_ctx, ...@@ -42,8 +42,19 @@ void BatchNormCooGradKernel(const Context& dev_ctx,
DenseTensor* scale_grad, DenseTensor* scale_grad,
DenseTensor* bias_grad) { DenseTensor* bias_grad) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, x_grad); EmptyLikeCooKernel<T, Context>(dev_ctx, x, x_grad);
*scale_grad = phi::EmptyLike<T, Context>(dev_ctx, scale);
*bias_grad = phi::EmptyLike<T, Context>(dev_ctx, bias); // TODO(umiswing): add check for parameter freezing automatically
PADDLE_ENFORCE_EQ((scale_grad == nullptr && bias_grad == nullptr) ||
(scale_grad != nullptr && bias_grad != nullptr),
true,
phi::errors::InvalidArgument(
"Weight and bias's stop_gradient of BatchNorm must be "
"True or False at the same time."));
if (scale_grad && bias_grad) {
*scale_grad = phi::EmptyLike<T, Context>(dev_ctx, scale);
*bias_grad = phi::EmptyLike<T, Context>(dev_ctx, bias);
}
phi::BatchNormGradKernel<T, Context>(dev_ctx, phi::BatchNormGradKernel<T, Context>(dev_ctx,
x.values(), x.values(),
scale, scale,
......
...@@ -56,6 +56,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -56,6 +56,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
const std::string& key, const std::string& key,
SparseCooTensor* x_grad, SparseCooTensor* x_grad,
DenseTensor* kernel_grad) { DenseTensor* kernel_grad) {
const bool is_params_freezing = kernel_grad == nullptr;
const auto& kernel_dims = kernel.dims(); const auto& kernel_dims = kernel.dims();
const bool is2D = kernel_dims.size() == 4 ? true : false; const bool is2D = kernel_dims.size() == 4 ? true : false;
const int kernel_size = const int kernel_size =
...@@ -79,10 +80,13 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -79,10 +80,13 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
T* in_features_ptr = in_features.data<T>(); T* in_features_ptr = in_features.data<T>();
T* d_x_features_ptr = d_x_features.data<T>(); T* d_x_features_ptr = d_x_features.data<T>();
T* out_grad_features_ptr = out_grad_features.data<T>(); T* out_grad_features_ptr = out_grad_features.data<T>();
*kernel_grad = phi::EmptyLike<T>(dev_ctx, kernel); T* d_kernel_ptr = nullptr;
T* d_kernel_ptr = kernel_grad->data<T>(); if (!is_params_freezing) {
phi::backends::gpu::GpuMemsetAsync( *kernel_grad = phi::EmptyLike<T>(dev_ctx, kernel);
d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel(), dev_ctx.stream()); d_kernel_ptr = kernel_grad->data<T>();
phi::backends::gpu::GpuMemsetAsync(
d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel(), dev_ctx.stream());
}
int half_kernel_size = kernel_size / 2; int half_kernel_size = kernel_size / 2;
auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx); auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
...@@ -184,6 +188,8 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -184,6 +188,8 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
} }
#endif #endif
const T* kernel_ptr = kernel.data<T>(); const T* kernel_ptr = kernel.data<T>();
T* tmp_d_x_ptr = nullptr;
T* tmp_d_kernel_ptr = nullptr;
for (int i = 0; i < kernel_size; i++) { for (int i = 0; i < kernel_size; i++) {
if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) { if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) {
continue; continue;
...@@ -195,8 +201,10 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -195,8 +201,10 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels;
T* tmp_out_grad_ptr = out_grad_features_ptr + offsets[i] * out_channels; T* tmp_out_grad_ptr = out_grad_features_ptr + offsets[i] * out_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * in_channels * out_channels; const T* tmp_kernel_ptr = kernel_ptr + i * in_channels * out_channels;
T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels; tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels;
T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels; if (!is_params_freezing) {
tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels;
}
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS #if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
if (cutlass) { if (cutlass) {
...@@ -204,26 +212,28 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -204,26 +212,28 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
const IntT* scatter_x_indices = rulebook_ptr + offsets[i]; const IntT* scatter_x_indices = rulebook_ptr + offsets[i];
const IntT* gather_out_indices = rulebook_ptr + rulebook_len + offsets[i]; const IntT* gather_out_indices = rulebook_ptr + rulebook_len + offsets[i];
const size_t key = autotune::GenKey(M / features_num_range, N, K); const size_t key = autotune::GenKey(M / features_num_range, N, K);
// call gemm: d_kernel = transpose(x) * out_grad if (!is_params_freezing) {
// (in_channels, n) * (n, out_channels) // call gemm: d_kernel = transpose(x) * out_grad
static cutlass::device_memory::allocation<uint8_t> workspace( // (in_channels, n) * (n, out_channels)
workspace_size); static cutlass::device_memory::allocation<uint8_t> workspace(
GatherGemmScatterDriver<80, true, false>( workspace_size);
dev_ctx, GatherGemmScatterDriver<80, true, false>(
key, dev_ctx,
x.values().data<T>(), key,
out_grad.values().data<T>(), x.values().data<T>(),
tmp_d_kernel_ptr, out_grad.values().data<T>(),
tmp_d_kernel_ptr, tmp_d_kernel_ptr,
in_channels, tmp_d_kernel_ptr,
out_channels, in_channels,
counter_ptr[i], out_channels,
gather_x_indices, counter_ptr[i],
gather_out_indices, gather_x_indices,
static_cast<const IntT*>(nullptr), gather_out_indices,
static_cast<const T>(1.0), static_cast<const IntT*>(nullptr),
static_cast<const T>(0.0), static_cast<const T>(1.0),
&workspace); static_cast<const T>(0.0),
&workspace);
}
// call gemm: d_x = out_grad * transpose(kernel) // call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels) // (n, out_channels) * (out_channels, in_channels)
GatherGemmScatterDriver<80, false, true>( GatherGemmScatterDriver<80, false, true>(
...@@ -244,18 +254,20 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, ...@@ -244,18 +254,20 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
nullptr); nullptr);
} else { } else {
#endif #endif
// call gemm: d_kernel = transpose(x) * out_grad if (!is_params_freezing) {
// (in_channels, n) * (n, out_channels) // call gemm: d_kernel = transpose(x) * out_grad
blas.GEMM(CblasTrans, // (in_channels, n) * (n, out_channels)
CblasNoTrans, blas.GEMM(CblasTrans,
K, CblasNoTrans,
N, K,
M, N,
static_cast<T>(1), M,
tmp_in_ptr, static_cast<T>(1),
tmp_out_grad_ptr, tmp_in_ptr,
static_cast<T>(0), tmp_out_grad_ptr,
tmp_d_kernel_ptr); static_cast<T>(0),
tmp_d_kernel_ptr);
}
// call gemm: d_x = out_grad * transpose(kernel) // call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels) // (n, out_channels) * (out_channels, in_channels)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册