diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 03c7f32a1261a184e6bdf4689aa411aa99ea8e68..375b05aff249c2562020f56e260bbd4fbc6ed737 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -133,7 +133,7 @@ paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, k paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', 'de6a906950bae9f3c245cb744d22b94e')) paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '419c3a24a83cc89219a029cf4092788b')) paddle.fluid.layers.spectral_norm (ArgSpec(args=['weight', 'dim', 'power_iters', 'eps', 'name'], varargs=None, keywords=None, defaults=(0, 1, 1e-12, None)), ('document', '3f536aafba30d793287b52d231baff1b')) -paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax'], varargs=None, keywords=None, defaults=(False, -100, True, False)), ('document', 'bce1b75e3d95b75cacd1099655cbb3c3')) +paddle.fluid.layers.softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax', 'axis'], varargs=None, keywords=None, defaults=(False, -100, True, False, -1)), ('document', '8b074f9c56b4233a2b65d03254eb309e')) paddle.fluid.layers.smooth_l1 (ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'c6b175d253c55baf4b9c0eca9b1dda88')) paddle.fluid.layers.one_hot (ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None), ('document', '960fc799549c202da1e85d626cb2c962')) paddle.fluid.layers.autoincreased_step_counter (ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)), ('document', '67afefa80b6cc38801bd5b631fed8a4a')) diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index 7eb663773ed072760c47a2914377b5306ceeb7af..1d625579052055ac732c4ff0bae73888dfdc2233 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -39,9 +39,10 @@ class CrossEntropyOpKernel : public framework::OpKernel { Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1); Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1); + int axis_dim = x->dims()[rank - 1]; math::CrossEntropyFunctor()( ctx.template device_context(), &y_2d, &x_2d, &labels_2d, - ctx.Attr("soft_label"), ctx.Attr("ignore_index")); + ctx.Attr("soft_label"), ctx.Attr("ignore_index"), axis_dim); } }; diff --git a/paddle/fluid/operators/math/cross_entropy.cc b/paddle/fluid/operators/math/cross_entropy.cc index 18bf1a66f6d9903f32048574dc93faf7e98953ac..9f7884fe05f2f446b1fb6eb7dfd53e293d8e19aa 100644 --- a/paddle/fluid/operators/math/cross_entropy.cc +++ b/paddle/fluid/operators/math/cross_entropy.cc @@ -29,8 +29,13 @@ class CrossEntropyFunctor { void operator()(const platform::CPUDeviceContext& ctx, framework::Tensor* out, const framework::Tensor* prob, const framework::Tensor* labels, const bool softLabel, - const int ignore_index) { + const int ignore_index, const int axis_dim) { const int batch_size = prob->dims()[0]; + const int num_classes = prob->dims()[1]; + const int num_remain = num_classes / axis_dim; + + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + if (softLabel) { auto in = EigenMatrix::From(*prob); auto lbl = EigenMatrix::From(*labels); @@ -38,24 +43,24 @@ class CrossEntropyFunctor { loss.device(*ctx.eigen_device()) = -((lbl * in.log().unaryExpr(math::TolerableValue())) - .sum(Eigen::DSizes(1)) - .reshape(Eigen::DSizes(batch_size, 1))); + .reshape(batch_axis_remain) + .sum(Eigen::DSizes(1))); } else { - const int class_num = prob->dims()[1]; const T* prob_data = prob->data(); T* loss_data = out->data(); const int64_t* label_data = labels->data(); for (int i = 0; i < batch_size; ++i) { - int lbl = label_data[i]; - PADDLE_ENFORCE_GE(lbl, 0); - PADDLE_ENFORCE_LT(lbl, class_num); - PADDLE_ENFORCE((lbl >= 0 && lbl < class_num) || lbl == ignore_index); - int index = i * class_num + lbl; - loss_data[i] = - lbl == ignore_index - ? 0 - : -math::TolerableValue()(std::log(prob_data[index])); + for (int j = 0; j < num_remain; j++) { + int lbl = label_data[i * num_remain + j]; + PADDLE_ENFORCE((lbl >= 0 && lbl < axis_dim) || lbl == ignore_index); + int index = i * num_classes + lbl * num_remain + j; + int loss_idx = i * num_remain + j; + loss_data[loss_idx] = + lbl == ignore_index + ? 0 + : -math::TolerableValue()(std::log(prob_data[index])); + } } } } diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index 44cbdf2e9882195819bc3ca047dbac6e2fa4e631..5bc05257aa9d3db7881330ca4547da439dab03bd 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -57,8 +57,8 @@ class CrossEntropyFunctor { public: void operator()(const platform::CUDADeviceContext& ctx, framework::Tensor* out, const framework::Tensor* prob, - const framework::Tensor* labels, bool softLabel, - const int ignore_index) { + const framework::Tensor* labels, const bool softLabel, + const int ignore_index, const int axis_dim) { const T* prob_data = prob->data(); T* loss_data = out->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/operators/math/cross_entropy.h b/paddle/fluid/operators/math/cross_entropy.h index 99a4935186e1e6f9e3bf36eb029ce3d230510117..48082a7273dd7ad713fbc964ebbd1445ed887cdd 100644 --- a/paddle/fluid/operators/math/cross_entropy.h +++ b/paddle/fluid/operators/math/cross_entropy.h @@ -60,7 +60,7 @@ class CrossEntropyFunctor { void operator()(const DeviceContext& context, framework::Tensor* out, const framework::Tensor* prob, const framework::Tensor* labels, const bool softLabel, - const int ignore_index); + const int ignore_index, const int axis_dim); }; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index 371ab0384a3fa2ff22ac4e5c3d1e54aff237b47d..447b5b3a199857ddd60aad202151ea04498648d6 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -26,23 +26,28 @@ class SoftmaxWithCrossEntropyOpMaker public: void Make() override { AddInput("Logits", - "(Tensor, default: Tensor), The unscaled log probabilities " - "which is a 2-D tensor with shape [N x K]. N is the batch_size, " - "and K is the class number."); - AddInput("Label", - "(Tensor) The ground truth which is a 2-D tensor. If soft_label " - "is set to false, Label is a Tensor with shape [N x 1]. If " - "soft_label is set to true, Label is a Tensor with " - "shape [N x K]."); + "(Tensor, default: Tensor), The input tensor of unscaled " + "log probabilities, whose dimension :attr:`axis` should be scaled " + "by softmax."); + AddInput( + "Label", + "(Tensor) The input tesnor of groud truth label. If :attr:`soft_label` " + "is set to false, Label is a Tensor in same shape with " + "Input(Logits) except the shape in dimension :attr:`axis` as 1. If " + "soft_label is set to true, Label is a Tensor in same " + "shape with Input(Logits)."); AddOutput( "Softmax", - "(Tensor, default: Tensor), A 2-D tensor with shape [N x K]. " + "(Tensor, default: Tensor), A tensor in same shape with " + "Input(Logits). " "The outputs value of softmax activation by given the input batch, " "which will be used in backward calculation.") .AsIntermediate(); AddOutput("Loss", - "(Tensor, default: Tensor), A 2-D tensor. The cross " - "entropy loss with shape [N x 1]."); + "(Tensor, default: Tensor), A tensor in same shape with " + "Input(Logits) " + "except the shape in dimension :attr:`axis` as 1. The cross " + "entropy loss."); AddAttr( "soft_label", "(bool, default: false), A flag to indicate whether to interpretate " @@ -60,6 +65,10 @@ class SoftmaxWithCrossEntropyOpMaker "does not contribute to the input gradient. Only valid if soft_label" "is set to False") .SetDefault(-100); + AddAttr("axis", + "The dimension index of Input(Logits) to perform softmax," + "default -1 for last dimension") + .SetDefault(-1); AddComment(R"DOC( Softmax With Cross Entropy Operator. @@ -107,38 +116,53 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { "Output(Softmax) should be not null."); PADDLE_ENFORCE(ctx->HasOutput("Loss"), "Output(Loss) should be not null."); + auto axis = ctx->Attrs().Get("axis"); auto logits_dims = ctx->GetInputDim("Logits"); auto labels_dims = ctx->GetInputDim("Label"); + auto logits_rank = logits_dims.size(); + PADDLE_ENFORCE(axis >= -logits_rank && axis < logits_rank, + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(Logits)."); + + axis = CanonicalAxis(axis, logits_rank); + for (int i = 0; i < logits_rank; i++) { + if (i != axis) { + if (ctx->IsRuntime() || (logits_dims[i] > 0 && labels_dims[i] > 0)) { + PADDLE_ENFORCE_EQ( + logits_dims[i], labels_dims[i], + "Input(Logits) and Input(Label) should in same shape in " + "dimensions except axis."); + } + } + } - int rank = logits_dims.size(); - PADDLE_ENFORCE_EQ( - rank, labels_dims.size(), - "Input(logits) and Input(Label) shall have the same rank."); - bool check = ctx->IsRuntime() || (framework::product(logits_dims) > 0 && - framework::product(labels_dims) > 0); - if (check) { - PADDLE_ENFORCE_EQ(framework::slice_ddim(logits_dims, 0, rank - 1), - framework::slice_ddim(labels_dims, 0, rank - 1), - "Input(X) and Input(Label) shall have the same shape " - "except the last dimension."); + auto numeric_stable_mode = ctx->Attrs().Get("numeric_stable_mode"); + if (axis != logits_rank - 1) { + PADDLE_ENFORCE( + numeric_stable_mode, + "Attr(axis) can only be -1 when not in numeric_stable_mode."); } - if (ctx->Attrs().Get("soft_label")) { - if (check) { - PADDLE_ENFORCE_EQ(logits_dims[rank - 1], labels_dims[rank - 1], - "If Attr(soft_label) == true, the last dimension of " + bool soft_label = ctx->Attrs().Get("soft_label"); + if (soft_label) { + if (ctx->IsRuntime() || + (logits_dims[axis] > 0 && labels_dims[axis] > 0)) { + PADDLE_ENFORCE_EQ(logits_dims[axis], labels_dims[axis], + "If Attr(soft_label) == true, the axis dimension of " "Input(X) and Input(Label) should be equal."); } } else { - PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL, - "If Attr(softLabel) == false, the last dimension of " - "Input(Label) should be 1."); + if (ctx->IsRuntime() || labels_dims[axis] > 0) { + PADDLE_ENFORCE_EQ(labels_dims[axis], 1UL, + "If Attr(soft_label) == false, the axis dimension of " + "Input(Label) should be 1."); + } } ctx->SetOutputDim("Softmax", logits_dims); - auto loss_dims = logits_dims; - loss_dims[rank - 1] = 1; - ctx->SetOutputDim("Loss", loss_dims); + + logits_dims[axis] = 1; + ctx->SetOutputDim("Loss", logits_dims); ctx->ShareLoD("Logits", /*->*/ "Softmax"); ctx->ShareLoD("Logits", /*->*/ "Loss"); @@ -165,36 +189,40 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")), "Output(Logits@Grad) should be not null."); + auto axis = ctx->Attrs().Get("axis"); auto softmax_dims = ctx->GetInputDim("Softmax"); auto labels_dims = ctx->GetInputDim("Label"); - - int rank = softmax_dims.size(); - PADDLE_ENFORCE_EQ( - rank, labels_dims.size(), - "Input(logits) and Input(Label) shall have the same rank."); - bool check = true; - if ((!ctx->IsRuntime()) && (framework::product(softmax_dims) <= 0 || - framework::product(labels_dims) <= 0)) { - check = false; - } - if (check) { - PADDLE_ENFORCE_EQ( - framework::slice_ddim(softmax_dims, 0, rank - 1), - framework::slice_ddim(labels_dims, 0, rank - 1), - "Input(Softmax) and Input(Label) shall have the same shape " - "except the last dimension."); + auto softmax_rank = softmax_dims.size(); + PADDLE_ENFORCE(axis >= -softmax_rank && axis < softmax_rank, + "Attr(axis) value should be in range [-R, R-1], " + "R is the rank of Input(Logits)."); + + axis = CanonicalAxis(axis, softmax_rank); + for (int i = 0; i < softmax_rank; i++) { + if (i != axis) { + if (ctx->IsRuntime() || (softmax_dims[i] > 0 && labels_dims[i] > 0)) { + PADDLE_ENFORCE_EQ( + softmax_dims[i], labels_dims[i], + "Input(Logits) and Input(Label) should in same shape in " + "dimensions except axis."); + } + } } - if (ctx->Attrs().Get("soft_label")) { - if (check) { - PADDLE_ENFORCE_EQ(softmax_dims[rank - 1], labels_dims[rank - 1], - "If Attr(soft_label) == true, the last dimension of " - "Input( Softmax) and Input(Label) should be equal."); + bool soft_label = ctx->Attrs().Get("soft_label"); + if (soft_label) { + if (ctx->IsRuntime() || + (softmax_dims[axis] > 0 && labels_dims[axis] > 0)) { + PADDLE_ENFORCE_EQ(softmax_dims[axis], labels_dims[axis], + "If Attr(soft_label) == true, the axis dimension of " + "Input(X) and Input(Label) should be equal."); } } else { - PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL, - "If Attr(softLabel) == false, the last dimension of " - "Input(Label) should be 1."); + if (ctx->IsRuntime() || labels_dims[axis] > 0) { + PADDLE_ENFORCE_EQ(labels_dims[axis], 1UL, + "If Attr(soft_label) == false, the axis dimension of " + "Input(Label) should be 1."); + } } ctx->SetOutputDim(framework::GradVarName("Logits"), diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index dc5ec7bc38cb60d15f796f6523b920b6696510cd..19b4698aca8a80fd4a0845dee1122cafb79c6b87 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -21,11 +21,13 @@ using Tensor = framework::Tensor; namespace { template __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, - const int batch_size, const int class_num, + const int n, const int d, const int remain, const int ignore_index) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n * remain; i += blockDim.x * gridDim.x) { - int idx = i * class_num + labels[i]; + int idx_n = i / remain; + int idx_remain = i % remain; + int idx = idx_n * d + labels[i] * remain + idx_remain; logit_grad[idx] -= ignore_index == labels[i] ? static_cast(0.) : static_cast(1.); } @@ -33,23 +35,26 @@ __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, template __global__ void Scale(T* logit_grad, const T* loss_grad, const int num, - const int class_num) { + const int d, const int remain) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) { - logit_grad[i] *= loss_grad[i / class_num]; + int idx_n = i / d; + int idx_remain = i % remain; + logit_grad[i] *= loss_grad[idx_n * remain + idx_remain]; } } template __global__ void SoftCrossEntropyGradientKernel(T* logit_grad, const T* loss_grad, - const T* labels, - const int batch_size, - const int class_num) { + const T* labels, const int n, + const int d, const int remain) { int ids = blockIdx.x * blockDim.x + threadIdx.x; - if (ids < batch_size * class_num) { - int row_ids = ids / class_num; - logit_grad[ids] = loss_grad[row_ids] * (logit_grad[ids] - labels[ids]); + if (ids < n * d) { + int idx_n = ids / d; + int idx_remain = ids % remain; + int idx_loss = idx_n * remain + idx_remain; + logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]); } } @@ -116,23 +121,30 @@ using BlockReduce = template using BlockReduceTempStorage = typename BlockReduce::TempStorage; -// Make sure that BlockDim <= feature_size +// Make sure that BlockDim <= axis_dim // This kernel is used to calculate the max element of each row template static __global__ void RowReductionForMax(const T* logits_data, T* max_data, - int feature_size) { + int d, int axis_dim) { __shared__ BlockReduceTempStorage temp_storage; - auto beg_idx = feature_size * blockIdx.x + threadIdx.x; - auto end_idx = feature_size * (blockIdx.x + 1); + // logits_data view as [n, axis_dim, remain] + // max_data view as [n, 1, remain] + // blockDim = n * remain, split blockIdx to idx_n and idx_remain + int remain = d / axis_dim; + int idx_n = blockIdx.x / remain; + int idx_remain = blockIdx.x % remain; + int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; + int end_idx = (idx_n + 1) * d; + int step = BlockDim * remain; T cur_max = logits_data[beg_idx]; - beg_idx += BlockDim; + beg_idx += step; while (beg_idx < end_idx) { if (cur_max < logits_data[beg_idx]) { cur_max = logits_data[beg_idx]; } - beg_idx += BlockDim; + beg_idx += step; } cur_max = BlockReduce(temp_storage).Reduce(cur_max, cub::Max()); @@ -143,25 +155,32 @@ static __global__ void RowReductionForMax(const T* logits_data, T* max_data, } } -// Make sure that BlockDim <= feature_size +// Make sure that BlockDim <= axis_dim template static __global__ void RowReductionForDiffMaxSum(const T* logits_data, - T* max_data, T* softmax, - int feature_size) { + T* max_data, T* softmax, int d, + int axis_dim) { __shared__ BlockReduceTempStorage temp_storage; - auto beg_idx = feature_size * blockIdx.x + threadIdx.x; - auto end_idx = feature_size * (blockIdx.x + 1); + // logits, softmax data view as [n, axis_dim, remain] + // max_data view as [n, 1, remain] + // blockDim = n * remain, split blockIdx to idx_n and idx_remain + int remain = d / axis_dim; + int idx_n = blockIdx.x / remain; + int idx_remain = blockIdx.x % remain; + int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; + int end_idx = (idx_n + 1) * d; auto block_max = max_data[blockIdx.x]; + int step = BlockDim * remain; softmax[beg_idx] = logits_data[beg_idx] - block_max; T diff_max_sum = exp_on_device(softmax[beg_idx]); - auto idx = beg_idx + BlockDim; + auto idx = beg_idx + step; while (idx < end_idx) { softmax[idx] = logits_data[idx] - block_max; diff_max_sum += exp_on_device(softmax[idx]); - idx += BlockDim; + idx += step; } diff_max_sum = @@ -172,34 +191,42 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data, __syncthreads(); diff_max_sum = max_data[blockIdx.x]; softmax[beg_idx] -= diff_max_sum; - beg_idx += BlockDim; + beg_idx += step; while (beg_idx < end_idx) { softmax[beg_idx] -= diff_max_sum; - beg_idx += BlockDim; + beg_idx += step; } if (threadIdx.x == 0) max_data[blockIdx.x] = 0; } -// Make sure that BlockDim <= feature_size +// Make sure that BlockDim <= axis_dim template static __global__ void RowReductionForSoftmaxAndCrossEntropy( - const T* labels_data, T* loss_data, T* softmax, int feature_size) { + const T* logits_data, const T* labels_data, T* loss_data, T* softmax, int d, + int axis_dim) { __shared__ BlockReduceTempStorage temp_storage; - auto beg_idx = feature_size * blockIdx.x + threadIdx.x; - auto end_idx = feature_size * (blockIdx.x + 1); + // logits, softmax, labels data view as [n, axis_dim, remain] + // loss_data view as [n, 1, remain] + // blockDim = n * remain, split blockIdx to idx_n and idx_remain + int remain = d / axis_dim; + int idx_n = blockIdx.x / remain; + int idx_remain = blockIdx.x % remain; + int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; + int end_idx = (idx_n + 1) * d; // log_diff_max_sum shares memory with loss auto block_log_diff_max_sum = loss_data[blockIdx.x]; auto tmp = softmax[beg_idx] - block_log_diff_max_sum; softmax[beg_idx] = exp_on_device(tmp); auto loss = -labels_data[beg_idx] * tmp; - beg_idx += BlockDim; + int step = BlockDim * remain; + beg_idx += step; while (beg_idx < end_idx) { tmp = softmax[beg_idx] - block_log_diff_max_sum; softmax[beg_idx] = exp_on_device(tmp); loss -= (labels_data[beg_idx] * tmp); - beg_idx += BlockDim; + beg_idx += step; } loss = BlockReduce(temp_storage).Reduce(loss, cub::Sum()); @@ -210,21 +237,27 @@ template struct HardLabelSoftmaxWithCrossEntropyFunctor { public: HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss, - T* log_softmax, int feature_size) + T* log_softmax, int d, int axis_dim) : labels_(labels), loss_(loss), log_softmax_(log_softmax), - feature_size_(feature_size) {} + d_(d), + axis_dim_(axis_dim) {} __device__ void operator()(int idx) const { - auto row_idx = idx / feature_size_; - auto col_idx = idx % feature_size_; - if (col_idx != labels_[row_idx]) { + // logits view as [n, axis_dim, remain], where d = axis_dim * remain + int remain = d_ / axis_dim_; + int idx_n = idx / d_; + int idx_axis = (idx % d_) / remain; + int idx_remain = idx % remain; + // labels, loss view as [n, remain] + int idx_lbl = idx_n * remain + idx_remain; + if (idx_axis != labels_[idx_lbl]) { log_softmax_[idx] = exp_on_device(log_softmax_[idx]); } else { auto softmax = log_softmax_[idx]; log_softmax_[idx] = exp_on_device(softmax); - loss_[row_idx] = -softmax; + loss_[idx_lbl] = -softmax; } } @@ -232,7 +265,8 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor { const int64_t* labels_; T* loss_; T* log_softmax_; - int feature_size_; + int d_; + int axis_dim_; }; template @@ -240,23 +274,29 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { public: HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels, T* loss, T* log_softmax, - int feature_size, + int d, int axis_dim, int ignore_idx) : labels_(labels), loss_(loss), log_softmax_(log_softmax), - feature_size_(feature_size), + d_(d), + axis_dim_(axis_dim), ignore_idx_(ignore_idx) {} __device__ void operator()(int idx) const { - auto row_idx = idx / feature_size_; - auto col_idx = idx % feature_size_; - if (col_idx != labels_[row_idx] || col_idx == ignore_idx_) { + // logits view as [n, axis_dim, remain], where d = axis_dim * remain + int remain = d_ / axis_dim_; + int idx_n = idx / d_; + int idx_axis = (idx % d_) / remain; + int idx_remain = idx % remain; + // labels, loss view as [n, remain] + int idx_lbl = idx_n * remain + idx_remain; + if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) { log_softmax_[idx] = exp_on_device(log_softmax_[idx]); } else { auto softmax = log_softmax_[idx]; log_softmax_[idx] = exp_on_device(softmax); - loss_[row_idx] = -softmax; + loss_[idx_lbl] = -softmax; } } @@ -264,44 +304,44 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { const int64_t* labels_; T* loss_; T* log_softmax_; - int feature_size_; + int d_; + int axis_dim_; int ignore_idx_; }; template -static __global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out, - int batch_size) { +static __global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out, int n) { auto idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < batch_size) out[idx] = static_cast(1); + if (idx < n) out[idx] = static_cast(1); } template static void HardLabelSoftmaxWithCrossEntropy( const platform::CUDADeviceContext& ctx, const T* logits_data, - const int64_t* labels_data, T* loss_data, T* softmax_data, int batch_size, - int feature_size, int ignore_idx) { + const int64_t* labels_data, T* loss_data, T* softmax_data, int n, int d, + int axis_dim, int ignore_idx) { constexpr int kMaxBlockDim = 512; - int block_dim = feature_size >= kMaxBlockDim + int block_dim = axis_dim >= kMaxBlockDim ? kMaxBlockDim - : (1 << static_cast(std::log2(feature_size))); + : (1 << static_cast(std::log2(axis_dim))); + int grid_dim = n * d / axis_dim; auto stream = ctx.stream(); -#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ - case BlockDim: { \ - RowReductionForMax<<>>( \ - logits_data, loss_data, feature_size); \ - RowReductionForDiffMaxSum<<>>( \ - logits_data, loss_data, softmax_data, feature_size); \ - platform::ForRange for_range( \ - ctx, batch_size* feature_size); \ - if (ignore_idx >= 0 && ignore_idx < feature_size) { \ - for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx( \ - labels_data, loss_data, softmax_data, feature_size, ignore_idx)); \ - } else { \ - for_range(HardLabelSoftmaxWithCrossEntropyFunctor( \ - labels_data, loss_data, softmax_data, feature_size)); \ - } \ +#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ + case BlockDim: { \ + RowReductionForMax<<>>( \ + logits_data, loss_data, d, axis_dim); \ + RowReductionForDiffMaxSum<<>>( \ + logits_data, loss_data, softmax_data, d, axis_dim); \ + platform::ForRange for_range(ctx, n* d); \ + if (ignore_idx >= 0 && ignore_idx < axis_dim) { \ + for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx( \ + labels_data, loss_data, softmax_data, d, axis_dim, ignore_idx)); \ + } else { \ + for_range(HardLabelSoftmaxWithCrossEntropyFunctor( \ + labels_data, loss_data, softmax_data, d, axis_dim)); \ + } \ } break switch (block_dim) { @@ -315,11 +355,11 @@ static void HardLabelSoftmaxWithCrossEntropy( CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); case 1: - SetSoftmaxToOneWhenFeatureSizeIsOne<<<(batch_size + kMaxBlockDim - 1) / + SetSoftmaxToOneWhenFeatureSizeIsOne<<<(grid_dim + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0, stream>>>( - softmax_data, batch_size); - cudaMemsetAsync(loss_data, 0, batch_size * sizeof(T), stream); + softmax_data, grid_dim); + cudaMemsetAsync(loss_data, 0, grid_dim * sizeof(T), stream); break; default: PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op"); @@ -332,23 +372,23 @@ template static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data, - int batch_size, int feature_size, + int n, int d, int axis_dim, cudaStream_t stream) { constexpr int kMaxBlockDim = 512; - int block_dim = feature_size >= kMaxBlockDim + int block_dim = axis_dim >= kMaxBlockDim ? kMaxBlockDim - : (1 << static_cast(std::log2(feature_size))); - -#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ - case BlockDim: \ - RowReductionForMax<<>>( \ - logits_data, loss_data, feature_size); \ - RowReductionForDiffMaxSum<<>>( \ - logits_data, loss_data, softmax_data, feature_size); \ - RowReductionForSoftmaxAndCrossEntropy< \ - T, BlockDim><<>>( \ - labels_data, loss_data, softmax_data, feature_size); \ + : (1 << static_cast(std::log2(axis_dim))); + int grid_dim = n * d / axis_dim; + +#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ + case BlockDim: \ + RowReductionForMax<<>>( \ + logits_data, loss_data, d, axis_dim); \ + RowReductionForDiffMaxSum<<>>( \ + logits_data, loss_data, softmax_data, d, axis_dim); \ + RowReductionForSoftmaxAndCrossEntropy< \ + T, BlockDim><<>>( \ + logits_data, labels_data, loss_data, softmax_data, d, axis_dim); \ break switch (block_dim) { @@ -362,11 +402,11 @@ static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data, CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4); CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2); case 1: - SetSoftmaxToOneWhenFeatureSizeIsOne<<<(batch_size + kMaxBlockDim - 1) / + SetSoftmaxToOneWhenFeatureSizeIsOne<<<(grid_dim + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0, stream>>>( - softmax_data, batch_size); - cudaMemsetAsync(loss_data, 0, batch_size * sizeof(T), stream); + softmax_data, n); + cudaMemsetAsync(loss_data, 0, grid_dim * sizeof(T), stream); break; default: PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op"); @@ -385,51 +425,46 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { const Tensor* logits = context.Input("Logits"); const Tensor* labels = context.Input("Label"); Tensor* softmax = context.Output("Softmax"); - Tensor* loss = context.Output("Loss"); + + const int rank = logits->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis_dim = logits->dims()[axis]; + + const int n = SizeToAxis(axis, logits->dims()); + const int d = SizeFromAxis(axis, logits->dims()); + auto* softmax_data = softmax->mutable_data(context.GetPlace()); auto* loss_data = loss->mutable_data(context.GetPlace()); auto soft_label = context.Attr("soft_label"); auto ignore_index = context.Attr("ignore_index"); - int rank = logits->dims().size(); if (soft_label) { - int batch_size = 1; - for (int i = 0; i < rank - 1; ++i) { - batch_size *= logits->dims()[i]; - } - - int feature_size = logits->dims()[rank - 1]; auto* logits_data = logits->data(); auto* labels_data = labels->data(); SoftmaxWithCrossEntropyFusedKernel( - logits_data, labels_data, softmax_data, loss_data, batch_size, - feature_size, context.cuda_device_context().stream()); + logits_data, labels_data, softmax_data, loss_data, n, d, axis_dim, + context.cuda_device_context().stream()); } else { if (!context.Attr("numeric_stable_mode")) { - // reshape to 2d - Tensor logits_2d = framework::ReshapeToMatrix(*logits, rank - 1); - Tensor softmax_2d = framework::ReshapeToMatrix(*softmax, rank - 1); - Tensor loss_2d = framework::ReshapeToMatrix(*loss, rank - 1); - Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1); - + // CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim + Tensor logits_2d, softmax_2d, labels_2d, loss_2d; + logits_2d.ShareDataWith(*logits).Resize({n, d}); + softmax_2d.ShareDataWith(*softmax).Resize({n, d}); + labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); + loss_2d.ShareDataWith(*loss).Resize({n, 1}); math::SoftmaxCUDNNFunctor()(context.cuda_device_context(), &logits_2d, &softmax_2d); math::CrossEntropyFunctor()( context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d, - false, ignore_index); + false, ignore_index, axis_dim); } else { - int batch_size = 1; - for (int i = 0; i < rank - 1; ++i) { - batch_size *= logits->dims()[i]; - } - int feature_size = logits->dims()[rank - 1]; auto* logits_data = logits->data(); auto* labels_data = labels->data(); HardLabelSoftmaxWithCrossEntropy( context.cuda_device_context(), logits_data, labels_data, loss_data, - softmax_data, batch_size, feature_size, ignore_index); + softmax_data, n, d, axis_dim, ignore_index); } } } @@ -453,30 +488,31 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { } T* logit_grad_data = logit_grad->data(); - int rank = logit_grad->dims().size(); - int batch_size = 1; - for (int i = 0; i < rank - 1; ++i) { - batch_size *= logit_grad->dims()[i]; - } + const int rank = logit_grad->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis_dim = logit_grad->dims()[axis]; + + const int n = SizeToAxis(axis, logit_grad->dims()); + const int d = SizeFromAxis(axis, logit_grad->dims()); + const int remain = d / axis_dim; - const int class_num = logit_grad->dims()[rank - 1]; int block = 512; auto stream = context.cuda_device_context().stream(); auto ignore_index = context.Attr("ignore_index"); if (context.Attr("soft_label")) { - int grid = (batch_size * class_num + block - 1) / block; + int grid = (n * d + block - 1) / block; const T* label_data = labels->data(); SoftCrossEntropyGradientKernel<<>>( - logit_grad_data, loss_grad_data, label_data, batch_size, class_num); + logit_grad_data, loss_grad_data, label_data, n, d, remain); } else { - int grid = (batch_size + block - 1) / block; + int grid = (n * remain + block - 1) / block; const int64_t* label_data = labels->data(); CrossEntropyGrad<<>>( - logit_grad_data, label_data, batch_size, class_num, ignore_index); - int num = batch_size * class_num; + logit_grad_data, label_data, n, d, remain, ignore_index); + int num = n * d; grid = (num + block - 1) / block; Scale<<>>(logit_grad_data, loss_grad_data, num, - class_num); + d, remain); } } }; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/softmax_with_cross_entropy_op.h index 7ef7c4f7424f2690f95fae0a70c1bdc6eb387502..4533295a8d8c0d7f36522143adc2820020179ace 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/softmax.h" +#include "paddle/fluid/operators/softmax_op.h" namespace paddle { namespace operators { @@ -36,26 +37,30 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { const Tensor* labels = context.Input("Label"); Tensor* softmax = context.Output("Softmax"); Tensor* loss = context.Output("Loss"); + const bool soft_label = context.Attr("soft_label"); + + const int rank = logits->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis_dim = logits->dims()[axis]; softmax->mutable_data(context.GetPlace()); loss->mutable_data(context.GetPlace()); - // reshape to 2D tensor - int rank = logits->dims().size(); - Tensor logits_2d = framework::ReshapeToMatrix(*logits, rank - 1); - Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1); - Tensor loss_2d = framework::ReshapeToMatrix(*loss, rank - 1); - Tensor softmax_2d = framework::ReshapeToMatrix(*softmax, rank - 1); - - int axis_dim = logits->dims()[rank - 1]; + const int n = SizeToAxis(axis, logits->dims()); + const int d = SizeFromAxis(axis, logits->dims()); + Tensor logits_2d, softmax_2d, labels_2d, loss_2d; + logits_2d.ShareDataWith(*logits).Resize({n, d}); + softmax_2d.ShareDataWith(*softmax).Resize({n, d}); + labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); + loss_2d.ShareDataWith(*loss).Resize({n, d / axis_dim}); auto& dev_ctx = context.template device_context(); math::SoftmaxFunctor()( dev_ctx, axis_dim, &logits_2d, &softmax_2d); math::CrossEntropyFunctor()( - dev_ctx, &loss_2d, &softmax_2d, &labels_2d, - context.Attr("soft_label"), context.Attr("ignore_index")); + dev_ctx, &loss_2d, &softmax_2d, &labels_2d, soft_label, + context.Attr("ignore_index"), axis_dim); } }; @@ -75,34 +80,43 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { context.device_context(), logit_grad); } - int rank = logit_grad->dims().size(); - const int class_num = logit_grad->dims()[rank - 1]; - // reshape to 2d - Tensor logit_grad_2d = framework::ReshapeToMatrix(*logit_grad, rank - 1); - Tensor out_grad_2d = framework::ReshapeToMatrix(*out_grad, rank - 1); + const bool soft_label = context.Attr("soft_label"); + + const int rank = logit_grad->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + int axis_dim = logit_grad->dims()[axis]; + + const int n = SizeToAxis(axis, logit_grad->dims()); + const int d = SizeFromAxis(axis, logit_grad->dims()); + Tensor logit_grad_2d, labels_2d, out_grad_2d; + logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d}); + labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); + out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim}); auto out_grad_mat = EigenMatrix::From(out_grad_2d); auto logit_grad_mat = EigenMatrix::From(logit_grad_2d); auto& place = *context.template device_context() .eigen_device(); - if (context.Attr("soft_label")) { - Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1); + if (soft_label) { auto lbl_mat = EigenMatrix::From(labels_2d); logit_grad_mat.device(place) = - out_grad_mat.broadcast(Eigen::DSizes(1, class_num)) * + out_grad_mat.broadcast(Eigen::DSizes(1, axis_dim)) * (logit_grad_mat - lbl_mat); } else { logit_grad_mat.device(place) = logit_grad_mat * - out_grad_mat.broadcast(Eigen::DSizes(1, class_num)); - - const int batch_size = logit_grad_2d.dims()[0]; + out_grad_mat.broadcast(Eigen::DSizes(1, axis_dim)); const int64_t* label_data = labels->data(); T* logit_grad_data = logit_grad->data(); const T* out_grad_data = out_grad->data(); - for (int i = 0; i < batch_size; ++i) { - logit_grad_data[i * class_num + label_data[i]] -= out_grad_data[i]; + const int remain = d / axis_dim; + for (int i = 0; i < n; ++i) { + for (int j = 0; j < remain; j++) { + int idx = i * remain + j; + logit_grad_data[i * d + label_data[idx] * remain + j] -= + out_grad_data[idx]; + } } } } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 428692cc63a9a6a75891b74b6581b4fc34388e86..0bf06f557c6925935fabd875fccbe6463fed0927 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -6099,22 +6099,24 @@ def softmax_with_cross_entropy(logits, soft_label=False, ignore_index=kIgnoreIndex, numeric_stable_mode=True, - return_softmax=False): + return_softmax=False, + axis=-1): """ **Softmax With Cross Entropy Operator.** Cross entropy loss with softmax is used as the output layer extensively. This - operator computes the softmax normalized values for each row of the input - tensor, after which cross-entropy loss is computed. This provides a more - numerically stable gradient. + operator computes the softmax normalized values for dimension :attr:`axis` of + the input tensor, after which cross-entropy loss is computed. This provides + a more numerically stable gradient. Because this operator performs a softmax on logits internally, it expects unscaled logits. This operator should not be used with the output of softmax operator since that would produce incorrect results. - When the attribute soft_label is set false, this operators expects mutually - exclusive hard labels, each sample in a batch is in exactly one class with a - probability of 1.0. Each sample in the batch will have a single label. + When the attribute :attr:`soft_label` is set :attr:`False`, this operators + expects mutually exclusive hard labels, each sample in a batch is in exactly + one class with a probability of 1.0. Each sample in the batch will have a + single label. The equation is as follows: @@ -6133,7 +6135,8 @@ def softmax_with_cross_entropy(logits, \\left(\\text{logit}_i - \\log\\left(\\sum_{i=0}^{K} \\exp(\\text{logit}_i)\\right)\\right), j = 1,...,K - 3) If numeric_stable_mode is True, softmax is calculated first by: + 3) If :attr:`numeric_stable_mode` is :attr:`True`, softmax is calculated + first by: .. math:: @@ -6146,32 +6149,39 @@ def softmax_with_cross_entropy(logits, and then cross entropy loss is calculated by softmax and label. Args: - logits (Variable): The unscaled log probabilities, which is a 2-D tensor - with shape [N x K]. N is the batch_size, and K is the class number. - label (Variable): The ground truth which is a 2-D tensor. If soft_label - is set to false, Label is a Tensor with shape [N x 1]. If - soft_label is set to true, Label is a Tensor with + logits (Variable): The input tensor of unscaled log probabilities. + label (Variable): The ground truth tensor. If :attr:`soft_label` + is set to :attr:`True`, Label is a Tensor in the + same shape with :attr:`logits`. If :attr:`soft_label` is set to + :attr:`True`, Label is a Tensor in the same shape with + :attr:`logits` expect shape in dimension :attr:`axis` as 1. soft_label (bool): A flag to indicate whether to interpretate the given - labels as soft labels. By default, `soft_label` is set to False. + labels as soft labels. Default False. ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. Only valid - if soft_label is set to False. Default: kIgnoreIndex + if :attr:`soft_label` is set to :attr:`False`. + Default: kIgnoreIndex numeric_stable_mode (bool): A flag to indicate whether to use a more numerically stable algorithm. Only valid - when soft_label is False and GPU is used. - When soft_label is True or CPU is used, - the algorithm is always numerically stable. + when :attr:`soft_label` is :attr:`False` + and GPU is used. When :attr:`soft_label` + is :attr:`True` or CPU is used, the + algorithm is always numerically stable. Note that the speed may be slower when use stable algorithm. Default: True return_softmax (bool): A flag indicating whether to return the softmax along with the cross entropy loss. Default: False + axis (int): The index of dimension to perform softmax calculations. It + should be in range :math:`[-1, rank - 1]`, while :math:`rank` + is the rank of input :attr:`logits`. Default: -1. Returns: Variable or Tuple of two Variables: Return the cross entropy loss if \ `return_softmax` is False, otherwise the tuple \ - (loss, softmax), where the cross entropy loss is \ - a 2-D tensor with shape [N x 1], and softmax is a \ - 2-D tensor with shape [N x K]. + (loss, softmax), softmax is in the same shape \ + with input logits and cross entropy loss is in \ + the same shape with input logits except shape \ + in dimension :attr:`axis` as 1. Examples: .. code-block:: python @@ -6194,7 +6204,8 @@ def softmax_with_cross_entropy(logits, attrs={ 'soft_label': soft_label, 'ignore_index': ignore_index, - 'numeric_stable_mode': numeric_stable_mode + 'numeric_stable_mode': numeric_stable_mode, + 'axis': axis }) if return_softmax: diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 46f025c33bc9cc3a7197a4e87475b4d9c132b4ed..2474125835fbf54316e26d272eec940fc380a448 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1221,10 +1221,25 @@ class TestBook(LayerTest): y = self._get_data(name='label', shape=[1], dtype='int64') loss, softmax = layers.softmax_with_cross_entropy( x, y, return_softmax=True) - return (loss) - return (softmax) + self.assertIsNotNone(loss) + self.assertIsNotNone(softmax) + loss = layers.softmax_with_cross_entropy(x, y) - return (loss) + self.assertIsNotNone(loss) + + x1 = self._get_data(name='x1', shape=[16, 32, 64], dtype='float32') + y1 = self._get_data(name='label1', shape=[1, 32, 64], dtype='int64') + y2 = self._get_data(name='label2', shape=[16, 1, 64], dtype='int64') + y3 = self._get_data(name='label3', shape=[16, 32, 1], dtype='int64') + loss1 = layers.softmax_with_cross_entropy(x1, y1, axis=1) + loss2 = layers.softmax_with_cross_entropy(x1, y2, axis=2) + loss3 = layers.softmax_with_cross_entropy(x1, y3, axis=3) + loss4 = layers.softmax_with_cross_entropy(x1, y3, axis=-1) + self.assertIsNotNone(loss1) + self.assertIsNotNone(loss2) + self.assertIsNotNone(loss3) + self.assertIsNotNone(loss4) + return (loss4) def make_smooth_l1(self): with program_guard(fluid.default_main_program(), diff --git a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py index b06b52f75d21a720e2473feba6ba2e1dccc2db89..d37731146d9c431bb6a0c333149ac62a0c4efd3b 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py @@ -21,37 +21,70 @@ from op_test import OpTest from test_softmax_op import stable_softmax +def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1): + if soft_label: + return (-label * np.log(softmax)).sum(axis=axis, keepdims=True) + + shape = softmax.shape + axis %= len(shape) + n = int(np.prod(shape[:axis])) + axis_dim = shape[axis] + remain = int(np.prod(shape[axis + 1:])) + softmax_reshape = softmax.reshape((n, axis_dim, remain)) + label_reshape = label.reshape((n, 1, remain)) + result = np.zeros_like(label_reshape, dtype=softmax.dtype) + for i in range(n): + for j in range(remain): + lbl = label_reshape[i, 0, j] + if lbl != ignore_index: + result[i, 0, j] -= np.log(softmax_reshape[i, lbl, j]) + return result.reshape(label.shape) + + class TestSoftmaxWithCrossEntropyOp(OpTest): """ Test softmax with cross entropy operator with discreate one-hot labels. """ def initParams(self): + self.op_type = "softmax_with_cross_entropy" self.numeric_stable_mode = False + self.soft_label = False self.dtype = np.float64 + self.axis = -1 + self.ignore_index = -1 + self.shape = [41, 37] def setUp(self): self.initParams() - self.op_type = "softmax_with_cross_entropy" - batch_size = 41 - class_num = 37 - logits = np.random.uniform(0.1, 1.0, - [batch_size, class_num]).astype(self.dtype) - softmax = np.apply_along_axis(stable_softmax, 1, logits) - labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int64") + logits = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype) + softmax = np.apply_along_axis(stable_softmax, self.axis, logits) + + if self.soft_label: + labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype) + labels /= np.sum(labels, axis=self.axis, keepdims=True) + else: + axis_dim = self.shape[self.axis] + self.shape[self.axis] = 1 + labels = np.random.randint(0, axis_dim, self.shape, dtype="int64") - cross_entropy = np.asmatrix( - [[-np.log(softmax[i][labels[i][0]])] - for i in range(softmax.shape[0])], - dtype=self.dtype) + loss = cross_entropy(softmax, labels, self.soft_label, self.axis, + self.ignore_index) self.inputs = {"Logits": logits, "Label": labels} self.outputs = { "Softmax": softmax.astype(self.dtype), - "Loss": cross_entropy.astype(self.dtype) + "Loss": loss.astype(self.dtype) } - self.attrs = {"numeric_stable_mode": self.numeric_stable_mode} + self.attrs = { + "numeric_stable_mode": self.numeric_stable_mode, + "soft_label": self.soft_label, + } + if self.ignore_index >= 0: + self.attrs['ignore_index'] = self.ignore_index + if self.axis != -1: + self.attrs['axis'] = self.axis def test_check_output(self): self.check_output() @@ -62,30 +95,38 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp): def initParams(self): + self.op_type = "softmax_with_cross_entropy" self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.axis = -1 + self.ignore_index = -1 + self.dtype = np.float64 class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp): def initParams(self): + self.op_type = "softmax_with_cross_entropy" self.numeric_stable_mode = False + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.axis = -1 + self.ignore_index = -1 self.dtype = np.float16 def setUp(self): self.initParams() self.op_type = "softmax_with_cross_entropy" - batch_size = 41 - class_num = 37 # NOTE: numpy float16 have very low accuracy, use float32 for numpy check. - logits = np.random.uniform(0.1, 1.0, - [batch_size, class_num]).astype(np.float32) - softmax = np.apply_along_axis(stable_softmax, 1, logits) - labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int64") + logits = np.random.uniform(0.1, 1.0, self.shape).astype(np.float32) + softmax = np.apply_along_axis(stable_softmax, self.axis, logits) - cross_entropy = np.asmatrix( - [[-np.log(softmax[i][labels[i][0]])] - for i in range(softmax.shape[0])], - dtype=np.float32) + axis_dim = self.shape[self.axis] + self.shape[self.axis] = 1 + labels = np.random.randint(0, axis_dim, self.shape, dtype="int64") + + loss = cross_entropy(softmax, labels, self.soft_label, self.axis) self.inputs = { "Logits": logits.astype(self.dtype).view(np.uint16), @@ -93,9 +134,14 @@ class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp): } self.outputs = { "Softmax": softmax.astype(self.dtype), - "Loss": cross_entropy.astype(self.dtype) + "Loss": loss.astype(self.dtype) + } + self.attrs = { + "numeric_stable_mode": self.numeric_stable_mode, + "soft_label": self.soft_label, } - self.attrs = {"numeric_stable_mode": self.numeric_stable_mode} + if self.axis != -1: + self.attrs['axis'] = self.axis def test_check_output(self): self.check_output(atol=1e-2) @@ -107,39 +153,31 @@ class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp): class TestSoftmaxWithCrossEntropyOpNoCudnnFp16( TestSoftmaxWithCrossEntropyOpFp16): def initParams(self): + self.op_type = "softmax_with_cross_entropy" self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.axis = -1 + self.ignore_index = -1 self.dtype = np.float16 def test_check_grad(self): self.check_grad(["Logits"], "Loss", max_relative_error=0.1) -class TestSoftmaxWithCrossEntropyOp2(OpTest): +class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp): """ Test softmax with cross entropy operator with soft labels. """ - def setUp(self): + def initParams(self): self.op_type = "softmax_with_cross_entropy" - batch_size = 41 - class_num = 37 - - logits = np.random.uniform(0.1, 1.0, - [batch_size, class_num]).astype("float64") - softmax = np.apply_along_axis(stable_softmax, 1, logits) - labels = np.random.uniform(0.1, 1.0, - [batch_size, class_num]).astype("float64") - labels /= np.sum(labels, axis=1, keepdims=True) - - cross_entropy = (-labels * np.log(softmax)).sum( - axis=1, keepdims=True).astype("float64") - - self.inputs = {"Logits": logits, "Label": labels} - self.outputs = { - "Softmax": softmax.astype("float64"), - "Loss": cross_entropy.astype("float64") - } - self.attrs = {"soft_label": True} + self.numeric_stable_mode = True + self.soft_label = True + self.dtype = np.float64 + self.axis = -1 + self.ignore_index = -1 + self.shape = [41, 37] def test_check_output(self): self.check_output() @@ -148,190 +186,226 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest): self.check_grad(["Logits"], "Loss") -class TestSoftmaxWithCrossEntropyOp3(OpTest): +class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp): """ Test softmax with cross entropy operator with ignore_index. """ def initParams(self): + self.op_type = "softmax_with_cross_entropy" self.numeric_stable_mode = False + self.soft_label = False + self.shape = [41, 37] + self.ignore_index = 5 + self.axis = -1 + self.dtype = np.float64 - def setUp(self): - self.initParams() + +class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3): + def initParams(self): self.op_type = "softmax_with_cross_entropy" - batch_size = 41 - class_num = 37 - - logits = np.random.uniform(0.1, 1.0, - [batch_size, class_num]).astype("float64") - softmax = np.apply_along_axis(stable_softmax, 1, logits) - labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int64") - ignore_index = 7 - cross_entropy = np.asmatrix( - [[-np.log(softmax[i][labels[i][0]])] - if labels[i] != ignore_index else [0] - for i in range(softmax.shape[0])], - dtype="float64") + self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.ignore_index = 4 + self.axis = -1 + self.dtype = np.float64 - self.inputs = {"Logits": logits, "Label": labels} - self.outputs = { - "Softmax": softmax.astype("float64"), - "Loss": cross_entropy.astype("float64") - } - self.attrs = { - "ignore_index": ignore_index, - "numeric_stable_mode": self.numeric_stable_mode - } - def test_check_output(self): - self.check_output() +class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp): + """ + Test softmax with cross entropy operator with discreate one-hot labels. + Given axis != -1 + """ - def test_check_grad(self): - self.check_grad(["Logits"], "Loss") + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.dtype = np.float64 + self.axis = 0 + self.ignore_index = -1 + self.shape = [3, 5, 7, 11] -class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3): +class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp): + """ + Test softmax with cross entropy operator with discreate one-hot labels. + Given axis != -1 + """ + def initParams(self): + self.op_type = "softmax_with_cross_entropy" self.numeric_stable_mode = True + self.soft_label = False + self.dtype = np.float64 + self.axis = 1 + self.ignore_index = -1 + self.shape = [3, 5, 7, 11] -class TestSoftmaxWithCrossEntropyOp5(OpTest): +class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp): """ - Test softmax with cross entropy operator with ignore_index. + Test softmax with cross entropy operator with discreate one-hot labels. + Given axis != -1 """ def initParams(self): - self.numeric_stable_mode = False + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.dtype = np.float64 + self.axis = 2 + self.ignore_index = -1 + self.shape = [3, 5, 7, 11] - def setUp(self): - self.initParams() + +class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp): + """ + Test softmax with cross entropy operator with discreate one-hot labels. + Given axis != -1 + """ + + def initParams(self): self.op_type = "softmax_with_cross_entropy" - batch_size = [6, 10] - class_num = 47 - - logits = np.random.uniform( - 0.1, 1.0, tuple(batch_size + [class_num])).astype("float64") - softmax = np.apply_along_axis(stable_softmax, 2, logits) - labels = np.random.randint( - 0, class_num, tuple(batch_size + [1]), dtype="int64") - ignore_index = 7 - - softmax_2d = np.reshape(softmax, [-1, class_num]) - labels_2d = np.reshape(labels, [-1, 1]) - cross_entropy = np.asmatrix( - [[-np.log(softmax_2d[i][labels_2d[i][0]])] - if labels_2d[i] != ignore_index else [0] - for i in range(softmax_2d.shape[0])], - dtype="float64") - - cross_entropy = np.reshape(cross_entropy, batch_size) - - output_shape = tuple(batch_size + [1]) - output_res = cross_entropy.astype("float64") - output_res = np.expand_dims(output_res, axis=2) - self.inputs = {"Logits": logits, "Label": labels} - self.outputs = { - "Softmax": softmax.astype("float64"), - "Loss": output_res, - } - self.attrs = { - "ignore_index": ignore_index, - "numeric_stable_mode": self.numeric_stable_mode - } + self.numeric_stable_mode = True + self.soft_label = False + self.dtype = np.float64 + self.axis = 3 + self.ignore_index = -1 + self.shape = [3, 5, 7, 11] - def test_check_output(self): - self.check_output() - def test_check_grad(self): - self.check_grad(["Logits"], "Loss") +class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1( + TestSoftmaxWithCrossEntropyOpNoCudnnFp16): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.axis = 0 + self.ignore_index = -1 + self.dtype = np.float16 -class TestSoftmaxWithCrossEntropyOp5NoCudnn(TestSoftmaxWithCrossEntropyOp5): +class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2( + TestSoftmaxWithCrossEntropyOpNoCudnnFp16): def initParams(self): + self.op_type = "softmax_with_cross_entropy" self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.axis = 1 + self.ignore_index = -1 + self.dtype = np.float16 -class TestSoftmaxWithCrossEntropyOp6(OpTest): - """ - Test softmax with cross entropy operator with soft labels. - """ +class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3( + TestSoftmaxWithCrossEntropyOpNoCudnnFp16): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.axis = 2 + self.ignore_index = -1 + self.dtype = np.float16 - def setUp(self): + +class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1( + TestSoftmaxWithCrossEntropyOp2): + def initParams(self): self.op_type = "softmax_with_cross_entropy" - batch_size = [6, 10] - class_num = 37 + self.numeric_stable_mode = True + self.soft_label = True + self.shape = [3, 5, 7, 11] + self.axis = 0 + self.ignore_index = -1 + self.dtype = np.float64 - logits = np.random.uniform( - 0.1, 1.0, tuple(batch_size + [class_num])).astype("float64") - softmax = np.apply_along_axis(stable_softmax, 2, logits) - labels = np.random.uniform( - 0.1, 1.0, tuple(batch_size + [class_num])).astype("float64") - labels /= np.sum(labels, axis=2, keepdims=True) - cross_entropy = (-labels * np.log(softmax)).sum( - axis=2, keepdims=True).astype("float64") +class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2( + TestSoftmaxWithCrossEntropyOp2): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = True + self.shape = [3, 5, 7, 11] + self.axis = 1 + self.ignore_index = -1 + self.dtype = np.float64 - self.inputs = {"Logits": logits, "Label": labels} - self.outputs = { - "Softmax": softmax.astype("float64"), - "Loss": cross_entropy.astype("float64") - } - self.attrs = {"soft_label": True} - def test_check_output(self): - self.check_output() +class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3( + TestSoftmaxWithCrossEntropyOp2): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = True + self.shape = [3, 5, 7, 11] + self.axis = 2 + self.ignore_index = -1 + self.dtype = np.float64 - def test_check_grad(self): - self.check_grad(["Logits"], "Loss") + +class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4( + TestSoftmaxWithCrossEntropyOp2): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = True + self.shape = [3, 5, 7, 11] + self.axis = 3 + self.ignore_index = -1 + self.dtype = np.float64 -class TestSoftmaxWithCrossEntropyOpFp16_2(TestSoftmaxWithCrossEntropyOp): +class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1( + TestSoftmaxWithCrossEntropyOp3): def initParams(self): - self.numeric_stable_mode = False - self.dtype = np.float16 + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.ignore_index = 1 + self.axis = 0 + self.dtype = np.float64 - def setUp(self): - self.initParams() + +class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2( + TestSoftmaxWithCrossEntropyOp3): + def initParams(self): self.op_type = "softmax_with_cross_entropy" - batch_size = [64, 10] - class_num = 37 + self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.ignore_index = 0 + self.axis = 1 + self.dtype = np.float64 - # NOTE: numpy float16 have very low accuracy, use float32 for numpy check. - logits = np.random.uniform( - 0.1, 1.0, tuple(batch_size + [class_num])).astype(np.float32) - softmax = np.apply_along_axis(stable_softmax, 2, logits) - labels = np.random.randint( - 0, class_num, tuple(batch_size + [1]), dtype="int64") - - softmax_2d = np.reshape(softmax, [-1, class_num]) - labels_2d = np.reshape(labels, [-1, 1]) - - cross_entropy = np.asmatrix( - [[-np.log(softmax_2d[i][labels_2d[i][0]])] - for i in range(softmax_2d.shape[0])], - dtype=np.float32) - - cross_entropy = np.reshape(cross_entropy, batch_size) - output_shape = tuple(batch_size + [1]) - output_res = cross_entropy.astype(self.dtype) - output_res = np.expand_dims(output_res, axis=2) - self.inputs = {"Logits": logits, "Label": labels} - self.inputs = { - "Logits": logits.astype(self.dtype).view(np.uint16), - "Label": labels - } - self.outputs = { - "Softmax": softmax.astype(self.dtype), - "Loss": output_res, - } - self.attrs = {"numeric_stable_mode": self.numeric_stable_mode} +class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3( + TestSoftmaxWithCrossEntropyOp3): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.ignore_index = 3 + self.axis = 2 + self.dtype = np.float64 - def test_check_output(self): - self.check_output(atol=1e-2) - def test_check_grad(self): - self.check_grad(["Logits"], "Loss", max_relative_error=0.1) +class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4( + TestSoftmaxWithCrossEntropyOp3): + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.ignore_index = 3 + self.axis = 3 + self.dtype = np.float64 if __name__ == "__main__":