diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index ae5f30e431aba4cae04b0fb35f00bce84f18de33..842fde1ec5f16aeec28a790f1869eaed64e3516c 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -100,7 +100,7 @@ paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_att paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.crf_decoding ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None) -paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label'], varargs=None, keywords=None, defaults=(False,)) +paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100)) paddle.fluid.layers.square_error_cost ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None)) @@ -142,7 +142,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)) -paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label'], varargs=None, keywords=None, defaults=(False,)) +paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100)) paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)) diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 578ab63bc380ee62d76e34b7cf3cbd590bfa2eda..66f19fe7ecfa51b2ce917f0c5fcb6d486f1a7307 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -138,6 +138,11 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false), a flag indicating whether to " "interpretate the given labels as soft labels.") .SetDefault(false); + AddAttr("ignore_index", + "(int, default -100), Specifies a target value that is" + "ignored and does not contribute to the input gradient." + "Only valid if soft_label is set to False") + .SetDefault(-100); AddComment(R"DOC( CrossEntropy Operator. diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index 36b58d80144d242277f6fc970a3a61a6721d4b50..03974a7fc511b1e1cb5b0eca532b260fdf9bf964 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -40,7 +40,7 @@ class CrossEntropyOpKernel : public framework::OpKernel { math::CrossEntropyFunctor()( ctx.template device_context(), &y_2d, &x_2d, &labels_2d, - ctx.Attr("soft_label")); + ctx.Attr("soft_label"), ctx.Attr("ignore_index")); } }; @@ -74,16 +74,22 @@ class XeGradFunctor { const T* dy, // NOLINT const T* x, // NOLINT const int64_t* label, // NOLINT - size_t num_classes) - : dx_(dx), dy_(dy), x_(x), label_(label), num_classes_(num_classes) {} + size_t num_classes, size_t ignore_index) + : dx_(dx), + dy_(dy), + x_(x), + label_(label), + num_classes_(num_classes), + ignore_index_(ignore_index) {} HOSTDEVICE void operator()(size_t sample_id) { auto x_is_true_offset = sample_id * num_classes_ + label_[sample_id]; for (size_t x_offset = sample_id * num_classes_; x_offset < (sample_id + 1) * num_classes_; ++x_offset) { - dx_[x_offset] = x_offset != x_is_true_offset - ? static_cast(0) - : -dy_[sample_id] / x_[x_offset]; + dx_[x_offset] = + (x_offset != x_is_true_offset || label_[sample_id] == ignore_index_) + ? static_cast(0) + : -dy_[sample_id] / x_[x_offset]; } } @@ -93,6 +99,7 @@ class XeGradFunctor { const T* x_; const int64_t* label_; size_t num_classes_; + size_t ignore_index_; }; template @@ -109,6 +116,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { // unnecessary to convert tensors to 2-D views. int rank = x->dims().size(); int64_t class_num = x->dims()[rank - 1]; + int64_t ignore_index = ctx.Attr("ignore_index"); if (ctx.Attr("soft_label")) { XeSoftlabelGradFunctor functor(dx_data, dy->data(), x->data(), label->data(), @@ -118,9 +126,9 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { static_cast(dx->numel())); for_range(functor); } else { - XeGradFunctor functor(dx_data, dy->data(), x->data(), - label->data(), - static_cast(class_num)); + XeGradFunctor functor( + dx_data, dy->data(), x->data(), label->data(), + static_cast(class_num), static_cast(ignore_index)); platform::ForRange for_range( ctx.template device_context(), static_cast(dy->numel())); diff --git a/paddle/fluid/operators/math/cross_entropy.cc b/paddle/fluid/operators/math/cross_entropy.cc index caff35e03ae3a144f799d982c859ded62cb3e93d..18bf1a66f6d9903f32048574dc93faf7e98953ac 100644 --- a/paddle/fluid/operators/math/cross_entropy.cc +++ b/paddle/fluid/operators/math/cross_entropy.cc @@ -28,7 +28,8 @@ class CrossEntropyFunctor { public: void operator()(const platform::CPUDeviceContext& ctx, framework::Tensor* out, const framework::Tensor* prob, - const framework::Tensor* labels, const bool softLabel) { + const framework::Tensor* labels, const bool softLabel, + const int ignore_index) { const int batch_size = prob->dims()[0]; if (softLabel) { auto in = EigenMatrix::From(*prob); @@ -49,8 +50,12 @@ class CrossEntropyFunctor { int lbl = label_data[i]; PADDLE_ENFORCE_GE(lbl, 0); PADDLE_ENFORCE_LT(lbl, class_num); + PADDLE_ENFORCE((lbl >= 0 && lbl < class_num) || lbl == ignore_index); int index = i * class_num + lbl; - loss_data[i] = -math::TolerableValue()(std::log(prob_data[index])); + loss_data[i] = + lbl == ignore_index + ? 0 + : -math::TolerableValue()(std::log(prob_data[index])); } } } diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index 0de58d5fddd84d33f708c4c73e5a19dc2fe8a86b..c92341ea55ea21773acba33665e267b2f1c25fe3 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -23,11 +23,14 @@ namespace math { namespace { template __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, - const int N, const int D) { + const int N, const int D, + const int ignore_index) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) { - PADDLE_ASSERT(label[i] >= 0 && label[i] < D); - Y[i] = -math::TolerableValue()(log(X[i * D + label[i]])); + PADDLE_ASSERT(label[i] >= 0 && label[i] < D || label[i] == ignore_index); + Y[i] = ignore_index == label[i] + ? 0 + : -math::TolerableValue()(log(X[i * D + label[i]])); } } @@ -57,7 +60,8 @@ class CrossEntropyFunctor { public: void operator()(const platform::CUDADeviceContext& ctx, framework::Tensor* out, const framework::Tensor* prob, - const framework::Tensor* labels, bool softLabel) { + const framework::Tensor* labels, bool softLabel, + const int ignore_index) { const T* prob_data = prob->data(); T* loss_data = out->mutable_data(ctx.GetPlace()); @@ -77,7 +81,8 @@ class CrossEntropyFunctor { int block = 512; int grid = (batch_size + block - 1) / block; CrossEntropyKernel<<>>( - loss_data, prob_data, label_data, batch_size, class_num); + loss_data, prob_data, label_data, batch_size, class_num, + ignore_index); } } }; diff --git a/paddle/fluid/operators/math/cross_entropy.h b/paddle/fluid/operators/math/cross_entropy.h index adc5b3fe47cd3bf524eb56747b6bd51e345a2eb6..e8aeb5d0575ac0f6b8761e97896df73578e8a103 100644 --- a/paddle/fluid/operators/math/cross_entropy.h +++ b/paddle/fluid/operators/math/cross_entropy.h @@ -38,7 +38,8 @@ class CrossEntropyFunctor { public: void operator()(const DeviceContext& context, framework::Tensor* out, const framework::Tensor* prob, - const framework::Tensor* labels, const bool softLabel); + const framework::Tensor* labels, const bool softLabel, + const int ignore_index); }; } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index 53cb716a979229c99fcbdc12f1aeab4e21b320f3..1a9324ec862fc3dd7ce669c5fed94527cac22b8f 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -44,6 +44,12 @@ class SoftmaxWithCrossEntropyOpMaker "(bool, default: false), A flag to indicate whether to interpretate " "the given labels as soft labels.") .SetDefault(false); + AddAttr( + "ignore_index", + "(int, default -100), Specifies a target value that is ignored and" + "does not contribute to the input gradient. Only valid if soft_label" + "is set to False") + .SetDefault(-100); AddComment(R"DOC( Softmax With Cross Entropy Operator. diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index a559b01ed32a48e3befb37c2ae8935b4f3a4acb0..148faec4af50c4fe3a8e9d1f22e0da70c8ddcb44 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -26,7 +26,8 @@ using Tensor = framework::Tensor; namespace { template __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, - const int batch_size, const int class_num) { + const int batch_size, const int class_num, + const int ignore_index) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batch_size; i += blockDim.x * gridDim.x) { int idx = i * class_num + labels[i]; @@ -260,6 +261,7 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { auto* loss_data = loss->mutable_data(context.GetPlace()); auto soft_label = context.Attr("soft_label"); + auto ignore_index = context.Attr("ignore_index"); if (soft_label) { int batch_size = logits->dims()[0]; int feature_size = logits->dims()[1]; @@ -272,7 +274,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { math::SoftmaxCUDNNFunctor()(context.cuda_device_context(), logits, softmax); math::CrossEntropyFunctor()( - context.cuda_device_context(), loss, softmax, labels, false); + context.cuda_device_context(), loss, softmax, labels, false, + ignore_index); } } }; @@ -295,7 +298,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { const int class_num = logit_grad->dims()[1]; int block = 512; auto stream = context.cuda_device_context().stream(); - + auto ignore_index = context.Attr("ignore_index"); if (context.Attr("soft_label")) { int grid = (batch_size * class_num + block - 1) / block; const T* label_data = labels->data(); @@ -305,7 +308,7 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { int grid = (batch_size + block - 1) / block; const int64_t* label_data = labels->data(); CrossEntropyGrad<<>>( - logit_grad_data, label_data, batch_size, class_num); + logit_grad_data, label_data, batch_size, class_num, ignore_index); int num = batch_size * class_num; grid = (num + block - 1) / block; Scale<<>>(logit_grad_data, loss_grad_data, num, diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/softmax_with_cross_entropy_op.h index dd6f6aca5ada7aa215d3b3444194fc53efeb7020..e9aba3b37b8cc01d4fe5de5200579d4e93f67e56 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.h @@ -45,7 +45,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { math::SoftmaxFunctor()(dev_ctx, logits, softmax); math::CrossEntropyFunctor()( - dev_ctx, loss, softmax, labels, context.Attr("soft_label")); + dev_ctx, loss, softmax, labels, context.Attr("soft_label"), + context.Attr("ignore_index")); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8408e6d2a12edacb310ed5eb543ad51585f3d82a..3ae0fac4bef5c47964f9a9cd8dd45b57e705e1f8 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -968,7 +968,7 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): return out -def cross_entropy(input, label, soft_label=False): +def cross_entropy(input, label, soft_label=False, ignore_index=-100): """ **Cross Entropy Layer** @@ -1012,7 +1012,10 @@ def cross_entropy(input, label, soft_label=False): tensor with shape [N x D]. soft_label (bool): a flag indicating whether to interpretate the given labels as soft - labels, default `False`. + labels. Default: `False`. + ignore_index (int): Specifies a target value that is ignored and does + not contribute to the input gradient. Only valid + if soft_label is set to False. Default: -100 Returns: A 2-D tensor with shape [N x 1], the cross entropy loss. @@ -1037,7 +1040,8 @@ def cross_entropy(input, label, soft_label=False): inputs={'X': [input], 'Label': [label]}, outputs={'Y': [out]}, - attrs={"soft_label": soft_label}) + attrs={"soft_label": soft_label, + "ignore_index": ignore_index}) return out @@ -4242,7 +4246,10 @@ def multiplex(inputs, index): return out -def softmax_with_cross_entropy(logits, label, soft_label=False): +def softmax_with_cross_entropy(logits, + label, + soft_label=False, + ignore_index=-100): """ **Softmax With Cross Entropy Operator.** @@ -4284,6 +4291,10 @@ def softmax_with_cross_entropy(logits, label, soft_label=False): soft_label is set to true, Label is a Tensor with soft_label (bool): A flag to indicate whether to interpretate the given labels as soft labels. By default, `soft_label` is set to False. + ignore_index (int): Specifies a target value that is ignored and does + not contribute to the input gradient. Only valid + if soft_label is set to False. Default: -100 + Returns: Variable: The cross entropy loss is a 2-D tensor with shape [N x 1]. @@ -4305,7 +4316,8 @@ def softmax_with_cross_entropy(logits, label, soft_label=False): 'Label': label}, outputs={'Softmax': softmax, 'Loss': loss}, - attrs={'soft_label': soft_label}) + attrs={'soft_label': soft_label, + 'ignore_index': ignore_index}) return loss diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py index fa367f95fc9c65dd782d53a2799cacadf74dcfd2..f22badbea0c67b210f7ac4e14e5d647f1cffa6cc 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_op.py @@ -209,5 +209,34 @@ class TestCrossEntropyOp6(OpTest): ["X"], "Y", max_relative_error=0.05, numeric_grad_delta=0.001) +class TestCrossEntropyOp7(OpTest): + """Test cross-entropy with ignore index. + """ + + def setUp(self): + self.op_type = "cross_entropy" + batch_size = 30 + class_num = 10 + ignore_index = 3 + + X = randomize_probability(batch_size, class_num, dtype='float64') + + label = np.random.randint(0, class_num, (batch_size, 1), dtype="int64") + cross_entropy = np.asmatrix( + [[-np.log(X[i][label[i][0]])] + if label[i][0] != ignore_index else [0] + for i in range(X.shape[0])], + dtype="float64") + self.inputs = {"X": X, "Label": label} + self.outputs = {"Y": cross_entropy} + self.attrs = {"soft_label": False, "ignore_index": ignore_index} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y", numeric_grad_delta=0.001) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index bc4d364c74c6cb6b8f0df59e7ede77e6271f4b96..b04346b052903959f44aa96f6fccb7d20652e854 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -556,6 +556,15 @@ class TestBook(unittest.TestCase): out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0) print(str(program)) + def test_cross_entropy(self): + program = Program() + with program_guard(program): + x = layers.data(name="x", shape=[30, 10], dtype="float32") + label = layers.data(name="label", shape=[30, 1], dtype="int32") + mode = 'channel' + out = layers.cross_entropy(x, label, False, 4) + self.assertIsNotNone(out) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py index b7e5ff6d52ad7dde3dd94b3bd660cfca383e1ada..a18941dd3126ac027f022ddafbbaed8516166233 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py @@ -88,5 +88,40 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest): self.check_grad(["Logits"], "Loss") +class TestSoftmaxWithCrossEntropyOp3(OpTest): + """ + Test softmax with cross entropy operator with ignore_index. + """ + + def setUp(self): + self.op_type = "softmax_with_cross_entropy" + batch_size = 41 + class_num = 37 + + logits = np.random.uniform(0.1, 1.0, + [batch_size, class_num]).astype("float64") + softmax = np.apply_along_axis(stable_softmax, 1, logits) + labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int64") + ignore_index = 7 + cross_entropy = np.asmatrix( + [[-np.log(softmax[i][labels[i][0]])] + if labels[i] != ignore_index else [0] + for i in range(softmax.shape[0])], + dtype="float64") + + self.inputs = {"Logits": logits, "Label": labels} + self.outputs = { + "Softmax": softmax.astype("float64"), + "Loss": cross_entropy.astype("float64") + } + self.attrs = {"ignore_index": ignore_index} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["Logits"], "Loss") + + if __name__ == "__main__": unittest.main()