diff --git a/paddle/fluid/operators/math/cross_entropy.cc b/paddle/fluid/operators/math/cross_entropy.cc index 23840143a443da106349214966aee78f50b25088..ec78b2dd7b01295d369a39e8a6e55ba05e638138 100644 --- a/paddle/fluid/operators/math/cross_entropy.cc +++ b/paddle/fluid/operators/math/cross_entropy.cc @@ -29,6 +29,65 @@ template using EigenMatrix = framework::EigenMatrix; +template +struct HardLabelCrossEntropyCPUFunctorImpl { + HardLabelCrossEntropyCPUFunctorImpl(framework::Tensor* out, + const framework::Tensor* prob, + const framework::Tensor* labels, + const int ignore_index, + const int axis_dim) + : out_(out), + prob_(prob), + labels_(labels), + ignore_index_(ignore_index), + axis_dim_(axis_dim) {} + + template + void apply() const { + const int batch_size = prob_->dims()[0]; + const int num_classes = prob_->dims()[1]; + const int num_remain = num_classes / axis_dim_; + + const T* prob_data = prob_->template data(); + T* loss_data = out_->template data(); + + const auto* label_data = labels_->template data(); + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < num_remain; j++) { + int lbl = static_cast(label_data[i * num_remain + j]); + if (lbl != ignore_index_) { + PADDLE_ENFORCE_GE(lbl, 0, + platform::errors::OutOfRange( + "label value should >= 0 when label " + "value(%f) not equal to ignore_index(%f)", + lbl, ignore_index_)); + PADDLE_ENFORCE_LT( + lbl, axis_dim_, + platform::errors::OutOfRange( + "label value should less than the shape of axis dimension " + "when label value(%f) not equal to ignore_index(%f), But " + "received label value as %ld and shape of axis dimension " + "is %d", + lbl, ignore_index_, lbl, axis_dim_)); + } + int index = i * num_classes + lbl * num_remain + j; + int loss_idx = i * num_remain + j; + loss_data[loss_idx] = + lbl == ignore_index_ + ? 0 + : -math::TolerableValue()(std::log(prob_data[index])); + } + } + } + + private: + framework::Tensor* out_; + const framework::Tensor* prob_; + const framework::Tensor* labels_; + const int ignore_index_; + const int axis_dim_; +}; + template class CrossEntropyFunctor { public: @@ -36,13 +95,12 @@ class CrossEntropyFunctor { const framework::Tensor* prob, const framework::Tensor* labels, const bool softLabel, const int ignore_index, const int axis_dim) { - const int batch_size = prob->dims()[0]; - const int num_classes = prob->dims()[1]; - const int num_remain = num_classes / axis_dim; - - Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); - if (softLabel) { + const int batch_size = prob->dims()[0]; + const int num_classes = prob->dims()[1]; + const int num_remain = num_classes / axis_dim; + + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); auto in = EigenMatrix::From(*prob); auto lbl = EigenMatrix::From(*labels); auto loss = EigenMatrix::From(*out); @@ -52,36 +110,9 @@ class CrossEntropyFunctor { .reshape(batch_axis_remain) .sum(Eigen::DSizes(1))); } else { - const T* prob_data = prob->data(); - T* loss_data = out->data(); - - const int64_t* label_data = labels->data(); - for (int i = 0; i < batch_size; ++i) { - for (int j = 0; j < num_remain; j++) { - int lbl = label_data[i * num_remain + j]; - if (lbl != ignore_index) { - PADDLE_ENFORCE_GE(lbl, 0, - platform::errors::OutOfRange( - "label value should >= 0 when label " - "value(%f) not equal to ignore_index(%f)", - lbl, ignore_index)); - PADDLE_ENFORCE_LT( - lbl, axis_dim, - platform::errors::OutOfRange( - "label value should less than the shape of axis dimension " - "when label value(%f) not equal to ignore_index(%f), But " - "received label value as %ld and shape of axis dimension " - "is %d", - lbl, ignore_index, lbl, axis_dim)); - } - int index = i * num_classes + lbl * num_remain + j; - int loss_idx = i * num_remain + j; - loss_data[loss_idx] = - lbl == ignore_index - ? 0 - : -math::TolerableValue()(std::log(prob_data[index])); - } - } + HardLabelCrossEntropyCPUFunctorImpl functor_impl( + out, prob, labels, ignore_index, axis_dim); + framework::VisitIntDataType(labels->type(), functor_impl); } } }; diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index 3e80e40f3577c3bf0b8fe4462665e11b85ea3ca7..a13adb63fdc54bc1c17992a34b097e8e943f60f0 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -21,18 +21,19 @@ namespace paddle { namespace operators { namespace math { -template -__global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, +template +__global__ void CrossEntropyKernel(T* Y, const T* X, const LabelT* label, const int N, const int D, const int ignore_index) { CUDA_KERNEL_LOOP(i, N) { - PADDLE_ENFORCE(label[i] >= 0 && label[i] < D || label[i] == ignore_index, + auto lbl = static_cast(label[i]); + PADDLE_ENFORCE(lbl >= 0 && lbl < D || lbl == ignore_index, "The value of label[%d] expected >= 0 and < %ld, or == %ld, " "but got %ld. Please check input value.", - i, D, ignore_index, label[i]); - Y[i] = ignore_index == label[i] + i, D, ignore_index, lbl); + Y[i] = ignore_index == lbl ? static_cast(0) - : -math::TolerableValue()(real_log(X[i * D + label[i]])); + : -math::TolerableValue()(real_log(X[i * D + lbl])); } } @@ -54,6 +55,43 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label, } } +template +struct HardLabelCrossEntropyCUDAFunctorImpl { + public: + HardLabelCrossEntropyCUDAFunctorImpl(T* loss_data, const T* prob_data, + const void* label_data, + const int batch_size, + const int class_num, + const int ignore_index, + const int block_size, gpuStream_t stream) + : loss_data_(loss_data), + prob_data_(prob_data), + label_data_(label_data), + batch_size_(batch_size), + class_num_(class_num), + ignore_index_(ignore_index), + block_size_(block_size), + stream_(stream) {} + + template + void apply() const { + int grid_size = (batch_size_ + block_size_ - 1) / block_size_; + CrossEntropyKernel<<>>( + loss_data_, prob_data_, static_cast(label_data_), batch_size_, + class_num_, ignore_index_); + } + + private: + T* loss_data_; + const T* prob_data_; + const void* label_data_; + const int batch_size_; + const int class_num_; + const int ignore_index_; + const int block_size_; + gpuStream_t stream_; +}; + template class CrossEntropyFunctor { public: @@ -81,12 +119,10 @@ class CrossEntropyFunctor { SoftCrossEntropyKernel<<>>( loss_data, prob_data, label_data, class_num); } else { - const int64_t* label_data = labels->data(); - int block = kMaxBlockDim; - int grid = (batch_size + block - 1) / block; - CrossEntropyKernel<<>>( - loss_data, prob_data, label_data, batch_size, class_num, - ignore_index); + HardLabelCrossEntropyCUDAFunctorImpl functor( + loss_data, prob_data, labels->data(), batch_size, class_num, + ignore_index, kMaxBlockDim, ctx.stream()); + framework::VisitDataType(labels->type(), functor); } } }; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index b6f89c7af4a31352e7de0f03ac824e73be566300..fe025641330c36db32162cae614ac40098bf7bd7 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -59,9 +59,9 @@ enum class SoftmaxMode { kSoftmax, kLogSoftmax, kCrossEntropy }; /* Hard label cross entropy. */ -template +template __global__ void CrossEntropyHardLabel(T* loss, const T* softmax, - const int64_t* labels, const int n, + const LabelT* labels, const int n, const int dim, const int d, const int ignore_idx) { int64_t ids = blockIdx.x * blockDim.x + threadIdx.x; @@ -70,13 +70,14 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax, // thread ids compute loss[ids] using softmax[idx] if (ids < n * d) { - if (labels[ids] < 0) { // label is negative + auto lbl = static_cast(labels[ids]); + if (lbl < 0) { // label is negative loss[ids] = static_cast(0.0); } else { // label is positive of zero - int64_t idx = idx_n * dim * d + labels[ids] * d + idx_d; + int64_t idx = idx_n * dim * d + lbl * d + idx_d; if (IgnoreIndex == true) { // IgnoreIndex is true - if (labels[ids] == ignore_idx) { + if (lbl == ignore_idx) { loss[ids] = static_cast(0.0); } else { loss[ids] = -Log(softmax[idx]); @@ -94,9 +95,9 @@ __global__ void CrossEntropyHardLabel(T* loss, const T* softmax, Input: log softmax Output: loss and exp(input) */ -template +template __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax, - const int64_t* labels, const int n, + const LabelT* labels, const int n, const int dim, const int d, const int ignore_idx) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -106,10 +107,11 @@ __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax, int64_t ids = idx_n * d + idx_d; if (idx < n * dim * d) { + auto lbl = static_cast(labels[ids]); if (IgnoreIndex == true) { // IgnoreIndex is true - if (idx_dim == labels[ids]) { - if (labels[ids] == ignore_idx) { + if (idx_dim == lbl) { + if (lbl == ignore_idx) { loss[ids] = static_cast(0.0); } else { loss[ids] = -softmax[idx]; @@ -117,8 +119,8 @@ __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax, } } else { // IgnoreIndex is false - if (labels[ids] >= 0 && labels[ids] < dim) { - if (labels[ids] == idx_dim) { + if (lbl >= 0 && lbl < dim) { + if (lbl == idx_dim) { loss[ids] = -softmax[idx]; } } else { @@ -151,10 +153,10 @@ __global__ void CrossEntropyExpHardLabel(T* loss, T* softmax, For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle api to compute max (sum) in one warp. */ -template +template __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, - const int64_t* label, const int batch_size, + const LabelT* label, const int batch_size, const int stride, const int element_count, const int ignore_index) { constexpr int kDimCeil = 1 << Log2Elements; @@ -299,10 +301,11 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, softmax[(first_batch + i) * stride + idx] = std::exp(logsoftmax); // label int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize; + auto lbl = static_cast(label[first_batch + i]); if (IgnoreIndex == true) { // IgnoreIndex is true - if (label[first_batch + i] == loss_idx) { - if (label[first_batch + i] != ignore_index) { + if (lbl == loss_idx) { + if (lbl != ignore_index) { loss[first_batch + i] = -logsoftmax; } else { loss[first_batch + i] = static_cast(0.0); @@ -310,9 +313,8 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, } } else { // IgnoreIndex is false - if (label[first_batch + i] >= 0 && - label[first_batch + i] < element_count) { - if (label[first_batch + i] == loss_idx) { + if (lbl >= 0 && lbl < element_count) { + if (lbl == loss_idx) { loss[first_batch + i] = -logsoftmax; } } else { @@ -342,17 +344,16 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, tmpptr[s] = std::exp(logsoftmax); // label int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s; + auto lbl = static_cast(label[first_batch + i]); if (IgnoreIndex == true) { // IgnoreIndex is true - if (label[first_batch + i] == loss_idx && - label[first_batch + i] != ignore_index) { + if (lbl == loss_idx && lbl != ignore_index) { loss[first_batch + i] = -logsoftmax; } } else { // IgnoreIndex is false - if (label[first_batch + i] >= 0 && - label[first_batch + i] < element_count) { - if (label[first_batch + i] == loss_idx) { + if (lbl >= 0 && lbl < element_count) { + if (lbl == loss_idx) { loss[first_batch + i] = -logsoftmax; } } else { @@ -373,9 +374,9 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, } } -#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, VecT, AccT) \ +#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \ case Log2Elements: \ - WarpSoftmaxForward<<>>( \ loss, softmax, src, label, batch_size, stride, element_count, \ ignore_index); \ @@ -384,9 +385,9 @@ __global__ void WarpSoftmaxForward(T* loss, T* softmax, const T* src, /* Wrapper of softmax with cross entropy forward hard label. */ -template +template void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src, - const int64_t* label, const int batch_size, + const LabelT* label, const int batch_size, const int stride, const int element_count, const int ignore_index, gpuStream_t stream) { using AccT = typename details::MPTypeTrait::Type; @@ -403,16 +404,16 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src, dim3 threads(kWarpSize, warps_per_block, 1); switch (log2_elements) { - SOFTMAX_WARP_FORWARD_CASE(0, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(1, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(2, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(3, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(4, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(5, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(6, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(7, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(8, T, AccT); - SOFTMAX_WARP_FORWARD_CASE(9, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(0, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(1, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(2, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(3, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(4, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(5, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(6, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(7, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(8, LabelT, T, AccT); + SOFTMAX_WARP_FORWARD_CASE(9, LabelT, T, AccT); default: break; } @@ -423,16 +424,16 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src, - SwitchWarpSoftmaxForward for small size - cudnn function for large size */ -template +template static void SoftmaxWithCrossEntropyHardLabel( const platform::CUDADeviceContext& ctx, int rank, int axis, - const T* logits_data, const int64_t* labels_data, T* loss_data, + const T* logits_data, const LabelT* labels_data, T* loss_data, T* softmax_data, int N, int dim, int D, const int ignore_index) { auto stream = ctx.stream(); constexpr int max_dim = 320; if (D == 1 && dim <= max_dim) { // small size const SoftmaxMode mode = SoftmaxMode::kCrossEntropy; - SwitchWarpSoftmaxForward( + SwitchWarpSoftmaxForward( loss_data, softmax_data, logits_data, labels_data, N, dim, dim, ignore_index, stream); } else { @@ -465,7 +466,8 @@ static void SoftmaxWithCrossEntropyHardLabel( int threads = 128; int blocks = (N * dim * D + threads - 1) / threads; // compute cross entropy, input is log softmax - CrossEntropyExpHardLabel<<>>( + CrossEntropyExpHardLabel<<>>( loss_data, softmax_data, labels_data, N, dim, D, ignore_index); } } @@ -473,9 +475,9 @@ static void SoftmaxWithCrossEntropyHardLabel( /* Wrapper of softmax with cross entropy grad hard label. */ -template +template __global__ void SoftmaxWithCrossEntropyGradHardLabel( - T* logits_grad, const T* loss_grad, const int64_t* labels, const int64_t n, + T* logits_grad, const T* loss_grad, const LabelT* labels, const int64_t n, const int64_t dim, const int64_t d, const int ignore_index) { int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; int64_t idx_n = idx / (d * dim); @@ -484,9 +486,10 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabel( int64_t ids = idx_n * d + idx_d; if (idx < n * dim * d) { - if (labels[ids] == ignore_index) { + auto lbl = static_cast(labels[ids]); + if (lbl == ignore_index) { logits_grad[idx] = static_cast(0.0); - } else if (labels[ids] == idx_dim) { + } else if (lbl == idx_dim) { logits_grad[idx] = (logits_grad[idx] - static_cast(1.0)) * loss_grad[ids]; } else { @@ -887,16 +890,16 @@ __global__ void SoftLabelCrossEntropyGradientKernel(T* logit_grad, } } -template +template __global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad, - const int64_t* labels, + const LabelT* labels, const int n, const int d, const int remain, const int ignore_index) { CUDA_KERNEL_LOOP(index, n * remain) { int idx_n = index / remain; int idx_remain = index % remain; - int tmp = labels[index]; + int tmp = static_cast(labels[index]); int idx = idx_n * d + tmp * remain + idx_remain; if (ignore_index != tmp) { logit_grad[idx] = -static_cast(1.) / logit_grad[idx]; @@ -904,18 +907,19 @@ __global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad, } } -template +template __global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad, const int num, const int d, const int remain, - const int64_t* labels, + const LabelT* labels, const int ignore_index) { CUDA_KERNEL_LOOP(index, num) { int idx_n = index / d; int idx_remain = index % remain; int idx_lbl = idx_n * remain + idx_remain; int k = (index % d) / remain; - if (labels[idx_lbl] == ignore_index || labels[idx_lbl] != k) { + auto lbl = static_cast(labels[idx_lbl]); + if (lbl == ignore_index || lbl != k) { logit_grad[index] = static_cast(0.); } else { logit_grad[index] *= loss_grad[idx_lbl]; @@ -927,6 +931,12 @@ template class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + RunSoftmaxWithCrossEntropyFunctor(context, *this); + } + + template + static void Apply(const framework::ExecutionContext& context, + const framework::Tensor& labels, const bool soft_label) { PADDLE_ENFORCE_EQ( platform::is_gpu_place(context.GetPlace()), true, platform::errors::Unavailable("softmax_with_cross_entropy operator's " @@ -936,7 +946,6 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { // do not with softmax op, and input is softmax if (!use_softmax) { const Tensor* softmax = context.Input("Logits"); - const Tensor* labels = context.Input("Label"); Tensor* softmax_out = context.Output("Softmax"); Tensor* loss = context.Output("Loss"); @@ -947,8 +956,9 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { const int n = SizeToAxis(axis, softmax->dims()); const int d = SizeFromAxis(axis, softmax->dims()); - auto* softmax_out_data = softmax_out->mutable_data(context.GetPlace()); - auto* loss_data = loss->mutable_data(context.GetPlace()); + auto* softmax_out_data = + softmax_out->template mutable_data(context.GetPlace()); + auto* loss_data = loss->template mutable_data(context.GetPlace()); math::SetConstant set_constant; set_constant(context.cuda_device_context(), loss, static_cast(0)); @@ -958,12 +968,11 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { return; } - auto soft_label = context.Attr("soft_label"); auto ignore_index = context.Attr("ignore_index"); Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d; softmax_2d.ShareDataWith(*softmax).Resize({n, d}); - labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); + labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n}); loss_2d.ShareDataWith(*loss).Resize({n, 1}); softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d}); @@ -977,8 +986,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { // if axis is not the last, we need a new impliment if (soft_label) { - auto* logits_data = softmax->data(); - auto* labels_data = labels->data(); + auto* logits_data = softmax->template data(); + auto* labels_data = labels.template data(); const int kDimLog2 = static_cast(Log2Ceil(axis_dim)); const int kDimCeil = 1 << kDimLog2; @@ -996,17 +1005,17 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { loss_data, NULL, logits_data, labels_data, n, axis_dim, d / axis_dim, kDimLog2); } else { // HardLabel - auto* logits_data = softmax->data(); - auto* labels_data = labels->data(); + auto* logits_data = softmax->template data(); + auto* labels_data = labels.template data(); int threads = 128; int blocks = (n * d / axis_dim + threads - 1) / threads; if (ignore_index >= 0 && ignore_index < axis_dim) { - CrossEntropyHardLabel<<< + CrossEntropyHardLabel<<< blocks, threads, 0, context.cuda_device_context().stream()>>>( loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim, ignore_index); } else { - CrossEntropyHardLabel<<< + CrossEntropyHardLabel<<< blocks, threads, 0, context.cuda_device_context().stream()>>>( loss_data, logits_data, labels_data, n, axis_dim, d / axis_dim, ignore_index); @@ -1022,7 +1031,6 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { } const Tensor* logits = context.Input("Logits"); - const Tensor* labels = context.Input("Label"); Tensor* softmax = context.Output("Softmax"); Tensor* loss = context.Output("Loss"); @@ -1033,8 +1041,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { const int64_t n = SizeToAxis(axis, logits->dims()); const int64_t d = SizeFromAxis(axis, logits->dims()); - auto* softmax_data = softmax->mutable_data(context.GetPlace()); - auto* loss_data = loss->mutable_data(context.GetPlace()); + auto* softmax_data = softmax->template mutable_data(context.GetPlace()); + auto* loss_data = loss->template mutable_data(context.GetPlace()); if (axis_dim == 1) { math::SetConstant set_constant; @@ -1043,12 +1051,11 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { return; } - auto soft_label = context.Attr("soft_label"); auto ignore_index = context.Attr("ignore_index"); if (soft_label) { - auto* logits_data = logits->data(); - auto* labels_data = labels->data(); + auto* logits_data = logits->template data(); + auto* labels_data = labels.template data(); SoftmaxWithCrossEntropySoftLabel( context.cuda_device_context(), rank, axis, logits_data, labels_data, softmax_data, loss_data, n, axis_dim, d / axis_dim); @@ -1058,7 +1065,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { Tensor logits_2d, softmax_2d, labels_2d, loss_2d; logits_2d.ShareDataWith(*logits).Resize({n, d}); softmax_2d.ShareDataWith(*softmax).Resize({n, d}); - labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); + labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n}); loss_2d.ShareDataWith(*loss).Resize({n, 1}); math::SoftmaxCUDNNFunctor()(context.cuda_device_context(), &logits_2d, &softmax_2d); @@ -1066,15 +1073,15 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d, false, ignore_index, axis_dim); } else { - auto* logits_data = logits->data(); - auto* labels_data = labels->data(); + auto* logits_data = logits->template data(); + auto* labels_data = labels.template data(); if (ignore_index >= 0 && ignore_index < axis_dim) { - SoftmaxWithCrossEntropyHardLabel( + SoftmaxWithCrossEntropyHardLabel( context.cuda_device_context(), rank, axis, logits_data, labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim, ignore_index); } else { - SoftmaxWithCrossEntropyHardLabel( + SoftmaxWithCrossEntropyHardLabel( context.cuda_device_context(), rank, axis, logits_data, labels_data, loss_data, softmax_data, n, axis_dim, d / axis_dim, ignore_index); @@ -1088,13 +1095,19 @@ template class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + RunSoftmaxWithCrossEntropyFunctor(context, *this); + } + + template + static void Apply(const framework::ExecutionContext& context, + const framework::Tensor& labels, const bool soft_label) { PADDLE_ENFORCE_EQ( platform::is_gpu_place(context.GetPlace()), true, platform::errors::Unavailable("softmax_with_cross_entropy operator's " "CUDA kernel only runs on GPU device.")); - const Tensor* labels = context.Input("Label"); const T* loss_grad_data = - context.Input(framework::GradVarName("Loss"))->data(); + context.Input(framework::GradVarName("Loss")) + ->template data(); Tensor* logit_grad = context.Output(framework::GradVarName("Logits")); const Tensor* softmax = context.Input("Softmax"); @@ -1102,7 +1115,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { framework::TensorCopy(*softmax, context.GetPlace(), context.device_context(), logit_grad); } - T* logit_grad_data = logit_grad->data(); + T* logit_grad_data = logit_grad->template data(); const int rank = logit_grad->dims().size(); const int axis = CanonicalAxis(context.Attr("axis"), rank); @@ -1123,21 +1136,22 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { // do not with softmax op, and input is softmax if (!use_softmax) { - if (context.Attr("soft_label")) { + if (soft_label) { int grid = (n * d + block - 1) / block; - const T* label_data = labels->data(); + const T* label_data = labels.template data(); SoftLabelCrossEntropyGradientKernel<<>>( logit_grad_data, loss_grad_data, label_data, n, d, remain); } else { Tensor logits_grad_2d; logits_grad_2d.ShareDataWith(*logit_grad).Resize({n, d}); int grid = (n * remain + block - 1) / block; - const int64_t* label_data = labels->data(); - HardLabelCrossEntropyGradientKernel<<>>( + const auto* label_data = labels.template data(); + HardLabelCrossEntropyGradientKernel<<>>( logit_grad_data, label_data, n, d, remain, ignore_index); int num = n * d; grid = (num + block - 1) / block; - ScaleCrossEntropyGradient<<>>( + ScaleCrossEntropyGradient<<>>( logit_grad_data, loss_grad_data, num, d, remain, label_data, ignore_index); } @@ -1147,13 +1161,13 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { // with softmax, continue - if (context.Attr("soft_label")) { + if (soft_label) { int64_t grid = (n * d + block - 1) / block; - const T* label_data = labels->data(); + const T* label_data = labels.template data(); SoftCrossEntropyGradientKernel<<>>( logit_grad_data, loss_grad_data, label_data, n, d, remain); } else { - const int64_t* label_data = labels->data(); + const auto* label_data = labels.template data(); int grid = (n * d + block - 1) / block; SoftmaxWithCrossEntropyGradHardLabel<<>>( logit_grad_data, loss_grad_data, label_data, n, d / remain, remain, diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/softmax_with_cross_entropy_op.h index ab1474954e3dd213c9513526b5d40cb81b4ab942..ef8b0bbdf9d3236b02f04a23fc8d962355d1208a 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.h @@ -24,6 +24,48 @@ namespace operators { using Tensor = framework::Tensor; +template +struct SoftmaxWithCrossEntropyFunctor { + public: + SoftmaxWithCrossEntropyFunctor(const framework::ExecutionContext& context, + const framework::Tensor& labels, + const bool soft_label, const Visitor& visitor) + : context_(context), + labels_(labels), + soft_label_(soft_label), + visitor_(visitor) {} + + template + void apply() const { + visitor_.template Apply(context_, labels_, soft_label_); + } + + private: + const framework::ExecutionContext& context_; + const framework::Tensor& labels_; + const bool soft_label_; + const Visitor& visitor_; +}; + +template +static void RunSoftmaxWithCrossEntropyFunctor( + const framework::ExecutionContext& context, const Visitor& visitor) { + const auto* labels = context.Input("Label"); + const bool soft_label = context.Attr("soft_label"); + SoftmaxWithCrossEntropyFunctor functor(context, *labels, + soft_label, visitor); + auto dtype = labels->type(); + if (soft_label) { + PADDLE_ENFORCE_EQ( + dtype, framework::DataTypeTrait::DataType(), + platform::errors::InvalidArgument("The Input(Label) should be with the " + "same data type as Input(Logits).")); + functor.template apply(); + } else { + framework::VisitIntDataType(dtype, functor); + } +} + template class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { public: @@ -32,14 +74,14 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { platform::is_cpu_place(context.GetPlace()), true, platform::errors::Unimplemented("This kernel only runs on CPU.")); const bool use_softmax = context.Attr("use_softmax"); + const Tensor* labels = context.Input("Label"); + const bool soft_label = context.Attr("soft_label"); // do not with softmax op, and input is softmax if (!use_softmax) { const Tensor* softmax = context.Input("Logits"); - const Tensor* labels = context.Input("Label"); Tensor* softmax_out = context.Output("Softmax"); Tensor* loss = context.Output("Loss"); - const bool soft_label = context.Attr("soft_label"); const int rank = softmax->dims().size(); const int axis = CanonicalAxis(context.Attr("axis"), rank); int axis_dim = softmax->dims()[axis]; @@ -86,10 +128,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { } const Tensor* logits = context.Input("Logits"); - const Tensor* labels = context.Input("Label"); Tensor* softmax = context.Output("Softmax"); Tensor* loss = context.Output("Loss"); - const bool soft_label = context.Attr("soft_label"); const int rank = logits->dims().size(); const int axis = CanonicalAxis(context.Attr("axis"), rank); @@ -132,9 +172,14 @@ template class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + RunSoftmaxWithCrossEntropyFunctor(context, *this); + } + + template + static void Apply(const framework::ExecutionContext& context, + const framework::Tensor& labels, const bool soft_label) { const Tensor* out_grad = context.Input(framework::GradVarName("Loss")); - const Tensor* labels = context.Input("Label"); Tensor* logit_grad = context.Output(framework::GradVarName("Logits")); const Tensor* softmax = context.Input("Softmax"); @@ -143,7 +188,6 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { framework::TensorCopy(*softmax, context.GetPlace(), context.device_context(), logit_grad); } - const bool soft_label = context.Attr("soft_label"); auto ignore_index = context.Attr("ignore_index"); const int rank = logit_grad->dims().size(); @@ -166,7 +210,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { const int d = SizeFromAxis(axis, logit_grad->dims()); Tensor logit_grad_2d, labels_2d, out_grad_2d; logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d}); - labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); + labels_2d.ShareDataWith(labels).Resize({n, labels.numel() / n}); out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim}); auto out_grad_mat = framework::EigenMatrix::From(out_grad_2d); auto logit_grad_mat = framework::EigenMatrix::From(logit_grad_2d); @@ -183,23 +227,24 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { logit_grad_mat; } else { // use_softmax step2 - const int64_t* label_data = labels->data(); - T* logit_grad_data = logit_grad->data(); - const T* out_grad_data = out_grad->data(); + const auto* label_data = labels.template data(); + T* logit_grad_data = logit_grad->template data(); + const T* out_grad_data = out_grad->template data(); const int remain = d / axis_dim; for (int i = 0; i < n; ++i) { // for each sample_1_dim for (int j = 0; j < remain; j++) { // for each sample_other_dims int idx = i * remain + j; // this sample's label_idx. for 1d case, // remain=1 and j=0, so, idx = i - if (label_data[idx] == ignore_index) { + auto lbl = static_cast(label_data[idx]); + if (lbl == ignore_index) { for (int k = 0; k < axis_dim; ++k) { // for each class id's label logit_grad_data[i * d + k * remain + j] = 0; } } else { // only for this sample's label_idx, the label is 1, others is 0, // so, only compute this label_idx's class - logit_grad_data[i * d + label_data[idx] * remain + j] = - (-1 / logit_grad_data[i * d + label_data[idx] * remain + j]) * + logit_grad_data[i * d + lbl * remain + j] = + (-1 / logit_grad_data[i * d + lbl * remain + j]) * out_grad_data[idx]; for (int k = 0; k < axis_dim; ++k) { // for each class id's label if (k != @@ -233,15 +278,16 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { logit_grad_mat * // element_wise multiply out_grad_mat.broadcast(Eigen::DSizes(1, axis_dim)); - const int64_t* label_data = labels->data(); - T* logit_grad_data = logit_grad->data(); - const T* out_grad_data = out_grad->data(); + const auto* label_data = labels.template data(); + T* logit_grad_data = logit_grad->template data(); + const T* out_grad_data = out_grad->template data(); const int remain = d / axis_dim; for (int i = 0; i < n; ++i) { // for each sample_1_dim for (int j = 0; j < remain; j++) { // for each sample_other_dims int idx = i * remain + j; // this sample's label_idx. for 1d case, // remain=1 and j=0, so, idx = i - if (label_data[idx] == ignore_index) { + auto lbl = static_cast(label_data[idx]); + if (lbl == ignore_index) { for (int k = 0; k < axis_dim; ++k) { // for each class id's label logit_grad_data[i * d + k * remain + j] = 0; } @@ -258,8 +304,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { // out_grad_data[idx] // means: dy/dp * dy= ( p - y ) * dy - logit_grad_data[i * d + label_data[idx] * remain + j] -= - out_grad_data[idx]; + logit_grad_data[i * d + lbl * remain + j] -= out_grad_data[idx]; } } } diff --git a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py index e754999d5d2055dccd0ae7b565f1aa140309bcb6..69f6a87dd9ed18909741fa6da0bd9acb6b6dd831 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest import numpy as np +import paddle import paddle.fluid.core as core from op_test import OpTest @@ -58,6 +59,9 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): self.shape = [41, 37] self.use_softmax = True + def hard_label_dtype(self): + return "int64" + def setUp(self): self.initParams() @@ -72,7 +76,8 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): else: axis_dim = self.shape[self.axis] self.shape[self.axis] = 1 - labels = np.random.randint(0, axis_dim, self.shape, dtype="int64") + labels = np.random.randint( + 0, axis_dim, self.shape, dtype=self.hard_label_dtype()) loss = cross_entropy(softmax, labels, self.soft_label, self.axis, self.ignore_index) @@ -107,6 +112,26 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): self.check_grad(["Logits"], "Loss", numeric_grad_delta=0.001) +class TestSoftmaxWithCrossEntropyOpInt32(TestSoftmaxWithCrossEntropyOp): + def hard_label_dtype(self): + return "int32" + + +class TestSoftmaxWithCrossEntropyOpInt16(TestSoftmaxWithCrossEntropyOp): + def hard_label_dtype(self): + return "int16" + + +class TestSoftmaxWithCrossEntropyOpInt8(TestSoftmaxWithCrossEntropyOp): + def hard_label_dtype(self): + return "int8" + + +class TestSoftmaxWithCrossEntropyOpUInt8(TestSoftmaxWithCrossEntropyOp): + def hard_label_dtype(self): + return "uint8" + + class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D( TestSoftmaxWithCrossEntropyOp): def initParams(self): @@ -711,4 +736,5 @@ class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp): if __name__ == "__main__": + paddle.enable_static() unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 90ada8c3c5ee63455d1248af7083c925bfa9d3a2..711fd1e94cae9eff403de685f152d05a8fb52a31 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1783,7 +1783,8 @@ def cross_entropy(input, fluid.data_feeder.check_variable_and_dtype( input, 'input', ['float32', 'float64'], 'softmax_cross_entropy') fluid.data_feeder.check_variable_and_dtype( - label, 'label', ['int32', 'int64', 'float32', 'float64'], + label, 'label', + ['uint8', 'int8', 'int16', 'int32', 'int64', 'float32', 'float64'], 'softmax_cross_entropy') attrs = { 'soft_label': soft_label,