diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index 946ede475ce68447db05f2ecd2bd624e90881376..5acebd7525af354af351841b04a60f2db8579797 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/op_version_registry.h" namespace paddle { namespace operators { @@ -53,6 +54,10 @@ class SoftmaxWithCrossEntropyOpMaker "(bool, default: false), A flag to indicate whether to interpretant " "the given labels as soft labels.") .SetDefault(false); + AddAttr( + "use_softmax", + "(bool, default: true), A flag to indicate whether to do softmax ") + .SetDefault(true); AddAttr( "numeric_stable_mode", "(bool, default: true), A flag to indicate whether to use more " @@ -312,3 +317,10 @@ REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy, REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad, ops::SoftmaxWithCrossEntropyGradKernel, ops::SoftmaxWithCrossEntropyGradKernel); + +REGISTER_OP_VERSION(softmax_with_cross_entropy) + .AddCheckpoint( + R"ROC( + Add a new attribute [use_softmax] )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "use_softmax", "A flag to indicate whether to do softmax", true)); diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index f3e7a33d9b1ab21cf2dff35f3804e32ea0994b17..9c1ad589ff373fa26278882eb73bb95afbf2d269 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -66,6 +66,57 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad, } } +template +__global__ void SoftLabelCrossEntropyGradientKernel(T* logit_grad, + const T* loss_grad, + const T* labels, + const int n, const int d, + const int remain) { + int ids = blockIdx.x * blockDim.x + threadIdx.x; + if (ids < n * d) { + int idx_n = ids / d; + int idx_remain = ids % remain; + int idx_loss = idx_n * remain + idx_remain; + logit_grad[ids] = loss_grad[idx_loss] * (-labels[ids] / logit_grad[ids]); + } +} + +template +__global__ void HardLabelCrossEntropyGradientKernel(T* logit_grad, + const int64_t* labels, + const int n, const int d, + const int remain, + const int ignore_index) { + CUDA_KERNEL_LOOP(index, n * remain) { + int idx_n = index / remain; + int idx_remain = index % remain; + int tmp = labels[index]; + int idx = idx_n * d + tmp * remain + idx_remain; + if (ignore_index != tmp) { + logit_grad[idx] = -static_cast(1.) / logit_grad[idx]; + } + } +} + +template +__global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad, + const int num, const int d, + const int remain, + const int64_t* labels, + const int ignore_index) { + CUDA_KERNEL_LOOP(index, num) { + int idx_n = index / d; + int idx_remain = index % remain; + int idx_lbl = idx_n * remain + idx_remain; + int k = (index % d) / remain; + if (labels[idx_lbl] == ignore_index || labels[idx_lbl] != k) { + logit_grad[index] = static_cast(0.); + } else { + logit_grad[index] *= loss_grad[idx_lbl]; + } + } +} + } // namespace static __device__ __forceinline__ platform::float16 exp_on_device( @@ -248,6 +299,160 @@ static __global__ void RowReductionForSoftmaxAndCrossEntropy( if (threadIdx.x == 0) loss_data[blockIdx.x] = loss; } +// Make sure that BlockDim <= axis_dim +template +static __global__ void RowReductionForCrossEntropy(const T* logits_data, + const T* labels_data, + T* loss_data, int d, + int axis_dim) { + __shared__ BlockReduceTempStorage temp_storage; + + // logits, softmax, labels data view as [n, axis_dim, remain] + // loss_data view as [n, 1, remain] + // blockDim = n * remain, split blockIdx to idx_n and idx_remain + int remain = d / axis_dim; + int idx_n = blockIdx.x / remain; + int idx_remain = blockIdx.x % remain; + int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; + int end_idx = (idx_n + 1) * d; + + // log_diff_max_sum shares memory with loss + auto block_log_diff_max_sum = loss_data[blockIdx.x]; + auto tmp = log_on_device(logits_data[beg_idx]); // when not with softmax, + // softmax is stored in + // logits_data + auto loss = -labels_data[beg_idx] * tmp; + int step = BlockDim * remain; + beg_idx += step; + while (beg_idx < end_idx) { + tmp = log_on_device(logits_data[beg_idx]); // when not with softmax, + // softmax is stored in + // logits_data + loss -= (labels_data[beg_idx] * tmp); + beg_idx += step; + } + + loss = BlockReduce(temp_storage).Reduce(loss, cub::Sum()); + if (threadIdx.x == 0) loss_data[blockIdx.x] = loss; +} + +template +struct HardLabelCrossEntropyFunctor { + public: + HardLabelCrossEntropyFunctor(const int64_t* labels, T* loss, + const T* logits_data, int d, int axis_dim) + : labels_(labels), + loss_(loss), + logits_data_(logits_data), + d_(d), + axis_dim_(axis_dim) {} + + __device__ void operator()(int idx) const { + // logits view as [n, axis_dim, remain], where d = axis_dim * remain + int remain = d_ / axis_dim_; + int idx_n = idx / d_; + int idx_axis = (idx % d_) / remain; + int idx_remain = idx % remain; + // labels, loss view as [n, remain] + int idx_lbl = idx_n * remain + idx_remain; + // It also would ignore labels not in range(class_num). + if (idx_axis != labels_[idx_lbl]) { + } else { + loss_[idx_lbl] = -log_on_device(logits_data_[idx]); + } + } + + private: + const int64_t* labels_; + T* loss_; + const T* logits_data_; + int d_; + int axis_dim_; +}; + +template +struct HardLabelCrossEntropyFunctorWithIgnoreIdx { + public: + HardLabelCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels, T* loss, + const T* logits_data, int d, + int axis_dim, int ignore_idx) + : labels_(labels), + loss_(loss), + logits_data_(logits_data), + d_(d), + axis_dim_(axis_dim), + ignore_idx_(ignore_idx) {} + + __device__ void operator()(int idx) const { + // logits view as [n, axis_dim, remain], where d = axis_dim * remain + int remain = d_ / axis_dim_; + int idx_n = idx / d_; + int idx_axis = (idx % d_) / remain; + int idx_remain = idx % remain; + // labels, loss view as [n, remain] + int idx_lbl = idx_n * remain + idx_remain; + + if (idx_axis == ignore_idx_) { + loss_[idx_lbl] = 0; + return; + } + + if (idx_axis == labels_[idx_lbl]) { + loss_[idx_lbl] = -log_on_device(logits_data_[idx]); + } + } + + private: + const int64_t* labels_; + T* loss_; + const T* logits_data_; + int d_; + int axis_dim_; + int ignore_idx_; +}; + +template +static void HardLabelCrossEntropy(const platform::CUDADeviceContext& ctx, + const T* logits_data, + const int64_t* labels_data, T* loss_data, + int n, int d, int axis_dim, int ignore_idx) { + constexpr int kMaxBlockDim = 512; + int block_dim = axis_dim >= kMaxBlockDim + ? kMaxBlockDim + : (1 << static_cast(std::log2(axis_dim))); + int grid_dim = n * d / axis_dim; + auto stream = ctx.stream(); + +#define CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ + case BlockDim: { \ + platform::ForRange for_range(ctx, n* d); \ + if (ignore_idx >= 0 && ignore_idx < axis_dim) { \ + for_range(HardLabelCrossEntropyFunctorWithIgnoreIdx( \ + labels_data, loss_data, logits_data, d, axis_dim, ignore_idx)); \ + } else { \ + for_range(HardLabelCrossEntropyFunctor(labels_data, loss_data, \ + logits_data, d, axis_dim)); \ + } \ + } break + + switch (block_dim) { + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(512); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(256); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(128); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(64); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(32); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(16); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(8); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(4); + CALL_HARD_LABEL_CROSS_ENTROPY_FUSED_KERNEL(2); + default: + PADDLE_THROW(platform::errors::Unavailable( + "Block Dimension must be 2^n in softmax_with_cross_entropy_op.")); + break; + } +#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL +} + template struct HardLabelSoftmaxWithCrossEntropyFunctor { public: @@ -420,6 +625,43 @@ static void SoftmaxWithCrossEntropyFusedKernel( #undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL } +// not with softmax +template +static void CrossEntropyFusedKernel(const T* logits_data, const T* labels_data, + T* loss_data, int n, int d, int axis_dim, + cudaStream_t stream) { + constexpr int kMaxBlockDim = 512; + int block_dim = axis_dim >= kMaxBlockDim + ? kMaxBlockDim + : (1 << static_cast(std::log2(axis_dim))); + int grid_dim = n * d / axis_dim; + +#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ + case BlockDim: \ + RowReductionForCrossEntropy<<>>( \ + logits_data, labels_data, loss_data, d, axis_dim); \ + break + + switch (block_dim) { + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); + CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); + default: + PADDLE_THROW(platform::errors::Unavailable( + "Block Dimension must be 2^n in softmax_with_cross_entropy_op.")); + break; + } + +#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL +} + template class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { public: @@ -428,6 +670,73 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { platform::is_gpu_place(context.GetPlace()), true, platform::errors::Unavailable("softmax_with_cross_entropy operator's " "CUDA kernel only runs on GPU device.")); + const bool use_softmax = context.Attr("use_softmax"); + + // do not with softmax op, and input is softmax + if (!use_softmax) { + const Tensor* softmax = context.Input("Logits"); + const Tensor* labels = context.Input("Label"); + Tensor* softmax_out = context.Output("Softmax"); + Tensor* loss = context.Output("Loss"); + + const int rank = softmax->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis_dim = softmax->dims()[axis]; + + const int n = SizeToAxis(axis, softmax->dims()); + const int d = SizeFromAxis(axis, softmax->dims()); + + auto* softmax_out_data = softmax_out->mutable_data(context.GetPlace()); + auto* loss_data = loss->mutable_data(context.GetPlace()); + + if (axis_dim == 1) { + math::SetConstant set_constant; + set_constant(context.cuda_device_context(), softmax_out, + static_cast(1)); + set_constant(context.cuda_device_context(), loss, static_cast(0)); + return; + } + + auto soft_label = context.Attr("soft_label"); + auto ignore_index = context.Attr("ignore_index"); + + Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d; + softmax_2d.ShareDataWith(*softmax).Resize({n, d}); + labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); + loss_2d.ShareDataWith(*loss).Resize({n, 1}); + softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d}); + + // math::CrossEntropyFunctor support axis is the last + if (axis == -1) { + math::CrossEntropyFunctor()( + context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d, + soft_label, ignore_index, axis_dim); + return; + } + + // if axis is not the last, we need a new impliment + if (soft_label) { + auto* logits_data = softmax->data(); + auto* labels_data = labels->data(); + CrossEntropyFusedKernel(logits_data, labels_data, loss_data, n, d, + axis_dim, + context.cuda_device_context().stream()); + } else { // HardLabel + auto* logits_data = softmax->data(); + auto* labels_data = labels->data(); + HardLabelCrossEntropy(context.cuda_device_context(), logits_data, + labels_data, loss_data, n, d, axis_dim, + ignore_index); + } + + // cause of input is softmax + // copy to output softmax, directly + framework::TensorCopy(*softmax, context.GetPlace(), + context.device_context(), softmax_out); + + return; + } + const Tensor* logits = context.Input("Logits"); const Tensor* labels = context.Input("Label"); Tensor* softmax = context.Output("Softmax"); @@ -514,6 +823,34 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { int block = 512; auto stream = context.cuda_device_context().stream(); auto ignore_index = context.Attr("ignore_index"); + auto use_softmax = context.Attr("use_softmax"); + + // do not with softmax op, and input is softmax + if (!use_softmax) { + if (context.Attr("soft_label")) { + int grid = (n * d + block - 1) / block; + const T* label_data = labels->data(); + SoftLabelCrossEntropyGradientKernel<<>>( + logit_grad_data, loss_grad_data, label_data, n, d, remain); + } else { + Tensor logits_grad_2d; + logits_grad_2d.ShareDataWith(*logit_grad).Resize({n, d}); + int grid = (n * remain + block - 1) / block; + const int64_t* label_data = labels->data(); + HardLabelCrossEntropyGradientKernel<<>>( + logit_grad_data, label_data, n, d, remain, ignore_index); + int num = n * d; + grid = (num + block - 1) / block; + ScaleCrossEntropyGradient<<>>( + logit_grad_data, loss_grad_data, num, d, remain, label_data, + ignore_index); + } + + return; + } + + // with softmax, continue + if (context.Attr("soft_label")) { int64_t grid = (n * d + block - 1) / block; const T* label_data = labels->data(); diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/softmax_with_cross_entropy_op.h index 93f2552c3cee90a3eb1a948494fb231a41f6f74d..d0f6df0bdcef8c8603318875ad67be6b8694f06c 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.h @@ -34,6 +34,46 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ( platform::is_cpu_place(context.GetPlace()), true, platform::errors::Unimplemented("This kernel only runs on CPU.")); + const bool use_softmax = context.Attr("use_softmax"); + + // do not with softmax op, and input is softmax + if (!use_softmax) { + const Tensor* softmax = context.Input("Logits"); + const Tensor* labels = context.Input("Label"); + Tensor* softmax_out = context.Output("Softmax"); + Tensor* loss = context.Output("Loss"); + const bool soft_label = context.Attr("soft_label"); + const int rank = softmax->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis_dim = softmax->dims()[axis]; + + softmax_out->mutable_data(context.GetPlace()); + loss->mutable_data(context.GetPlace()); + + const int n = SizeToAxis(axis, softmax->dims()); + const int d = SizeFromAxis(axis, softmax->dims()); + + Tensor softmax_2d, labels_2d, loss_2d, softmax_out_2d; + softmax_2d.ShareDataWith(*softmax).Resize({n, d}); + labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); + loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim}); + softmax_out_2d.ShareDataWith(*softmax_out).Resize({n, d}); + + auto& dev_ctx = + context.template device_context(); + + math::CrossEntropyFunctor()( + dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label, + context.Attr("ignore_index"), axis_dim); + + // cause of input is softmax + // copy to output softmax, directly + framework::TensorCopy(*softmax, context.GetPlace(), + context.device_context(), softmax_out); + + return; + } + const Tensor* logits = context.Input("Logits"); const Tensor* labels = context.Input("Label"); Tensor* softmax = context.Output("Softmax"); @@ -76,7 +116,9 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { context.Output(framework::GradVarName("Logits")); const Tensor* softmax = context.Input("Softmax"); - if (logit_grad != softmax) { + const bool use_softmax = context.Attr("use_softmax"); + + if (logit_grad != softmax || !use_softmax) { framework::TensorCopy(*softmax, context.GetPlace(), context.device_context(), logit_grad); } @@ -99,28 +141,94 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { auto logit_grad_mat = EigenMatrix::From(logit_grad_2d); auto& place = *context.template device_context() .eigen_device(); + if (!use_softmax) { + // use_softmax step1 + if (soft_label) { + auto lbl_mat = framework::EigenMatrix::From(labels_2d); + logit_grad_mat.device(place) = + (-lbl_mat / logit_grad_mat); // for each sample ,i is sample id + logit_grad_mat.device(place) = + out_grad_mat.broadcast(Eigen::DSizes(1, axis_dim)) * + logit_grad_mat; + } + // use_softmax step2 + else { + const int64_t* label_data = labels->data(); + T* logit_grad_data = logit_grad->data(); + const T* out_grad_data = out_grad->data(); + const int remain = d / axis_dim; + for (int i = 0; i < n; ++i) { // for each sample_1_dim + for (int j = 0; j < remain; j++) { // for each sample_other_dims + int idx = i * remain + j; // this sample's label_idx. for 1d case, + // remain=1 and j=0, so, idx = i + if (label_data[idx] == ignore_index) { + for (int k = 0; k < axis_dim; ++k) { // for each class id's label + logit_grad_data[i * d + k * remain + j] = 0; + } + } else { + // only for this sample's label_idx, the label is 1, others is 0, + // so, only compute this label_idx's class + logit_grad_data[i * d + label_data[idx] * remain + j] = + (-1 / logit_grad_data[i * d + label_data[idx] * remain + j]) * + out_grad_data[idx]; + for (int k = 0; k < axis_dim; ++k) { // for each class id's label + if (k != + label_data[idx]) { // label_data[idx]: this sample's label + logit_grad_data[i * d + k * remain + j] = 0; + } + } + } + } + } + } + return; + } + + // for use_softmax=False, continue + if (soft_label) { - auto lbl_mat = EigenMatrix::From(labels_2d); + // when soft_label = True, ignore_index is not supported + auto lbl_mat = framework::EigenMatrix::From(labels_2d); logit_grad_mat.device(place) = out_grad_mat.broadcast(Eigen::DSizes(1, axis_dim)) * - (logit_grad_mat - lbl_mat); + (logit_grad_mat - lbl_mat); // for each sample ,i is sample id + // 1) compute dy/dx by p_j - y_j or P-Y, where j is class id, + // P=logit_grad_mat[i] is all class's probs, Y=lbl_mat[i] is + // all class's labels + // 2) compute dy * dy/dx by Chain rule, dy=out_grad_mat[i] + // for high dims, e.g. (n,c) or (n,d1,...,dm, c), compute grad by matrix + // operation + } else { logit_grad_mat.device(place) = - logit_grad_mat * + logit_grad_mat * // element_wise multiply out_grad_mat.broadcast(Eigen::DSizes(1, axis_dim)); const int64_t* label_data = labels->data(); T* logit_grad_data = logit_grad->data(); const T* out_grad_data = out_grad->data(); const int remain = d / axis_dim; - for (int i = 0; i < n; ++i) { - for (int j = 0; j < remain; j++) { - int idx = i * remain + j; + for (int i = 0; i < n; ++i) { // for each sample_1_dim + for (int j = 0; j < remain; j++) { // for each sample_other_dims + int idx = i * remain + j; // this sample's label_idx. for 1d case, + // remain=1 and j=0, so, idx = i if (label_data[idx] == ignore_index) { - for (int k = 0; k < axis_dim; ++k) { + for (int k = 0; k < axis_dim; ++k) { // for each class id's label logit_grad_data[i * d + k * remain + j] = 0; } } else { + // only for this sample's label_idx, the label is 1, others is 0, + // so, only compute this label_idx's class + // for 1d case, remain=1 and j=0, so, [i * d + label_data[idx] * + // remain + j] = [i * d + label_data[idx]] + // let idx_x = i * d + label_data[idx] * remain + j, + // logit_grad_data[idx_x] = logit_grad_data[idx_x] - + // out_grad_data[idx] + // note: logit_grad_mat = logit_grad_mat * out_grad_mat + // so: logit_grad_data[idx_x] = (logit_grad_data[idx_x] - 1) * + // out_grad_data[idx] + // means: dy/dp * dy= ( p - y ) * dy + logit_grad_data[i * d + label_data[idx] * remain + j] -= out_grad_data[idx]; } diff --git a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py index 0ee58d5be15e60f50b0d6f4d0fc7c55075b81aea..b9c03efbe74c7255e8a0360b1ef3f7a01536029b 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py @@ -55,6 +55,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): self.axis = -1 self.ignore_index = -1 self.shape = [41, 37] + self.use_softmax = True def setUp(self): self.initParams() @@ -75,7 +76,11 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): loss = cross_entropy(softmax, labels, self.soft_label, self.axis, self.ignore_index) - self.inputs = {"Logits": logits, "Label": labels} + if self.use_softmax == False: + self.inputs = {"Logits": softmax, "Label": labels} + else: + self.inputs = {"Logits": logits, "Label": labels} + self.outputs = { "Softmax": softmax.astype(self.dtype), "Loss": loss.astype(self.dtype) @@ -84,6 +89,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): "numeric_stable_mode": self.numeric_stable_mode, "soft_label": self.soft_label, "ignore_index": self.ignore_index, + "use_softmax": self.use_softmax, } if self.axis != -1: @@ -93,7 +99,215 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(["Logits"], "Loss", max_relative_error=5e-5) + self.check_grad(["Logits"], "Loss", numeric_grad_delta=0.001) + + +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = True + self.shape = [13, 8] + self.axis = -1 + self.ignore_index = -1 + self.dtype = np.float64 + self.use_softmax = False #default is true, means "with softmax" + + +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_1D( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.shape = [13, 8] + self.axis = -1 + self.ignore_index = -1 + self.dtype = np.float64 + self.use_softmax = False #default is true, means "with softmax" + + +############################################################################## +#NotWithSoftmax_SoftLabel_2D start +############################################################################## +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = True + self.shape = [3, 5, 7, 11] + self.axis = -1 + self.ignore_index = -1 + self.dtype = np.float64 + self.use_softmax = False #default is true, means "with softmax" + + +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis2( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = True + self.dtype = np.float64 + self.axis = 1 + self.ignore_index = -1 + self.shape = [3, 5, 7, 11] + self.use_softmax = False #default is true, means "with softmax" + + +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis3( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = True + self.dtype = np.float64 + self.axis = 2 + self.ignore_index = -1 + self.shape = [3, 5, 7, 11] + self.use_softmax = False #default is true, means "with softmax" + + +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis4( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = True + self.dtype = np.float64 + self.axis = 3 + self.ignore_index = -1 + self.shape = [3, 5, 7, 11] + self.use_softmax = False #default is true, means "with softmax" + + +############################################################################## +#NotWithSoftmax_SoftLabel_2D end +############################################################################## + +############################################################################## +#NotWithSoftmax_HardLabel_2D start +############################################################################## + + +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.axis = -1 + self.ignore_index = -1 + self.dtype = np.float64 + self.use_softmax = False #default is true, means "with softmax" + + +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis2( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.dtype = np.float64 + self.axis = 1 + self.ignore_index = -1 + self.shape = [3, 5, 7, 11] + self.use_softmax = False #default is true, means "with softmax" + + +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis3( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.dtype = np.float64 + self.axis = 2 + self.ignore_index = -1 + self.shape = [3, 5, 7, 11] + self.use_softmax = False #default is true, means "with softmax" + + +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis4( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.dtype = np.float64 + self.axis = 3 + self.ignore_index = -1 + self.shape = [3, 5, 7, 11] + self.use_softmax = False #default is true, means "with softmax" + + +############################################################################## +#NotWithSoftmax_HardLabel_2D end +############################################################################## + +############################################################################## +#NotWithSoftmax_HardLabel_2D_Ignore start +############################################################################## + + +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = False + self.soft_label = False + self.shape = [13, 8] + self.axis = -1 + self.ignore_index = 2 + self.dtype = np.float64 + self.use_softmax = False #default is true, means "with softmax" + + +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore_Axis( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = False + self.soft_label = False + self.shape = [13, 8] + self.axis = 1 + self.ignore_index = 2 + self.dtype = np.float64 + self.use_softmax = False #default is true, means "with softmax" + + +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.axis = -1 + self.ignore_index = 2 + self.dtype = np.float64 + self.use_softmax = False #default is true, means "with softmax" + + +class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3( + TestSoftmaxWithCrossEntropyOp): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.dtype = np.float64 + self.axis = 2 + self.ignore_index = 2 + self.shape = [3, 5, 7, 11] + self.use_softmax = False #default is true, means "with softmax" + + +############################################################################## +#NotWithSoftmax_HardLabel_2D_Ignore end +############################################################################## class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp): @@ -105,6 +319,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp): self.axis = -1 self.ignore_index = -1 self.dtype = np.float64 + self.use_softmax = True @unittest.skipIf(not core.is_compiled_with_cuda(), @@ -182,6 +397,7 @@ class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp): self.axis = -1 self.ignore_index = -1 self.shape = [41, 37] + self.use_softmax = True def test_check_output(self): self.check_output() @@ -203,6 +419,7 @@ class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp): self.ignore_index = 5 self.axis = -1 self.dtype = np.float64 + self.use_softmax = True class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3): @@ -214,6 +431,7 @@ class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3): self.ignore_index = 4 self.axis = -1 self.dtype = np.float64 + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp): @@ -230,6 +448,7 @@ class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp): self.axis = 0 self.ignore_index = -1 self.shape = [3, 5, 7, 11] + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp): @@ -246,6 +465,7 @@ class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp): self.axis = 1 self.ignore_index = -1 self.shape = [3, 5, 7, 11] + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp): @@ -262,6 +482,7 @@ class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp): self.axis = 2 self.ignore_index = -1 self.shape = [3, 5, 7, 11] + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp): @@ -278,6 +499,7 @@ class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp): self.axis = 3 self.ignore_index = -1 self.shape = [3, 5, 7, 11] + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne( @@ -295,6 +517,7 @@ class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne( self.axis = -1 self.ignore_index = -1 self.shape = [3, 5, 7, 1] + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1( @@ -307,6 +530,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1( self.axis = 0 self.ignore_index = -1 self.dtype = np.float16 + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2( @@ -319,6 +543,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2( self.axis = 1 self.ignore_index = -1 self.dtype = np.float16 + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3( @@ -331,6 +556,7 @@ class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3( self.axis = 2 self.ignore_index = -1 self.dtype = np.float16 + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1( @@ -343,6 +569,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1( self.axis = 0 self.ignore_index = -1 self.dtype = np.float64 + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2( @@ -355,6 +582,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2( self.axis = 1 self.ignore_index = -1 self.dtype = np.float64 + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3( @@ -367,6 +595,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3( self.axis = 2 self.ignore_index = -1 self.dtype = np.float64 + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4( @@ -379,6 +608,7 @@ class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4( self.axis = 3 self.ignore_index = -1 self.dtype = np.float64 + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1( @@ -391,6 +621,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1( self.ignore_index = 1 self.axis = 0 self.dtype = np.float64 + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2( @@ -403,6 +634,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2( self.ignore_index = 0 self.axis = 1 self.dtype = np.float64 + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3( @@ -415,6 +647,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3( self.ignore_index = 3 self.axis = 2 self.dtype = np.float64 + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4( @@ -427,6 +660,7 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4( self.ignore_index = 3 self.axis = 3 self.dtype = np.float64 + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpBoundary0(TestSoftmaxWithCrossEntropyOp): @@ -444,6 +678,7 @@ class TestSoftmaxWithCrossEntropyOpBoundary0(TestSoftmaxWithCrossEntropyOp): self.ignore_index = -1 self.dtype = np.float64 self.logits = np.full(self.shape, -500.0).astype(self.dtype) + self.use_softmax = True class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp): @@ -462,6 +697,7 @@ class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp): self.dtype = np.float64 self.logits = np.full(self.shape, 1000.0).astype(self.dtype) self.logits[:, :, 0, :] = -1000.0 + self.use_softmax = True if __name__ == "__main__": diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 52c605d5bb49825503b031f0788008f640398d93..e5e3fa7bf8f76d60d16f82c149b4c7ae5bb3c693 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1388,8 +1388,6 @@ def cross_entropy(input, "should be '-100', but received %s, which is not allowed." % ignore_index) - softmax_switch = use_softmax - input_dims = len(list(input.shape)) label_dims = len(list(label.shape)) if input_dims - 1 != label_dims and input_dims != label_dims: @@ -1402,7 +1400,7 @@ def cross_entropy(input, _, out = core.ops.softmax_with_cross_entropy( input, label, 'soft_label', soft_label, 'ignore_index', ignore_index, 'numeric_stable_mode', True, 'axis', axis, - 'softmax_switch', softmax_switch) + 'use_softmax', use_softmax) if weight is not None: @@ -1484,7 +1482,7 @@ def cross_entropy(input, 'ignore_index': ignore_index, 'numeric_stable_mode': True, 'axis': axis, - 'softmax_switch': softmax_switch + 'use_softmax': use_softmax } helper = LayerHelper('softmax_with_cross_entropy', **locals()) softmax = helper.create_variable_for_type_inference(dtype=input.dtype)