From 79cec5311179e6e50b0126fea0e6dfa8a7cf354a Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Tue, 20 Nov 2018 12:37:04 +0000 Subject: [PATCH] add ignore index for sigmoid cross entropy with logits op, test=develop --- .../sigmoid_cross_entropy_with_logits_op.cc | 5 + .../sigmoid_cross_entropy_with_logits_op.h | 93 ++++++++++++++----- python/paddle/fluid/layers/nn.py | 5 +- .../fluid/tests/unittests/test_layers.py | 3 +- ...st_sigmoid_cross_entropy_with_logits_op.py | 35 +++++++ 5 files changed, 116 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc index 193de05422b..d6a2fa6a179 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cc @@ -100,6 +100,11 @@ class SigmoidCrossEntropyWithLogitsOpMaker AddOutput("Out", "(Tensor, default Tensor), a 2-D tensor with shape N x D " " of elementwise logistic losses."); + AddAttr( + "ignore_index", + "(int, default -1), Specifies a target value that is ignored and" + "does not contribute to the input gradient.") + .SetDefault(-1); AddComment(R"DOC( SigmoidCrossEntropyWithLogits Operator. diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h index faef72866eb..2bfba6f1704 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h @@ -15,33 +15,82 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/hostdevice.h" +#include "paddle/legacy/utils/Logging.h" namespace paddle { namespace operators { +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; +template +using EigenMatrix = framework::EigenMatrix; + +template +struct SigmoidCrossEntropyWithLogitsForward { + // EIGEN_EMPTY_STRUCT_CTOR(SigmoidCrossEntropyWithLogitsForward) + HOSTDEVICE SigmoidCrossEntropyWithLogitsForward(const int &ignore_index) + : ignore_index(ignore_index) {} + + HOSTDEVICE T operator()(const T &x, const T &label) const { + if (static_cast(label) == ignore_index) { + return static_cast(0.); + } + T term1 = (x > 0) ? x : 0; + T term2 = x * label; + T term3 = std::log(static_cast(1) + std::exp(-(std::abs(x)))); + return term1 - term2 + term3; + } + + int ignore_index; +}; + +template +struct SigmoidCrossEntropyWithLogitsBackward { + // EIGEN_EMPTY_STRUCT_CTOR(SigmoidCrossEntropyWithLogitsForward) + HOSTDEVICE SigmoidCrossEntropyWithLogitsBackward(const int &ignore_index) + : ignore_index(ignore_index) {} + + HOSTDEVICE T operator()(const T &x, const T &label) const { + if (static_cast(label) == ignore_index) { + return static_cast(0.); + } + T simoid_x = static_cast(1) / (static_cast(1) + std::exp(-x)); + return simoid_x - label; + } + + int ignore_index; +}; + // Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X))) template class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - const framework::Tensor *X = context.Input("X"); - const framework::Tensor *Labels = context.Input("Label"); - framework::Tensor *Out = context.Output("Out"); + const Tensor *X = context.Input("X"); + const Tensor *Labels = context.Input("Label"); + Tensor *Out = context.Output("Out"); Out->mutable_data(context.GetPlace()); + int ignore_index = context.Attr("ignore_index"); - auto x = framework::EigenVector::Flatten(*X); - auto labels = framework::EigenVector::Flatten(*Labels); - auto out = framework::EigenVector::Flatten(*Out); + auto x = EigenVector::Flatten(*X); + auto labels = EigenVector::Flatten(*Labels); + auto out = EigenVector::Flatten(*Out); auto &place = *context.device_context().eigen_device(); + out.device(place) = x.binaryExpr( + labels, SigmoidCrossEntropyWithLogitsForward(ignore_index)); // term1 = max(x, 0) - auto term1 = x.cwiseMax(static_cast(0)); + // auto term1 = x.cwiseMax(static_cast(0)); // term2 = x * labels - auto term2 = x * labels; + // auto term2 = x * labels; // term3 = log(1 + exp(-abs(x))) - auto term3 = (static_cast(1) + (-(x.abs())).exp()).log(); + // auto term3 = (static_cast(1) + (-(x.abs())).exp()).log(); - out.device(place) = term1 - term2 + term3; + // out.device(place) = term1 - term2 + term3; } }; @@ -50,23 +99,23 @@ template class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - const framework::Tensor *X = context.Input("X"); - const framework::Tensor *Labels = context.Input("Label"); - const framework::Tensor *dOut = - context.Input(framework::GradVarName("Out")); - framework::Tensor *dX = - context.Output(framework::GradVarName("X")); + const Tensor *X = context.Input("X"); + const Tensor *Labels = context.Input("Label"); + const Tensor *dOut = context.Input(framework::GradVarName("Out")); + Tensor *dX = context.Output(framework::GradVarName("X")); dX->mutable_data(context.GetPlace()); - auto x = framework::EigenVector::Flatten(*X); - auto labels = framework::EigenVector::Flatten(*Labels); - auto dout = framework::EigenVector::Flatten(*dOut); - auto dx = framework::EigenVector::Flatten(*dX); + auto ignore_index = context.Attr("ignore_index"); + auto x = EigenVector::Flatten(*X); + auto labels = EigenVector::Flatten(*Labels); + auto dout = EigenVector::Flatten(*dOut); + auto dx = EigenVector::Flatten(*dX); auto &place = *context.template device_context().eigen_device(); - auto sigmoid_x = static_cast(1) / (static_cast(1) + (-x).exp()); - dx.device(place) = dout * (sigmoid_x - labels); + auto diff = x.binaryExpr(labels, SigmoidCrossEntropyWithLogitsBackward( + static_cast(ignore_index))); + dx.device(place) = dout * diff; } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 99acd7e3088..e032835de32 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -7892,13 +7892,14 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): @templatedoc() -def sigmoid_cross_entropy_with_logits(x, label, name=None): +def sigmoid_cross_entropy_with_logits(x, label, ignore_index=-1, name=None): """ ${comment} Args: x(${x_type}): ${x_comment} label(${label_type}): ${label_comment} + ignore_index(&{ignore_index}): ${ignore_index_comment} name(basestring|None): Name of the output. Returns: @@ -7917,7 +7918,7 @@ def sigmoid_cross_entropy_with_logits(x, label, name=None): type="sigmoid_cross_entropy_with_logits", inputs={"X": x, "Label": label}, - attrs={}, + attrs={"ignore_index": ignore_index}, outputs={"Out": out}) return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index a8fa5436c43..8e098e4961f 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -170,9 +170,10 @@ class TestBook(unittest.TestCase): with program_guard(program): dat = layers.data(name='data', shape=[10], dtype='float32') lbl = layers.data(name='label', shape=[10], dtype='float32') + ignore_index = -1 self.assertIsNotNone( layers.sigmoid_cross_entropy_with_logits( - x=dat, label=lbl)) + x=dat, label=lbl, ignore_index=-1)) print(str(program)) def test_hsigmoid(self): diff --git a/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py b/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py index 97ff203499c..64f6f088e10 100644 --- a/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py +++ b/python/paddle/fluid/tests/unittests/test_sigmoid_cross_entropy_with_logits_op.py @@ -56,6 +56,40 @@ class TestSigmoidCrossEntropyWithLogitsOp2(OpTest): """Test sigmoid_cross_entropy_with_logit_op with probabalistic label """ + def setUp(self): + self.op_type = "sigmoid_cross_entropy_with_logits" + batch_size = 64 + num_classes = 20 + ignore_index = -1 + self.inputs = { + 'X': logit( + np.random.uniform(0, 1, (batch_size, num_classes)) + .astype("float32")), + 'Label': np.random.randint(-1, 2, (batch_size, num_classes)) + .astype("float32") + } + self.attrs = {'ignore_index': ignore_index, } + # Fw Pass is implemented as elementwise sigmoid followed by + # elementwise logistic loss + # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X)) + sigmoid_X = expit(self.inputs['X']) + term1 = self.inputs['Label'] * np.log(sigmoid_X) + term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X) + out = -term1 - term2 + out[np.where(self.inputs['Label'] == ignore_index)] = 0 + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestSigmoidCrossEntropyWithLogitsOp3(OpTest): + """Test sigmoid_cross_entropy_with_logit_op with probabalistic label + """ + def setUp(self): self.op_type = "sigmoid_cross_entropy_with_logits" batch_size = 64 @@ -85,3 +119,4 @@ class TestSigmoidCrossEntropyWithLogitsOp2(OpTest): if __name__ == '__main__': unittest.main() + np.random.seed(0) -- GitLab