未验证 提交 8bdb336d 编写于 作者: U umiswing 提交者: GitHub

[Sparse] Fix bugs in parameter freezing (#56154)

* Add enforce for sparse_bn.

* Add enforce for sp conv.
上级 e9c0fe03
......@@ -163,17 +163,20 @@ inline void SubmPreProcess(const Context& dev_ctx,
DenseTensor* kernel_grad,
DenseTensor* x_grad) {
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
T* d_kernel_ptr = kernel_grad->data<T>();
blas.GEMM(CblasTrans,
CblasNoTrans,
x.non_zero_elements().dims()[1],
out_grad.dims()[1],
x.non_zero_elements().dims()[0],
static_cast<T>(1),
x.non_zero_elements().data<T>(),
out_grad.data<T>(),
static_cast<T>(0),
d_kernel_ptr + half_kernel_size * in_channels * out_channels);
const bool is_params_freezing = kernel_grad == nullptr;
if (!is_params_freezing) {
T* d_kernel_ptr = kernel_grad->data<T>();
blas.GEMM(CblasTrans,
CblasNoTrans,
x.non_zero_elements().dims()[1],
out_grad.dims()[1],
x.non_zero_elements().dims()[0],
static_cast<T>(1),
x.non_zero_elements().data<T>(),
out_grad.data<T>(),
static_cast<T>(0),
d_kernel_ptr + half_kernel_size * in_channels * out_channels);
}
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
......
......@@ -42,8 +42,19 @@ void BatchNormCooGradKernel(const Context& dev_ctx,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
EmptyLikeCooKernel<T, Context>(dev_ctx, x, x_grad);
*scale_grad = phi::EmptyLike<T, Context>(dev_ctx, scale);
*bias_grad = phi::EmptyLike<T, Context>(dev_ctx, bias);
// TODO(umiswing): add check for parameter freezing automatically
PADDLE_ENFORCE_EQ((scale_grad == nullptr && bias_grad == nullptr) ||
(scale_grad != nullptr && bias_grad != nullptr),
true,
phi::errors::InvalidArgument(
"Weight and bias's stop_gradient of BatchNorm must be "
"True or False at the same time."));
if (scale_grad && bias_grad) {
*scale_grad = phi::EmptyLike<T, Context>(dev_ctx, scale);
*bias_grad = phi::EmptyLike<T, Context>(dev_ctx, bias);
}
phi::BatchNormGradKernel<T, Context>(dev_ctx,
x.values(),
scale,
......
......@@ -56,6 +56,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
const std::string& key,
SparseCooTensor* x_grad,
DenseTensor* kernel_grad) {
const bool is_params_freezing = kernel_grad == nullptr;
const auto& kernel_dims = kernel.dims();
const bool is2D = kernel_dims.size() == 4 ? true : false;
const int kernel_size =
......@@ -79,10 +80,13 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
T* in_features_ptr = in_features.data<T>();
T* d_x_features_ptr = d_x_features.data<T>();
T* out_grad_features_ptr = out_grad_features.data<T>();
*kernel_grad = phi::EmptyLike<T>(dev_ctx, kernel);
T* d_kernel_ptr = kernel_grad->data<T>();
phi::backends::gpu::GpuMemsetAsync(
d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel(), dev_ctx.stream());
T* d_kernel_ptr = nullptr;
if (!is_params_freezing) {
*kernel_grad = phi::EmptyLike<T>(dev_ctx, kernel);
d_kernel_ptr = kernel_grad->data<T>();
phi::backends::gpu::GpuMemsetAsync(
d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel(), dev_ctx.stream());
}
int half_kernel_size = kernel_size / 2;
auto blas = phi::funcs::GetBlas<GPUContext, T>(dev_ctx);
......@@ -184,6 +188,8 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
}
#endif
const T* kernel_ptr = kernel.data<T>();
T* tmp_d_x_ptr = nullptr;
T* tmp_d_kernel_ptr = nullptr;
for (int i = 0; i < kernel_size; i++) {
if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) {
continue;
......@@ -195,8 +201,10 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels;
T* tmp_out_grad_ptr = out_grad_features_ptr + offsets[i] * out_channels;
const T* tmp_kernel_ptr = kernel_ptr + i * in_channels * out_channels;
T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels;
T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels;
tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels;
if (!is_params_freezing) {
tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels;
}
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
if (cutlass) {
......@@ -204,26 +212,28 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
const IntT* scatter_x_indices = rulebook_ptr + offsets[i];
const IntT* gather_out_indices = rulebook_ptr + rulebook_len + offsets[i];
const size_t key = autotune::GenKey(M / features_num_range, N, K);
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
static cutlass::device_memory::allocation<uint8_t> workspace(
workspace_size);
GatherGemmScatterDriver<80, true, false>(
dev_ctx,
key,
x.values().data<T>(),
out_grad.values().data<T>(),
tmp_d_kernel_ptr,
tmp_d_kernel_ptr,
in_channels,
out_channels,
counter_ptr[i],
gather_x_indices,
gather_out_indices,
static_cast<const IntT*>(nullptr),
static_cast<const T>(1.0),
static_cast<const T>(0.0),
&workspace);
if (!is_params_freezing) {
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
static cutlass::device_memory::allocation<uint8_t> workspace(
workspace_size);
GatherGemmScatterDriver<80, true, false>(
dev_ctx,
key,
x.values().data<T>(),
out_grad.values().data<T>(),
tmp_d_kernel_ptr,
tmp_d_kernel_ptr,
in_channels,
out_channels,
counter_ptr[i],
gather_x_indices,
gather_out_indices,
static_cast<const IntT*>(nullptr),
static_cast<const T>(1.0),
static_cast<const T>(0.0),
&workspace);
}
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
GatherGemmScatterDriver<80, false, true>(
......@@ -244,18 +254,20 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
nullptr);
} else {
#endif
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
blas.GEMM(CblasTrans,
CblasNoTrans,
K,
N,
M,
static_cast<T>(1),
tmp_in_ptr,
tmp_out_grad_ptr,
static_cast<T>(0),
tmp_d_kernel_ptr);
if (!is_params_freezing) {
// call gemm: d_kernel = transpose(x) * out_grad
// (in_channels, n) * (n, out_channels)
blas.GEMM(CblasTrans,
CblasNoTrans,
K,
N,
M,
static_cast<T>(1),
tmp_in_ptr,
tmp_out_grad_ptr,
static_cast<T>(0),
tmp_d_kernel_ptr);
}
// call gemm: d_x = out_grad * transpose(kernel)
// (n, out_channels) * (out_channels, in_channels)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册