diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu b/paddle/fluid/operators/softmax_cudnn_op.cu index b62d71bdbc4dba43d749c6d7eeb20519908b1822..83b7b78aaec909f7d8924eaf2a2ff46372bbb8c7 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu +++ b/paddle/fluid/operators/softmax_cudnn_op.cu @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/math/math_cuda_utils.h" +#include "paddle/fluid/operators/softmax_impl.cuh" #include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/platform/cuda_device_function.h" #ifdef PADDLE_WITH_HIP @@ -21,7 +23,6 @@ limitations under the License. */ #else #include "paddle/fluid/platform/cudnn_helper.h" #endif -#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace platform { @@ -37,288 +38,414 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using DataLayout = platform::DataLayout; using Tensor = framework::Tensor; -#define LAUNCH_SOFTMAX_WARP_FORWARD(Log2Elements) \ - case Log2Elements: \ - WarpSoftmaxForward<<< \ - blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \ - out_data, x->data(), N, dim, dim); \ - break; - -#define LAUNCH_SOFTMAX_WARP_BACKWARD(Log2Elements) \ - case Log2Elements: \ - softmax_warp_backward<<< \ - blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \ - dx_data, mul_grad.data(), out->data(), N, dim, dim); \ - break; - -static inline int SizeOutAxis(const int axis, DDim dims) { - int size = 1; - for (int i = axis + 1; i < dims.size(); i++) { - size *= dims[i]; - } - return size; -} - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -union vec_t { - static_assert(sizeof(T) == -1, "vec_t is only available by specialization."); +// Vectorization trait 4 * sizeof(T) +template +class VecT4 {}; +template <> +class VecT4 { + public: + using Type = long4; }; - template <> -union vec_t { - float4 s; - float v[4]; +class VecT4 { + public: + using Type = int4; +}; +template <> +class VecT4 { + public: + using Type = int2; }; +// Vectorization trait 2 * sizeof(T) +template +class VecT2 {}; template <> -union vec_t { - int2 s; - platform::float16 v[4]; +class VecT2 { + public: + using Type = int4; +}; +template <> +class VecT2 { + public: + using Type = int2; +}; +template <> +class VecT2 { + public: + using Type = int; }; -template -__global__ void VecSoftmaxForward(T* dst, const T* src, const int batch_size, - const int softmax_ele) { - int offset = blockIdx.x * softmax_ele * WARP_PER_BLOCK; - int idx = threadIdx.x * VPT; - - VECT buf = reinterpret_cast(&src[offset + idx])[0]; - T* bufp = reinterpret_cast(&buf); - float4 val4; - float* val4p = reinterpret_cast(&val4); - for (int i = 0; i < VPT; ++i) { - val4p[i] = static_cast(bufp[i]); - } - float val = val4.x + val4.y + val4.z + val4.w; - float max_val = math::warpReduceMax( - max(max(val4.x, val4.y), max(val4.z, val4.w)), 0xffffffff); - float4 tmp4 = make_float4(__expf(val4.x - max_val), __expf(val4.y - max_val), - __expf(val4.z - max_val), __expf(val4.w - max_val)); - float* tmp4p = reinterpret_cast(&tmp4); - float invsum = 1.f / (math::warpReduceSum( - tmp4.x + tmp4.y + tmp4.z + tmp4.w, 0xffffffff) + - 1e-6f); - for (int i = 0; i < VPT; ++i) { - bufp[i] = static_cast(tmp4p[i] * invsum); - } - reinterpret_cast(&dst[offset + idx])[0] = buf; +int static inline log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; } -template -__device__ __forceinline__ void warp_reduce_sum(T* sum) { +/* +Core function of computing softmax forward for axis=-1. +The computation includes + - Compute maximum of batch: maxvalue_{i} = max_j src_{i,j} + - Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} } + - Compute: (a_{i,j} - maxvalue_{i}) / s_{i} +One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). +For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle +api to compute max (sum) in one warp. +*/ +template +__global__ void WarpSoftmaxForward(T* softmax, const T* src, + const int batch_size, const int stride, + const int element_count) { + constexpr int kDimCeil = 1 << Log2Elements; + constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + constexpr int kVSize = sizeof(VecT) / sizeof(T); + constexpr int kIterations = kDimCeil / kWarpSize; + constexpr int kIterationsV = + (kIterations >= kVSize) ? (kIterations / kVSize) : 1; + constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; + + // max index to read + int idx_max_v[kBatchSize]; #pragma unroll - for (int offset = WARP_SIZE_SOFTMAX / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); - sum[i] = sum[i] + sum_val; - } + for (int i = 0; i < kBatchSize; i++) { + int idx_max = ((i + first_batch) < batch_size) ? element_count : 0; + idx_max_v[i] = idx_max / kVSize; } -} -template -__device__ __forceinline__ void warp_reduce_max(T* sum) { + // read data from global memory + AccT srcdata[kBatchSize][kIterationsV][kVSize]; + +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { +// read data +#pragma unroll + for (int it = 0; it < kIterationsV; ++it) { + int src_idx = threadIdx.x + it * kWarpSize; + if (kVSize == 1) { + if (src_idx < idx_max_v[i]) { + srcdata[i][it][0] = + static_cast(src[(first_batch + i) * stride + src_idx]); + } else { + srcdata[i][it][0] = -std::numeric_limits::infinity(); + } + } else { + const VecT* src_v = + reinterpret_cast(&src[(first_batch + i) * stride]); + if (src_idx < idx_max_v[i]) { + VecT srctmp = src_v[src_idx]; + const T* srcinptr = reinterpret_cast(&srctmp); #pragma unroll - for (int offset = WARP_SIZE_SOFTMAX / 2; offset > 0; offset /= 2) { + for (int s = 0; s < kVSize; s++) { + srcdata[i][it][s] = static_cast(srcinptr[s]); + } + } else { #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); - sum[i] = max(sum[i], max_val); + for (int s = 0; s < kVSize; s++) { + srcdata[i][it][s] = -std::numeric_limits::infinity(); + } + } + } } } -} - -template -__global__ void WarpSoftmaxForward(T* dst, const T* src, const int batch_size, - const int stride, const int element_count) { - constexpr int next_power_of_two = 1 << Log2Elements; - constexpr int warp_size_softmax = - (next_power_of_two < 32) ? next_power_of_two : 32; - constexpr int WARP_ITERATIONS = next_power_of_two / warp_size_softmax; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) { - local_batches = WARP_BATCH; - } - - int local_idx = threadIdx.x; - - src += first_batch * stride + local_idx; - dst += first_batch * stride + local_idx; + // compute max value + AccT max_value[kBatchSize]; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { + // it = 0 + AccT valmax = srcdata[i][0][0]; +#pragma unroll + for (int s = 1; s < kVSize; ++s) { + valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s]; + } + max_value[i] = valmax; - // load data from global memory - AccT elements[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * warp_size_softmax; - if (element_index < batch_element_count) { - elements[i][it] = - static_cast(src[i * element_count + it * warp_size_softmax]); - } else { - elements[i][it] = -std::numeric_limits::infinity(); +// it = 1, 2, ... +#pragma unroll + for (int it = 1; it < kIterationsV; ++it) { + AccT valmax = srcdata[i][it][0]; +#pragma unroll + for (int s = 1; s < kVSize; ++s) { + valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s]; } + max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax; } } + WarpReduceMax(max_value); - // compute max_value - AccT max_value[WARP_BATCH]; + // compute sum + AccT sum[kBatchSize]; #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; + for (int i = 0; i < kBatchSize; ++i) { + // it = 0 + if (LogMode) { + sum[i] = std::exp(srcdata[i][0][0] - max_value[i]); + } else { + srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]); + sum[i] = srcdata[i][0][0]; + } #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + for (int s = 1; s < kVSize; ++s) { + if (LogMode) { + sum[i] += std::exp(srcdata[i][0][s] - max_value[i]); + } else { + srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]); + sum[i] += srcdata[i][0][s]; + } } - } - warp_reduce_max(max_value); - AccT sum[WARP_BATCH]{0.0f}; +// it = 1, 2, ... #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { + for (int it = 1; it < kIterationsV; ++it) { #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = (std::exp((elements[i][it] - max_value[i]))); - sum[i] += elements[i][it]; + for (int s = 0; s < kVSize; ++s) { + if (LogMode) { + sum[i] += std::exp(srcdata[i][it][s] - max_value[i]); + } else { + srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]); + sum[i] += srcdata[i][it][s]; + } + } } } - warp_reduce_sum(sum); + WarpReduceSum(sum); -// store result +// write result to global memory #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) break; + for (int i = 0; i < kBatchSize; ++i) { + if (LogMode) { + sum[i] = std::log(sum[i]); + } + #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * warp_size_softmax; - if (element_index < element_count) { - dst[i * element_count + it * warp_size_softmax] = - elements[i][it] / sum[i]; + for (int it = 0; it < kIterationsV; ++it) { + int idx = threadIdx.x + it * kWarpSize; + if (kVSize == 1) { + if (idx < idx_max_v[i]) { + if (LogMode) { + softmax[(first_batch + i) * stride + idx] = + srcdata[i][it][0] - max_value[i] - sum[i]; + } else { + softmax[(first_batch + i) * stride + idx] = + srcdata[i][it][0] / sum[i]; + } + } else { + break; + } } else { - break; + VecT* softmax_v = + reinterpret_cast(&softmax[(first_batch + i) * stride]); + VecT tmpdata; + T* tmpptr = reinterpret_cast(&tmpdata); +#pragma unroll + for (int s = 0; s < kVSize; ++s) { + if (LogMode) { + tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i]; + } else { + tmpptr[s] = srcdata[i][it][s] / sum[i]; + } + } + + if (idx < idx_max_v[i]) { + softmax_v[idx] = tmpdata; + } else { + break; + } } } } } -template -__global__ void softmax_warp_backward(T* gradInput, const T* grad, - const T* output, int batch_size, - int stride, int element_count) { - constexpr int next_power_of_two = 1 << Log2Elements; - constexpr int warp_size_softmax = - (next_power_of_two < 32) ? next_power_of_two : 32; - constexpr int WARP_ITERATIONS = next_power_of_two / warp_size_softmax; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - +/* +Core function of computing softmax backward for axis=-1. +The computation includes + - Compute sum of exp batch: s_{i} = sum_{j} {src_{i,j} * grad_{i,j} + - Compute src_{i,j} * ( grad_{i,j}) - s_{i} ) +One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). +For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle +api to compute max (sum) in one warp. +*/ +template +__global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src, + int batch_size, int stride, + int element_count) { + constexpr int kVSize = sizeof(VecT) / sizeof(T); + constexpr int kDimCeil = 1 << Log2Elements; + constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + constexpr int kIterations = kDimCeil / kWarpSize; + constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1; + constexpr int kIterationsV = + (kIterations >= kVSize) ? (kIterations / kVSize) : 1; + int element_count_v = element_count / kVSize; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) { - local_batches = WARP_BATCH; + if (local_batches > kBatchSize) { + local_batches = kBatchSize; } - int local_idx = threadIdx.x % warp_size_softmax; - - int thread_offset = first_batch * stride + local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - AccT grad_reg[WARP_BATCH][WARP_ITERATIONS]; - AccT output_reg[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * warp_size_softmax; - if (element_index < batch_element_count) { - grad_reg[i][it] = - static_cast(grad[i * element_count + it * warp_size_softmax]); - output_reg[i][it] = static_cast( - output[i * element_count + it * warp_size_softmax]); + // read data from global memory + VecT src_reg[kBatchSize][kIterationsV]; + VecT grad_reg[kBatchSize][kIterationsV]; + + for (int i = 0; i < kBatchSize; ++i) { + const VecT* src_v = + reinterpret_cast(&src[(first_batch + i) * stride]); + const VecT* grad_v = + reinterpret_cast(&grad[(first_batch + i) * stride]); + + // max index to read + int idx_max = (i < local_batches) ? element_count : 0; + int idx_max_v = idx_max / kVSize; + + // read data + for (int it = 0; it < kIterationsV; ++it) { + int src_idx = threadIdx.x + it * kWarpSize; + if (src_idx < idx_max_v) { + src_reg[i][it] = src_v[src_idx]; + grad_reg[i][it] = grad_v[src_idx]; } else { - grad_reg[i][it] = AccT(0); - output_reg[i][it] = AccT(0); +#pragma unroll + for (int s = 0; s < kVSize; s++) { + reinterpret_cast(&src_reg[i][it])[s] = 0.0; + reinterpret_cast(&grad_reg[i][it])[s] = 0.0; + } } } } - AccT sum[WARP_BATCH]; + // compute sum + AccT sum[kBatchSize]{0.0}; +#pragma unroll + for (int i = 0; i < kBatchSize; ++i) { #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; + for (int it = 0; it < kIterationsV; ++it) { + T* gradptr = reinterpret_cast(&grad_reg[i][it]); + T* srcptr = reinterpret_cast(&src_reg[i][it]); #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; + for (int s = 0; s < kVSize; ++s) { + if (LogMode) { + sum[i] += static_cast(gradptr[s]); + } else { + sum[i] += static_cast(gradptr[s] * srcptr[s]); + } + } } } - warp_reduce_sum(sum); + WarpReduceSum(sum); -// store result +// write result #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { + for (int i = 0; i < kBatchSize; ++i) { if (i >= local_batches) break; + + VecT* dst_v = reinterpret_cast(&dst[(first_batch + i) * stride]); + + // max index to write + int idx_max = (i < local_batches) ? element_count : 0; + int idx_max_v = idx_max / kVSize; + #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * warp_size_softmax; - if (element_index < element_count) { - // compute gradients - gradInput[i * element_count + it * warp_size_softmax] = - (grad_reg[i][it] - output_reg[i][it] * sum[i]); + for (int it = 0; it < kIterationsV; ++it) { + VecT tmpdata; + T* tmpptr = reinterpret_cast(&tmpdata); + T* gradptr = reinterpret_cast(&grad_reg[i][it]); + T* srcptr = reinterpret_cast(&src_reg[i][it]); +#pragma unroll + for (int s = 0; s < kVSize; ++s) { + if (LogMode) { + tmpptr[s] = static_cast(gradptr[s]) - + std::exp(static_cast(srcptr[s])) * sum[i]; + } else { + tmpptr[s] = static_cast(srcptr[s]) * + (static_cast(gradptr[s]) - sum[i]); + } + } + + int idx = threadIdx.x + it * kWarpSize; + if (idx < idx_max_v) { + dst_v[idx] = tmpdata; } } } } -template -__global__ void MultiplyCUDAKernel(T* C, const T* A, const T* B, int N) { - CUDA_KERNEL_LOOP(i, N) { - C[i] = static_cast(static_cast(A[i]) * static_cast(B[i])); +#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \ + case Log2Elements: \ + WarpSoftmaxForward< \ + T, VecT, AccT, Log2Elements, \ + LogMode><<>>( \ + dst, src, batch_size, stride, element_count); \ + break; + +/* + Wrapper of softmax formward with template instantiation on size of input. +*/ +template +void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads, + const framework::ExecutionContext& ctx, T* dst, + const T* src, const int batch_size, + const int stride, const int element_count, + int Log2Elements) { + using AccT = typename details::MPTypeTrait::Type; + switch (Log2Elements) { + SOFTMAX_WARP_FORWARD_CASE(0, AccT); + SOFTMAX_WARP_FORWARD_CASE(1, AccT); + SOFTMAX_WARP_FORWARD_CASE(2, AccT); + SOFTMAX_WARP_FORWARD_CASE(3, AccT); + SOFTMAX_WARP_FORWARD_CASE(4, AccT); + SOFTMAX_WARP_FORWARD_CASE(5, AccT); + SOFTMAX_WARP_FORWARD_CASE(6, AccT); + SOFTMAX_WARP_FORWARD_CASE(7, AccT); + SOFTMAX_WARP_FORWARD_CASE(8, AccT); + SOFTMAX_WARP_FORWARD_CASE(9, AccT); + default: + break; } } -template -__global__ void VecSoftmaxBackward(T* dst, const T* grad, const T* src, - const int batch_size, - const int softmax_ele) { - const int offset = - blockIdx.x * softmax_ele * WARP_PER_BLOCK + threadIdx.x * VPT; - - float local_sum_gy = 0.f; - vec_t local_grad; - vec_t local_src; - - local_grad.s = - reinterpret_cast(&grad[offset])[0]; - local_src.s = reinterpret_cast(&src[offset])[0]; - - for (int i = 0; i < VPT; ++i) { - local_sum_gy += static_cast(local_grad.v[i]) * - static_cast(local_src.v[i]); - } - float sum_gy = math::warpReduceSum(local_sum_gy, 0xffffffff); +#define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT) \ + case Log2Elements: \ + WarpSoftmaxBackward< \ + T, VecT, AccT, Log2Elements, \ + LogMode><<>>( \ + dst, grad, src, batch_size, stride, element_count); \ + break; - vec_t local_dst; - for (int i = 0; i < VPT; ++i) { - local_dst.v[i] = - static_cast(static_cast(local_src.v[i]) * - (static_cast(local_grad.v[i]) - sum_gy)); +/* +Wrapper of softmax backward with template instantiation on size of input. +*/ +template +void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads, + const framework::ExecutionContext& ctx, T* dst, + const T* grad, const T* src, + const int batch_size, const int stride, + const int element_count, int Log2Elements) { + using AccT = typename details::MPTypeTrait::Type; + switch (Log2Elements) { + SOFTMAX_WARP_BACKWARD_CASE(0, AccT); + SOFTMAX_WARP_BACKWARD_CASE(1, AccT); + SOFTMAX_WARP_BACKWARD_CASE(2, AccT); + SOFTMAX_WARP_BACKWARD_CASE(3, AccT); + SOFTMAX_WARP_BACKWARD_CASE(4, AccT); + SOFTMAX_WARP_BACKWARD_CASE(5, AccT); + SOFTMAX_WARP_BACKWARD_CASE(6, AccT); + SOFTMAX_WARP_BACKWARD_CASE(7, AccT); + SOFTMAX_WARP_BACKWARD_CASE(8, AccT); + SOFTMAX_WARP_BACKWARD_CASE(9, AccT); + default: + break; } - reinterpret_cast(&dst[offset])[0] = local_dst.s; } -template +#undef SOFTMAX_WARP_FORWARD_CASE +#undef SOFTMAX_WARP_BACKWARD_CASE + +template class SoftmaxCUDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -335,60 +462,39 @@ class SoftmaxCUDNNKernel : public framework::OpKernel { const int D = SizeOutAxis(axis, dims); constexpr int max_dim = 320; - bool optimize = false; constexpr int warps_per_block = 4; + if (D == 1 && dim <= max_dim && sizeof(T) <= 4) { - if (dim == 128 && N % warps_per_block == 0) { - optimize = true; - // a warp for a batch, 4 elements for a thread, only support the softmax - // dim size = 128 currently - if (sizeof(T) == 2) { - VecSoftmaxForward<<< - N / warps_per_block, warps_per_block * WARP_SIZE, 0, - ctx.cuda_device_context().stream()>>>(out_data, x->data(), N, - dim); - } else if (sizeof(T) == 4) { - VecSoftmaxForward<<< - N / warps_per_block, warps_per_block * WARP_SIZE, 0, - ctx.cuda_device_context().stream()>>>(out_data, x->data(), N, - dim); - } else { - assert(false && "not support"); - } - } else if (dim < max_dim) { - optimize = true; - int log2_elements = static_cast(log2_ceil(dim)); - const int next_power_of_two = 1 << log2_elements; - - int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; - - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (N + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - - switch (log2_elements) { - LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 - LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 - LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4 - LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8 - LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16 - LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32 - LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64 - LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 - LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 - LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 - default: - break; - } + const int kDimLog2 = static_cast(log2_ceil(dim)); + const int kDimCeil = 1 << kDimLog2; + int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + int batches_per_warp = (kDimCeil <= 32) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / kWarpSize); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (N + batches_per_block - 1) / batches_per_block; + dim3 threads(kWarpSize, warps_per_block, 1); + + // vectorization read/write + using T4 = typename VecT4::Type; + using T2 = typename VecT2::Type; + if (dim % 4 == 0) { + SwitchWarpSoftmaxForward(blocks, threads, ctx, out_data, + x->data(), N, dim, dim, + kDimLog2); + } else if (dim % 2 == 0) { + SwitchWarpSoftmaxForward(blocks, threads, ctx, out_data, + x->data(), N, dim, dim, + kDimLog2); + } else { + SwitchWarpSoftmaxForward(blocks, threads, ctx, out_data, + x->data(), N, dim, dim, + kDimLog2); } - } - if (!optimize) { + } else { ScopedTensorDescriptor desc; std::vector tensor_dims = {N, dim, D, 1}; DataLayout layout = DataLayout::kNCHW; @@ -405,22 +511,37 @@ class SoftmaxCUDNNKernel : public framework::OpKernel { #ifdef PADDLE_WITH_HIP auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE : MIOPEN_SOFTMAX_MODE_CHANNEL; - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward( - handle, platform::CudnnDataType::kOne(), desc_, x->data(), - platform::CudnnDataType::kZero(), desc_, out_data)); + if (LogMode) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( + handle, platform::CudnnDataType::kOne(), desc_, x->data(), + platform::CudnnDataType::kZero(), desc_, out_data, + MIOPEN_SOFTMAX_LOG, mode)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2( + handle, platform::CudnnDataType::kOne(), desc_, x->data(), + platform::CudnnDataType::kZero(), desc_, out_data, + MIOPEN_SOFTMAX_ACCURATE, mode)); + } #else auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE : CUDNN_SOFTMAX_MODE_CHANNEL; - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward( - handle, CUDNN_SOFTMAX_ACCURATE, mode, - platform::CudnnDataType::kOne(), desc_, x->data(), - platform::CudnnDataType::kZero(), desc_, out_data)); + if (LogMode) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward( + handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType::kOne(), + desc_, x->data(), platform::CudnnDataType::kZero(), desc_, + out_data)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward( + handle, CUDNN_SOFTMAX_ACCURATE, mode, + platform::CudnnDataType::kOne(), desc_, x->data(), + platform::CudnnDataType::kZero(), desc_, out_data)); + } #endif } } }; -template +template class SoftmaxGradCUDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -437,78 +558,38 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel { const int N = SizeToAxis(axis, dims); const int D = SizeOutAxis(axis, dims); + constexpr int max_dim = 320; constexpr int warps_per_block = 4; - constexpr bool warp_softmax_available = - std::is_same::value || - std::is_same::value; - bool optimize = false; - if (D == 1 && warp_softmax_available) { - if (dim == 128 && N % warps_per_block == 0) { - optimize = true; - if (std::is_same::value) { - VecSoftmaxBackward<<< - N / warps_per_block, warps_per_block * WARP_SIZE, 0, - ctx.cuda_device_context().stream()>>>(dx->data(), - dout->data(), - out->data(), N, dim); - } else if (std::is_same::value) { - VecSoftmaxBackward<<< - N / warps_per_block, warps_per_block * WARP_SIZE, 0, - ctx.cuda_device_context().stream()>>>( - dx->data(), dout->data(), - out->data(), N, dim); - } else { - PADDLE_ENFORCE_EQ( - warp_softmax_available, true, - platform::errors::Unimplemented( - "Warp softmax backward is only available for fp32 and fp16")); - } - } else if (dim < 40 && dim % 32 != 0) { - optimize = true; - Tensor mul_grad; - int numel = N * dim; - mul_grad.mutable_data({numel}, ctx.GetPlace()); - - auto stream = ctx.cuda_device_context().stream(); - auto& dev_ctx = - ctx.template device_context(); - auto config = GetGpuLaunchConfig1D(dev_ctx, numel); - - MultiplyCUDAKernel<<>>( - mul_grad.data(), dout->data(), out->data(), numel); - - int log2_elements = log2_ceil(dim); - const int next_power_of_two = 1 << log2_elements; - - int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; - - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (N + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - - switch (log2_elements) { - LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1 - LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2 - LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4 - LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8 - LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16 - LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32 - LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64 - LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128 - LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256 - LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512 - default: - break; - } + + if (D == 1 && dim <= max_dim && sizeof(T) <= 4) { + const int kDimLog2 = log2_ceil(dim); + const int kDimCeil = 1 << kDimLog2; + int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + int batches_per_warp = (kDimCeil <= 128) ? 2 : 1; + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / kWarpSize); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (N + batches_per_block - 1) / batches_per_block; + dim3 threads(kWarpSize, warps_per_block, 1); + + // vectorization read/write + using T4 = typename VecT4::Type; + using T2 = typename VecT2::Type; + if (dim % 4 == 0) { + SwitchWarpSoftmaxBackward( + blocks, threads, ctx, dx_data, dout->data(), out->data(), N, + dim, dim, kDimLog2); + } else if (dim % 2 == 0) { + SwitchWarpSoftmaxBackward( + blocks, threads, ctx, dx_data, dout->data(), out->data(), N, + dim, dim, kDimLog2); + } else { + SwitchWarpSoftmaxBackward( + blocks, threads, ctx, dx_data, dout->data(), out->data(), N, + dim, dim, kDimLog2); } - } - if (!optimize) { + } else { ScopedTensorDescriptor desc; std::vector tensor_dims = {N, dim, D, 1}; DataLayout layout = DataLayout::kNCHW; @@ -525,18 +606,32 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel { #ifdef PADDLE_WITH_HIP auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE : MIOPEN_SOFTMAX_MODE_CHANNEL; - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward( - handle, platform::CudnnDataType::kOne(), desc_, out->data(), - desc_, dout->data(), platform::CudnnDataType::kZero(), desc_, - dx_data)); + if (LogMode) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2( + handle, platform::CudnnDataType::kOne(), desc_, out->data(), + desc_, dout->data(), platform::CudnnDataType::kZero(), desc_, + dx_data, MIOPEN_SOFTMAX_LOG, mode)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2( + handle, platform::CudnnDataType::kOne(), desc_, out->data(), + desc_, dout->data(), platform::CudnnDataType::kZero(), desc_, + dx_data, MIOPEN_SOFTMAX_ACCURATE, mode)); + } #else auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE : CUDNN_SOFTMAX_MODE_CHANNEL; - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward( - handle, CUDNN_SOFTMAX_ACCURATE, mode, - platform::CudnnDataType::kOne(), desc_, out->data(), desc_, - dout->data(), platform::CudnnDataType::kZero(), desc_, - dx_data)); + if (LogMode) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward( + handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType::kOne(), + desc_, out->data(), desc_, dout->data(), + platform::CudnnDataType::kZero(), desc_, dx_data)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward( + handle, CUDNN_SOFTMAX_ACCURATE, mode, + platform::CudnnDataType::kOne(), desc_, out->data(), desc_, + dout->data(), platform::CudnnDataType::kZero(), desc_, + dx_data)); + } #endif } } diff --git a/paddle/fluid/operators/softmax_impl.cuh b/paddle/fluid/operators/softmax_impl.cuh new file mode 100755 index 0000000000000000000000000000000000000000..2acc55d2398e99db465f5eeccb7972c456d55a33 --- /dev/null +++ b/paddle/fluid/operators/softmax_impl.cuh @@ -0,0 +1,47 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/platform/cuda_device_function.h" + +namespace paddle { +namespace operators { + +template +__device__ __forceinline__ void WarpReduceSum(T* sum) { +#pragma unroll + for (int offset = WarpSize / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < BatchSize; ++i) { + T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); + sum[i] = sum[i] + sum_val; + } + } +} + +template +__device__ __forceinline__ void WarpReduceMax(T* sum) { +#pragma unroll + for (int offset = WarpSize / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < BatchSize; ++i) { + T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); + sum[i] = max(sum[i], max_val); + } + } +} + +} // namespace operators +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index a964c3b57a635b3e5f0a4c163e3b3c13d465102b..08266318fb970ba976269991351152c22b38dbf2 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -45,6 +45,14 @@ static inline int SizeFromAxis(const int axis, DDim dims) { return size; } +static inline int SizeOutAxis(const int axis, DDim dims) { + int size = 1; + for (int i = axis + 1; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + template class SoftmaxKernel : public framework::OpKernel { public: