未验证 提交 ef76f664 编写于 作者: L Liu-xiandong 提交者: GitHub

Rewrite Softmax in Kernel Primitive API, test=develop (#36706)

上级 b151a451
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h" #include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
...@@ -99,6 +100,97 @@ __device__ __forceinline__ void WarpReduceMax(T* sum) { ...@@ -99,6 +100,97 @@ __device__ __forceinline__ void WarpReduceMax(T* sum) {
} }
} }
namespace kps = paddle::operators::kernel_primitives;
template <typename Tx, typename Ty = Tx>
struct ReduceMaxFunctor {
inline Ty initial() { return -std::numeric_limits<Ty>::infinity(); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return max(a, b);
}
};
template <typename Tx, typename Ty = Tx>
struct ExpSubFunctor {
HOSTDEVICE inline ExpSubFunctor() { y = static_cast<Tx>(0.0f); }
HOSTDEVICE explicit inline ExpSubFunctor(Tx y) : y((Tx)(y)) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(std::exp(x - y));
}
private:
Tx y;
};
template <typename Tx, typename Ty = Tx>
struct ExpMulFunctor {
HOSTDEVICE inline ExpMulFunctor() { y = static_cast<Tx>(1.0f); }
HOSTDEVICE explicit inline ExpMulFunctor(Tx y) : y((Tx)(y)) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(std::exp(x) * y);
}
private:
Tx y;
};
template <typename Tx, typename Ty = Tx>
struct UnarySubFunctor {
HOSTDEVICE inline UnarySubFunctor() { y = static_cast<Tx>(0.0f); }
HOSTDEVICE explicit inline UnarySubFunctor(Tx y) : y((Tx)(y)) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x - y);
}
private:
Tx y;
};
template <typename Tx, typename Ty = Tx>
struct UnaryLogFunctor {
HOSTDEVICE inline UnaryLogFunctor() {}
HOSTDEVICE explicit inline UnaryLogFunctor(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(std::log(x));
}
};
template <typename Tx, typename Ty>
struct DataTransFunctor {
HOSTDEVICE inline DataTransFunctor() {}
HOSTDEVICE explicit inline DataTransFunctor(int n) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return x == -std::numeric_limits<Tx>::infinity()
? -std::numeric_limits<Ty>::infinity()
: static_cast<Ty>(x);
}
};
template <typename Tx, typename Ty = Tx>
struct UnaryDivFunctor {
HOSTDEVICE inline UnaryDivFunctor() { n_inv = static_cast<Tx>(1.0f); }
HOSTDEVICE explicit inline UnaryDivFunctor(Tx n) : n_inv((Tx)(1.0 / n)) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x * n_inv);
}
private:
Tx n_inv;
};
/* /*
Core function of computing softmax forward for axis=-1. Core function of computing softmax forward for axis=-1.
The computation includes The computation includes
...@@ -117,12 +209,14 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, ...@@ -117,12 +209,14 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
constexpr int kDimCeil = 1 << Log2Elements; constexpr int kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr int kVSize = sizeof(VecT) / sizeof(T); constexpr int kVSize = sizeof(VecT) / sizeof(T);
constexpr int kIterations = kDimCeil / kWarpSize; constexpr int kLoops = kDimCeil / kWarpSize;
constexpr int kIterationsV = constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
(kIterations >= kVSize) ? (kIterations / kVSize) : 1;
constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1; constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
constexpr int kStep = kBatchSize * kLoopsV * kVSize;
constexpr int kVItem = kLoopsV * kVSize;
constexpr AccT kLowInf = -std::numeric_limits<AccT>::infinity();
using kMode = kps::details::ReduceMode;
// max index to read // max index to read
int idx_max_v[kBatchSize]; int idx_max_v[kBatchSize];
...@@ -133,146 +227,51 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, ...@@ -133,146 +227,51 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
} }
// read data from global memory // read data from global memory
AccT srcdata[kBatchSize][kIterationsV][kVSize]; AccT srcdata[kBatchSize][kLoopsV][kVSize];
kps::Init<AccT, kStep>(&srcdata[0][0][0], kLowInf);
T src_tmp[kBatchSize][kLoopsV][kVSize];
kps::Init<T, kStep>(&src_tmp[0][0][0], -std::numeric_limits<T>::infinity());
#pragma unroll #pragma unroll
for (int i = 0; i < kBatchSize; ++i) { for (int i = 0; i < kBatchSize; ++i) {
// read data int ptr = (first_batch + i) * stride;
#pragma unroll const VecT* src_v = reinterpret_cast<const VecT*>(&src[ptr]);
for (int it = 0; it < kIterationsV; ++it) { VecT* reg_v = reinterpret_cast<VecT*>(&src_tmp[i][0][0]);
int src_idx = threadIdx.x + it * kWarpSize; kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
if (kVSize == 1) { &reg_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1);
if (src_idx < idx_max_v[i]) { kps::ElementwiseUnary<T, AccT, kVItem, 1, 1, DataTransFunctor<T, AccT>>(
srcdata[i][it][0] = &srcdata[i][0][0], &src_tmp[i][0][0], DataTransFunctor<T, AccT>());
static_cast<AccT>(src[(first_batch + i) * stride + src_idx]); }
} else {
srcdata[i][it][0] = -std::numeric_limits<AccT>::infinity(); // compute max
} AccT max[kBatchSize];
} else { kps::Init<AccT, kBatchSize>(&max[0], kLowInf);
const VecT* src_v = kps::Reduce<AccT, kVItem, kBatchSize, 1, ReduceMaxFunctor<AccT>,
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]); kMode::kLocalMode>(&max[0], &srcdata[0][0][0],
if (src_idx < idx_max_v[i]) { ReduceMaxFunctor<AccT>(), true);
VecT srctmp = src_v[src_idx]; WarpReduceMax<AccT, kBatchSize, kWarpSize>(max);
const T* srcinptr = reinterpret_cast<const T*>(&srctmp);
#pragma unroll
for (int s = 0; s < kVSize; s++) {
srcdata[i][it][s] = static_cast<AccT>(srcinptr[s]);
}
} else {
#pragma unroll
for (int s = 0; s < kVSize; s++) {
srcdata[i][it][s] = -std::numeric_limits<AccT>::infinity();
}
}
}
}
}
// compute max value
AccT max_value[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
// it = 0
AccT valmax = srcdata[i][0][0];
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s];
}
max_value[i] = valmax;
// it = 1, 2, ...
#pragma unroll
for (int it = 1; it < kIterationsV; ++it) {
AccT valmax = srcdata[i][it][0];
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s];
}
max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax;
}
}
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value);
// compute sum // compute sum
AccT sum[kBatchSize]; AccT sum[kBatchSize] = {0};
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) { for (int i = 0; i < kBatchSize; ++i) {
// it = 0 kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpSubFunctor<AccT>>(
if (LogMode) { &srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor<AccT>(max[i]));
sum[i] = std::exp(srcdata[i][0][0] - max_value[i]);
} else {
srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]);
sum[i] = srcdata[i][0][0];
}
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
if (LogMode) {
sum[i] += std::exp(srcdata[i][0][s] - max_value[i]);
} else {
srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]);
sum[i] += srcdata[i][0][s];
}
}
// it = 1, 2, ...
#pragma unroll
for (int it = 1; it < kIterationsV; ++it) {
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
sum[i] += std::exp(srcdata[i][it][s] - max_value[i]);
} else {
srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]);
sum[i] += srcdata[i][it][s];
}
}
}
} }
kps::Reduce<AccT, kVItem, kBatchSize, 1, kps::AddFunctor<AccT>,
kMode::kLocalMode>(&sum[0], &srcdata[0][0][0],
kps::AddFunctor<AccT>(), true);
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum); WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// write result to global memory // write result to global memory
T out_tmp[kBatchSize][kLoopsV][kVSize];
#pragma unroll #pragma unroll
for (int i = 0; i < kBatchSize; ++i) { for (int i = 0; i < kBatchSize; ++i) {
if (LogMode) { kps::ElementwiseUnary<AccT, T, kVItem, 1, 1, UnaryDivFunctor<AccT>>(
sum[i] = std::log(sum[i]); &out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor<AccT>(sum[i]));
} int softmax_ptr = (first_batch + i) * stride;
VecT* softmax_v = reinterpret_cast<VecT*>(&softmax[softmax_ptr]);
#pragma unroll VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
for (int it = 0; it < kIterationsV; ++it) { kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
int idx = threadIdx.x + it * kWarpSize; &softmax_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
if (kVSize == 1) {
if (idx < idx_max_v[i]) {
if (LogMode) {
softmax[(first_batch + i) * stride + idx] =
srcdata[i][it][0] - max_value[i] - sum[i];
} else {
softmax[(first_batch + i) * stride + idx] =
srcdata[i][it][0] / sum[i];
}
} else {
break;
}
} else {
VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
VecT tmpdata;
T* tmpptr = reinterpret_cast<T*>(&tmpdata);
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i];
} else {
tmpptr[s] = srcdata[i][it][s] / sum[i];
}
}
if (idx < idx_max_v[i]) {
softmax_v[idx] = tmpdata;
} else {
break;
}
}
}
} }
} }
...@@ -293,101 +292,82 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src, ...@@ -293,101 +292,82 @@ __global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src,
constexpr int kVSize = sizeof(VecT) / sizeof(T); constexpr int kVSize = sizeof(VecT) / sizeof(T);
constexpr int kDimCeil = 1 << Log2Elements; constexpr int kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr int kIterations = kDimCeil / kWarpSize; constexpr int kLoops = kDimCeil / kWarpSize;
constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1; constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
constexpr int kIterationsV = constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
(kIterations >= kVSize) ? (kIterations / kVSize) : 1;
int element_count_v = element_count / kVSize; int element_count_v = element_count / kVSize;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
int local_batches = batch_size - first_batch; int local_batches = min(batch_size - first_batch, kBatchSize);
if (local_batches > kBatchSize) {
local_batches = kBatchSize;
}
// read data from global memory
VecT src_reg[kBatchSize][kIterationsV];
VecT grad_reg[kBatchSize][kIterationsV];
for (int i = 0; i < kBatchSize; ++i) {
const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
const VecT* grad_v =
reinterpret_cast<const VecT*>(&grad[(first_batch + i) * stride]);
// max index to read // max index to read
int idx_max = (i < local_batches) ? element_count : 0; int idx_max_v[kBatchSize];
int idx_max_v = idx_max / kVSize;
// read data
for (int it = 0; it < kIterationsV; ++it) {
int src_idx = threadIdx.x + it * kWarpSize;
if (src_idx < idx_max_v) {
src_reg[i][it] = src_v[src_idx];
grad_reg[i][it] = grad_v[src_idx];
} else {
#pragma unroll #pragma unroll
for (int s = 0; s < kVSize; s++) { for (int i = 0; i < kBatchSize; i++) {
reinterpret_cast<T*>(&src_reg[i][it])[s] = 0.0; int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
reinterpret_cast<T*>(&grad_reg[i][it])[s] = 0.0; idx_max_v[i] = idx_max / kVSize;
}
}
} }
// read data from global memory
VecT src_reg[kBatchSize][kLoopsV];
VecT grad_reg[kBatchSize][kLoopsV];
VecT k_value;
for (int s = 0; s < kVSize; s++) {
reinterpret_cast<T*>(&k_value)[s] = 0.0;
} }
kps::Init<VecT, kBatchSize * kLoopsV>(&src_reg[0][0], k_value);
kps::Init<VecT, kBatchSize * kLoopsV>(&grad_reg[0][0], k_value);
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
int flag = i < local_batches ? 1 : 0;
int ptr = (first_batch + i) * stride;
const VecT* src_v = reinterpret_cast<const VecT*>(&src[ptr]);
const VecT* grad_v = reinterpret_cast<const VecT*>(&grad[ptr]);
kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
&src_reg[i][0], &src_v[0], idx_max_v[i], 0, kWarpSize, flag);
kps::ReadData<VecT, VecT, kLoopsV, 1, 1, true>(
&grad_reg[i][0], &grad_v[0], idx_max_v[i], 0, kWarpSize, flag);
}
// change T to AccT
AccT src_tmp[kBatchSize][kLoopsV][kVSize];
AccT grad_tmp[kBatchSize][kLoopsV][kVSize];
const T* src_ptr = reinterpret_cast<const T*>(&src_reg[0][0]);
const T* grad_ptr = reinterpret_cast<const T*>(&grad_reg[0][0]);
constexpr int kStep = kBatchSize * kLoopsV * kVSize;
constexpr int kVItem = kLoopsV * kVSize;
kps::ElementwiseUnary<T, AccT, kStep, 1, 1, DataTransFunctor<T, AccT>>(
&src_tmp[0][0][0], &src_ptr[0], DataTransFunctor<T, AccT>());
kps::ElementwiseUnary<T, AccT, kStep, 1, 1, DataTransFunctor<T, AccT>>(
&grad_tmp[0][0][0], &grad_ptr[0], DataTransFunctor<T, AccT>());
// compute sum // compute sum
AccT sum[kBatchSize]{0.0}; AccT sum[kBatchSize]{0.0};
#pragma unroll AccT sum_tmp[kBatchSize][kLoopsV][kVSize];
for (int i = 0; i < kBatchSize; ++i) { AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[0][0][0]);
#pragma unroll AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[0][0][0]);
for (int it = 0; it < kIterationsV; ++it) { kps::ElementwiseBinary<AccT, AccT, kStep, 1, 1, kps::MulFunctor<AccT>>(
T* gradptr = reinterpret_cast<T*>(&grad_reg[i][it]); &sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor<AccT>());
T* srcptr = reinterpret_cast<T*>(&src_reg[i][it]); kps::Reduce<AccT, kVItem, kBatchSize, 1, kps::AddFunctor<AccT>,
#pragma unroll kps::details::ReduceMode::kLocalMode>(
for (int s = 0; s < kVSize; ++s) { &sum[0], &sum_tmp[0][0][0], kps::AddFunctor<AccT>(), true);
if (LogMode) {
sum[i] += static_cast<AccT>(gradptr[s]);
} else {
sum[i] += static_cast<AccT>(gradptr[s] * srcptr[s]);
}
}
}
}
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum); WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// write result // write result to global memory
AccT out[kBatchSize][kLoopsV][kVSize];
T out_tmp[kBatchSize][kLoopsV][kVSize];
#pragma unroll #pragma unroll
for (int i = 0; i < kBatchSize; ++i) { for (int i = 0; i < kBatchSize; ++i) {
if (i >= local_batches) break; if (i >= local_batches) break;
AccT* gradptr = reinterpret_cast<AccT*>(&grad_tmp[i][0][0]);
AccT* srcptr = reinterpret_cast<AccT*>(&src_tmp[i][0][0]);
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, UnarySubFunctor<AccT>>(
&out[i][0][0], &gradptr[0], UnarySubFunctor<AccT>(sum[i]));
kps::ElementwiseBinary<AccT, T, kVItem, 1, 1, kps::MulFunctor<AccT>>(
&out_tmp[i][0][0], &srcptr[0], &out[i][0][0], kps::MulFunctor<AccT>());
VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]); VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
// max index to write kps::WriteData<VecT, VecT, kLoopsV, 1, 1, true>(
int idx_max = (i < local_batches) ? element_count : 0; &dst_v[0], &reg_v[0], idx_max_v[i], 0, kWarpSize, 1);
int idx_max_v = idx_max / kVSize;
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
VecT tmpdata;
T* tmpptr = reinterpret_cast<T*>(&tmpdata);
T* gradptr = reinterpret_cast<T*>(&grad_reg[i][it]);
T* srcptr = reinterpret_cast<T*>(&src_reg[i][it]);
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
tmpptr[s] = static_cast<AccT>(gradptr[s]) -
std::exp(static_cast<AccT>(srcptr[s])) * sum[i];
} else {
tmpptr[s] = static_cast<AccT>(srcptr[s]) *
(static_cast<AccT>(gradptr[s]) - sum[i]);
}
}
int idx = threadIdx.x + it * kWarpSize;
if (idx < idx_max_v) {
dst_v[idx] = tmpdata;
}
}
} }
} }
...@@ -493,6 +473,7 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, ...@@ -493,6 +473,7 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
// vectorization read/write // vectorization read/write
using T4 = typename VecT4<T>::Type; using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type; using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) { if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks, threads, dev_ctx, SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks, threads, dev_ctx,
out_data, x.data<T>(), N, dim, out_data, x.data<T>(), N, dim,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册