未验证 提交 ef76f664 编写于 作者: L Liu-xiandong 提交者: GitHub

Rewrite Softmax in Kernel Primitive API, test=develop (#36706)

上级 b151a451
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
......@@ -99,6 +100,97 @@ __device__ __forceinline__ void WarpReduceMax(T* sum) {
}
}
namespace kps = paddle::operators::kernel_primitives;
template <typename Tx, typename Ty = Tx>
struct ReduceMaxFunctor {
inline Ty initial() { return -std::numeric_limits<Ty>::infinity(); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return max(a, b);
}
};
template <typename Tx, typename Ty = Tx>
struct ExpSubFunctor {
HOSTDEVICE inline ExpSubFunctor() { y = static_cast<Tx>(0.0f); }
HOSTDEVICE explicit inline ExpSubFunctor(Tx y) : y((Tx)(y)) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(std::exp(x - y));
}
private:
Tx y;
};
template <typename Tx, typename Ty = Tx>
struct ExpMulFunctor {
HOSTDEVICE inline ExpMulFunctor() { y = static_cast<Tx>(1.0f); }
HOSTDEVICE explicit inline ExpMulFunctor(Tx y) : y((Tx)(y)) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(std::exp(x) * y);
}
private:
Tx y;
};
template <typename Tx, typename Ty = Tx>
struct UnarySubFunctor {
HOSTDEVICE inline UnarySubFunctor() { y = static_cast<Tx>(0.0f); }
HOSTDEVICE explicit inline UnarySubFunctor(Tx y) : y((Tx)(y)) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x - y);
}
private:
Tx y;
};
template <typename Tx, typename Ty = Tx>
struct UnaryLogFunctor {
HOSTDEVICE inline UnaryLogFunctor() {}
HOSTDEVICE explicit inline UnaryLogFunctor(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(std::log(x));
}
};
template <typename Tx, typename Ty>
struct DataTransFunctor {
HOSTDEVICE inline DataTransFunctor() {}
HOSTDEVICE explicit inline DataTransFunctor(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return x == -std::numeric_limits<Tx>::infinity()
? -std::numeric_limits<Ty>::infinity()
: static_cast<Ty>(x);
}
};
template <typename Tx, typename Ty = Tx>
struct UnaryDivFunctor {
HOSTDEVICE inline UnaryDivFunctor() { n_inv = static_cast<Tx>(1.0f); }
HOSTDEVICE explicit inline UnaryDivFunctor(Tx n) : n_inv((Tx)(1.0 / n)) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x * n_inv);
}
private:
Tx n_inv;
};
/*
Core function of computing softmax forward for axis=-1.
The computation includes
......@@ -117,12 +209,14 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
constexpr int kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr int kVSize = sizeof(VecT) / sizeof(T);
constexpr int kIterations = kDimCeil / kWarpSize;
constexpr int kIterationsV =
(kIterations >= kVSize) ? (kIterations / kVSize) : 1;
constexpr int kLoops = kDimCeil / kWarpSize;
constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
constexpr int kStep = kBatchSize * kLoopsV * kVSize;
constexpr int kVItem = kLoopsV * kVSize;
constexpr AccT kLowInf = -std::numeric_limits<AccT>::infinity();
using kMode = kps::details::ReduceMode;
// max index to read
int idx_max_v[kBatchSize];
......@@ -133,146 +227,51 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
}
// read data from global memory
AccT srcdata[kBatchSize][kIterationsV][kVSize];
AccT srcdata[kBatchSize][kLoopsV][kVSize];
kps::Init<AccT, kStep>(&srcdata[0][0][0], kLowInf);
T src_tmp[kBatchSize][kLoopsV][kVSize];
kps::Init<T, kStep>(&src_tmp[0][0][0], -std::numeric_limits<T>::infinity());
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
// read data
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
int src_idx = threadIdx.x + it * kWarpSize;
if (kVSize == 1) {
if (src_idx < idx_max_v[i]) {
srcdata[i][it][0] =
static_cast<AccT>(src[(first_batch + i) * stride + src_idx]);
} else {
srcdata[i][it][0] = -std::numeric_limits<AccT>::infinity();
}
} else {
const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
if (src_idx < idx_max_v[i]) {
VecT srctmp = src_v[src_idx];
const T* srcinptr = reinterpret_cast<const T*>(&srctmp);
#pragma unroll
for (int s = 0; s < kVSize; s++) {
srcdata[i][it][s] = static_cast<AccT>(srcinptr[s]);
}
} else {
#pragma unroll
for (int s = 0; s < kVSize; s++) {
srcdata[i][it][s] = -std::numeric_limits<AccT>::infinity();
}
}
}
}
int ptr = (first_batch + i) * stride;
const VecT* src_v = reinterpret_cast<const VecT*>(&src[ptr]);
VecT* reg_v = reinterpret_cast<VecT*>(&src_tmp[i][0][0]);
kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
&reg_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1);
kps::ElementwiseUnary<T, AccT, kVItem, 1, 1, DataTransFunctor<T, AccT>>(
&srcdata[i][0][0], &src_tmp[i][0][0], DataTransFunctor<T, AccT>());
}
// compute max value
AccT max_value[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
// it = 0
AccT valmax = srcdata[i][0][0];
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s];
}
max_value[i] = valmax;
// it = 1, 2, ...
#pragma unroll
for (int it = 1; it < kIterationsV; ++it) {
AccT valmax = srcdata[i][it][0];
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s];
}
max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax;
}
}
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value);
// compute max
AccT max[kBatchSize];
kps::Init<AccT, kBatchSize>(&max[0], kLowInf);
kps::Reduce<AccT, kVItem, kBatchSize, 1, ReduceMaxFunctor<AccT>,
kMode::kLocalMode>(&max[0], &srcdata[0][0][0],
ReduceMaxFunctor<AccT>(), true);
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max);
// compute sum
AccT sum[kBatchSize];
#pragma unroll
AccT sum[kBatchSize] = {0};
for (int i = 0; i < kBatchSize; ++i) {
// it = 0
if (LogMode) {
sum[i] = std::exp(srcdata[i][0][0] - max_value[i]);
} else {
srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]);
sum[i] = srcdata[i][0][0];
}
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
if (LogMode) {
sum[i] += std::exp(srcdata[i][0][s] - max_value[i]);
} else {
srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]);
sum[i] += srcdata[i][0][s];
}
}
// it = 1, 2, ...
#pragma unroll
for (int it = 1; it < kIterationsV; ++it) {
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
sum[i] += std::exp(srcdata[i][it][s] - max_value[i]);
} else {
srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]);
sum[i] += srcdata[i][it][s];
}
}
}
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpSubFunctor<AccT>>(
&srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor<AccT>(max[i]));
}
kps::Reduce<AccT, kVItem, kBatchSize, 1, kps::AddFunctor<AccT>,
kMode::kLocalMode>(&sum[0], &srcdata[0][0][0],
kps::AddFunctor<AccT>(), true);
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// write result to global memory
// write result to global memory
T out_tmp[kBatchSize][kLoopsV][kVSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
if (LogMode) {
sum[i] = std::log(sum[i]);
}
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
int idx = threadIdx.x + it * kWarpSize;
if (kVSize == 1) {
if (idx < idx_max_v[i]) {
if (LogMode) {
softmax[(first_batch + i) * stride + idx] =
srcdata[i][it][0] - max_value[i] - sum[i];
} else {
softmax[(first_batch + i) * stride + idx] =
srcdata[i][it][0] / sum[i];
}
} else {
break;
}
} else {
VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
VecT tmpdata;
T* tmpptr = reinterpret_cast<T*>(&tmpdata);
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i];
} else {
tmpptr[s] = srcdata[i][it][s] / sum[i];
}
}
if (idx < idx_max_v[i]) {
softmax_v[idx] = tmpdata;
} else {
break;
}
}
}
kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnaryDivFunctor<AccT>>(
&out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor<AccT>(sum[i]));
int softmax_ptr = (first_batch + i) * stride;
VecT* softmax_v = reinterpret_cast<VecT*>(&softmax[softmax_ptr]);
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
&softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
}
}
......@@ -293,101 +292,82 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src,
constexpr int kVSize = sizeof(VecT) / sizeof(T);
constexpr int kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr int kIterations = kDimCeil / kWarpSize;
constexpr int kLoops = kDimCeil / kWarpSize;
constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
constexpr int kIterationsV =
(kIterations >= kVSize) ? (kIterations / kVSize) : 1;
constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
int element_count_v = element_count / kVSize;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
int local_batches = batch_size - first_batch;
if (local_batches > kBatchSize) {
local_batches = kBatchSize;
int local_batches = min(batch_size - first_batch, kBatchSize);
// max index to read
int idx_max_v[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; i++) {
int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
idx_max_v[i] = idx_max / kVSize;
}
// read data from global memory
VecT src_reg[kBatchSize][kIterationsV];
VecT grad_reg[kBatchSize][kIterationsV];
for (int i = 0; i < kBatchSize; ++i) {
const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
const VecT* grad_v =
reinterpret_cast<const VecT*>(&grad[(first_batch + i) * stride]);
// max index to read
int idx_max = (i < local_batches) ? element_count : 0;
int idx_max_v = idx_max / kVSize;
// read data
for (int it = 0; it < kIterationsV; ++it) {
int src_idx = threadIdx.x + it * kWarpSize;
if (src_idx < idx_max_v) {
src_reg[i][it] = src_v[src_idx];
grad_reg[i][it] = grad_v[src_idx];
} else {
VecT src_reg[kBatchSize][kLoopsV];
VecT grad_reg[kBatchSize][kLoopsV];
VecT k_value;
for (int s = 0; s < kVSize; s++) {
reinterpret_cast<T*>(&k_value)[s] = 0.0;
}
kps::Init<VecT, kBatchSize * kLoopsV>(&src_reg[0][0], k_value);
kps::Init<VecT, kBatchSize * kLoopsV>(&grad_reg[0][0], k_value);
#pragma unroll
for (int s = 0; s < kVSize; s++) {
reinterpret_cast<T*>(&src_reg[i][it])[s] = 0.0;
reinterpret_cast<T*>(&grad_reg[i][it])[s] = 0.0;
}
}
}
for (int i = 0; i < kBatchSize; ++i) {
int flag = i < local_batches ? 1 : 0;
int ptr = (first_batch + i) * stride;
const VecT* src_v = reinterpret_cast<const VecT*>(&src[ptr]);
const VecT* grad_v = reinterpret_cast<const VecT*>(&grad[ptr]);
kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
&src_reg[i][0], &src_v[0], idx_max_v[i], 0, kWarpSize, flag);
kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
&grad_reg[i][0], &grad_v[0], idx_max_v[i], 0, kWarpSize, flag);
}
// change T to AccT
AccT src_tmp[kBatchSize][kLoopsV][kVSize];
AccT grad_tmp[kBatchSize][kLoopsV][kVSize];
const T* src_ptr = reinterpret_cast<const T*>(&src_reg[0][0]);
const T* grad_ptr = reinterpret_cast<const T*>(&grad_reg[0][0]);
constexpr int kStep = kBatchSize * kLoopsV * kVSize;
constexpr int kVItem = kLoopsV * kVSize;
kps::ElementwiseUnary<T, AccT, kStep, 1, 1, DataTransFunctor<T, AccT>>(
&src_tmp[0][0][0], &src_ptr[0], DataTransFunctor<T, AccT>());
kps::ElementwiseUnary<T, AccT, kStep, 1, 1, DataTransFunctor<T, AccT>>(
&grad_tmp[0][0][0], &grad_ptr[0], DataTransFunctor<T, AccT>());
// compute sum
AccT sum[kBatchSize]{0.0};
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
T* gradptr = reinterpret_cast<T*>(&grad_reg[i][it]);
T* srcptr = reinterpret_cast<T*>(&src_reg[i][it]);
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
sum[i] += static_cast<AccT>(gradptr[s]);
} else {
sum[i] += static_cast<AccT>(gradptr[s] * srcptr[s]);
}
}
}
}
AccT sum_tmp[kBatchSize][kLoopsV][kVSize];
AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[0][0][0]);
AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[0][0][0]);
kps::ElementwiseBinary<AccT, AccT, kStep, 1, 1, kps::MulFunctor<AccT>>(
&sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor<AccT>());
kps::Reduce<AccT, kVItem, kBatchSize, 1, kps::AddFunctor<AccT>,
kps::details::ReduceMode::kLocalMode>(
&sum[0], &sum_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// write result
// write result to global memory
AccT out[kBatchSize][kLoopsV][kVSize];
T out_tmp[kBatchSize][kLoopsV][kVSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
if (i >= local_batches) break;
AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[i][0][0]);
AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[i][0][0]);
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnarySubFunctor<AccT>>(
&out[i][0][0], &gradptr[0], UnarySubFunctor<AccT>(sum[i]));
kps::ElementwiseBinary<AccT, T, kVItem, 1, 1, kps::MulFunctor<AccT>>(
&out_tmp[i][0][0], &srcptr[0], &out[i][0][0], kps::MulFunctor<AccT>());
VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]);
// max index to write
int idx_max = (i < local_batches) ? element_count : 0;
int idx_max_v = idx_max / kVSize;
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
VecT tmpdata;
T* tmpptr = reinterpret_cast<T*>(&tmpdata);
T* gradptr = reinterpret_cast<T*>(&grad_reg[i][it]);
T* srcptr = reinterpret_cast<T*>(&src_reg[i][it]);
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
tmpptr[s] = static_cast<AccT>(gradptr[s]) -
std::exp(static_cast<AccT>(srcptr[s])) * sum[i];
} else {
tmpptr[s] = static_cast<AccT>(srcptr[s]) *
(static_cast<AccT>(gradptr[s]) - sum[i]);
}
}
int idx = threadIdx.x + it * kWarpSize;
if (idx < idx_max_v) {
dst_v[idx] = tmpdata;
}
}
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
&dst_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
}
}
......@@ -493,6 +473,7 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
// vectorization read/write
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks, threads, dev_ctx,
out_data, x.data<T>(), N, dim,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册