未验证 提交 eaa3fd45 编写于 作者: S sneaxiy 提交者: GitHub

add more int type support for softmax_with_cross_entropy (#39409)

上级 8d87b3bc
...@@ -30,59 +30,90 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -30,59 +30,90 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T> template <typename T>
class CrossEntropyFunctor<platform::CPUDeviceContext, T> { struct HardLabelCrossEntropyCPUFunctorImpl {
public: HardLabelCrossEntropyCPUFunctorImpl(framework::Tensor* out,
void operator()(const platform::CPUDeviceContext& ctx, framework::Tensor* out,
const framework::Tensor* prob, const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel, const framework::Tensor* labels,
const int ignore_index, const int axis_dim) { const int ignore_index,
const int batch_size = prob->dims()[0]; const int axis_dim)
const int num_classes = prob->dims()[1]; : out_(out),
const int num_remain = num_classes / axis_dim; prob_(prob),
labels_(labels),
ignore_index_(ignore_index),
axis_dim_(axis_dim) {}
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain); template <typename U>
void apply() const {
const int batch_size = prob_->dims()[0];
const int num_classes = prob_->dims()[1];
const int num_remain = num_classes / axis_dim_;
if (softLabel) { const T* prob_data = prob_->template data<T>();
auto in = EigenMatrix<T>::From(*prob); T* loss_data = out_->template data<T>();
auto lbl = EigenMatrix<T>::From(*labels);
auto loss = EigenMatrix<T>::From(*out);
loss.device(*ctx.eigen_device()) = const auto* label_data = labels_->template data<U>();
-((lbl * in.log().unaryExpr(math::TolerableValue<T>()))
.reshape(batch_axis_remain)
.sum(Eigen::DSizes<int, 1>(1)));
} else {
const T* prob_data = prob->data<T>();
T* loss_data = out->data<T>();
const int64_t* label_data = labels->data<int64_t>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < num_remain; j++) { for (int j = 0; j < num_remain; j++) {
int lbl = label_data[i * num_remain + j]; int lbl = static_cast<int>(label_data[i * num_remain + j]);
if (lbl != ignore_index) { if (lbl != ignore_index_) {
PADDLE_ENFORCE_GE(lbl, 0, PADDLE_ENFORCE_GE(lbl, 0,
platform::errors::OutOfRange( platform::errors::OutOfRange(
"label value should >= 0 when label " "label value should >= 0 when label "
"value(%f) not equal to ignore_index(%f)", "value(%f) not equal to ignore_index(%f)",
lbl, ignore_index)); lbl, ignore_index_));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
lbl, axis_dim, lbl, axis_dim_,
platform::errors::OutOfRange( platform::errors::OutOfRange(
"label value should less than the shape of axis dimension " "label value should less than the shape of axis dimension "
"when label value(%f) not equal to ignore_index(%f), But " "when label value(%f) not equal to ignore_index(%f), But "
"received label value as %ld and shape of axis dimension " "received label value as %ld and shape of axis dimension "
"is %d", "is %d",
lbl, ignore_index, lbl, axis_dim)); lbl, ignore_index_, lbl, axis_dim_));
} }
int index = i * num_classes + lbl * num_remain + j; int index = i * num_classes + lbl * num_remain + j;
int loss_idx = i * num_remain + j; int loss_idx = i * num_remain + j;
loss_data[loss_idx] = loss_data[loss_idx] =
lbl == ignore_index lbl == ignore_index_
? 0 ? 0
: -math::TolerableValue<T>()(std::log(prob_data[index])); : -math::TolerableValue<T>()(std::log(prob_data[index]));
} }
} }
} }
private:
framework::Tensor* out_;
const framework::Tensor* prob_;
const framework::Tensor* labels_;
const int ignore_index_;
const int axis_dim_;
};
template <typename T>
class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& ctx, framework::Tensor* out,
const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel,
const int ignore_index, const int axis_dim) {
if (softLabel) {
const int batch_size = prob->dims()[0];
const int num_classes = prob->dims()[1];
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
auto in = EigenMatrix<T>::From(*prob);
auto lbl = EigenMatrix<T>::From(*labels);
auto loss = EigenMatrix<T>::From(*out);
loss.device(*ctx.eigen_device()) =
-((lbl * in.log().unaryExpr(math::TolerableValue<T>()))
.reshape(batch_axis_remain)
.sum(Eigen::DSizes<int, 1>(1)));
} else {
HardLabelCrossEntropyCPUFunctorImpl<T> functor_impl(
out, prob, labels, ignore_index, axis_dim);
framework::VisitIntDataType(labels->type(), functor_impl);
}
} }
}; };
......
...@@ -21,18 +21,19 @@ namespace paddle { ...@@ -21,18 +21,19 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename T> template <typename T, typename LabelT>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, __global__ void CrossEntropyKernel(T* Y, const T* X, const LabelT* label,
const int N, const int D, const int N, const int D,
const int ignore_index) { const int ignore_index) {
CUDA_KERNEL_LOOP(i, N) { CUDA_KERNEL_LOOP(i, N) {
PADDLE_ENFORCE(label[i] >= 0 && label[i] < D || label[i] == ignore_index, auto lbl = static_cast<int64_t>(label[i]);
PADDLE_ENFORCE(lbl >= 0 && lbl < D || lbl == ignore_index,
"The value of label[%d] expected >= 0 and < %ld, or == %ld, " "The value of label[%d] expected >= 0 and < %ld, or == %ld, "
"but got %ld. Please check input value.", "but got %ld. Please check input value.",
i, D, ignore_index, label[i]); i, D, ignore_index, lbl);
Y[i] = ignore_index == label[i] Y[i] = ignore_index == lbl
? static_cast<T>(0) ? static_cast<T>(0)
: -math::TolerableValue<T>()(real_log(X[i * D + label[i]])); : -math::TolerableValue<T>()(real_log(X[i * D + lbl]));
} }
} }
...@@ -54,6 +55,43 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, ...@@ -54,6 +55,43 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
} }
} }
template <typename T>
struct HardLabelCrossEntropyCUDAFunctorImpl {
public:
HardLabelCrossEntropyCUDAFunctorImpl(T* loss_data, const T* prob_data,
const void* label_data,
const int batch_size,
const int class_num,
const int ignore_index,
const int block_size, gpuStream_t stream)
: loss_data_(loss_data),
prob_data_(prob_data),
label_data_(label_data),
batch_size_(batch_size),
class_num_(class_num),
ignore_index_(ignore_index),
block_size_(block_size),
stream_(stream) {}
template <typename U>
void apply() const {
int grid_size = (batch_size_ + block_size_ - 1) / block_size_;
CrossEntropyKernel<T, U><<<grid_size, block_size_, 0, stream_>>>(
loss_data_, prob_data_, static_cast<const U*>(label_data_), batch_size_,
class_num_, ignore_index_);
}
private:
T* loss_data_;
const T* prob_data_;
const void* label_data_;
const int batch_size_;
const int class_num_;
const int ignore_index_;
const int block_size_;
gpuStream_t stream_;
};
template <typename T> template <typename T>
class CrossEntropyFunctor<platform::CUDADeviceContext, T> { class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
public: public:
...@@ -81,12 +119,10 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> { ...@@ -81,12 +119,10 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
SoftCrossEntropyKernel<T><<<batch_size, block, 0, ctx.stream()>>>( SoftCrossEntropyKernel<T><<<batch_size, block, 0, ctx.stream()>>>(
loss_data, prob_data, label_data, class_num); loss_data, prob_data, label_data, class_num);
} else { } else {
const int64_t* label_data = labels->data<int64_t>(); HardLabelCrossEntropyCUDAFunctorImpl<T> functor(
int block = kMaxBlockDim; loss_data, prob_data, labels->data(), batch_size, class_num,
int grid = (batch_size + block - 1) / block; ignore_index, kMaxBlockDim, ctx.stream());
CrossEntropyKernel<T><<<grid, block, 0, ctx.stream()>>>( framework::VisitDataType(labels->type(), functor);
loss_data, prob_data, label_data, batch_size, class_num,
ignore_index);
} }
} }
}; };
......
...@@ -59,9 +59,9 @@ enum class SoftmaxMode { kSoftmax, kLogSoftmax, kCrossEntropy }; ...@@ -59,9 +59,9 @@ enum class SoftmaxMode { kSoftmax, kLogSoftmax, kCrossEntropy };
/* /*
Hard label cross entropy. Hard label cross entropy.
*/ */
template <typename T, bool IgnoreIndex> template <typename T, typename LabelT, bool IgnoreIndex>
__global__ void CrossEntropyHardLabel(T* loss, const T* softmax, __global__ void CrossEntropyHardLabel(T* loss, const T* softmax,
const int64_t* labels, const int n, const LabelT* labels, const int n,
const int dim, const int d, const int dim, const int d,
const int ignore_idx) { const int ignore_idx) {
int64_t ids = blockIdx.x * blockDim.x + threadIdx.x; int64_t ids = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -70,13 +70,14 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax, ...@@ -70,13 +70,14 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax,
// thread ids compute loss[ids] using softmax[idx] // thread ids compute loss[ids] using softmax[idx]
if (ids < n * d) { if (ids < n * d) {
if (labels[ids] < 0) { // label is negative auto lbl = static_cast<int64_t>(labels[ids]);
if (lbl < 0) { // label is negative
loss[ids] = static_cast<T>(0.0); loss[ids] = static_cast<T>(0.0);
} else { // label is positive of zero } else { // label is positive of zero
int64_t idx = idx_n * dim * d + labels[ids] * d + idx_d; int64_t idx = idx_n * dim * d + lbl * d + idx_d;
if (IgnoreIndex == true) { if (IgnoreIndex == true) {
// IgnoreIndex is true // IgnoreIndex is true
if (labels[ids] == ignore_idx) { if (lbl == ignore_idx) {
loss[ids] = static_cast<T>(0.0); loss[ids] = static_cast<T>(0.0);
} else { } else {
loss[ids] = -Log(softmax[idx]); loss[ids] = -Log(softmax[idx]);
...@@ -94,9 +95,9 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax, ...@@ -94,9 +95,9 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax,
Input: log softmax Input: log softmax
Output: loss and exp(input) Output: loss and exp(input)
*/ */
template <typename T, bool IgnoreIndex> template <typename T, typename LabelT, bool IgnoreIndex>
__global__ void CrossEntropyExpHardLabel(T* loss, T* softmax, __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax,
const int64_t* labels, const int n, const LabelT* labels, const int n,
const int dim, const int d, const int dim, const int d,
const int ignore_idx) { const int ignore_idx) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -106,10 +107,11 @@ __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax, ...@@ -106,10 +107,11 @@ __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax,
int64_t ids = idx_n * d + idx_d; int64_t ids = idx_n * d + idx_d;
if (idx < n * dim * d) { if (idx < n * dim * d) {
auto lbl = static_cast<int64_t>(labels[ids]);
if (IgnoreIndex == true) { if (IgnoreIndex == true) {
// IgnoreIndex is true // IgnoreIndex is true
if (idx_dim == labels[ids]) { if (idx_dim == lbl) {
if (labels[ids] == ignore_idx) { if (lbl == ignore_idx) {
loss[ids] = static_cast<T>(0.0); loss[ids] = static_cast<T>(0.0);
} else { } else {
loss[ids] = -softmax[idx]; loss[ids] = -softmax[idx];
...@@ -117,8 +119,8 @@ __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax, ...@@ -117,8 +119,8 @@ __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax,
} }
} else { } else {
// IgnoreIndex is false // IgnoreIndex is false
if (labels[ids] >= 0 && labels[ids] < dim) { if (lbl >= 0 && lbl < dim) {
if (labels[ids] == idx_dim) { if (lbl == idx_dim) {
loss[ids] = -softmax[idx]; loss[ids] = -softmax[idx];
} }
} else { } else {
...@@ -151,10 +153,10 @@ __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax, ...@@ -151,10 +153,10 @@ __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax,
For reduction max (sum), firstly compute max (sum) to one warp, then use For reduction max (sum), firstly compute max (sum) to one warp, then use
shuffle api to compute max (sum) in one warp. shuffle api to compute max (sum) in one warp.
*/ */
template <typename T, typename VecT, typename AccT, int Log2Elements, template <typename T, typename LabelT, typename VecT, typename AccT,
SoftmaxMode mode, bool IgnoreIndex> int Log2Elements, SoftmaxMode mode, bool IgnoreIndex>
__global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src,
const int64_t* label, const int batch_size, const LabelT* label, const int batch_size,
const int stride, const int element_count, const int stride, const int element_count,
const int ignore_index) { const int ignore_index) {
constexpr int kDimCeil = 1 << Log2Elements; constexpr int kDimCeil = 1 << Log2Elements;
...@@ -299,10 +301,11 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, ...@@ -299,10 +301,11 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src,
softmax[(first_batch + i) * stride + idx] = std::exp(logsoftmax); softmax[(first_batch + i) * stride + idx] = std::exp(logsoftmax);
// label // label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize; int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize;
auto lbl = static_cast<int64_t>(label[first_batch + i]);
if (IgnoreIndex == true) { if (IgnoreIndex == true) {
// IgnoreIndex is true // IgnoreIndex is true
if (label[first_batch + i] == loss_idx) { if (lbl == loss_idx) {
if (label[first_batch + i] != ignore_index) { if (lbl != ignore_index) {
loss[first_batch + i] = -logsoftmax; loss[first_batch + i] = -logsoftmax;
} else { } else {
loss[first_batch + i] = static_cast<T>(0.0); loss[first_batch + i] = static_cast<T>(0.0);
...@@ -310,9 +313,8 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, ...@@ -310,9 +313,8 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src,
} }
} else { } else {
// IgnoreIndex is false // IgnoreIndex is false
if (label[first_batch + i] >= 0 && if (lbl >= 0 && lbl < element_count) {
label[first_batch + i] < element_count) { if (lbl == loss_idx) {
if (label[first_batch + i] == loss_idx) {
loss[first_batch + i] = -logsoftmax; loss[first_batch + i] = -logsoftmax;
} }
} else { } else {
...@@ -342,17 +344,16 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, ...@@ -342,17 +344,16 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src,
tmpptr[s] = std::exp(logsoftmax); tmpptr[s] = std::exp(logsoftmax);
// label // label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s; int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s;
auto lbl = static_cast<int64_t>(label[first_batch + i]);
if (IgnoreIndex == true) { if (IgnoreIndex == true) {
// IgnoreIndex is true // IgnoreIndex is true
if (label[first_batch + i] == loss_idx && if (lbl == loss_idx && lbl != ignore_index) {
label[first_batch + i] != ignore_index) {
loss[first_batch + i] = -logsoftmax; loss[first_batch + i] = -logsoftmax;
} }
} else { } else {
// IgnoreIndex is false // IgnoreIndex is false
if (label[first_batch + i] >= 0 && if (lbl >= 0 && lbl < element_count) {
label[first_batch + i] < element_count) { if (lbl == loss_idx) {
if (label[first_batch + i] == loss_idx) {
loss[first_batch + i] = -logsoftmax; loss[first_batch + i] = -logsoftmax;
} }
} else { } else {
...@@ -373,9 +374,9 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, ...@@ -373,9 +374,9 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src,
} }
} }
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, VecT, AccT) \ #define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \
case Log2Elements: \ case Log2Elements: \
WarpSoftmaxForward<T, VecT, AccT, Log2Elements, mode, \ WarpSoftmaxForward<T, LabelT, VecT, AccT, Log2Elements, mode, \
IgnoreIndex><<<blocks, threads, 0, stream>>>( \ IgnoreIndex><<<blocks, threads, 0, stream>>>( \
loss, softmax, src, label, batch_size, stride, element_count, \ loss, softmax, src, label, batch_size, stride, element_count, \
ignore_index); \ ignore_index); \
...@@ -384,9 +385,9 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, ...@@ -384,9 +385,9 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src,
/* /*
Wrapper of softmax with cross entropy forward hard label. Wrapper of softmax with cross entropy forward hard label.
*/ */
template <typename T, SoftmaxMode mode, bool IgnoreIndex> template <typename T, typename LabelT, SoftmaxMode mode, bool IgnoreIndex>
void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src, void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src,
const int64_t* label, const int batch_size, const LabelT* label, const int batch_size,
const int stride, const int element_count, const int stride, const int element_count,
const int ignore_index, gpuStream_t stream) { const int ignore_index, gpuStream_t stream) {
using AccT = typename details::MPTypeTrait<T>::Type; using AccT = typename details::MPTypeTrait<T>::Type;
...@@ -403,16 +404,16 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src, ...@@ -403,16 +404,16 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src,
dim3 threads(kWarpSize, warps_per_block, 1); dim3 threads(kWarpSize, warps_per_block, 1);
switch (log2_elements) { switch (log2_elements) {
SOFTMAX_WARP_FORWARD_CASE(0, T, AccT); SOFTMAX_WARP_FORWARD_CASE(0, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(1, T, AccT); SOFTMAX_WARP_FORWARD_CASE(1, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(2, T, AccT); SOFTMAX_WARP_FORWARD_CASE(2, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(3, T, AccT); SOFTMAX_WARP_FORWARD_CASE(3, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(4, T, AccT); SOFTMAX_WARP_FORWARD_CASE(4, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(5, T, AccT); SOFTMAX_WARP_FORWARD_CASE(5, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(6, T, AccT); SOFTMAX_WARP_FORWARD_CASE(6, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(7, T, AccT); SOFTMAX_WARP_FORWARD_CASE(7, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(8, T, AccT); SOFTMAX_WARP_FORWARD_CASE(8, LabelT, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(9, T, AccT); SOFTMAX_WARP_FORWARD_CASE(9, LabelT, T, AccT);
default: default:
break; break;
} }
...@@ -423,16 +424,16 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src, ...@@ -423,16 +424,16 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src,
- SwitchWarpSoftmaxForward for small size - SwitchWarpSoftmaxForward for small size
- cudnn function for large size - cudnn function for large size
*/ */
template <typename T, bool IgnoreIndex> template <typename T, typename LabelT, bool IgnoreIndex>
static void SoftmaxWithCrossEntropyHardLabel( static void SoftmaxWithCrossEntropyHardLabel(
const platform::CUDADeviceContext& ctx, int rank, int axis, const platform::CUDADeviceContext& ctx, int rank, int axis,
const T* logits_data, const int64_t* labels_data, T* loss_data, const T* logits_data, const LabelT* labels_data, T* loss_data,
T* softmax_data, int N, int dim, int D, const int ignore_index) { T* softmax_data, int N, int dim, int D, const int ignore_index) {
auto stream = ctx.stream(); auto stream = ctx.stream();
constexpr int max_dim = 320; constexpr int max_dim = 320;
if (D == 1 && dim <= max_dim) { // small size if (D == 1 && dim <= max_dim) { // small size
const SoftmaxMode mode = SoftmaxMode::kCrossEntropy; const SoftmaxMode mode = SoftmaxMode::kCrossEntropy;
SwitchWarpSoftmaxForward<T, mode, IgnoreIndex>( SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(
loss_data, softmax_data, logits_data, labels_data, N, dim, dim, loss_data, softmax_data, logits_data, labels_data, N, dim, dim,
ignore_index, stream); ignore_index, stream);
} else { } else {
...@@ -465,7 +466,8 @@ static void SoftmaxWithCrossEntropyHardLabel( ...@@ -465,7 +466,8 @@ static void SoftmaxWithCrossEntropyHardLabel(
int threads = 128; int threads = 128;
int blocks = (N * dim * D + threads - 1) / threads; int blocks = (N * dim * D + threads - 1) / threads;
// compute cross entropy, input is log softmax // compute cross entropy, input is log softmax
CrossEntropyExpHardLabel<T, IgnoreIndex><<<blocks, threads, 0, stream>>>( CrossEntropyExpHardLabel<T, LabelT,
IgnoreIndex><<<blocks, threads, 0, stream>>>(
loss_data, softmax_data, labels_data, N, dim, D, ignore_index); loss_data, softmax_data, labels_data, N, dim, D, ignore_index);
} }
} }
...@@ -473,9 +475,9 @@ static void SoftmaxWithCrossEntropyHardLabel( ...@@ -473,9 +475,9 @@ static void SoftmaxWithCrossEntropyHardLabel(
/* /*
Wrapper of softmax with cross entropy grad hard label. Wrapper of softmax with cross entropy grad hard label.
*/ */
template <typename T> template <typename T, typename LabelT>
__global__ void SoftmaxWithCrossEntropyGradHardLabel( __global__ void SoftmaxWithCrossEntropyGradHardLabel(
T* logits_grad, const T* loss_grad, const int64_t* labels, const int64_t n, T* logits_grad, const T* loss_grad, const LabelT* labels, const int64_t n,
const int64_t dim, const int64_t d, const int ignore_index) { const int64_t dim, const int64_t d, const int ignore_index) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx_n = idx / (d * dim); int64_t idx_n = idx / (d * dim);
...@@ -484,9 +486,10 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabel( ...@@ -484,9 +486,10 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabel(
int64_t ids = idx_n * d + idx_d; int64_t ids = idx_n * d + idx_d;
if (idx < n * dim * d) { if (idx < n * dim * d) {
if (labels[ids] == ignore_index) { auto lbl = static_cast<int64_t>(labels[ids]);
if (lbl == ignore_index) {
logits_grad[idx] = static_cast<T>(0.0); logits_grad[idx] = static_cast<T>(0.0);
} else if (labels[ids] == idx_dim) { } else if (lbl == idx_dim) {
logits_grad[idx] = logits_grad[idx] =
(logits_grad[idx] - static_cast<T>(1.0)) * loss_grad[ids]; (logits_grad[idx] - static_cast<T>(1.0)) * loss_grad[ids];
} else { } else {
...@@ -887,16 +890,16 @@ __global__ void SoftLabelCrossEntropyGradientKernel(T* logit_grad, ...@@ -887,16 +890,16 @@ __global__ void SoftLabelCrossEntropyGradientKernel(T* logit_grad,
} }
} }
template <typename T> template <typename T, typename LabelT>
__global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad, __global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad,
const int64_t* labels, const LabelT* labels,
const int n, const int d, const int n, const int d,
const int remain, const int remain,
const int ignore_index) { const int ignore_index) {
CUDA_KERNEL_LOOP(index, n * remain) { CUDA_KERNEL_LOOP(index, n * remain) {
int idx_n = index / remain; int idx_n = index / remain;
int idx_remain = index % remain; int idx_remain = index % remain;
int tmp = labels[index]; int tmp = static_cast<int>(labels[index]);
int idx = idx_n * d + tmp * remain + idx_remain; int idx = idx_n * d + tmp * remain + idx_remain;
if (ignore_index != tmp) { if (ignore_index != tmp) {
logit_grad[idx] = -static_cast<T>(1.) / logit_grad[idx]; logit_grad[idx] = -static_cast<T>(1.) / logit_grad[idx];
...@@ -904,18 +907,19 @@ __global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad, ...@@ -904,18 +907,19 @@ __global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad,
} }
} }
template <typename T> template <typename T, typename LabelT>
__global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad, __global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad,
const int num, const int d, const int num, const int d,
const int remain, const int remain,
const int64_t* labels, const LabelT* labels,
const int ignore_index) { const int ignore_index) {
CUDA_KERNEL_LOOP(index, num) { CUDA_KERNEL_LOOP(index, num) {
int idx_n = index / d; int idx_n = index / d;
int idx_remain = index % remain; int idx_remain = index % remain;
int idx_lbl = idx_n * remain + idx_remain; int idx_lbl = idx_n * remain + idx_remain;
int k = (index % d) / remain; int k = (index % d) / remain;
if (labels[idx_lbl] == ignore_index || labels[idx_lbl] != k) { auto lbl = static_cast<int64_t>(labels[idx_lbl]);
if (lbl == ignore_index || lbl != k) {
logit_grad[index] = static_cast<T>(0.); logit_grad[index] = static_cast<T>(0.);
} else { } else {
logit_grad[index] *= loss_grad[idx_lbl]; logit_grad[index] *= loss_grad[idx_lbl];
...@@ -927,6 +931,12 @@ template <typename T> ...@@ -927,6 +931,12 @@ template <typename T>
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
RunSoftmaxWithCrossEntropyFunctor<T>(context, *this);
}
template <typename LabelT>
static void Apply(const framework::ExecutionContext& context,
const framework::Tensor& labels, const bool soft_label) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true, platform::is_gpu_place(context.GetPlace()), true,
platform::errors::Unavailable("softmax_with_cross_entropy operator's " platform::errors::Unavailable("softmax_with_cross_entropy operator's "
...@@ -936,7 +946,6 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -936,7 +946,6 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
// do not with softmax op, and input is softmax // do not with softmax op, and input is softmax
if (!use_softmax) { if (!use_softmax) {
const Tensor* softmax = context.Input<Tensor>("Logits"); const Tensor* softmax = context.Input<Tensor>("Logits");
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax_out = context.Output<Tensor>("Softmax"); Tensor* softmax_out = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss"); Tensor* loss = context.Output<Tensor>("Loss");
...@@ -947,8 +956,9 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -947,8 +956,9 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
const int n = SizeToAxis(axis, softmax->dims()); const int n = SizeToAxis(axis, softmax->dims());
const int d = SizeFromAxis(axis, softmax->dims()); const int d = SizeFromAxis(axis, softmax->dims());
auto* softmax_out_data = softmax_out->mutable_data<T>(context.GetPlace()); auto* softmax_out_data =
auto* loss_data = loss->mutable_data<T>(context.GetPlace()); softmax_out->template mutable_data<T>(context.GetPlace());
auto* loss_data = loss->template mutable_data<T>(context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_constant; math::SetConstant<platform::CUDADeviceContext, T> set_constant;
set_constant(context.cuda_device_context(), loss, static_cast<T>(0)); set_constant(context.cuda_device_context(), loss, static_cast<T>(0));
...@@ -958,12 +968,11 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -958,12 +968,11 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
return; return;
} }
auto soft_label = context.Attr<bool>("soft_label");
auto ignore_index = context.Attr<int>("ignore_index"); auto ignore_index = context.Attr<int>("ignore_index");
Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d; Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d;
softmax_2d.ShareDataWith(*softmax).Resize({n, d}); softmax_2d.ShareDataWith(*softmax).Resize({n, d});
labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n});
loss_2d.ShareDataWith(*loss).Resize({n, 1}); loss_2d.ShareDataWith(*loss).Resize({n, 1});
softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d}); softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d});
...@@ -977,8 +986,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -977,8 +986,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
// if axis is not the last, we need a new impliment // if axis is not the last, we need a new impliment
if (soft_label) { if (soft_label) {
auto* logits_data = softmax->data<T>(); auto* logits_data = softmax->template data<T>();
auto* labels_data = labels->data<T>(); auto* labels_data = labels.template data<T>();
const int kDimLog2 = static_cast<int>(Log2Ceil(axis_dim)); const int kDimLog2 = static_cast<int>(Log2Ceil(axis_dim));
const int kDimCeil = 1 << kDimLog2; const int kDimCeil = 1 << kDimLog2;
...@@ -996,17 +1005,17 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -996,17 +1005,17 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
loss_data, NULL, logits_data, labels_data, n, axis_dim, loss_data, NULL, logits_data, labels_data, n, axis_dim,
d / axis_dim, kDimLog2); d / axis_dim, kDimLog2);
} else { // HardLabel } else { // HardLabel
auto* logits_data = softmax->data<T>(); auto* logits_data = softmax->template data<T>();
auto* labels_data = labels->data<int64_t>(); auto* labels_data = labels.template data<LabelT>();
int threads = 128; int threads = 128;
int blocks = (n * d / axis_dim + threads - 1) / threads; int blocks = (n * d / axis_dim + threads - 1) / threads;
if (ignore_index >= 0 && ignore_index < axis_dim) { if (ignore_index >= 0 && ignore_index < axis_dim) {
CrossEntropyHardLabel<T, true><<< CrossEntropyHardLabel<T, LabelT, true><<<
blocks, threads, 0, context.cuda_device_context().stream()>>>( blocks, threads, 0, context.cuda_device_context().stream()>>>(
loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim, loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim,
ignore_index); ignore_index);
} else { } else {
CrossEntropyHardLabel<T, false><<< CrossEntropyHardLabel<T, LabelT, false><<<
blocks, threads, 0, context.cuda_device_context().stream()>>>( blocks, threads, 0, context.cuda_device_context().stream()>>>(
loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim, loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim,
ignore_index); ignore_index);
...@@ -1022,7 +1031,6 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -1022,7 +1031,6 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
} }
const Tensor* logits = context.Input<Tensor>("Logits"); const Tensor* logits = context.Input<Tensor>("Logits");
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax = context.Output<Tensor>("Softmax"); Tensor* softmax = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss"); Tensor* loss = context.Output<Tensor>("Loss");
...@@ -1033,8 +1041,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -1033,8 +1041,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
const int64_t n = SizeToAxis(axis, logits->dims()); const int64_t n = SizeToAxis(axis, logits->dims());
const int64_t d = SizeFromAxis(axis, logits->dims()); const int64_t d = SizeFromAxis(axis, logits->dims());
auto* softmax_data = softmax->mutable_data<T>(context.GetPlace()); auto* softmax_data = softmax->template mutable_data<T>(context.GetPlace());
auto* loss_data = loss->mutable_data<T>(context.GetPlace()); auto* loss_data = loss->template mutable_data<T>(context.GetPlace());
if (axis_dim == 1) { if (axis_dim == 1) {
math::SetConstant<platform::CUDADeviceContext, T> set_constant; math::SetConstant<platform::CUDADeviceContext, T> set_constant;
...@@ -1043,12 +1051,11 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -1043,12 +1051,11 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
return; return;
} }
auto soft_label = context.Attr<bool>("soft_label");
auto ignore_index = context.Attr<int>("ignore_index"); auto ignore_index = context.Attr<int>("ignore_index");
if (soft_label) { if (soft_label) {
auto* logits_data = logits->data<T>(); auto* logits_data = logits->template data<T>();
auto* labels_data = labels->data<T>(); auto* labels_data = labels.template data<T>();
SoftmaxWithCrossEntropySoftLabel<T>( SoftmaxWithCrossEntropySoftLabel<T>(
context.cuda_device_context(), rank, axis, logits_data, labels_data, context.cuda_device_context(), rank, axis, logits_data, labels_data,
softmax_data, loss_data, n, axis_dim, d / axis_dim); softmax_data, loss_data, n, axis_dim, d / axis_dim);
...@@ -1058,7 +1065,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -1058,7 +1065,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
Tensor logits_2d, softmax_2d, labels_2d, loss_2d; Tensor logits_2d, softmax_2d, labels_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({n, d}); logits_2d.ShareDataWith(*logits).Resize({n, d});
softmax_2d.ShareDataWith(*softmax).Resize({n, d}); softmax_2d.ShareDataWith(*softmax).Resize({n, d});
labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n});
loss_2d.ShareDataWith(*loss).Resize({n, 1}); loss_2d.ShareDataWith(*loss).Resize({n, 1});
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(), math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
&logits_2d, &softmax_2d); &logits_2d, &softmax_2d);
...@@ -1066,15 +1073,15 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -1066,15 +1073,15 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d, context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
false, ignore_index, axis_dim); false, ignore_index, axis_dim);
} else { } else {
auto* logits_data = logits->data<T>(); auto* logits_data = logits->template data<T>();
auto* labels_data = labels->data<int64_t>(); auto* labels_data = labels.template data<LabelT>();
if (ignore_index >= 0 && ignore_index < axis_dim) { if (ignore_index >= 0 && ignore_index < axis_dim) {
SoftmaxWithCrossEntropyHardLabel<T, true>( SoftmaxWithCrossEntropyHardLabel<T, LabelT, true>(
context.cuda_device_context(), rank, axis, logits_data, context.cuda_device_context(), rank, axis, logits_data,
labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim, labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim,
ignore_index); ignore_index);
} else { } else {
SoftmaxWithCrossEntropyHardLabel<T, false>( SoftmaxWithCrossEntropyHardLabel<T, LabelT, false>(
context.cuda_device_context(), rank, axis, logits_data, context.cuda_device_context(), rank, axis, logits_data,
labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim, labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim,
ignore_index); ignore_index);
...@@ -1088,13 +1095,19 @@ template <typename T> ...@@ -1088,13 +1095,19 @@ template <typename T>
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
RunSoftmaxWithCrossEntropyFunctor<T>(context, *this);
}
template <typename LabelT>
static void Apply(const framework::ExecutionContext& context,
const framework::Tensor& labels, const bool soft_label) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true, platform::is_gpu_place(context.GetPlace()), true,
platform::errors::Unavailable("softmax_with_cross_entropy operator's " platform::errors::Unavailable("softmax_with_cross_entropy operator's "
"CUDA kernel only runs on GPU device.")); "CUDA kernel only runs on GPU device."));
const Tensor* labels = context.Input<Tensor>("Label");
const T* loss_grad_data = const T* loss_grad_data =
context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>(); context.Input<Tensor>(framework::GradVarName("Loss"))
->template data<T>();
Tensor* logit_grad = Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits")); context.Output<Tensor>(framework::GradVarName("Logits"));
const Tensor* softmax = context.Input<Tensor>("Softmax"); const Tensor* softmax = context.Input<Tensor>("Softmax");
...@@ -1102,7 +1115,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -1102,7 +1115,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
framework::TensorCopy(*softmax, context.GetPlace(), framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), logit_grad); context.device_context(), logit_grad);
} }
T* logit_grad_data = logit_grad->data<T>(); T* logit_grad_data = logit_grad->template data<T>();
const int rank = logit_grad->dims().size(); const int rank = logit_grad->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
...@@ -1123,21 +1136,22 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -1123,21 +1136,22 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
// do not with softmax op, and input is softmax // do not with softmax op, and input is softmax
if (!use_softmax) { if (!use_softmax) {
if (context.Attr<bool>("soft_label")) { if (soft_label) {
int grid = (n * d + block - 1) / block; int grid = (n * d + block - 1) / block;
const T* label_data = labels->data<T>(); const T* label_data = labels.template data<T>();
SoftLabelCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>( SoftLabelCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain); logit_grad_data, loss_grad_data, label_data, n, d, remain);
} else { } else {
Tensor logits_grad_2d; Tensor logits_grad_2d;
logits_grad_2d.ShareDataWith(*logit_grad).Resize({n, d}); logits_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
int grid = (n * remain + block - 1) / block; int grid = (n * remain + block - 1) / block;
const int64_t* label_data = labels->data<int64_t>(); const auto* label_data = labels.template data<LabelT>();
HardLabelCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>( HardLabelCrossEntropyGradientKernel<T,
LabelT><<<grid, block, 0, stream>>>(
logit_grad_data, label_data, n, d, remain, ignore_index); logit_grad_data, label_data, n, d, remain, ignore_index);
int num = n * d; int num = n * d;
grid = (num + block - 1) / block; grid = (num + block - 1) / block;
ScaleCrossEntropyGradient<T><<<grid, block, 0, stream>>>( ScaleCrossEntropyGradient<T, LabelT><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, num, d, remain, label_data, logit_grad_data, loss_grad_data, num, d, remain, label_data,
ignore_index); ignore_index);
} }
...@@ -1147,13 +1161,13 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -1147,13 +1161,13 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
// with softmax, continue // with softmax, continue
if (context.Attr<bool>("soft_label")) { if (soft_label) {
int64_t grid = (n * d + block - 1) / block; int64_t grid = (n * d + block - 1) / block;
const T* label_data = labels->data<T>(); const T* label_data = labels.template data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>( SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain); logit_grad_data, loss_grad_data, label_data, n, d, remain);
} else { } else {
const int64_t* label_data = labels->data<int64_t>(); const auto* label_data = labels.template data<LabelT>();
int grid = (n * d + block - 1) / block; int grid = (n * d + block - 1) / block;
SoftmaxWithCrossEntropyGradHardLabel<T><<<grid, block, 0, stream>>>( SoftmaxWithCrossEntropyGradHardLabel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d / remain, remain, logit_grad_data, loss_grad_data, label_data, n, d / remain, remain,
......
...@@ -24,6 +24,48 @@ namespace operators { ...@@ -24,6 +24,48 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, typename Visitor>
struct SoftmaxWithCrossEntropyFunctor {
public:
SoftmaxWithCrossEntropyFunctor(const framework::ExecutionContext& context,
const framework::Tensor& labels,
const bool soft_label, const Visitor& visitor)
: context_(context),
labels_(labels),
soft_label_(soft_label),
visitor_(visitor) {}
template <typename U>
void apply() const {
visitor_.template Apply<U>(context_, labels_, soft_label_);
}
private:
const framework::ExecutionContext& context_;
const framework::Tensor& labels_;
const bool soft_label_;
const Visitor& visitor_;
};
template <typename T, typename Visitor>
static void RunSoftmaxWithCrossEntropyFunctor(
const framework::ExecutionContext& context, const Visitor& visitor) {
const auto* labels = context.Input<framework::Tensor>("Label");
const bool soft_label = context.Attr<bool>("soft_label");
SoftmaxWithCrossEntropyFunctor<T, Visitor> functor(context, *labels,
soft_label, visitor);
auto dtype = labels->type();
if (soft_label) {
PADDLE_ENFORCE_EQ(
dtype, framework::DataTypeTrait<T>::DataType(),
platform::errors::InvalidArgument("The Input(Label) should be with the "
"same data type as Input(Logits)."));
functor.template apply<T>();
} else {
framework::VisitIntDataType(dtype, functor);
}
}
template <typename T> template <typename T>
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> { class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
public: public:
...@@ -32,14 +74,14 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> { ...@@ -32,14 +74,14 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
platform::is_cpu_place(context.GetPlace()), true, platform::is_cpu_place(context.GetPlace()), true,
platform::errors::Unimplemented("This kernel only runs on CPU.")); platform::errors::Unimplemented("This kernel only runs on CPU."));
const bool use_softmax = context.Attr<bool>("use_softmax"); const bool use_softmax = context.Attr<bool>("use_softmax");
const Tensor* labels = context.Input<Tensor>("Label");
const bool soft_label = context.Attr<bool>("soft_label");
// do not with softmax op, and input is softmax // do not with softmax op, and input is softmax
if (!use_softmax) { if (!use_softmax) {
const Tensor* softmax = context.Input<Tensor>("Logits"); const Tensor* softmax = context.Input<Tensor>("Logits");
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax_out = context.Output<Tensor>("Softmax"); Tensor* softmax_out = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss"); Tensor* loss = context.Output<Tensor>("Loss");
const bool soft_label = context.Attr<bool>("soft_label");
const int rank = softmax->dims().size(); const int rank = softmax->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = softmax->dims()[axis]; int axis_dim = softmax->dims()[axis];
...@@ -86,10 +128,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> { ...@@ -86,10 +128,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
} }
const Tensor* logits = context.Input<Tensor>("Logits"); const Tensor* logits = context.Input<Tensor>("Logits");
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* softmax = context.Output<Tensor>("Softmax"); Tensor* softmax = context.Output<Tensor>("Softmax");
Tensor* loss = context.Output<Tensor>("Loss"); Tensor* loss = context.Output<Tensor>("Loss");
const bool soft_label = context.Attr<bool>("soft_label");
const int rank = logits->dims().size(); const int rank = logits->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
...@@ -132,9 +172,14 @@ template <typename T> ...@@ -132,9 +172,14 @@ template <typename T>
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
RunSoftmaxWithCrossEntropyFunctor<T>(context, *this);
}
template <typename LabelT>
static void Apply(const framework::ExecutionContext& context,
const framework::Tensor& labels, const bool soft_label) {
const Tensor* out_grad = const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Loss")); context.Input<Tensor>(framework::GradVarName("Loss"));
const Tensor* labels = context.Input<Tensor>("Label");
Tensor* logit_grad = Tensor* logit_grad =
context.Output<Tensor>(framework::GradVarName("Logits")); context.Output<Tensor>(framework::GradVarName("Logits"));
const Tensor* softmax = context.Input<Tensor>("Softmax"); const Tensor* softmax = context.Input<Tensor>("Softmax");
...@@ -143,7 +188,6 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -143,7 +188,6 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
framework::TensorCopy(*softmax, context.GetPlace(), framework::TensorCopy(*softmax, context.GetPlace(),
context.device_context(), logit_grad); context.device_context(), logit_grad);
} }
const bool soft_label = context.Attr<bool>("soft_label");
auto ignore_index = context.Attr<int>("ignore_index"); auto ignore_index = context.Attr<int>("ignore_index");
const int rank = logit_grad->dims().size(); const int rank = logit_grad->dims().size();
...@@ -166,7 +210,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -166,7 +210,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
const int d = SizeFromAxis(axis, logit_grad->dims()); const int d = SizeFromAxis(axis, logit_grad->dims());
Tensor logit_grad_2d, labels_2d, out_grad_2d; Tensor logit_grad_2d, labels_2d, out_grad_2d;
logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d}); logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d});
labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n});
out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim}); out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim});
auto out_grad_mat = framework::EigenMatrix<T>::From(out_grad_2d); auto out_grad_mat = framework::EigenMatrix<T>::From(out_grad_2d);
auto logit_grad_mat = framework::EigenMatrix<T>::From(logit_grad_2d); auto logit_grad_mat = framework::EigenMatrix<T>::From(logit_grad_2d);
...@@ -183,23 +227,24 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -183,23 +227,24 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
logit_grad_mat; logit_grad_mat;
} else { } else {
// use_softmax step2 // use_softmax step2
const int64_t* label_data = labels->data<int64_t>(); const auto* label_data = labels.template data<LabelT>();
T* logit_grad_data = logit_grad->data<T>(); T* logit_grad_data = logit_grad->template data<T>();
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->template data<T>();
const int remain = d / axis_dim; const int remain = d / axis_dim;
for (int i = 0; i < n; ++i) { // for each sample_1_dim for (int i = 0; i < n; ++i) { // for each sample_1_dim
for (int j = 0; j < remain; j++) { // for each sample_other_dims for (int j = 0; j < remain; j++) { // for each sample_other_dims
int idx = i * remain + j; // this sample's label_idx. for 1d case, int idx = i * remain + j; // this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i // remain=1 and j=0, so, idx = i
if (label_data[idx] == ignore_index) { auto lbl = static_cast<int64_t>(label_data[idx]);
if (lbl == ignore_index) {
for (int k = 0; k < axis_dim; ++k) { // for each class id's label for (int k = 0; k < axis_dim; ++k) { // for each class id's label
logit_grad_data[i * d + k * remain + j] = 0; logit_grad_data[i * d + k * remain + j] = 0;
} }
} else { } else {
// only for this sample's label_idx, the label is 1, others is 0, // only for this sample's label_idx, the label is 1, others is 0,
// so, only compute this label_idx's class // so, only compute this label_idx's class
logit_grad_data[i * d + label_data[idx] * remain + j] = logit_grad_data[i * d + lbl * remain + j] =
(-1 / logit_grad_data[i * d + label_data[idx] * remain + j]) * (-1 / logit_grad_data[i * d + lbl * remain + j]) *
out_grad_data[idx]; out_grad_data[idx];
for (int k = 0; k < axis_dim; ++k) { // for each class id's label for (int k = 0; k < axis_dim; ++k) { // for each class id's label
if (k != if (k !=
...@@ -233,15 +278,16 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -233,15 +278,16 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
logit_grad_mat * // element_wise multiply logit_grad_mat * // element_wise multiply
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim)); out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, axis_dim));
const int64_t* label_data = labels->data<int64_t>(); const auto* label_data = labels.template data<LabelT>();
T* logit_grad_data = logit_grad->data<T>(); T* logit_grad_data = logit_grad->template data<T>();
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->template data<T>();
const int remain = d / axis_dim; const int remain = d / axis_dim;
for (int i = 0; i < n; ++i) { // for each sample_1_dim for (int i = 0; i < n; ++i) { // for each sample_1_dim
for (int j = 0; j < remain; j++) { // for each sample_other_dims for (int j = 0; j < remain; j++) { // for each sample_other_dims
int idx = i * remain + j; // this sample's label_idx. for 1d case, int idx = i * remain + j; // this sample's label_idx. for 1d case,
// remain=1 and j=0, so, idx = i // remain=1 and j=0, so, idx = i
if (label_data[idx] == ignore_index) { auto lbl = static_cast<int64_t>(label_data[idx]);
if (lbl == ignore_index) {
for (int k = 0; k < axis_dim; ++k) { // for each class id's label for (int k = 0; k < axis_dim; ++k) { // for each class id's label
logit_grad_data[i * d + k * remain + j] = 0; logit_grad_data[i * d + k * remain + j] = 0;
} }
...@@ -258,8 +304,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> { ...@@ -258,8 +304,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
// out_grad_data[idx] // out_grad_data[idx]
// means: dy/dp * dy= ( p - y ) * dy // means: dy/dp * dy= ( p - y ) * dy
logit_grad_data[i * d + label_data[idx] * remain + j] -= logit_grad_data[i * d + lbl * remain + j] -= out_grad_data[idx];
out_grad_data[idx];
} }
} }
} }
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
from op_test import OpTest from op_test import OpTest
...@@ -58,6 +59,9 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -58,6 +59,9 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self.shape = [41, 37] self.shape = [41, 37]
self.use_softmax = True self.use_softmax = True
def hard_label_dtype(self):
return "int64"
def setUp(self): def setUp(self):
self.initParams() self.initParams()
...@@ -72,7 +76,8 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -72,7 +76,8 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
else: else:
axis_dim = self.shape[self.axis] axis_dim = self.shape[self.axis]
self.shape[self.axis] = 1 self.shape[self.axis] = 1
labels = np.random.randint(0, axis_dim, self.shape, dtype="int64") labels = np.random.randint(
0, axis_dim, self.shape, dtype=self.hard_label_dtype())
loss = cross_entropy(softmax, labels, self.soft_label, self.axis, loss = cross_entropy(softmax, labels, self.soft_label, self.axis,
self.ignore_index) self.ignore_index)
...@@ -107,6 +112,26 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): ...@@ -107,6 +112,26 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self.check_grad(["Logits"], "Loss", numeric_grad_delta=0.001) self.check_grad(["Logits"], "Loss", numeric_grad_delta=0.001)
class TestSoftmaxWithCrossEntropyOpInt32(TestSoftmaxWithCrossEntropyOp):
def hard_label_dtype(self):
return "int32"
class TestSoftmaxWithCrossEntropyOpInt16(TestSoftmaxWithCrossEntropyOp):
def hard_label_dtype(self):
return "int16"
class TestSoftmaxWithCrossEntropyOpInt8(TestSoftmaxWithCrossEntropyOp):
def hard_label_dtype(self):
return "int8"
class TestSoftmaxWithCrossEntropyOpUInt8(TestSoftmaxWithCrossEntropyOp):
def hard_label_dtype(self):
return "uint8"
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D( class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D(
TestSoftmaxWithCrossEntropyOp): TestSoftmaxWithCrossEntropyOp):
def initParams(self): def initParams(self):
...@@ -711,4 +736,5 @@ class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp): ...@@ -711,4 +736,5 @@ class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp):
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
...@@ -1783,7 +1783,8 @@ def cross_entropy(input, ...@@ -1783,7 +1783,8 @@ def cross_entropy(input,
fluid.data_feeder.check_variable_and_dtype( fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'softmax_cross_entropy') input, 'input', ['float32', 'float64'], 'softmax_cross_entropy')
fluid.data_feeder.check_variable_and_dtype( fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['int32', 'int64', 'float32', 'float64'], label, 'label',
['uint8', 'int8', 'int16', 'int32', 'int64', 'float32', 'float64'],
'softmax_cross_entropy') 'softmax_cross_entropy')
attrs = { attrs = {
'soft_label': soft_label, 'soft_label': soft_label,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册