未验证 提交 95e33481 编写于 作者: Z zlsh80826 提交者: GitHub

Softmax vectorization (#29404)

* vec softmax fw

* vec softmax bw

* add a message argument for compiler compatibility
上级 a136c9cd
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cudnn_helper.h"
......@@ -38,6 +39,81 @@ static inline int SizeOutAxis(const int axis, DDim dims) {
return size;
}
template <typename T, int VLEN>
union vec_t {
static_assert(sizeof(T) == -1, "vec_t is only available by specialization.");
};
template <>
union vec_t<float, 4> {
float4 s;
float v[4];
};
template <>
union vec_t<platform::float16, 4> {
int2 s;
platform::float16 v[4];
};
template <typename T, typename VECT, int VPT, int WARP_PER_BLOCK>
__global__ void VecSoftmaxForward(T* dst, const T* src, const int batch_size,
const int softmax_ele) {
int offset = blockIdx.x * softmax_ele * WARP_PER_BLOCK;
int idx = threadIdx.x * VPT;
VECT buf = reinterpret_cast<const VECT*>(&src[offset + idx])[0];
T* bufp = reinterpret_cast<T*>(&buf);
float4 val4;
float* val4p = reinterpret_cast<float*>(&val4);
for (int i = 0; i < VPT; ++i) {
val4p[i] = static_cast<float>(bufp[i]);
}
float val = val4.x + val4.y + val4.z + val4.w;
float max_val = math::warpReduceMax<float>(
max(max(val4.x, val4.y), max(val4.z, val4.w)), 0xffffffff);
float4 tmp4 = make_float4(__expf(val4.x - max_val), __expf(val4.y - max_val),
__expf(val4.z - max_val), __expf(val4.w - max_val));
float* tmp4p = reinterpret_cast<float*>(&tmp4);
float invsum = 1.f / (math::warpReduceSum<float>(
tmp4.x + tmp4.y + tmp4.z + tmp4.w, 0xffffffff) +
1e-6f);
for (int i = 0; i < VPT; ++i) {
bufp[i] = static_cast<T>(tmp4p[i] * invsum);
}
reinterpret_cast<VECT*>(&dst[offset + idx])[0] = buf;
}
template <typename T, int VPT, int WARP_PER_BLOCK>
__global__ void VecSoftmaxBackward(T* dst, const T* grad, const T* src,
const int batch_size,
const int softmax_ele) {
const int offset =
blockIdx.x * softmax_ele * WARP_PER_BLOCK + threadIdx.x * VPT;
float local_sum_gy = 0.f;
vec_t<T, VPT> local_grad;
vec_t<T, VPT> local_src;
local_grad.s =
reinterpret_cast<const decltype(local_grad.s)*>(&grad[offset])[0];
local_src.s = reinterpret_cast<const decltype(local_src.s)*>(&src[offset])[0];
for (int i = 0; i < VPT; ++i) {
local_sum_gy += static_cast<float>(local_grad.v[i]) *
static_cast<float>(local_src.v[i]);
}
float sum_gy = math::warpReduceSum<float>(local_sum_gy, 0xffffffff);
vec_t<T, VPT> local_dst;
for (int i = 0; i < VPT; ++i) {
local_dst.v[i] =
static_cast<T>(static_cast<float>(local_src.v[i]) *
(static_cast<float>(local_grad.v[i]) - sum_gy));
}
reinterpret_cast<decltype(local_dst.s)*>(&dst[offset])[0] = local_dst.s;
}
template <typename T>
class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
public:
......@@ -54,20 +130,42 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
const int N = SizeToAxis(axis, dims);
const int D = SizeOutAxis(axis, dims);
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward(
handle, CUDNN_SOFTMAX_ACCURATE, mode,
platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(),
platform::CudnnDataType<T>::kZero(), desc_, out_data));
constexpr int warps_per_block = 4;
if (D == 1 && dim == 128 && N % warps_per_block == 0 && sizeof(T) <= 4) {
// a warp for a batch, 4 elements for a thread, only support the softmax
// dim size = 128 currently
if (sizeof(T) == 2) {
VecSoftmaxForward<
T, int2, 4,
warps_per_block><<<N / warps_per_block, warps_per_block * WARP_SIZE,
0, ctx.cuda_device_context().stream()>>>(
out_data, x->data<T>(), N, dim);
} else if (sizeof(T) == 4) {
VecSoftmaxForward<
T, int4, 4,
warps_per_block><<<N / warps_per_block, warps_per_block * WARP_SIZE,
0, ctx.cuda_device_context().stream()>>>(
out_data, x->data<T>(), N, dim);
} else {
assert(false && "not support");
}
} else {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward(
handle, CUDNN_SOFTMAX_ACCURATE, mode,
platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(),
platform::CudnnDataType<T>::kZero(), desc_, out_data));
}
}
};
......@@ -88,20 +186,49 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
const int N = SizeToAxis(axis, dims);
const int D = SizeOutAxis(axis, dims);
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward(
handle, CUDNN_SOFTMAX_ACCURATE, mode,
platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(), desc_,
dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_, dx_data));
constexpr int warps_per_block = 4;
constexpr bool warp_softmax_available =
std::is_same<T, float>::value ||
std::is_same<T, platform::float16>::value;
if (D == 1 && dim == 128 && N % warps_per_block == 0 &&
warp_softmax_available) {
if (std::is_same<T, float>::value) {
VecSoftmaxBackward<
float, 4,
warps_per_block><<<N / warps_per_block, warps_per_block * WARP_SIZE,
0, ctx.cuda_device_context().stream()>>>(
dx->data<float>(), dout->data<float>(), out->data<float>(), N, dim);
} else if (std::is_same<T, platform::float16>::value) {
VecSoftmaxBackward<
platform::float16, 4,
warps_per_block><<<N / warps_per_block, warps_per_block * WARP_SIZE,
0, ctx.cuda_device_context().stream()>>>(
dx->data<platform::float16>(), dout->data<platform::float16>(),
out->data<platform::float16>(), N, dim);
} else {
PADDLE_ENFORCE_EQ(
warp_softmax_available, true,
platform::errors::Unimplemented(
"Warp softmax backward is only available for fp32 and fp16"));
}
} else {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward(
handle, CUDNN_SOFTMAX_ACCURATE, mode,
platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(), desc_,
dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
dx_data));
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册