未验证 提交 63abd500 编写于 作者: X xingfeng01 提交者: GitHub

softmax reconstruction and optimization (#31821)

上级 8552a182
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h" #include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_impl.cuh"
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
...@@ -21,7 +23,6 @@ limitations under the License. */ ...@@ -21,7 +23,6 @@ limitations under the License. */
#else #else
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#endif #endif
#include "paddle/fluid/platform/gpu_launch_config.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -37,288 +38,414 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; ...@@ -37,288 +38,414 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
#define LAUNCH_SOFTMAX_WARP_FORWARD(Log2Elements) \ // Vectorization trait 4 * sizeof(T)
case Log2Elements: \ template <typename T>
WarpSoftmaxForward<T, float, Log2Elements><<< \ class VecT4 {};
blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \ template <>
out_data, x->data<T>(), N, dim, dim); \ class VecT4<double> {
break; public:
using Type = long4;
#define LAUNCH_SOFTMAX_WARP_BACKWARD(Log2Elements) \
case Log2Elements: \
softmax_warp_backward<T, float, Log2Elements><<< \
blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
dx_data, mul_grad.data<T>(), out->data<T>(), N, dim, dim); \
break;
static inline int SizeOutAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template <typename T, int VLEN>
union vec_t {
static_assert(sizeof(T) == -1, "vec_t is only available by specialization.");
}; };
template <> template <>
union vec_t<float, 4> { class VecT4<float> {
float4 s; public:
float v[4]; using Type = int4;
};
template <>
class VecT4<platform::float16> {
public:
using Type = int2;
}; };
// Vectorization trait 2 * sizeof(T)
template <typename T>
class VecT2 {};
template <>
class VecT2<double> {
public:
using Type = int4;
};
template <> template <>
union vec_t<platform::float16, 4> { class VecT2<float> {
int2 s; public:
platform::float16 v[4]; using Type = int2;
};
template <>
class VecT2<platform::float16> {
public:
using Type = int;
}; };
template <typename T, typename VECT, int VPT, int WARP_PER_BLOCK> int static inline log2_ceil(int value) {
__global__ void VecSoftmaxForward(T* dst, const T* src, const int batch_size, int log2_value = 0;
const int softmax_ele) { while ((1 << log2_value) < value) ++log2_value;
int offset = blockIdx.x * softmax_ele * WARP_PER_BLOCK; return log2_value;
int idx = threadIdx.x * VPT;
VECT buf = reinterpret_cast<const VECT*>(&src[offset + idx])[0];
T* bufp = reinterpret_cast<T*>(&buf);
float4 val4;
float* val4p = reinterpret_cast<float*>(&val4);
for (int i = 0; i < VPT; ++i) {
val4p[i] = static_cast<float>(bufp[i]);
}
float val = val4.x + val4.y + val4.z + val4.w;
float max_val = math::warpReduceMax<float>(
max(max(val4.x, val4.y), max(val4.z, val4.w)), 0xffffffff);
float4 tmp4 = make_float4(__expf(val4.x - max_val), __expf(val4.y - max_val),
__expf(val4.z - max_val), __expf(val4.w - max_val));
float* tmp4p = reinterpret_cast<float*>(&tmp4);
float invsum = 1.f / (math::warpReduceSum<float>(
tmp4.x + tmp4.y + tmp4.z + tmp4.w, 0xffffffff) +
1e-6f);
for (int i = 0; i < VPT; ++i) {
bufp[i] = static_cast<T>(tmp4p[i] * invsum);
}
reinterpret_cast<VECT*>(&dst[offset + idx])[0] = buf;
} }
template <typename T, int WARP_BATCH, int WARP_SIZE_SOFTMAX> /*
__device__ __forceinline__ void warp_reduce_sum(T* sum) { Core function of computing softmax forward for axis=-1.
The computation includes
- Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
- Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
- Compute: (a_{i,j} - maxvalue_{i}) / s_{i}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template <typename T, typename VecT, typename AccT, int Log2Elements,
bool LogMode = false>
__global__ void WarpSoftmaxForward(T* softmax, const T* src,
const int batch_size, const int stride,
const int element_count) {
constexpr int kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr int kVSize = sizeof(VecT) / sizeof(T);
constexpr int kIterations = kDimCeil / kWarpSize;
constexpr int kIterationsV =
(kIterations >= kVSize) ? (kIterations / kVSize) : 1;
constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
// max index to read
int idx_max_v[kBatchSize];
#pragma unroll #pragma unroll
for (int offset = WARP_SIZE_SOFTMAX / 2; offset > 0; offset /= 2) { for (int i = 0; i < kBatchSize; i++) {
#pragma unroll int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
for (int i = 0; i < WARP_BATCH; ++i) { idx_max_v[i] = idx_max / kVSize;
T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
sum[i] = sum[i] + sum_val;
}
} }
}
template <typename T, int WARP_BATCH, int WARP_SIZE_SOFTMAX> // read data from global memory
__device__ __forceinline__ void warp_reduce_max(T* sum) { AccT srcdata[kBatchSize][kIterationsV][kVSize];
#pragma unroll #pragma unroll
for (int offset = WARP_SIZE_SOFTMAX / 2; offset > 0; offset /= 2) { for (int i = 0; i < kBatchSize; ++i) {
// read data
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < kIterationsV; ++it) {
T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); int src_idx = threadIdx.x + it * kWarpSize;
sum[i] = max(sum[i], max_val); if (kVSize == 1) {
if (src_idx < idx_max_v[i]) {
srcdata[i][it][0] =
static_cast<AccT>(src[(first_batch + i) * stride + src_idx]);
} else {
srcdata[i][it][0] = -std::numeric_limits<AccT>::infinity();
}
} else {
const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
if (src_idx < idx_max_v[i]) {
VecT srctmp = src_v[src_idx];
const T* srcinptr = reinterpret_cast<const T*>(&srctmp);
#pragma unroll
for (int s = 0; s < kVSize; s++) {
srcdata[i][it][s] = static_cast<AccT>(srcinptr[s]);
}
} else {
#pragma unroll
for (int s = 0; s < kVSize; s++) {
srcdata[i][it][s] = -std::numeric_limits<AccT>::infinity();
}
}
} }
} }
}
template <typename T, typename AccT, int Log2Elements>
__global__ void WarpSoftmaxForward(T* dst, const T* src, const int batch_size,
const int stride, const int element_count) {
constexpr int next_power_of_two = 1 << Log2Elements;
constexpr int warp_size_softmax =
(next_power_of_two < 32) ? next_power_of_two : 32;
constexpr int WARP_ITERATIONS = next_power_of_two / warp_size_softmax;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH) {
local_batches = WARP_BATCH;
} }
int local_idx = threadIdx.x; // compute max value
AccT max_value[kBatchSize];
src += first_batch * stride + local_idx; #pragma unroll
dst += first_batch * stride + local_idx; for (int i = 0; i < kBatchSize; ++i) {
// it = 0
AccT valmax = srcdata[i][0][0];
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s];
}
max_value[i] = valmax;
// load data from global memory // it = 1, 2, ...
AccT elements[WARP_BATCH][WARP_ITERATIONS]; #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 1; it < kIterationsV; ++it) {
int batch_element_count = (i >= local_batches) ? 0 : element_count; AccT valmax = srcdata[i][it][0];
for (int it = 0; it < WARP_ITERATIONS; ++it) { #pragma unroll
int element_index = local_idx + it * warp_size_softmax; for (int s = 1; s < kVSize; ++s) {
if (element_index < batch_element_count) { valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s];
elements[i][it] =
static_cast<float>(src[i * element_count + it * warp_size_softmax]);
} else {
elements[i][it] = -std::numeric_limits<AccT>::infinity();
} }
max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax;
} }
} }
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value);
// compute max_value // compute sum
AccT max_value[WARP_BATCH]; AccT sum[kBatchSize];
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < kBatchSize; ++i) {
max_value[i] = elements[i][0]; // it = 0
if (LogMode) {
sum[i] = std::exp(srcdata[i][0][0] - max_value[i]);
} else {
srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]);
sum[i] = srcdata[i][0][0];
}
#pragma unroll #pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int s = 1; s < kVSize; ++s) {
max_value[i] = if (LogMode) {
(max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; sum[i] += std::exp(srcdata[i][0][s] - max_value[i]);
} else {
srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]);
sum[i] += srcdata[i][0][s];
} }
} }
warp_reduce_max<AccT, WARP_BATCH, warp_size_softmax>(max_value);
AccT sum[WARP_BATCH]{0.0f}; // it = 1, 2, ...
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 1; it < kIterationsV; ++it) {
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int s = 0; s < kVSize; ++s) {
elements[i][it] = (std::exp((elements[i][it] - max_value[i]))); if (LogMode) {
sum[i] += elements[i][it]; sum[i] += std::exp(srcdata[i][it][s] - max_value[i]);
} else {
srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]);
sum[i] += srcdata[i][it][s];
}
} }
} }
warp_reduce_sum<AccT, WARP_BATCH, warp_size_softmax>(sum); }
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// store result // write result to global memory
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < kBatchSize; ++i) {
if (i >= local_batches) break; if (LogMode) {
sum[i] = std::log(sum[i]);
}
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < kIterationsV; ++it) {
int element_index = local_idx + it * warp_size_softmax; int idx = threadIdx.x + it * kWarpSize;
if (element_index < element_count) { if (kVSize == 1) {
dst[i * element_count + it * warp_size_softmax] = if (idx < idx_max_v[i]) {
elements[i][it] / sum[i]; if (LogMode) {
softmax[(first_batch + i) * stride + idx] =
srcdata[i][it][0] - max_value[i] - sum[i];
} else {
softmax[(first_batch + i) * stride + idx] =
srcdata[i][it][0] / sum[i];
}
} else { } else {
break; break;
} }
} else {
VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
VecT tmpdata;
T* tmpptr = reinterpret_cast<T*>(&tmpdata);
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i];
} else {
tmpptr[s] = srcdata[i][it][s] / sum[i];
} }
} }
}
template <typename T, typename AccT, int Log2Elements> if (idx < idx_max_v[i]) {
__global__ void softmax_warp_backward(T* gradInput, const T* grad, softmax_v[idx] = tmpdata;
const T* output, int batch_size, } else {
int stride, int element_count) { break;
constexpr int next_power_of_two = 1 << Log2Elements; }
constexpr int warp_size_softmax = }
(next_power_of_two < 32) ? next_power_of_two : 32; }
constexpr int WARP_ITERATIONS = next_power_of_two / warp_size_softmax; }
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; }
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
/*
Core function of computing softmax backward for axis=-1.
The computation includes
- Compute sum of exp batch: s_{i} = sum_{j} {src_{i,j} * grad_{i,j}
- Compute src_{i,j} * ( grad_{i,j}) - s_{i} )
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template <typename T, typename VecT, typename AccT, int Log2Elements,
bool LogMode = false>
__global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src,
int batch_size, int stride,
int element_count) {
constexpr int kVSize = sizeof(VecT) / sizeof(T);
constexpr int kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr int kIterations = kDimCeil / kWarpSize;
constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
constexpr int kIterationsV =
(kIterations >= kVSize) ? (kIterations / kVSize) : 1;
int element_count_v = element_count / kVSize;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
int local_batches = batch_size - first_batch; int local_batches = batch_size - first_batch;
if (local_batches > WARP_BATCH) { if (local_batches > kBatchSize) {
local_batches = WARP_BATCH; local_batches = kBatchSize;
} }
int local_idx = threadIdx.x % warp_size_softmax; // read data from global memory
VecT src_reg[kBatchSize][kIterationsV];
int thread_offset = first_batch * stride + local_idx; VecT grad_reg[kBatchSize][kIterationsV];
grad += thread_offset;
output += thread_offset; for (int i = 0; i < kBatchSize; ++i) {
gradInput += thread_offset; const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
// load data from global memory const VecT* grad_v =
AccT grad_reg[WARP_BATCH][WARP_ITERATIONS]; reinterpret_cast<const VecT*>(&grad[(first_batch + i) * stride]);
AccT output_reg[WARP_BATCH][WARP_ITERATIONS];
for (int i = 0; i < WARP_BATCH; ++i) { // max index to read
int batch_element_count = (i >= local_batches) ? 0 : element_count; int idx_max = (i < local_batches) ? element_count : 0;
for (int it = 0; it < WARP_ITERATIONS; ++it) { int idx_max_v = idx_max / kVSize;
int element_index = local_idx + it * warp_size_softmax;
if (element_index < batch_element_count) { // read data
grad_reg[i][it] = for (int it = 0; it < kIterationsV; ++it) {
static_cast<AccT>(grad[i * element_count + it * warp_size_softmax]); int src_idx = threadIdx.x + it * kWarpSize;
output_reg[i][it] = static_cast<AccT>( if (src_idx < idx_max_v) {
output[i * element_count + it * warp_size_softmax]); src_reg[i][it] = src_v[src_idx];
grad_reg[i][it] = grad_v[src_idx];
} else { } else {
grad_reg[i][it] = AccT(0); #pragma unroll
output_reg[i][it] = AccT(0); for (int s = 0; s < kVSize; s++) {
reinterpret_cast<T*>(&src_reg[i][it])[s] = 0.0;
reinterpret_cast<T*>(&grad_reg[i][it])[s] = 0.0;
}
} }
} }
} }
AccT sum[WARP_BATCH]; // compute sum
AccT sum[kBatchSize]{0.0};
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int it = 0; it < kIterationsV; ++it) {
sum[i] = grad_reg[i][0]; T* gradptr = reinterpret_cast<T*>(&grad_reg[i][it]);
T* srcptr = reinterpret_cast<T*>(&src_reg[i][it]);
#pragma unroll #pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) { for (int s = 0; s < kVSize; ++s) {
sum[i] += grad_reg[i][it]; if (LogMode) {
sum[i] += static_cast<AccT>(gradptr[s]);
} else {
sum[i] += static_cast<AccT>(gradptr[s] * srcptr[s]);
}
}
} }
} }
warp_reduce_sum<AccT, WARP_BATCH, warp_size_softmax>(sum); WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// store result // write result
#pragma unroll #pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) { for (int i = 0; i < kBatchSize; ++i) {
if (i >= local_batches) break; if (i >= local_batches) break;
VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]);
// max index to write
int idx_max = (i < local_batches) ? element_count : 0;
int idx_max_v = idx_max / kVSize;
#pragma unroll #pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) { for (int it = 0; it < kIterationsV; ++it) {
int element_index = local_idx + it * warp_size_softmax; VecT tmpdata;
if (element_index < element_count) { T* tmpptr = reinterpret_cast<T*>(&tmpdata);
// compute gradients T* gradptr = reinterpret_cast<T*>(&grad_reg[i][it]);
gradInput[i * element_count + it * warp_size_softmax] = T* srcptr = reinterpret_cast<T*>(&src_reg[i][it]);
(grad_reg[i][it] - output_reg[i][it] * sum[i]); #pragma unroll
} for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
tmpptr[s] = static_cast<AccT>(gradptr[s]) -
std::exp(static_cast<AccT>(srcptr[s])) * sum[i];
} else {
tmpptr[s] = static_cast<AccT>(srcptr[s]) *
(static_cast<AccT>(gradptr[s]) - sum[i]);
} }
} }
}
template <typename T> int idx = threadIdx.x + it * kWarpSize;
__global__ void MultiplyCUDAKernel(T* C, const T* A, const T* B, int N) { if (idx < idx_max_v) {
CUDA_KERNEL_LOOP(i, N) { dst_v[idx] = tmpdata;
C[i] = static_cast<T>(static_cast<float>(A[i]) * static_cast<float>(B[i])); }
}
} }
} }
template <typename T, int VPT, int WARP_PER_BLOCK> #define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \
__global__ void VecSoftmaxBackward(T* dst, const T* grad, const T* src, case Log2Elements: \
const int batch_size, WarpSoftmaxForward< \
const int softmax_ele) { T, VecT, AccT, Log2Elements, \
const int offset = LogMode><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
blockIdx.x * softmax_ele * WARP_PER_BLOCK + threadIdx.x * VPT; dst, src, batch_size, stride, element_count); \
break;
float local_sum_gy = 0.f;
vec_t<T, VPT> local_grad;
vec_t<T, VPT> local_src;
local_grad.s =
reinterpret_cast<const decltype(local_grad.s)*>(&grad[offset])[0];
local_src.s = reinterpret_cast<const decltype(local_src.s)*>(&src[offset])[0];
for (int i = 0; i < VPT; ++i) { /*
local_sum_gy += static_cast<float>(local_grad.v[i]) * Wrapper of softmax formward with template instantiation on size of input.
static_cast<float>(local_src.v[i]); */
template <typename T, typename VecT, bool LogMode>
void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads,
const framework::ExecutionContext& ctx, T* dst,
const T* src, const int batch_size,
const int stride, const int element_count,
int Log2Elements) {
using AccT = typename details::MPTypeTrait<T>::Type;
switch (Log2Elements) {
SOFTMAX_WARP_FORWARD_CASE(0, AccT);
SOFTMAX_WARP_FORWARD_CASE(1, AccT);
SOFTMAX_WARP_FORWARD_CASE(2, AccT);
SOFTMAX_WARP_FORWARD_CASE(3, AccT);
SOFTMAX_WARP_FORWARD_CASE(4, AccT);
SOFTMAX_WARP_FORWARD_CASE(5, AccT);
SOFTMAX_WARP_FORWARD_CASE(6, AccT);
SOFTMAX_WARP_FORWARD_CASE(7, AccT);
SOFTMAX_WARP_FORWARD_CASE(8, AccT);
SOFTMAX_WARP_FORWARD_CASE(9, AccT);
default:
break;
} }
float sum_gy = math::warpReduceSum<float>(local_sum_gy, 0xffffffff); }
vec_t<T, VPT> local_dst; #define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT) \
for (int i = 0; i < VPT; ++i) { case Log2Elements: \
local_dst.v[i] = WarpSoftmaxBackward< \
static_cast<T>(static_cast<float>(local_src.v[i]) * T, VecT, AccT, Log2Elements, \
(static_cast<float>(local_grad.v[i]) - sum_gy)); LogMode><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
dst, grad, src, batch_size, stride, element_count); \
break;
/*
Wrapper of softmax backward with template instantiation on size of input.
*/
template <typename T, typename VecT, bool LogMode>
void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads,
const framework::ExecutionContext& ctx, T* dst,
const T* grad, const T* src,
const int batch_size, const int stride,
const int element_count, int Log2Elements) {
using AccT = typename details::MPTypeTrait<T>::Type;
switch (Log2Elements) {
SOFTMAX_WARP_BACKWARD_CASE(0, AccT);
SOFTMAX_WARP_BACKWARD_CASE(1, AccT);
SOFTMAX_WARP_BACKWARD_CASE(2, AccT);
SOFTMAX_WARP_BACKWARD_CASE(3, AccT);
SOFTMAX_WARP_BACKWARD_CASE(4, AccT);
SOFTMAX_WARP_BACKWARD_CASE(5, AccT);
SOFTMAX_WARP_BACKWARD_CASE(6, AccT);
SOFTMAX_WARP_BACKWARD_CASE(7, AccT);
SOFTMAX_WARP_BACKWARD_CASE(8, AccT);
SOFTMAX_WARP_BACKWARD_CASE(9, AccT);
default:
break;
} }
reinterpret_cast<decltype(local_dst.s)*>(&dst[offset])[0] = local_dst.s;
} }
template <typename T> #undef SOFTMAX_WARP_FORWARD_CASE
#undef SOFTMAX_WARP_BACKWARD_CASE
template <typename T, bool LogMode = false>
class SoftmaxCUDNNKernel : public framework::OpKernel<T> { class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -335,60 +462,39 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> { ...@@ -335,60 +462,39 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
const int D = SizeOutAxis(axis, dims); const int D = SizeOutAxis(axis, dims);
constexpr int max_dim = 320; constexpr int max_dim = 320;
bool optimize = false;
constexpr int warps_per_block = 4; constexpr int warps_per_block = 4;
if (D == 1 && dim <= max_dim && sizeof(T) <= 4) {
if (dim == 128 && N % warps_per_block == 0) {
optimize = true;
// a warp for a batch, 4 elements for a thread, only support the softmax
// dim size = 128 currently
if (sizeof(T) == 2) {
VecSoftmaxForward<T, int2, 4, warps_per_block><<<
N / warps_per_block, warps_per_block * WARP_SIZE, 0,
ctx.cuda_device_context().stream()>>>(out_data, x->data<T>(), N,
dim);
} else if (sizeof(T) == 4) {
VecSoftmaxForward<T, int4, 4, warps_per_block><<<
N / warps_per_block, warps_per_block * WARP_SIZE, 0,
ctx.cuda_device_context().stream()>>>(out_data, x->data<T>(), N,
dim);
} else {
assert(false && "not support");
}
} else if (dim < max_dim) {
optimize = true;
int log2_elements = static_cast<int>(log2_ceil(dim));
const int next_power_of_two = 1 << log2_elements;
int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; if (D == 1 && dim <= max_dim && sizeof(T) <= 4) {
const int kDimLog2 = static_cast<int>(log2_ceil(dim));
const int kDimCeil = 1 << kDimLog2;
int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
int batches_per_warp = (kDimCeil <= 32) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization // use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128; constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / kWarpSize);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block; int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(kWarpSize, warps_per_block, 1);
switch (log2_elements) { // vectorization read/write
LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 using T4 = typename VecT4<T>::Type;
LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 using T2 = typename VecT2<T>::Type;
LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4 if (dim % 4 == 0) {
LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8 SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks, threads, ctx, out_data,
LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16 x->data<T>(), N, dim, dim,
LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32 kDimLog2);
LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64 } else if (dim % 2 == 0) {
LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks, threads, ctx, out_data,
LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 x->data<T>(), N, dim, dim,
LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 kDimLog2);
default: } else {
break; SwitchWarpSoftmaxForward<T, T, LogMode>(blocks, threads, ctx, out_data,
} x->data<T>(), N, dim, dim,
} kDimLog2);
} }
if (!optimize) { } else {
ScopedTensorDescriptor desc; ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1}; std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
...@@ -405,22 +511,37 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> { ...@@ -405,22 +511,37 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
: MIOPEN_SOFTMAX_MODE_CHANNEL; : MIOPEN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward( if (LogMode) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2(
handle, platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(), handle, platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(),
platform::CudnnDataType<T>::kZero(), desc_, out_data)); platform::CudnnDataType<T>::kZero(), desc_, out_data,
MIOPEN_SOFTMAX_LOG, mode));
} else {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2(
handle, platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(),
platform::CudnnDataType<T>::kZero(), desc_, out_data,
MIOPEN_SOFTMAX_ACCURATE, mode));
}
#else #else
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL; : CUDNN_SOFTMAX_MODE_CHANNEL;
if (LogMode) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward(
handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(),
desc_, x->data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
out_data));
} else {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward(
handle, CUDNN_SOFTMAX_ACCURATE, mode, handle, CUDNN_SOFTMAX_ACCURATE, mode,
platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(), platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(),
platform::CudnnDataType<T>::kZero(), desc_, out_data)); platform::CudnnDataType<T>::kZero(), desc_, out_data));
}
#endif #endif
} }
} }
}; };
template <typename T> template <typename T, bool LogMode = false>
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> { class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -437,78 +558,38 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> { ...@@ -437,78 +558,38 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
const int N = SizeToAxis(axis, dims); const int N = SizeToAxis(axis, dims);
const int D = SizeOutAxis(axis, dims); const int D = SizeOutAxis(axis, dims);
constexpr int max_dim = 320;
constexpr int warps_per_block = 4; constexpr int warps_per_block = 4;
constexpr bool warp_softmax_available =
std::is_same<T, float>::value ||
std::is_same<T, platform::float16>::value;
bool optimize = false;
if (D == 1 && warp_softmax_available) {
if (dim == 128 && N % warps_per_block == 0) {
optimize = true;
if (std::is_same<T, float>::value) {
VecSoftmaxBackward<float, 4, warps_per_block><<<
N / warps_per_block, warps_per_block * WARP_SIZE, 0,
ctx.cuda_device_context().stream()>>>(dx->data<float>(),
dout->data<float>(),
out->data<float>(), N, dim);
} else if (std::is_same<T, platform::float16>::value) {
VecSoftmaxBackward<platform::float16, 4, warps_per_block><<<
N / warps_per_block, warps_per_block * WARP_SIZE, 0,
ctx.cuda_device_context().stream()>>>(
dx->data<platform::float16>(), dout->data<platform::float16>(),
out->data<platform::float16>(), N, dim);
} else {
PADDLE_ENFORCE_EQ(
warp_softmax_available, true,
platform::errors::Unimplemented(
"Warp softmax backward is only available for fp32 and fp16"));
}
} else if (dim < 40 && dim % 32 != 0) {
optimize = true;
Tensor mul_grad;
int numel = N * dim;
mul_grad.mutable_data<T>({numel}, ctx.GetPlace());
auto stream = ctx.cuda_device_context().stream();
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto config = GetGpuLaunchConfig1D(dev_ctx, numel);
MultiplyCUDAKernel<T><<<config.block_per_grid.x,
config.thread_per_block.x, 0, stream>>>(
mul_grad.data<T>(), dout->data<T>(), out->data<T>(), numel);
int log2_elements = log2_ceil(dim);
const int next_power_of_two = 1 << log2_elements;
int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
if (D == 1 && dim <= max_dim && sizeof(T) <= 4) {
const int kDimLog2 = log2_ceil(dim);
const int kDimCeil = 1 << kDimLog2;
int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
int batches_per_warp = (kDimCeil <= 128) ? 2 : 1;
constexpr int threads_per_block = 128; constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / kWarpSize);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block; int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(kWarpSize, warps_per_block, 1);
switch (log2_elements) { // vectorization read/write
LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1 using T4 = typename VecT4<T>::Type;
LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2 using T2 = typename VecT2<T>::Type;
LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4 if (dim % 4 == 0) {
LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8 SwitchWarpSoftmaxBackward<T, T4, LogMode>(
LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16 blocks, threads, ctx, dx_data, dout->data<T>(), out->data<T>(), N,
LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32 dim, dim, kDimLog2);
LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64 } else if (dim % 2 == 0) {
LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128 SwitchWarpSoftmaxBackward<T, T2, LogMode>(
LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256 blocks, threads, ctx, dx_data, dout->data<T>(), out->data<T>(), N,
LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512 dim, dim, kDimLog2);
default: } else {
break; SwitchWarpSoftmaxBackward<T, T, LogMode>(
} blocks, threads, ctx, dx_data, dout->data<T>(), out->data<T>(), N,
} dim, dim, kDimLog2);
} }
if (!optimize) { } else {
ScopedTensorDescriptor desc; ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1}; std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
...@@ -525,18 +606,32 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> { ...@@ -525,18 +606,32 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
: MIOPEN_SOFTMAX_MODE_CHANNEL; : MIOPEN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward( if (LogMode) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2(
handle, platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(), handle, platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(),
desc_, dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_, desc_, dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
dx_data)); dx_data, MIOPEN_SOFTMAX_LOG, mode));
} else {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2(
handle, platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(),
desc_, dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
dx_data, MIOPEN_SOFTMAX_ACCURATE, mode));
}
#else #else
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL; : CUDNN_SOFTMAX_MODE_CHANNEL;
if (LogMode) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward(
handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(),
desc_, out->data<T>(), desc_, dout->data<T>(),
platform::CudnnDataType<T>::kZero(), desc_, dx_data));
} else {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward(
handle, CUDNN_SOFTMAX_ACCURATE, mode, handle, CUDNN_SOFTMAX_ACCURATE, mode,
platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(), desc_, platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(), desc_,
dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_, dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
dx_data)); dx_data));
}
#endif #endif
} }
} }
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/cuda_device_function.h"
namespace paddle {
namespace operators {
template <typename T, int BatchSize, int WarpSize>
__device__ __forceinline__ void WarpReduceSum(T* sum) {
#pragma unroll
for (int offset = WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < BatchSize; ++i) {
T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
sum[i] = sum[i] + sum_val;
}
}
}
template <typename T, int BatchSize, int WarpSize>
__device__ __forceinline__ void WarpReduceMax(T* sum) {
#pragma unroll
for (int offset = WarpSize / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < BatchSize; ++i) {
T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset);
sum[i] = max(sum[i], max_val);
}
}
}
} // namespace operators
} // namespace paddle
\ No newline at end of file
...@@ -45,6 +45,14 @@ static inline int SizeFromAxis(const int axis, DDim dims) { ...@@ -45,6 +45,14 @@ static inline int SizeFromAxis(const int axis, DDim dims) {
return size; return size;
} }
static inline int SizeOutAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SoftmaxKernel : public framework::OpKernel<T> { class SoftmaxKernel : public framework::OpKernel<T> {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册