未验证 提交 250e254f 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of log_softmax (#38992)

* Optimize performance of log_softmax

* delete unity build

* modify to phi

* fix

* fixfixfixfix

* fix

* fix

* fix

* fix

* simplify

* fix

* fix enforce
上级 02e80f59
...@@ -12,459 +12,43 @@ ...@@ -12,459 +12,43 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <limits>
#include "paddle/fluid/operators/log_softmax_op.h" #include "paddle/fluid/operators/log_softmax_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#define LAUNCH_WARP_FORWAR_COMPUTE(near_greater_power_of_two) \ using Tensor = framework::Tensor;
case near_greater_power_of_two: \
ComputeLogSoftmaxForwardInWarp< \
T, AccT, near_greater_power_of_two><<<blocks, threads, 0, stream>>>( \
dst, src, outer_size, dim_size); \
break;
template <typename T, int KernelWarpSize>
__device__ __forceinline__ T WarpReduceSum(T value) {
#pragma unroll
for (int offset = KernelWarpSize / 2; offset > 0; offset /= 2) {
T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, value, offset);
value = value + sum_val;
}
return value;
}
template <typename T, int KernelWarpSize>
__device__ __forceinline__ T WarpReduceMax(T value) {
#pragma unroll
for (int offset = KernelWarpSize / 2; offset > 0; offset /= 2) {
T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, value, offset);
value = max(value, max_val);
}
return value;
}
int GetNearGreaterPowerOfTwo(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) {
++log2_value;
}
return 1 << log2_value;
}
template <typename T, typename AccT, int NearGreaterPowerOfTwo>
__global__ void ComputeLogSoftmaxForwardInWarp(T *dst, const T *src,
int batch_size,
int element_count) {
constexpr int near_greater_power_of_two = NearGreaterPowerOfTwo;
constexpr int kernel_warp_size =
(near_greater_power_of_two < 32) ? near_greater_power_of_two : 32;
constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size;
int batch_id = blockDim.y * blockIdx.x + threadIdx.y;
int thread_in_warp_idx = threadIdx.x;
// 1.read data from global memory to registers
AccT elements[warp_iter];
// set effective_element_count as the num of elements when warps do effective
// work
// set effective_element_count as 0, when warps do ineffective work
int effective_element_count = (batch_id < batch_size) ? element_count : 0;
for (int it = 0; it < warp_iter; ++it) {
int element_index = thread_in_warp_idx + it * kernel_warp_size;
if (element_index < effective_element_count) {
elements[it] =
static_cast<AccT>(src[batch_id * element_count + element_index]);
} else {
elements[it] = -std::numeric_limits<AccT>::infinity();
}
}
// 2.compute max_value. For each thread, loop all registers to find max
AccT max_value = elements[0];
#pragma unroll
for (int it = 1; it < warp_iter; ++it) {
max_value = (max_value > elements[it]) ? max_value : elements[it];
}
max_value = WarpReduceMax<AccT, kernel_warp_size>(max_value);
// 3.For each warp, accumulate all thread registers
AccT sum = 0.0f;
#pragma unroll
for (int it = 0; it < warp_iter; ++it) {
sum += std::exp(elements[it] - max_value);
}
sum = WarpReduceSum<AccT, kernel_warp_size>(sum);
// 4.store result.
sum = std::log(sum);
#pragma unroll
for (int it = 0; it < warp_iter; ++it) {
int element_index = thread_in_warp_idx + it * kernel_warp_size;
if (element_index < effective_element_count) {
dst[batch_id * element_count + element_index] =
static_cast<T>(elements[it] - max_value - sum);
} else {
break;
}
}
}
template <typename T, typename AccT>
void LaunchSoftmaxForwardForLastAxis(T *dst, const T *src, int dim_size,
int outer_size, gpuStream_t stream) {
int threads_per_block = 128;
int near_greater_power_of_two = GetNearGreaterPowerOfTwo(dim_size);
int kernel_warp_size =
(near_greater_power_of_two < 32) ? near_greater_power_of_two : 32;
int warps_per_block = (threads_per_block / kernel_warp_size);
int blocks = (outer_size + warps_per_block - 1) / warps_per_block;
dim3 threads(kernel_warp_size, warps_per_block, 1);
switch (near_greater_power_of_two) {
LAUNCH_WARP_FORWAR_COMPUTE(1);
LAUNCH_WARP_FORWAR_COMPUTE(2);
LAUNCH_WARP_FORWAR_COMPUTE(4); // dim_size: 3~4
LAUNCH_WARP_FORWAR_COMPUTE(8); // dim_size: 5~8
LAUNCH_WARP_FORWAR_COMPUTE(16); // dim_size: 9~16
LAUNCH_WARP_FORWAR_COMPUTE(32); // dim_size: 17~32
LAUNCH_WARP_FORWAR_COMPUTE(64); // dim_size: 33~64
LAUNCH_WARP_FORWAR_COMPUTE(128); // dim_size 65~128
LAUNCH_WARP_FORWAR_COMPUTE(256); // dim_size 129~256
LAUNCH_WARP_FORWAR_COMPUTE(512); // dim_size 257~512
LAUNCH_WARP_FORWAR_COMPUTE(1024); // dim_size 513~1024
default:
break;
}
}
// Returns the final item after reduce operation along block.x.
// Firstly, get shared memory(smem) offset, find the starting position for every
// y.
// Secondly, initialise every smem position with value 'val' of thread itself.
// Thirdly, apply standard reduction along x direction as below:
//
// -> x direction
// [o o o o o o o o] time 0
// | |/ /
// | /| /
// | / | /
// |/ |/
// [o o o o x x x x] time 1
// | |/ /
// |/|/
// [o o x x x x x x] time 2
// |/
// [o x x x x x x x] time 3
//
// Finally, return the first item.
// Imaging multiple reductions executed in paralell along y axis,
// Note that when blockDim.x is not 1, it's a EVEN number in all cases,
// and the size of shared memory is even as well.
template <typename T, template <typename> class Functor>
__forceinline__ __device__ T BlockReduceAlongDimX(T *shared, T val) {
Functor<T> func;
// This reduction is not Block-wise reduction, only reduce along block.x.
// therefore the shared mem has offsets for different block.y.
shared += threadIdx.y * blockDim.x;
shared[threadIdx.x] = val;
int offset = blockDim.x / 2;
while (offset > 0) {
__syncthreads();
if (threadIdx.x < offset) {
shared[threadIdx.x] =
func(shared[threadIdx.x], shared[threadIdx.x + offset]);
}
offset /= 2;
}
__syncthreads();
return shared[0];
}
template <typename T, typename AccT>
__global__ void LogSoftmaxForwardCUDAKernelNotLastAxis(
T *output, const T *input, int outer_size, int dim_size, int inner_size) {
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<AccT *>(smem);
const int outer_stride = inner_size * dim_size;
const int dim_stride = inner_size;
for (int x_id = blockIdx.x; x_id < outer_size; x_id += gridDim.x) {
for (int y_id = blockIdx.y * blockDim.y + threadIdx.y; y_id < inner_size;
y_id += blockDim.y * gridDim.y) {
const int data_offset = x_id * outer_stride + y_id;
// When blockDim.x==1, no block.x-reduction opetaions are needed.
// And threadIdx.x is 0 all the time, so the for-loops below are literally
// loops (No parallel executions). Loop all elements along axis and
// calculate the Max, Sum and (input[id]-Max-log(Sum)) to get the final
// log_softmax values along that axis.
// 1. reduce max
AccT max_value = -std::numeric_limits<AccT>::infinity();
// For one thread, iterate all items it responsable for, and get
// max_value.
// If there are N threads, N max_value will be returned.
for (int d = threadIdx.x; d < dim_size; d += blockDim.x) {
const AccT value =
static_cast<AccT>(input[data_offset + d * dim_stride]);
max_value = phi::funcs::MaxFunctor<AccT>()(max_value, value);
}
// If there are more than 1 threads along block x, reduce all max_values
// and get the global max_value, which is the max value along "axis".
// If there is only one thread along block x, no need to reduce, as the
// 'max_value' is the global max_value.
if (blockDim.x > 1) {
max_value = BlockReduceAlongDimX<AccT, phi::funcs::MaxFunctor>(
sdata, max_value);
}
// 2. reduce sum
AccT sum = 0;
// Below is the same execution as '1. reduce max'
for (int d = threadIdx.x; d < dim_size; d += blockDim.x) {
sum += std::exp(static_cast<AccT>(input[data_offset + d * dim_stride]) -
max_value);
}
if (blockDim.x > 1) {
sum = BlockReduceAlongDimX<AccT, phi::funcs::AddFunctor>(sdata, sum);
}
// 3. input-max-log_sum and write to output
for (int d = threadIdx.x; d < dim_size; d += blockDim.x) {
output[data_offset + d * dim_stride] = static_cast<T>(
static_cast<AccT>(input[data_offset + d * dim_stride]) - max_value -
std::log(sum));
}
}
}
}
// block.y covers inner_size. Threads along the x axis process dim_size
// elements, and make sure not to exceed the 1024 threads per block.
// Note that dim_threads namely blockDim.x is either 1 or a even number.
inline dim3 GetBlockSize(int dim_size, int inner_size) {
int inner_threads = inner_size;
inner_threads = std::min(inner_threads, 1024);
int dim_threads = 1;
while (dim_threads * inner_threads <= 1024 && dim_threads <= dim_size) {
dim_threads *= 2;
}
dim_threads /= 2;
return dim3(dim_threads, inner_threads);
}
// First cover the y axis as many blocks as possible.
// Then cover the x axis as many blocks as possible,
// and make sure not to exceed the max_active_blocks.
inline dim3 GetGridSize(dim3 block, int max_active_blocks, int outer_size,
int dim_size, int inner_size) {
int inner_blocks = (inner_size + block.y - 1) / block.y;
if (inner_blocks > max_active_blocks) inner_blocks = max_active_blocks;
int outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks;
if (outer_blocks > outer_size) outer_blocks = outer_size;
return dim3(outer_blocks, inner_blocks);
}
// When designing grid size and block size, priority is given to block size,
// and grid will be determined according to the maximum number of active blocks,
// which is set by as a experience value.
template <typename T, typename Kernel>
void ComputeLaunchConfigure(Kernel k, int outer_size, int dim_size,
int inner_size, dim3 &grid, dim3 &block,
int &shared_mem, int num_sm) {
block = GetBlockSize(dim_size, inner_size);
int block_threads = block.x * block.y;
shared_mem = block.x == 1 ? 0 : block_threads * sizeof(T);
int max_active_blocks = num_sm * 2;
grid =
GetGridSize(block, max_active_blocks, outer_size, dim_size, inner_size);
}
template <typename T, typename MPDType>
void LaunchLogSoftmaxForwardCUDAKernelNotLastAxis(T *output_data,
const T *input_data,
int outer_size, int dim_size,
int inner_size, int num_sm,
gpuStream_t stream) {
int shared_mem;
dim3 grid;
dim3 block;
ComputeLaunchConfigure<MPDType>(
&LogSoftmaxForwardCUDAKernelNotLastAxis<T, MPDType>, outer_size, dim_size,
inner_size, grid, block, shared_mem, num_sm);
LogSoftmaxForwardCUDAKernelNotLastAxis<
T, MPDType><<<grid, block, shared_mem, stream>>>(
output_data, input_data, outer_size, dim_size, inner_size);
}
template <typename T> template <typename T>
class LogSoftmaxKernel<platform::CUDADeviceContext, T> class LogSoftmaxKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = context.Input<framework::Tensor>("X"); auto *x = ctx.Input<Tensor>("X");
auto *out = context.Output<framework::Tensor>("Out"); auto *out = ctx.Output<Tensor>("Out");
const auto *input_data = x->data<T>(); out->mutable_data<T>(ctx.GetPlace());
auto *output_data = out->mutable_data<T>(context.GetPlace());
const int rank = x->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int dim_size = x->dims()[axis]; int input_axis = ctx.Attr<int>("axis");
int inner_size = 1; auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
for (int i = axis + 1; i < x->dims().size(); ++i) { phi::SoftmaxForwardCUDAKernelDriver<T, true>(dev_ctx, *x, input_axis, out);
inner_size *= x->dims()[i];
}
int outer_size = SizeToAxis(axis, x->dims());
gpuStream_t stream = context.cuda_device_context().stream();
int num_sm = context.cuda_device_context().GetSMCount();
if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) {
LaunchSoftmaxForwardForLastAxis<T, MPDType>(output_data, input_data,
dim_size, outer_size, stream);
} else {
LaunchLogSoftmaxForwardCUDAKernelNotLastAxis<T, MPDType>(
output_data, input_data, outer_size, dim_size, inner_size, num_sm,
stream);
}
} }
}; };
// Backward below
#define LAUNCH_WARP_BACKWARD_COMPUTE(near_greater_power_of_two) \
case near_greater_power_of_two: \
ComputeLogSoftmaxBackwardInWarp< \
T, AccT, near_greater_power_of_two><<<blocks, threads, 0, stream>>>( \
output, grad_output, grad_input, outer_size, dim_size); \
break;
template <typename T, typename AccT, int NearGreaterPowerOfTwo>
__global__ void ComputeLogSoftmaxBackwardInWarp(const T *output,
const T *grad_output,
T *grad_input, int batch_size,
int element_count) {
constexpr int near_greater_power_of_two = NearGreaterPowerOfTwo;
constexpr int kernel_warp_size =
(near_greater_power_of_two < 32) ? near_greater_power_of_two : 32;
constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size;
int batch_id = blockDim.y * blockIdx.x + threadIdx.y;
int thread_in_warp_idx = threadIdx.x;
// 1.read data from global memory to registers
AccT output_register[warp_iter];
AccT grad_output_register[warp_iter];
int effective_element_count = (batch_id < batch_size) ? element_count : 0;
for (int iter = 0; iter < warp_iter; ++iter) {
int element_index = thread_in_warp_idx + iter * kernel_warp_size;
if (element_index < effective_element_count) {
output_register[iter] =
static_cast<AccT>(output[batch_id * element_count + element_index]);
grad_output_register[iter] = static_cast<AccT>(
grad_output[batch_id * element_count + element_index]);
} else {
output_register[iter] = static_cast<AccT>(0);
grad_output_register[iter] = static_cast<AccT>(0);
}
}
// 2. For each warp, accumulate all thread registers
AccT sum = grad_output_register[0];
#pragma unroll
for (int iter = 1; iter < warp_iter; ++iter) {
sum += grad_output_register[iter];
}
sum = WarpReduceSum<AccT, kernel_warp_size>(sum);
// 3. write result in grad_input
#pragma unroll
for (int iter = 0; iter < warp_iter; ++iter) {
int element_index = thread_in_warp_idx + iter * kernel_warp_size;
if (element_index < effective_element_count) {
grad_input[batch_id * element_count + element_index] = static_cast<T>(
(grad_output_register[iter] - std::exp(output_register[iter]) * sum));
}
}
}
template <typename T, typename AccT>
void LaunchSoftmaxBackwardForLastAxis(T *grad_input, const T *grad_output,
const T *output, int dim_size,
int outer_size, gpuStream_t stream) {
int threads_per_block = 128;
int near_greater_power_of_two = GetNearGreaterPowerOfTwo(dim_size);
int kernel_warp_size =
(near_greater_power_of_two < 32) ? near_greater_power_of_two : 32;
int warps_per_block = (threads_per_block / kernel_warp_size);
int blocks = (outer_size + warps_per_block - 1) / warps_per_block;
dim3 threads(kernel_warp_size, warps_per_block, 1);
switch (near_greater_power_of_two) {
LAUNCH_WARP_BACKWARD_COMPUTE(1); // dim_size: 1
LAUNCH_WARP_BACKWARD_COMPUTE(2); // dim_size: 2
LAUNCH_WARP_BACKWARD_COMPUTE(4); // dim_size: 3~4
LAUNCH_WARP_BACKWARD_COMPUTE(8); // dim_size: 5~8
LAUNCH_WARP_BACKWARD_COMPUTE(16); // dim_size: 9~16
LAUNCH_WARP_BACKWARD_COMPUTE(32); // dim_size: 17~32
LAUNCH_WARP_BACKWARD_COMPUTE(64); // dim_size: 33~64
LAUNCH_WARP_BACKWARD_COMPUTE(128); // dim_size: 65~128
LAUNCH_WARP_BACKWARD_COMPUTE(256); // dim_size: 129~256
LAUNCH_WARP_BACKWARD_COMPUTE(512); // dim_size: 257~512
LAUNCH_WARP_BACKWARD_COMPUTE(1024); // dim_size: 513~1024
default:
break;
}
}
template <typename T> template <typename T>
class LogSoftmaxGradKernel<platform::CUDADeviceContext, T> class LogSoftmaxGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto *out = context.Input<framework::Tensor>("Out"); auto *out = ctx.Input<Tensor>("Out");
const auto *d_out = auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
context.Input<framework::Tensor>(framework::GradVarName("Out")); auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_x = context.Output<framework::Tensor>(framework::GradVarName("X")); dx->mutable_data<T>(ctx.GetPlace());
const auto *out_data = out->data<T>(); int input_axis = ctx.Attr<int>("axis");
const auto *d_out_data = d_out->data<T>(); auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto *d_x_data = d_x->mutable_data<T>(context.GetPlace()); phi::SoftmaxBackwardCUDAKernelDriver<T, true>(dev_ctx, *out, *dout,
input_axis, dx);
const int rank = out->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int dim_size = out->dims()[axis];
int inner_size = 1;
for (int i = axis + 1; i < out->dims().size(); ++i) {
inner_size *= out->dims()[i];
}
int outer_size = SizeToAxis(axis, out->dims());
gpuStream_t stream = context.cuda_device_context().stream();
if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) {
LaunchSoftmaxBackwardForLastAxis<T, MPDType>(
d_x_data, d_out_data, out_data, dim_size, outer_size, stream);
} else {
LogSoftmaxGradFunctor<platform::CUDADeviceContext, T>()(
context.template device_context<platform::CUDADeviceContext>(), out,
d_out, d_x, axis);
}
} }
}; };
...@@ -473,6 +57,17 @@ class LogSoftmaxGradKernel<platform::CUDADeviceContext, T> ...@@ -473,6 +57,17 @@ class LogSoftmaxGradKernel<platform::CUDADeviceContext, T>
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(
log_softmax, ops::LogSoftmaxKernel<plat::CUDADeviceContext, float>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, plat::float16>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(
log_softmax_grad, ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, float>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::bfloat16>);
#else
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
log_softmax, ops::LogSoftmaxKernel<plat::CUDADeviceContext, float>, log_softmax, ops::LogSoftmaxKernel<plat::CUDADeviceContext, float>,
ops::LogSoftmaxKernel<plat::CUDADeviceContext, double>, ops::LogSoftmaxKernel<plat::CUDADeviceContext, double>,
...@@ -483,3 +78,4 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -483,3 +78,4 @@ REGISTER_OP_CUDA_KERNEL(
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, double>, ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, double>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::float16>, ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::bfloat16>); ops::LogSoftmaxGradKernel<plat::CUDADeviceContext, plat::bfloat16>);
#endif
...@@ -351,8 +351,17 @@ __global__ void WarpSoftmaxForward(T* softmax, ...@@ -351,8 +351,17 @@ __global__ void WarpSoftmaxForward(T* softmax,
VecT* softmax_v = VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]); reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]); VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnaryDivFunctor<AccT>>( if (LogMode) {
&out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor<AccT>(sum[i])); kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnaryLogFunctor<AccT>>(
&srcdata[i][0][0], &srcdata[i][0][0], UnaryLogFunctor<AccT>());
kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnarySubFunctor<AccT>>(
&out_tmp[i][0][0],
&srcdata[i][0][0],
UnarySubFunctor<AccT>(std::log(sum[i])));
} else {
kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnaryDivFunctor<AccT>>(
&out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor<AccT>(sum[i]));
}
kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>( kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
&softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1); &softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
} }
...@@ -434,15 +443,25 @@ __global__ void WarpSoftmaxBackward(T* dst, ...@@ -434,15 +443,25 @@ __global__ void WarpSoftmaxBackward(T* dst,
AccT sum_tmp[kBatchSize][kLoopsV][kVSize]; AccT sum_tmp[kBatchSize][kLoopsV][kVSize];
AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[0][0][0]); AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[0][0][0]);
AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[0][0][0]); AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[0][0][0]);
kps::ElementwiseBinary<AccT, AccT, kStep, 1, 1, kps::MulFunctor<AccT>>( if (LogMode) {
&sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor<AccT>()); kps::Reduce<AccT,
kps::Reduce<AccT, kVItem,
kVItem, kBatchSize,
kBatchSize, 1,
1, kps::AddFunctor<AccT>,
kps::AddFunctor<AccT>, kps::details::ReduceMode::kLocalMode>(
kps::details::ReduceMode::kLocalMode>( &sum[0], &grad_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
&sum[0], &sum_tmp[0][0][0], kps::AddFunctor<AccT>(), true); } else {
kps::ElementwiseBinary<AccT, AccT, kStep, 1, 1, kps::MulFunctor<AccT>>(
&sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor<AccT>());
kps::Reduce<AccT,
kVItem,
kBatchSize,
1,
kps::AddFunctor<AccT>,
kps::details::ReduceMode::kLocalMode>(
&sum[0], &sum_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
}
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum); WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// write result to global memory // write result to global memory
...@@ -453,10 +472,23 @@ __global__ void WarpSoftmaxBackward(T* dst, ...@@ -453,10 +472,23 @@ __global__ void WarpSoftmaxBackward(T* dst,
if (i >= local_batches) break; if (i >= local_batches) break;
AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[i][0][0]); AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[i][0][0]);
AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[i][0][0]); AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[i][0][0]);
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnarySubFunctor<AccT>>( if (LogMode) {
&out[i][0][0], &gradptr[0], UnarySubFunctor<AccT>(sum[i])); kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpMulFunctor<AccT>>(
kps::ElementwiseBinary<AccT, T, kVItem, 1, 1, kps::MulFunctor<AccT>>( &out[i][0][0], &srcptr[0], ExpMulFunctor<AccT>(sum[i]));
&out_tmp[i][0][0], &srcptr[0], &out[i][0][0], kps::MulFunctor<AccT>()); kps::ElementwiseBinary<AccT, T, kVItem, 1, 1, kps::SubFunctor<AccT>>(
&out_tmp[i][0][0],
&gradptr[0],
&out[i][0][0],
kps::SubFunctor<AccT>());
} else {
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnarySubFunctor<AccT>>(
&out[i][0][0], &gradptr[0], UnarySubFunctor<AccT>(sum[i]));
kps::ElementwiseBinary<AccT, T, kVItem, 1, 1, kps::MulFunctor<AccT>>(
&out_tmp[i][0][0],
&srcptr[0],
&out[i][0][0],
kps::MulFunctor<AccT>());
}
VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]); VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]); VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>( kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
...@@ -639,7 +671,8 @@ __global__ void NormalSoftmaxForward( ...@@ -639,7 +671,8 @@ __global__ void NormalSoftmaxForward(
template <typename T, template <typename T,
typename AccT, typename AccT,
template <typename, typename> class Functor> template <typename, typename> class Functor,
bool LogMode>
__global__ void NormalSoftmaxBackward(T* input_grad, __global__ void NormalSoftmaxBackward(T* input_grad,
const T* output_grad, const T* output_grad,
const T* output, const T* output,
...@@ -656,10 +689,17 @@ __global__ void NormalSoftmaxBackward(T* input_grad, ...@@ -656,10 +689,17 @@ __global__ void NormalSoftmaxBackward(T* input_grad,
// 1. reduce sum // 1. reduce sum
AccT sum = 0; AccT sum = 0;
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) { if (LogMode) {
int data_offset = grad_offset + mid_id * mid_stride; for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
sum += static_cast<AccT>(output_grad[data_offset]) * int data_offset = grad_offset + mid_id * mid_stride;
static_cast<AccT>(output[data_offset]); sum += static_cast<AccT>(output_grad[data_offset]);
}
} else {
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
int data_offset = grad_offset + mid_id * mid_stride;
sum += static_cast<AccT>(output_grad[data_offset]) *
static_cast<AccT>(output[data_offset]);
}
} }
if (blockDim.y > 1) { if (blockDim.y > 1) {
kps::Reduce<AccT, 1, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>( kps::Reduce<AccT, 1, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
...@@ -715,10 +755,10 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx, ...@@ -715,10 +755,10 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx,
dim3 grid, block; dim3 grid, block;
GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block); GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
if (LogMode) { if (LogMode) {
NormalSoftmaxBackward< NormalSoftmaxBackward<T,
T, AccT,
AccT, LogSoftmaxBackwardFunctor,
LogSoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>( LogMode><<<grid, block, 0, dev_ctx.stream()>>>(
input_grad_data, input_grad_data,
output_grad_data, output_grad_data,
output_data, output_data,
...@@ -726,10 +766,10 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx, ...@@ -726,10 +766,10 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx,
mid_dim, mid_dim,
low_dim); low_dim);
} else { } else {
NormalSoftmaxBackward< NormalSoftmaxBackward<T,
T, AccT,
AccT, SoftmaxBackwardFunctor,
SoftmaxBackwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>( LogMode><<<grid, block, 0, dev_ctx.stream()>>>(
input_grad_data, input_grad_data,
output_grad_data, output_grad_data,
output_data, output_data,
...@@ -864,6 +904,32 @@ static bool CanUseCudnnSoftmax(const GPUContext& dev_ctx) { ...@@ -864,6 +904,32 @@ static bool CanUseCudnnSoftmax(const GPUContext& dev_ctx) {
return false; return false;
} }
#if CUDNN_VERSION < 8100
template <>
inline void SoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const DenseTensor& x,
const int axis,
const bool log_mode,
DenseTensor* out) {
PADDLE_THROW(errors::Unavailable(
"This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
"8100."));
}
template <>
inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
const GPUContext& dev_ctx,
const DenseTensor& out,
const DenseTensor& dout,
const int axis,
const bool log_mode,
DenseTensor* dx) {
PADDLE_THROW(errors::Unavailable(
"This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < "
"8100."));
}
#endif
template <typename T, bool LogMode = false> template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册