test_softmax_with_cross_entropy_op.py 24.7 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

C
caoying03 已提交
17 18
import unittest
import numpy as np
19
import paddle.fluid.core as core
C
caoying03 已提交
20

21 22
from op_test import OpTest
from test_softmax_op import stable_softmax
C
caoying03 已提交
23 24


25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1):
    if soft_label:
        return (-label * np.log(softmax)).sum(axis=axis, keepdims=True)

    shape = softmax.shape
    axis %= len(shape)
    n = int(np.prod(shape[:axis]))
    axis_dim = shape[axis]
    remain = int(np.prod(shape[axis + 1:]))
    softmax_reshape = softmax.reshape((n, axis_dim, remain))
    label_reshape = label.reshape((n, 1, remain))
    result = np.zeros_like(label_reshape, dtype=softmax.dtype)
    for i in range(n):
        for j in range(remain):
            lbl = label_reshape[i, 0, j]
            if lbl != ignore_index:
                result[i, 0, j] -= np.log(softmax_reshape[i, lbl, j])
    return result.reshape(label.shape)


C
caoying03 已提交
45
class TestSoftmaxWithCrossEntropyOp(OpTest):
46 47 48 49
    """
    Test softmax with cross entropy operator with discreate one-hot labels.
    """

S
sneaxiy 已提交
50
    def initParams(self):
51
        self.op_type = "softmax_with_cross_entropy"
S
sneaxiy 已提交
52
        self.numeric_stable_mode = False
53
        self.soft_label = False
54 55
        # explicilty use float32 for ROCm, as MIOpen does not yet support float64
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
56 57 58
        self.axis = -1
        self.ignore_index = -1
        self.shape = [41, 37]
59
        self.softmax_switch = True
S
sneaxiy 已提交
60

C
caoying03 已提交
61
    def setUp(self):
S
sneaxiy 已提交
62
        self.initParams()
C
caoying03 已提交
63

64 65 66
        logits = getattr(
            self, "logits",
            np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype))
67 68 69 70 71 72 73 74 75
        softmax = np.apply_along_axis(stable_softmax, self.axis, logits)

        if self.soft_label:
            labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)
            labels /= np.sum(labels, axis=self.axis, keepdims=True)
        else:
            axis_dim = self.shape[self.axis]
            self.shape[self.axis] = 1
            labels = np.random.randint(0, axis_dim, self.shape, dtype="int64")
C
caoying03 已提交
76

77 78
        loss = cross_entropy(softmax, labels, self.soft_label, self.axis,
                             self.ignore_index)
C
caoying03 已提交
79

80 81 82 83 84
        if self.softmax_switch == False:
            self.inputs = {"Logits": softmax, "Label": labels}
        else:
            self.inputs = {"Logits": logits, "Label": labels}

Y
Yu Yang 已提交
85
        self.outputs = {
86
            "Softmax": softmax.astype(self.dtype),
87
            "Loss": loss.astype(self.dtype)
Y
Yu Yang 已提交
88
        }
89 90 91
        self.attrs = {
            "numeric_stable_mode": self.numeric_stable_mode,
            "soft_label": self.soft_label,
92
            "ignore_index": self.ignore_index,
93
            "softmax_switch": self.softmax_switch,
94
        }
95

96 97
        if self.axis != -1:
            self.attrs['axis'] = self.axis
C
caoying03 已提交
98 99 100 101 102

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
103 104 105 106
        if core.is_compiled_with_rocm():
            # HIP will have accuracy fail when using float32 in CPU place
            self.check_grad(["Logits"], "Loss", max_relative_error=5e-1)
        else:
107 108 109 110 111 112 113 114 115 116 117 118
            self.check_grad(["Logits"], "Loss", numeric_grad_delta=0.001)


