未验证 提交 7be6191b 编写于 作者: F Feng Xing 提交者: GitHub

optimize softmax with cross entropy hard label (#32290)

* optimize softmax with cross entropy hard label

* label ignore_index cleaning
上级 0e5d832c
...@@ -15,44 +15,481 @@ limitations under the License. */ ...@@ -15,44 +15,481 @@ limitations under the License. */
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
namespace cub = hipcub; namespace cub = hipcub;
#endif #endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/softmax_impl.cuh"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
namespace { // Wrapper of log function. Use log(float32) for float16
template <typename T> template <typename T>
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, static __device__ __forceinline__ T Log(T x) {
const int64_t n, const int64_t d, using AccT = typename details::MPTypeTrait<T>::Type;
const int64_t remain, const int ignore_index) { AccT logx = std::log(static_cast<AccT>(x));
CUDA_KERNEL_LOOP_TYPE(index, n * remain, int64_t) { return math::TolerableValue<T>()(static_cast<T>(logx));
int64_t idx_n = index / remain; }
int64_t idx_remain = index % remain;
int64_t tmp = labels[index]; // Wrapper of exp function. Use exp(float32) for float16
if (ignore_index != tmp) { template <typename T>
int64_t idx = idx_n * d + tmp * remain + idx_remain; static __device__ __forceinline__ T Exp(T x) {
logit_grad[idx] -= static_cast<T>(1.); using AccT = typename details::MPTypeTrait<T>::Type;
AccT expx = std::exp(static_cast<AccT>(x));
return math::TolerableValue<T>()(static_cast<T>(expx));
}
// log2(value)
static inline int Log2Ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
enum class SoftmaxMode { kSoftmax, kLogSoftmax, kCrossEntropy };
/*
Hard label cross entropy.
*/
template <typename T, bool IgnoreIndex>
__global__ void CrossEntropyHardLabel(T* loss, const T* softmax,
const int64_t* labels, const int n,
const int dim, const int d,
const int ignore_idx) {
int64_t ids = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx_n = ids / d;
int64_t idx_d = ids % d;
// thread ids compute loss[ids] using softmax[idx]
if (ids < n * d) {
int64_t idx = idx_n * dim * d + labels[ids] * d + idx_d;
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (labels[ids] == ignore_idx) {
loss[ids] = static_cast<T>(0.0);
} else {
loss[ids] = -Log(softmax[idx]);
}
} else {
// IgnoreIndex is false
loss[ids] = -Log(softmax[idx]);
}
}
}
/*
Hard label cross entropy with exp.
Input: log softmax
Output: loss and exp(input)
*/
template <typename T, bool IgnoreIndex>
__global__ void CrossEntropyExpHardLabel(T* loss, T* softmax,
const int64_t* labels, const int n,
const int dim, const int d,
const int ignore_idx) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx_n = idx / (d * dim);
int64_t idx_dim = (idx / d) % dim;
int64_t idx_d = idx % d;
int64_t ids = idx_n * d + idx_d;
if (idx < n * dim * d) {
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (idx_dim == labels[ids]) {
if (labels[ids] == ignore_idx) {
loss[ids] = static_cast<T>(0.0);
} else {
loss[ids] = -softmax[idx];
} }
} }
} else {
// IgnoreIndex is false
if (labels[ids] >= 0 && labels[ids] < dim) {
if (labels[ids] == idx_dim) {
loss[ids] = -softmax[idx];
}
} else {
loss[ids] = static_cast<T>(0.0);
}
}
softmax[idx] = Exp(softmax[idx]);
}
} }
/*
Core function of softmax with cross entropy forward
- softmax, SoftmaxMode=kSoftmax
- log softmax, SoftmaxMode=kLogSoftmax
- softmax with cross entropy hard label, SoftmaxMode=kCrossEntropy
The computation includes
- Compute max value: maxvalue_{i} = max_j src_{i,j}
- Compute sum of exp: s_{i} = sum_{j}{e^{src_{i,j} - maxvalue_{i}}}
- Compute: softmax_{i,j} = e^{src_{i,j} - maxvalue_{i}} / s_{i}
- Compute: logsoftmax_{i,j} = src_{i,j} - maxvalue_{i} - log(s_{i})
- Compute: loss_{i} = -logsoftmax[i,label[i]] (Hard label)
This computation results from following formula:
softmax_{i,j} = e^{src_{i,j}} / sum_{j}{e^{src_{i,j}}}
= e^{src_{i,j} - maxvalue_{i}}
/ sum_{j}{e^{src_{i,j} - maxvalue_{i}}}
= e^{src_{i,j} - maxvalue_{i}} / s_{i}
logsoftmax_{i,j} = log(softmax_{i,j})
= src_{i,j} - maxvalue_{i} - log(s_{i})
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use
shuffle api to compute max (sum) in one warp.
*/
template <typename T, typename VecT, typename AccT, int Log2Elements,
SoftmaxMode mode, bool IgnoreIndex>
__global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src,
const int64_t* label, const int batch_size,
const int stride, const int element_count,
const int ignore_index) {
constexpr int kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr int kVSize = sizeof(VecT) / sizeof(T);
constexpr int kIterations = kDimCeil / kWarpSize;
constexpr int kIterationsV =
(kIterations >= kVSize) ? (kIterations / kVSize) : 1;
constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
// max index to read
int idx_max_v[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; i++) {
int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
idx_max_v[i] = idx_max / kVSize;
}
// read data from global memory
AccT srcdata[kBatchSize][kIterationsV][kVSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
// read data to srcdata: - KVSize==1, - KVSize>1
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
int src_idx = threadIdx.x + it * kWarpSize;
if (kVSize == 1) {
if (src_idx < idx_max_v[i]) {
srcdata[i][it][0] =
static_cast<AccT>(src[(first_batch + i) * stride + src_idx]);
} else {
srcdata[i][it][0] = -std::numeric_limits<AccT>::infinity();
}
} else {
const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
if (src_idx < idx_max_v[i]) {
VecT srctmp = src_v[src_idx];
const T* srcinptr = reinterpret_cast<const T*>(&srctmp);
#pragma unroll
for (int s = 0; s < kVSize; s++) {
srcdata[i][it][s] = static_cast<AccT>(srcinptr[s]);
}
} else {
#pragma unroll
for (int s = 0; s < kVSize; s++) {
srcdata[i][it][s] = -std::numeric_limits<AccT>::infinity();
}
}
}
}
}
// compute max value: maxvalue_{i} = max_j src_{i,j}
AccT max_value[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
// it = 0
AccT valmax = srcdata[i][0][0];
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s];
}
max_value[i] = valmax;
// it = 1, 2, ...
#pragma unroll
for (int it = 1; it < kIterationsV; ++it) {
AccT valmax = srcdata[i][it][0];
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s];
}
max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax;
}
}
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value);
// compute sum: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
AccT sum[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
// it = 0
if (mode == SoftmaxMode::kLogSoftmax ||
mode == SoftmaxMode::kCrossEntropy) {
sum[i] = std::exp(srcdata[i][0][0] - max_value[i]);
} else {
srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]);
sum[i] = srcdata[i][0][0];
}
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
if (mode == SoftmaxMode::kLogSoftmax ||
mode == SoftmaxMode::kCrossEntropy) {
sum[i] += std::exp(srcdata[i][0][s] - max_value[i]);
} else {
srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]);
sum[i] += srcdata[i][0][s];
}
}
// it = 1, 2, ...
#pragma unroll
for (int it = 1; it < kIterationsV; ++it) {
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (mode == SoftmaxMode::kLogSoftmax ||
mode == SoftmaxMode::kCrossEntropy) {
sum[i] += std::exp(srcdata[i][it][s] - max_value[i]);
} else {
srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]);
sum[i] += srcdata[i][it][s];
}
}
}
}
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// write data
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
if (mode == SoftmaxMode::kLogSoftmax ||
mode == SoftmaxMode::kCrossEntropy) {
sum[i] = std::log(sum[i]);
}
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
int idx = threadIdx.x + it * kWarpSize;
if (kVSize == 1) { // kVSize==1
if (idx < idx_max_v[i]) {
if (mode == SoftmaxMode::kLogSoftmax) { // log softmax
softmax[(first_batch + i) * stride + idx] =
srcdata[i][it][0] - max_value[i] - sum[i];
// softmax with cross entropy hard label
} else if (mode == SoftmaxMode::kCrossEntropy) {
AccT logsoftmax = srcdata[i][it][0] - max_value[i] - sum[i];
// softmax
softmax[(first_batch + i) * stride + idx] = std::exp(logsoftmax);
// label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize;
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (label[first_batch + i] == loss_idx) {
if (label[first_batch + i] != ignore_index) {
loss[first_batch + i] = -logsoftmax;
} else {
loss[first_batch + i] = static_cast<T>(0.0);
}
}
} else {
// IgnoreIndex is false
if (label[first_batch + i] >= 0 &&
label[first_batch + i] < element_count) {
if (label[first_batch + i] == loss_idx) {
loss[first_batch + i] = -logsoftmax;
}
} else {
loss[first_batch + i] = static_cast<T>(0.0);
}
}
} else { // softmax
softmax[(first_batch + i) * stride + idx] =
srcdata[i][it][0] / sum[i];
}
} else {
break;
}
} else { // KVSize>1
VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
VecT tmpdata;
T* tmpptr = reinterpret_cast<T*>(&tmpdata);
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (mode == SoftmaxMode::kLogSoftmax) { // log softmax
tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i];
// softmax with cross entropy hard label
} else if (mode == SoftmaxMode::kCrossEntropy) {
AccT logsoftmax = srcdata[i][it][s] - max_value[i] - sum[i];
// softmax
tmpptr[s] = std::exp(logsoftmax);
// label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s;
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (label[first_batch + i] == loss_idx &&
label[first_batch + i] != ignore_index) {
loss[first_batch + i] = -logsoftmax;
}
} else {
// IgnoreIndex is false
if (label[first_batch + i] >= 0 &&
label[first_batch + i] < element_count) {
if (label[first_batch + i] == loss_idx) {
loss[first_batch + i] = -logsoftmax;
}
} else {
loss[first_batch + i] = static_cast<T>(0.0);
}
}
} else { // softmax
tmpptr[s] = srcdata[i][it][s] / sum[i];
}
}
if (idx < idx_max_v[i]) {
softmax_v[idx] = tmpdata;
} else {
break;
}
}
}
}
}
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, VecT, AccT) \
case Log2Elements: \
WarpSoftmaxForward<T, VecT, AccT, Log2Elements, mode, \
IgnoreIndex><<<blocks, threads, 0, stream>>>( \
loss, softmax, src, label, batch_size, stride, element_count, \
ignore_index); \
break;
/*
Wrapper of softmax with cross entropy forward hard label.
*/
template <typename T, SoftmaxMode mode, bool IgnoreIndex>
void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src,
const int64_t* label, const int batch_size,
const int stride, const int element_count,
const int ignore_index, gpuStream_t stream) {
using AccT = typename details::MPTypeTrait<T>::Type;
// use 128 threads per block to maximimize gpu utilization
const int Log2Elements = static_cast<int>(Log2Ceil(element_count));
const int kDimCeil = 1 << Log2Elements;
int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
int batches_per_warp = (kDimCeil <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / kWarpSize);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_size + batches_per_block - 1) / batches_per_block;
dim3 threads(kWarpSize, warps_per_block, 1);
switch (Log2Elements) {
SOFTMAX_WARP_FORWARD_CASE(0, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(1, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(2, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(3, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(4, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(5, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(6, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(7, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(8, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(9, T, AccT);
default:
break;
}
}
/*
Wrapper of softmax with cross entropy hard label.
- SwitchWarpSoftmaxForward for small size
- cudnn function for large size
*/
template <typename T, bool IgnoreIndex>
static void SoftmaxWithCrossEntropyHardLabel(
const platform::CUDADeviceContext& ctx, int rank, int axis,
const T* logits_data, const int64_t* labels_data, T* loss_data,
T* softmax_data, int N, int dim, int D, const int ignore_index) {
auto stream = ctx.stream();
constexpr int max_dim = 320;
if (D == 1 && dim <= max_dim) { // small size
const SoftmaxMode mode = SoftmaxMode::kCrossEntropy;
SwitchWarpSoftmaxForward<T, mode, IgnoreIndex>(
loss_data, softmax_data, logits_data, labels_data, N, dim, dim,
ignore_index, stream);
} else {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#else
cudnnTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#endif
auto handle = ctx.cudnn_handle();
#ifdef PADDLE_WITH_HIP
auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
: MIOPEN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2(
handle, platform::CudnnDataType<T>::kOne(), descp, logits_data,
platform::CudnnDataType<T>::kZero(), descp, softmax_data,
MIOPEN_SOFTMAX_LOG, mode));
#else
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward(
handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(),
descp, logits_data, platform::CudnnDataType<T>::kZero(), descp,
softmax_data));
#endif
int threads = 128;
int blocks = (N * dim * D + threads - 1) / threads;
// compute cross entropy, input is log softmax
CrossEntropyExpHardLabel<T, IgnoreIndex><<<blocks, threads, 0, stream>>>(
loss_data, softmax_data, labels_data, N, dim, D, ignore_index);
}
}
/*
Wrapper of softmax with cross entropy grad hard label.
*/
template <typename T> template <typename T>
__global__ void Scale(T* logit_grad, const T* loss_grad, const int64_t num, __global__ void SoftmaxWithCrossEntropyGradHardLabel(
const int64_t d, const int64_t remain, T* logits_grad, const T* loss_grad, const int64_t* labels, const int64_t n,
const int64_t* labels, const int ignore_index) { const int64_t dim, const int64_t d, const int ignore_index) {
CUDA_KERNEL_LOOP_TYPE(index, num, int64_t) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx_n = index / d; int64_t idx_n = idx / (d * dim);
int64_t idx_remain = index % remain; int64_t idx_dim = (idx / d) % dim;
int64_t idx_lbl = idx_n * remain + idx_remain; int64_t idx_d = idx % d;
if (labels[idx_lbl] == ignore_index) { int64_t ids = idx_n * d + idx_d;
logit_grad[index] = static_cast<T>(0.);
if (idx < n * dim * d) {
if (labels[ids] == ignore_index) {
logits_grad[idx] = static_cast<T>(0.0);
} else if (labels[ids] == idx_dim) {
logits_grad[idx] =
(logits_grad[idx] - static_cast<T>(1.0)) * loss_grad[ids];
} else { } else {
logit_grad[index] *= loss_grad[idx_lbl]; logits_grad[idx] *= loss_grad[ids];
} }
} }
} }
...@@ -123,8 +560,6 @@ __global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad, ...@@ -123,8 +560,6 @@ __global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad,
} }
} }
} // namespace
static __device__ __forceinline__ platform::float16 exp_on_device( static __device__ __forceinline__ platform::float16 exp_on_device(
platform::float16 x) { platform::float16 x) {
return ::Eigen::numext::exp(x); return ::Eigen::numext::exp(x);
...@@ -396,278 +831,6 @@ static __global__ void RowReductionForCrossEntropy(const T* logits_data, ...@@ -396,278 +831,6 @@ static __global__ void RowReductionForCrossEntropy(const T* logits_data,
if (threadIdx.x == 0) loss_data[blockIdx.x] = loss; if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
} }
template <typename T>
struct HardLabelCrossEntropyFunctor {
public:
HardLabelCrossEntropyFunctor(const int64_t* labels, T* loss,
const T* logits_data, int d, int axis_dim)
: labels_(labels),
loss_(loss),
logits_data_(logits_data),
d_(d),
axis_dim_(axis_dim) {}
__device__ void operator()(int idx) const {
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
int remain = d_ / axis_dim_;
int idx_n = idx / d_;
int idx_axis = (idx % d_) / remain;
int idx_remain = idx % remain;
// labels, loss view as [n, remain]
int idx_lbl = idx_n * remain + idx_remain;
// It also would ignore labels not in range(class_num).
if (idx_axis != labels_[idx_lbl]) {
} else {
loss_[idx_lbl] = -log_on_device(logits_data_[idx]);
}
}
private:
const int64_t* labels_;
T* loss_;
const T* logits_data_;
int d_;
int axis_dim_;
};
template <typename T>
struct HardLabelCrossEntropyFunctorWithIgnoreIdx {
public:
HardLabelCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels, T* loss,
const T* logits_data, int d,
int axis_dim, int ignore_idx)
: labels_(labels),
loss_(loss),
logits_data_(logits_data),
d_(d),
axis_dim_(axis_dim),
ignore_idx_(ignore_idx) {}
__device__ void operator()(int idx) const {
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
int remain = d_ / axis_dim_;
int idx_n = idx / d_;
int idx_axis = (idx % d_) / remain;
int idx_remain = idx % remain;
// labels, loss view as [n, remain]
int idx_lbl = idx_n * remain + idx_remain;
if (idx_axis == labels_[idx_lbl] && idx_axis != ignore_idx_) {
loss_[idx_lbl] = -log_on_device(logits_data_[idx]);
}
}
private:
const int64_t* labels_;
T* loss_;
const T* logits_data_;
int d_;
int axis_dim_;
int ignore_idx_;
};
template <typename T>
static void HardLabelCrossEntropy(const platform::CUDADeviceContext& ctx,
const T* logits_data,
const int64_t* labels_data, T* loss_data,
int n, int d, int axis_dim, int ignore_idx) {
constexpr int kMaxBlockDim = 512;
int block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim)));
int grid_dim = n * d / axis_dim;
auto stream = ctx.stream();
#define CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: { \
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d); \
if (ignore_idx >= 0 && ignore_idx < axis_dim) { \
for_range(HardLabelCrossEntropyFunctorWithIgnoreIdx<T>( \
labels_data, loss_data, logits_data, d, axis_dim, ignore_idx)); \
} else { \
for_range(HardLabelCrossEntropyFunctor<T>(labels_data, loss_data, \
logits_data, d, axis_dim)); \
} \
} break
switch (block_dim) {
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(512);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(256);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(128);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(64);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(32);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(16);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(8);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(4);
CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(2);
default:
PADDLE_THROW(platform::errors::Unavailable(
"Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
break;
}
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}
template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctor {
public:
HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
T* log_softmax, int64_t d,
int axis_dim, int ignore_idx)
: labels_(labels),
loss_(loss),
log_softmax_(log_softmax),
d_(d),
axis_dim_(axis_dim),
ignore_idx_(ignore_idx) {}
__device__ void operator()(int64_t idx) const {
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
int64_t remain = d_ / axis_dim_;
int64_t idx_n = idx / d_;
int64_t idx_axis = (idx % d_) / remain;
int64_t idx_remain = idx % remain;
// labels, loss view as [n, remain]
int64_t idx_lbl = idx_n * remain + idx_remain;
PADDLE_ENFORCE(labels_[idx_lbl] >= 0 && labels_[idx_lbl] < d_ ||
labels_[idx_lbl] == ignore_idx_,
"The value of label[%ld] expected >= 0 and < %ld, or == %d,"
"but got %ld. Please check input value.",
idx_lbl, d_, ignore_idx_, labels_[idx_lbl]);
// It also would ignore labels not in range(class_num).
if (idx_axis != labels_[idx_lbl]) {
log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
} else {
auto softmax = log_softmax_[idx];
log_softmax_[idx] = exp_on_device(softmax);
loss_[idx_lbl] = -softmax;
}
}
private:
const int64_t* labels_;
T* loss_;
T* log_softmax_;
int64_t d_;
int axis_dim_;
int ignore_idx_;
};
template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
public:
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels,
T* loss, T* log_softmax,
int64_t d, int axis_dim,
int ignore_idx)
: labels_(labels),
loss_(loss),
log_softmax_(log_softmax),
d_(d),
axis_dim_(axis_dim),
ignore_idx_(ignore_idx) {}
__device__ void operator()(int64_t idx) const {
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
int64_t remain = d_ / axis_dim_;
int64_t idx_n = idx / d_;
int64_t idx_axis = (idx % d_) / remain;
int64_t idx_remain = idx % remain;
// labels, loss view as [n, remain]
int64_t idx_lbl = idx_n * remain + idx_remain;
if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) {
log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
} else {
auto softmax = log_softmax_[idx];
log_softmax_[idx] = exp_on_device(softmax);
loss_[idx_lbl] = -softmax;
}
}
private:
const int64_t* labels_;
T* loss_;
T* log_softmax_;
int64_t d_;
int axis_dim_;
int ignore_idx_;
};
template <typename T>
static void HardLabelSoftmaxWithCrossEntropy(
const platform::CUDADeviceContext& ctx, const T* logits_data,
const int64_t* labels_data, T* loss_data, T* softmax_data, int64_t n,
int64_t d, int axis_dim, int ignore_idx) {
#ifdef __HIPCC__
// HIP platform will have loss nan if dim size > 256
constexpr int kMaxBlockDim = 256;
#else
constexpr int kMaxBlockDim = 512;
#endif
int64_t block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim)));
int64_t grid_dim = n * d / axis_dim;
auto stream = ctx.stream();
#ifdef __HIPCC__
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: { \
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForMax<T, BlockDim>), \
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
loss_data, d, axis_dim); \
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForSum<T, BlockDim>), \
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
loss_data, softmax_data, d, axis_dim); \
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForDiff<T, BlockDim>), \
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
loss_data, softmax_data, d, axis_dim); \
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d); \
if (ignore_idx >= 0 && ignore_idx < axis_dim) { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>( \
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
} else { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
} \
} break
#else
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: { \
RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
logits_data, loss_data, d, axis_dim); \
RowReductionForDiffMaxSum<T, BlockDim, \
true><<<grid_dim, BlockDim, 0, stream>>>( \
logits_data, loss_data, softmax_data, d, axis_dim); \
platform::ForRange<platform::CUDADeviceContext> for_range(ctx, n* d); \
if (ignore_idx >= 0 && ignore_idx < axis_dim) { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>( \
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
} else { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \
} \
} break
#endif
switch (block_dim) {
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
default:
PADDLE_THROW(platform::errors::Unavailable(
"Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
break;
}
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}
template <typename T> template <typename T>
static void SoftmaxWithCrossEntropyFusedKernel( static void SoftmaxWithCrossEntropyFusedKernel(
const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data, const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data,
...@@ -783,7 +946,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -783,7 +946,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
const int rank = softmax->dims().size(); const int rank = softmax->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = softmax->dims()[axis]; const int axis_dim = softmax->dims()[axis];
const int n = SizeToAxis(axis, softmax->dims()); const int n = SizeToAxis(axis, softmax->dims());
const int d = SizeFromAxis(axis, softmax->dims()); const int d = SizeFromAxis(axis, softmax->dims());
...@@ -826,10 +989,20 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -826,10 +989,20 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
} else { // HardLabel } else { // HardLabel
auto* logits_data = softmax->data<T>(); auto* logits_data = softmax->data<T>();
auto* labels_data = labels->data<int64_t>(); auto* labels_data = labels->data<int64_t>();
HardLabelCrossEntropy<T>(context.cuda_device_context(), logits_data, int threads = 128;
labels_data, loss_data, n, d, axis_dim, int blocks = (n * d / axis_dim + threads - 1) / threads;
if (ignore_index >= 0 && ignore_index < axis_dim) {
CrossEntropyHardLabel<T, true><<<
blocks, threads, 0, context.cuda_device_context().stream()>>>(
loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim,
ignore_index);
} else {
CrossEntropyHardLabel<T, false><<<
blocks, threads, 0, context.cuda_device_context().stream()>>>(
loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim,
ignore_index); ignore_index);
} }
}
// cause of input is softmax // cause of input is softmax
// copy to output softmax, directly // copy to output softmax, directly
...@@ -886,9 +1059,17 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -886,9 +1059,17 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
} else { } else {
auto* logits_data = logits->data<T>(); auto* logits_data = logits->data<T>();
auto* labels_data = labels->data<int64_t>(); auto* labels_data = labels->data<int64_t>();
HardLabelSoftmaxWithCrossEntropy<T>( if (ignore_index >= 0 && ignore_index < axis_dim) {
context.cuda_device_context(), logits_data, labels_data, loss_data, SoftmaxWithCrossEntropyHardLabel<T, true>(
softmax_data, n, d, axis_dim, ignore_index); context.cuda_device_context(), rank, axis, logits_data,
labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim,
ignore_index);
} else {
SoftmaxWithCrossEntropyHardLabel<T, false>(
context.cuda_device_context(), rank, axis, logits_data,
labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim,
ignore_index);
}
} }
} }
} }
...@@ -959,14 +1140,11 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -959,14 +1140,11 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>( SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain); logit_grad_data, loss_grad_data, label_data, n, d, remain);
} else { } else {
int64_t grid = (n * remain + block - 1) / block;
const int64_t* label_data = labels->data<int64_t>(); const int64_t* label_data = labels->data<int64_t>();
CrossEntropyGrad<T><<<grid, block, 0, stream>>>( int grid = (n * d + block - 1) / block;
logit_grad_data, label_data, n, d, remain, ignore_index); SoftmaxWithCrossEntropyGradHardLabel<T><<<grid, block, 0, stream>>>(
int64_t num = n * d; logit_grad_data, loss_grad_data, label_data, n, d / remain, remain,
grid = (num + block - 1) / block; ignore_index);
Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
d, remain, label_data, ignore_index);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册