// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "paddle/pten/kernels/norm_grad_kernel.h" #ifdef __NVCC__ #include "cub/cub.cuh" #endif #ifdef __HIPCC__ #include namespace cub = hipcub; #endif #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/pten/common/bfloat16.h" #include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/kernels/funcs/common_shape.h" namespace pten { template __global__ void NormalizeGradient(const T* x, const T* x_norm, const T* y_grad, const int pre, const int axis_n, const int post, T* x_grad) { using MT = typename paddle::operators::details::MPTypeTrait::Type; typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage_sum; int num = pre * post; for (int i = blockIdx.x; i < num; i += gridDim.x) { MT sum = 0.0; __shared__ MT row_sum; __shared__ MT row_sqrt_norm; __shared__ MT row_norm; auto base = (i / post) * post * axis_n + (i % post); for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { int index = base + j * post; sum += static_cast(x[index]) * static_cast(y_grad[index]); } MT reduce_result = BlockReduce(temp_storage_sum).Sum(sum); if (threadIdx.x == 0) { row_sum = reduce_result; row_sqrt_norm = static_cast(x_norm[i]); row_norm = row_sqrt_norm * row_sqrt_norm; } __syncthreads(); for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { int index = base + j * post; const MT x_ij = static_cast(x[index]); const MT dy_ij = static_cast(y_grad[index]); x_grad[index] = static_cast((dy_ij - x_ij * row_sum / row_norm) / row_sqrt_norm); } } } template void NormGradKernel(const Context& ctx, const DenseTensor& out_grad, const DenseTensor& x, const DenseTensor& norm, int axis, float epsilon, bool is_test, DenseTensor* x_grad) { auto* in_x = &x; auto* in_norm = &norm; auto* in_dy = &out_grad; auto* out_dx = x_grad; ctx.template Alloc(out_dx); T* dx = out_dx->data(); const T* x_data = in_x->data(); const T* x_norm = in_norm->data(); const T* dy = in_dy->data(); auto xdim = in_x->dims(); if (axis < 0) axis = xdim.size() + axis; int pre, n, post; funcs::GetPrePostNumel(xdim, axis, &pre, &n, &post); #ifdef __HIPCC__ const int block = 256; #else const int block = 512; #endif int max_threads = ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); int grid = std::min(max_blocks, pre * post); NormalizeGradient<<>>( x_data, x_norm, dy, pre, n, post, dx); } } // namespace pten PT_REGISTER_KERNEL(norm_grad, GPU, ALL_LAYOUT, pten::NormGradKernel, float, double, paddle::platform::float16) {}