class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_1D(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = True
        self.shape = [13, 8]
        self.axis = -1
        self.ignore_index = -1
119
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
120 121 122 123 124 125 126 127 128 129 130 131
        self.softmax_switch = False  #default is true, means "with softmax"


class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_1D(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
        self.shape = [13, 8]
        self.axis = -1
        self.ignore_index = -1
132
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
        self.softmax_switch = False  #default is true, means "with softmax"


##############################################################################
#NotWithSoftmax_SoftLabel_2D start
##############################################################################
class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = True
        self.shape = [3, 5, 7, 11]
        self.axis = -1
        self.ignore_index = -1
148
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
149 150 151 152 153 154 155 156 157
        self.softmax_switch = False  #default is true, means "with softmax"


class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis2(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = True
158
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
159 160 161 162 163 164 165 166 167 168 169 170
        self.axis = 1
        self.ignore_index = -1
        self.shape = [3, 5, 7, 11]
        self.softmax_switch = False  #default is true, means "with softmax"


class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis3(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = True
171
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
172 173 174 175 176 177 178 179 180 181 182 183
        self.axis = 2
        self.ignore_index = -1
        self.shape = [3, 5, 7, 11]
        self.softmax_switch = False  #default is true, means "with softmax"


class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_SoftLabel_2D_Axis4(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = True
184
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
        self.axis = 3
        self.ignore_index = -1
        self.shape = [3, 5, 7, 11]
        self.softmax_switch = False  #default is true, means "with softmax"


##############################################################################
#NotWithSoftmax_SoftLabel_2D end
##############################################################################

##############################################################################
#NotWithSoftmax_HardLabel_2D start
##############################################################################


class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.axis = -1
        self.ignore_index = -1
209
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
210 211 212 213 214 215 216 217 218
        self.softmax_switch = False  #default is true, means "with softmax"


class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis2(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
219
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
220 221 222 223 224 225 226 227 228 229 230 231
        self.axis = 1
        self.ignore_index = -1
        self.shape = [3, 5, 7, 11]
        self.softmax_switch = False  #default is true, means "with softmax"


class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis3(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
232
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
233 234 235 236 237 238 239 240 241 242 243 244
        self.axis = 2
        self.ignore_index = -1
        self.shape = [3, 5, 7, 11]
        self.softmax_switch = False  #default is true, means "with softmax"


class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Axis4(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
245
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
        self.axis = 3
        self.ignore_index = -1
        self.shape = [3, 5, 7, 11]
        self.softmax_switch = False  #default is true, means "with softmax"


##############################################################################
#NotWithSoftmax_HardLabel_2D end
##############################################################################

##############################################################################
#NotWithSoftmax_HardLabel_2D_Ignore start
##############################################################################


class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = False
        self.soft_label = False
        self.shape = [13, 8]
        self.axis = -1
        self.ignore_index = 2
270
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
271 272 273 274 275 276 277 278 279 280 281 282
        self.softmax_switch = False  #default is true, means "with softmax"


class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_Ignore_Axis(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = False
        self.soft_label = False
        self.shape = [13, 8]
        self.axis = 1
        self.ignore_index = 2
283
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
284 285 286 287 288 289 290 291 292 293 294 295
        self.softmax_switch = False  #default is true, means "with softmax"


class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.axis = -1
        self.ignore_index = 2
296
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
297 298 299 300 301 302 303 304 305
        self.softmax_switch = False  #default is true, means "with softmax"


class TestSoftmaxWithCrossEntropyOp_NotWithSoftmax_HardLabel_2D_Ignore_Axis3(
        TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
306
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
307 308 309 310 311 312 313 314 315
        self.axis = 2
        self.ignore_index = 2
        self.shape = [3, 5, 7, 11]
        self.softmax_switch = False  #default is true, means "with softmax"


##############################################################################
#NotWithSoftmax_HardLabel_2D_Ignore end
##############################################################################
C
caoying03 已提交
316 317


S
sneaxiy 已提交
318 319
class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
320
        self.op_type = "softmax_with_cross_entropy"
S
sneaxiy 已提交
321
        self.numeric_stable_mode = True
322 323 324 325
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.axis = -1
        self.ignore_index = -1
326
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
327
        self.softmax_switch = True
S
sneaxiy 已提交
328 329


330 331
@unittest.skipIf(not core.is_compiled_with_cuda(),
                 "core is not compiled with CUDA")
332 333
class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp):
    def initParams(self):
334
        self.op_type = "softmax_with_cross_entropy"
335
        self.numeric_stable_mode = False
336 337 338 339
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.axis = -1
        self.ignore_index = -1
340 341 342 343 344 345 346
        self.dtype = np.float16

    def setUp(self):
        self.initParams()
        self.op_type = "softmax_with_cross_entropy"

        # NOTE: numpy float16 have very low accuracy, use float32 for numpy check.
347
        date_type = np.float32 if core.is_compiled_with_rocm() else np.float64
348 349
        logits = getattr(
            self, "logits",
350
            np.random.uniform(0.1, 1.0, self.shape).astype(date_type))
351
        softmax = np.apply_along_axis(stable_softmax, self.axis, logits)
352

353 354 355 356 357
        axis_dim = self.shape[self.axis]
        self.shape[self.axis] = 1
        labels = np.random.randint(0, axis_dim, self.shape, dtype="int64")

        loss = cross_entropy(softmax, labels, self.soft_label, self.axis)
358

359
        self.inputs = {"Logits": logits.astype(self.dtype), "Label": labels}
360 361
        self.outputs = {
            "Softmax": softmax.astype(self.dtype),
362 363 364 365 366
            "Loss": loss.astype(self.dtype)
        }
        self.attrs = {
            "numeric_stable_mode": self.numeric_stable_mode,
            "soft_label": self.soft_label,
367
        }
368 369
        if self.axis != -1:
            self.attrs['axis'] = self.axis
370 371 372 373 374 375 376 377 378 379 380

    def test_check_output(self):
        self.check_output(atol=1e-2)

    def test_check_grad(self):
        self.check_grad(["Logits"], "Loss", max_relative_error=0.1)


class TestSoftmaxWithCrossEntropyOpNoCudnnFp16(
        TestSoftmaxWithCrossEntropyOpFp16):
    def initParams(self):
381
        self.op_type = "softmax_with_cross_entropy"
382
        self.numeric_stable_mode = True
383 384 385 386
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.axis = -1
        self.ignore_index = -1
387 388 389 390 391 392
        self.dtype = np.float16

    def test_check_grad(self):
        self.check_grad(["Logits"], "Loss", max_relative_error=0.1)


393
class TestSoftmaxWithCrossEntropyOp2(TestSoftmaxWithCrossEntropyOp):
394 395 396 397
    """
    Test softmax with cross entropy operator with soft labels.
    """

398
    def initParams(self):
399
        self.op_type = "softmax_with_cross_entropy"
400 401
        self.numeric_stable_mode = True
        self.soft_label = True
402
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
403 404 405
        self.axis = -1
        self.ignore_index = -1
        self.shape = [41, 37]
406
        self.softmax_switch = True
407 408 409 410 411

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
412 413 414 415 416
        if core.is_compiled_with_rocm():
            # HIP will have accuracy fail when using float32 in CPU place
            self.check_grad(["Logits"], "Loss", max_relative_error=0.1)
        else:
            self.check_grad(["Logits"], "Loss")
417 418


419
class TestSoftmaxWithCrossEntropyOp3(TestSoftmaxWithCrossEntropyOp):
420 421 422 423
    """
    Test softmax with cross entropy operator with ignore_index.
    """

S
sneaxiy 已提交
424
    def initParams(self):
425
        self.op_type = "softmax_with_cross_entropy"
S
sneaxiy 已提交
426
        self.numeric_stable_mode = False
427 428 429 430
        self.soft_label = False
        self.shape = [41, 37]
        self.ignore_index = 5
        self.axis = -1
431
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
432
        self.softmax_switch = True
S
sneaxiy 已提交
433

434 435 436

class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3):
    def initParams(self):
437
        self.op_type = "softmax_with_cross_entropy"
438 439 440 441 442
        self.numeric_stable_mode = True
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.ignore_index = 4
        self.axis = -1
443
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
444
        self.softmax_switch = True
445 446


447 448 449 450 451
class TestSoftmaxWithCrossEntropyOpAxis1(TestSoftmaxWithCrossEntropyOp):
    """
    Test softmax with cross entropy operator with discreate one-hot labels.
    Given axis != -1
    """
452

453 454 455 456
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
457
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
458 459 460
        self.axis = 0
        self.ignore_index = -1
        self.shape = [3, 5, 7, 11]
461
        self.softmax_switch = True
462 463


464 465 466 467 468 469
class TestSoftmaxWithCrossEntropyOpAxis2(TestSoftmaxWithCrossEntropyOp):
    """
    Test softmax with cross entropy operator with discreate one-hot labels.
    Given axis != -1
    """

S
sneaxiy 已提交
470
    def initParams(self):
471
        self.op_type = "softmax_with_cross_entropy"
S
sneaxiy 已提交
472
        self.numeric_stable_mode = True
473
        self.soft_label = False
474
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
475 476 477
        self.axis = 1
        self.ignore_index = -1
        self.shape = [3, 5, 7, 11]
478
        self.softmax_switch = True
S
sneaxiy 已提交
479 480


481
class TestSoftmaxWithCrossEntropyOpAxis3(TestSoftmaxWithCrossEntropyOp):
P
phlrain 已提交
482
    """
483 484
    Test softmax with cross entropy operator with discreate one-hot labels.
    Given axis != -1
P
phlrain 已提交
485 486 487
    """

    def initParams(self):
488 489 490
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
491
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
492 493 494
        self.axis = 2
        self.ignore_index = -1
        self.shape = [3, 5, 7, 11]
495
        self.softmax_switch = True
P
phlrain 已提交
496

497 498 499 500 501 502 503 504

class TestSoftmaxWithCrossEntropyOpAxis4(TestSoftmaxWithCrossEntropyOp):
    """
    Test softmax with cross entropy operator with discreate one-hot labels.
    Given axis != -1
    """

    def initParams(self):
P
phlrain 已提交
505
        self.op_type = "softmax_with_cross_entropy"
506 507
        self.numeric_stable_mode = True
        self.soft_label = False
508
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
509 510 511
        self.axis = 3
        self.ignore_index = -1
        self.shape = [3, 5, 7, 11]
512
        self.softmax_switch = True
P
phlrain 已提交
513 514


515 516 517 518 519 520 521 522 523 524 525
class TestSoftmaxWithCrossEntropyOpAxisDimEqualOne(
        TestSoftmaxWithCrossEntropyOp):
    """
    Test softmax with cross entropy operator with discreate one-hot labels.
    Given axis != -1
    """

    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
526
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
527 528 529
        self.axis = -1
        self.ignore_index = -1
        self.shape = [3, 5, 7, 1]
530
        self.softmax_switch = True
531 532


533 534 535 536 537 538 539 540 541 542
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis1(
        TestSoftmaxWithCrossEntropyOpNoCudnnFp16):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.axis = 0
        self.ignore_index = -1
        self.dtype = np.float16
543
        self.softmax_switch = True
P
phlrain 已提交
544 545


546 547
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis2(
        TestSoftmaxWithCrossEntropyOpNoCudnnFp16):
P
phlrain 已提交
548
    def initParams(self):
549
        self.op_type = "softmax_with_cross_entropy"
P
phlrain 已提交
550
        self.numeric_stable_mode = True
551 552 553 554 555
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.axis = 1
        self.ignore_index = -1
        self.dtype = np.float16
556
        self.softmax_switch = True
P
phlrain 已提交
557 558


559 560 561 562 563 564 565 566 567 568
class TestSoftmaxWithCrossEntropyOpNoCudnnFp16Axis3(
        TestSoftmaxWithCrossEntropyOpNoCudnnFp16):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.axis = 2
        self.ignore_index = -1
        self.dtype = np.float16
569
        self.softmax_switch = True
P
phlrain 已提交
570

571 572 573 574

class TestSoftmaxWithCrossEntropyOpSoftLabelAxis1(
        TestSoftmaxWithCrossEntropyOp2):
    def initParams(self):
P
phlrain 已提交
575
        self.op_type = "softmax_with_cross_entropy"
576 577 578 579 580
        self.numeric_stable_mode = True
        self.soft_label = True
        self.shape = [3, 5, 7, 11]
        self.axis = 0
        self.ignore_index = -1
581
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
582
        self.softmax_switch = True
P
phlrain 已提交
583 584


585 586 587 588 589 590 591 592 593
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis2(
        TestSoftmaxWithCrossEntropyOp2):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = True
        self.shape = [3, 5, 7, 11]
        self.axis = 1
        self.ignore_index = -1
594
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
595
        self.softmax_switch = True
P
phlrain 已提交
596 597


598 599 600 601 602 603 604 605 606
class TestSoftmaxWithCrossEntropyOpSoftLabelAxis3(
        TestSoftmaxWithCrossEntropyOp2):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = True
        self.shape = [3, 5, 7, 11]
        self.axis = 2
        self.ignore_index = -1
607
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
608
        self.softmax_switch = True
P
phlrain 已提交
609

610 611 612 613 614 615 616 617 618 619

class TestSoftmaxWithCrossEntropyOpSoftLabelAxis4(
        TestSoftmaxWithCrossEntropyOp2):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = True
        self.shape = [3, 5, 7, 11]
        self.axis = 3
        self.ignore_index = -1
620
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
621
        self.softmax_switch = True
P
phlrain 已提交
622 623


624 625
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1(
        TestSoftmaxWithCrossEntropyOp3):
P
phlrain 已提交
626
    def initParams(self):
627 628 629 630 631 632
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.ignore_index = 1
        self.axis = 0
633
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
634
        self.softmax_switch = True
P
phlrain 已提交
635

636 637 638 639

class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2(
        TestSoftmaxWithCrossEntropyOp3):
    def initParams(self):
P
phlrain 已提交
640
        self.op_type = "softmax_with_cross_entropy"
641 642 643 644 645
        self.numeric_stable_mode = True
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.ignore_index = 0
        self.axis = 1
646
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
647
        self.softmax_switch = True
P
phlrain 已提交
648 649


650 651 652 653 654 655 656 657 658
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3(
        TestSoftmaxWithCrossEntropyOp3):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.ignore_index = 3
        self.axis = 2
659
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
660
        self.softmax_switch = True
P
phlrain 已提交
661 662


663 664 665 666 667 668 669 670 671
class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4(
        TestSoftmaxWithCrossEntropyOp3):
    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.ignore_index = 3
        self.axis = 3
672
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
673
        self.softmax_switch = True
P
phlrain 已提交
674 675


676 677 678 679 680 681 682 683 684 685 686 687 688
class TestSoftmaxWithCrossEntropyOpBoundary0(TestSoftmaxWithCrossEntropyOp):
    """
    Test stable softmax with cross entropy operator will not product INF
    with small logits value.
    """

    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.axis = -1
        self.ignore_index = -1
689
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
690
        self.logits = np.full(self.shape, -500.0).astype(self.dtype)
691
        self.softmax_switch = True
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706


class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp):
    """
    Test stable softmax with cross entropy operator will not product INF
    with small logits value.
    """

    def initParams(self):
        self.op_type = "softmax_with_cross_entropy"
        self.numeric_stable_mode = True
        self.soft_label = False
        self.shape = [3, 5, 7, 11]
        self.axis = -1
        self.ignore_index = -1
707
        self.dtype = np.float32 if core.is_compiled_with_rocm() else np.float64
708 709
        self.logits = np.full(self.shape, 1000.0).astype(self.dtype)
        self.logits[:, :, 0, :] = -1000.0
710
        self.softmax_switch = True
711 712


C
caoying03 已提交
713 714
if __name__ == "__main__":
    unittest.main()