test_sigmoid_cross_entropy_with_logits_op.py 2.8 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import numpy as np
16
from .op_test import OpTest
17 18
from scipy.special import logit
from scipy.special import expit
19
import unittest
20 21 22


class TestSigmoidCrossEntropyWithLogitsOp1(OpTest):
23 24
    """Test sigmoid_cross_entropy_with_logit_op with binary label
    """
25 26 27 28 29 30 31 32 33

    def setUp(self):
        self.op_type = "sigmoid_cross_entropy_with_logits"
        batch_size = 64
        num_classes = 20
        self.inputs = {
            'X': logit(
                np.random.uniform(0, 1, (batch_size, num_classes))
                .astype("float32")),
34
            'Label': np.random.randint(0, 2, (batch_size, num_classes))
35 36 37 38 39
            .astype("float32")
        }

        # Fw Pass is implemented as elementwise sigmoid followed by
        # elementwise logistic loss
40
        # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
41
        sigmoid_X = expit(self.inputs['X'])
42 43
        term1 = self.inputs['Label'] * np.log(sigmoid_X)
        term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
44 45 46 47 48 49 50 51 52 53
        self.outputs = {'Out': -term1 - term2}

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(['X'], 'Out')


class TestSigmoidCrossEntropyWithLogitsOp2(OpTest):
54 55
    """Test sigmoid_cross_entropy_with_logit_op with probabalistic label
    """
56 57 58 59 60 61 62 63 64

    def setUp(self):
        self.op_type = "sigmoid_cross_entropy_with_logits"
        batch_size = 64
        num_classes = 20
        self.inputs = {
            'X': logit(
                np.random.uniform(0, 1, (batch_size, num_classes))
                .astype("float32")),
65
            'Label': np.random.uniform(0, 1, (batch_size, num_classes))
66 67 68 69 70
            .astype("float32")
        }

        # Fw Pass is implemented as elementwise sigmoid followed by
        # elementwise logistic loss
71
        # Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
72
        sigmoid_X = expit(self.inputs['X'])
73 74
        term1 = self.inputs['Label'] * np.log(sigmoid_X)
        term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
75 76 77 78 79 80 81
        self.outputs = {'Out': -term1 - term2}

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(['X'], 'Out')
82 83 84 85


if __name__ == '__main__':
    unittest.main()