test_layer_norm_op.py 29.8 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

C
chengduoZH 已提交
15
import unittest
16 17 18
from functools import reduce
from operator import mul

C
chengduoZH 已提交
19
import numpy as np
20 21 22 23 24
from eager_op_test import (
    OpTest,
    _set_use_system_allocator,
    convert_float_to_uint16,
)
C
chengduoZH 已提交
25

26
import paddle
27
import paddle.nn.functional as F
28 29
from paddle import fluid
from paddle.fluid import Program, core, program_guard
30
from paddle.static.amp.fp16_utils import _keep_layer_norm_scale_bias_to_fp32
31 32

paddle.enable_static()
C
chengduoZH 已提交
33

34 35
np.random.seed(123)
paddle.seed(123)
C
chengduoZH 已提交
36

37 38
_set_use_system_allocator(True)

C
chengduoZH 已提交
39

C
chengduoZH 已提交
40
def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1):
C
chengduoZH 已提交
41 42
    x_shape = x.shape
    N = reduce(mul, x_shape[0:begin_norm_axis], 1)
43
    D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1)
C
chengduoZH 已提交
44
    x.shape = [N, D]
C
chengduoZH 已提交
45

C
chengduoZH 已提交
46 47
    mean = np.mean(x, axis=1)
    var = np.var(x, axis=1) + epsilon
48 49 50
    output = np.divide(
        (x - mean.reshape([N, 1])), (np.sqrt(var)).reshape([N, 1])
    )
51 52 53 54
    if scale is not None:
        output = scale.reshape([1, D]) * output
    if beta is not None:
        output = output + beta.reshape([1, D])
C
chengduoZH 已提交
55 56

    x.shape, output.shape = x_shape, x_shape
C
chengduoZH 已提交
57 58 59
    return output, mean, var


60 61 62
def _reference_layer_norm_grad(
    x, grad_y, scale, bias, mean, var, begin_norm_axis=1
):
C
chengduoZH 已提交
63
    x_shape = x.shape
C
chengduoZH 已提交
64
    N = reduce(mul, x_shape[0:begin_norm_axis], 1)
65
    D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1)
66 67 68 69

    if scale is not None:
        scale_shape = scale.shape
        scale.shape = [1, D]
C
chengduoZH 已提交
70 71
    x.shape, grad_y.shape = [N, D], [N, D]
    var.shape, mean.shape = [N, 1], [N, 1]
C
chengduoZH 已提交
72

C
chengduoZH 已提交
73
    # d_bias
74 75 76 77
    if bias is not None:
        d_bias = np.sum(grad_y, axis=0).reshape([1, D])
    else:
        d_bias = None
C
chengduoZH 已提交
78
    # d_scale
79
    if scale is not None:
80 81 82
        d_scale = np.sum(
            ((x - mean) * np.sqrt(1 / var)) * grad_y, axis=0
        ).reshape([1, D])
83 84
    else:
        d_scale = None
C
chengduoZH 已提交
85
    # dx
86 87
    if scale is not None:
        dx_end = scale * np.sqrt(1.0 / var) * grad_y
88 89 90
        d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * scale, axis=1).reshape(
            [N, 1]
        )  # the second part equals to zero.
91
        d_mean = 1.0 / D * d_mean_0
92 93 94 95 96
        d_std = np.sum(
            -(1.0 / var) * (x - mean) * grad_y * scale, axis=1
        ).reshape([N, 1]) * (
            1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean)
        )
97 98
    else:
        dx_end = 1.0 * np.sqrt(1.0 / var) * grad_y
99 100 101
        d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * 1.0, axis=1).reshape(
            [N, 1]
        )  # the second part equals to zero.
102
        d_mean = 1.0 / D * d_mean_0
103 104 105 106 107
        d_std = np.sum(
            -(1.0 / var) * (x - mean) * grad_y * 1.0, axis=1
        ).reshape([N, 1]) * (
            1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean)
        )
C
chengduoZH 已提交
108

C
chengduoZH 已提交
109
    grad_x = dx_end + d_mean + d_std
C
chengduoZH 已提交
110

C
chengduoZH 已提交
111
    grad_x.shape, x.shape, grad_y.shape = x_shape, x_shape, x_shape
112
    var.shape, mean.shape = [N], [N]
113 114 115

    if scale is not None:
        scale.shape = scale_shape
C
chengduoZH 已提交
116
    return grad_x, d_scale, d_bias
C
chengduoZH 已提交
117 118


119 120 121 122 123 124 125 126 127 128
def layer_norm_wrapper(
    x, scale=None, bias=None, epsilon=1e-05, begin_norm_axis=1
):
    input_shape = list(x.shape)
    normalized_shape = input_shape[begin_norm_axis:]
    return paddle.nn.functional.layer_norm(
        x, normalized_shape, weight=scale, bias=bias, epsilon=epsilon
    )


129 130 131 132
@unittest.skipIf(
    paddle.is_compiled_with_rocm(),
    "ROCm doesn't support fp64 LayerNormOpByOp currently",
)
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
class TestLayerNormOpByOpTest(OpTest):
    def setUp(self):
        self.python_api = layer_norm_wrapper
        self.public_python_api = layer_norm_wrapper
        self.op_type = "layer_norm"
        self.prim_op_type = "comp"
        self.python_out_sig = ["Y"]
        self.initConfig()
        self.initTestCase()

    def test_check_output(self):
        self.check_output(
            no_check_set=["Mean", "Variance"],
            atol=self.ori_atol,
            rtol=self.ori_rtol,
            check_prim=True,
        )

    def test_check_grad(self):
        self.check_grad(
            self.check_grad_input_list,
            ['Y'],
            max_relative_error=self.max_relative_error,
            check_prim=True,
        )

    def initConfig(self):
        self.rev_comp_atol = 1e-7
        self.rev_comp_rtol = 1e-7
        self.fw_comp_atol = 1e-6
        self.fw_comp_rtol = 1e-6

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.cinn_atol = 1e-5
        self.cinn_rtol = 1e-5

        self.max_relative_error = 1e-5
171
        # ROCm does not have float64 LayerNorm kernel
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
        self.dtype = "float64"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = True
        self.has_bias = True

    def initTestCase(self):
        np.random.seed(123)

        self.D = reduce(
            mul, self.x_shape[self.begin_norm_axis : len(self.x_shape)], 1
        )
        self.scale_shape = [self.D]
        x = np.random.random(self.x_shape).astype(self.dtype)
        scale = (
            np.random.random(self.scale_shape).astype(self.dtype)
            if self.has_scale
            else None
        )
        bias = (
            np.random.random(self.scale_shape).astype(self.dtype)
            if self.has_bias
            else None
        )
        self.inputs = {
            "X": x,
        }
        self.check_grad_input_list = ['X']

        if self.has_scale:
            self.inputs.update({"Scale": scale})
            self.check_grad_input_list.append('Scale')
        if self.has_bias:
            self.inputs.update({"Bias": bias})
            self.check_grad_input_list.append('Bias')

        self.attrs = {
            "epsilon": self.epsilon,
            "begin_norm_axis": self.begin_norm_axis,
        }
        y, mean, variance = _reference_layer_norm_naive(
            x, scale, bias, self.epsilon, self.begin_norm_axis
        )
        self.outputs = {
            "Y": y,
            "Mean": mean,
            "Variance": variance,
        }


223 224
@unittest.skipIf(
    not core.is_compiled_with_cuda()
225
    or paddle.is_compiled_with_rocm()
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
    or not core.is_bfloat16_supported(core.CUDAPlace(0)),
    "core is not compiled with CUDA or not support the bfloat16",
)
class TestLayerNormBF16OpByOpTest(OpTest):
    def setUp(self):
        self.python_api = layer_norm_wrapper
        self.public_python_api = layer_norm_wrapper
        self.op_type = "layer_norm"
        self.prim_op_type = "comp"
        self.python_out_sig = ["Y"]
        self.initConfig()
        self.initTestCase()

    def test_check_output(self):
        self.check_output_with_place(
            place=core.CUDAPlace(0),
            no_check_set=["Mean", "Variance"],
            atol=self.ori_atol,
            rtol=self.ori_rtol,
            check_prim=True,
        )

    def test_check_grad(self):
        self.check_grad_with_place(
            core.CUDAPlace(0),
            self.check_grad_input_list,
            ['Y'],
            max_relative_error=self.max_relative_error,
            check_prim=True,
        )

    def initConfig(self):
        self.ori_atol = 1e-2
        self.ori_rtol = 1e-2

        self.max_relative_error = 1e-5

        self.dtype = np.uint16
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = True
        self.has_bias = True

    def initTestCase(self):
        np.random.seed(123)

        self.D = reduce(
            mul, self.x_shape[self.begin_norm_axis : len(self.x_shape)], 1
        )
        self.scale_shape = [self.D]
        x = np.random.random(self.x_shape).astype("float32")
        scale = (
            np.random.random(self.scale_shape).astype("float32")
            if self.has_scale
            else None
        )
        bias = (
            np.random.random(self.scale_shape).astype("float32")
            if self.has_bias
            else None
        )
        self.inputs = {
            "X": convert_float_to_uint16(x),
        }
        self.check_grad_input_list = ['X']

        if self.has_scale:
            self.inputs.update({"Scale": convert_float_to_uint16(scale)})
            self.check_grad_input_list.append('Scale')
        if self.has_bias:
            self.inputs.update({"Bias": convert_float_to_uint16(bias)})
            self.check_grad_input_list.append('Bias')

        self.attrs = {
            "epsilon": self.epsilon,
            "begin_norm_axis": self.begin_norm_axis,
        }
        y, mean, variance = _reference_layer_norm_naive(
            x, scale, bias, self.epsilon, self.begin_norm_axis
        )
        self.outputs = {
            "Y": convert_float_to_uint16(y),
            "Mean": convert_float_to_uint16(mean),
            "Variance": convert_float_to_uint16(variance),
        }


314 315 316 317
@unittest.skipIf(
    paddle.is_compiled_with_rocm(),
    "ROCm doesn't support fp64 LayerNormOpByOp currently",
)
318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
class TestLayerNormOpByOpTestFP64_case2(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-6
        self.rev_comp_rtol = 1e-6
        self.fw_comp_atol = 1e-7
        self.fw_comp_rtol = 1e-7

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.cinn_atol = 1e-5
        self.cinn_rtol = 1e-5

        self.max_relative_error = 1e-5

        self.dtype = "float64"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = False
        self.has_bias = False


340 341 342 343
@unittest.skipIf(
    paddle.is_compiled_with_rocm(),
    "ROCm doesn't support bf16 LayerNormOpByOp currently",
)
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
class TestLayerNormBF16OpByOpTest_case2(TestLayerNormBF16OpByOpTest):
    def initConfig(self):
        self.ori_atol = 1e-2
        self.ori_rtol = 1e-2

        self.max_relative_error = 1e-5

        self.dtype = np.uint16
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = False
        self.has_bias = False


359 360 361 362
@unittest.skipIf(
    paddle.is_compiled_with_rocm(),
    "ROCm doesn't support fp64 LayerNormOpByOp currently",
)
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
class TestLayerNormOpByOpTestFP64_case3(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-7
        self.rev_comp_rtol = 1e-7
        self.fw_comp_atol = 1e-7
        self.fw_comp_rtol = 1e-7

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.cinn_atol = 1e-5
        self.cinn_rtol = 1e-5

        self.max_relative_error = 1e-5

        self.dtype = "float64"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = True
        self.has_bias = False


385 386 387 388
@unittest.skipIf(
    paddle.is_compiled_with_rocm(),
    "ROCm doesn't support bf16 LayerNormOpByOp currently",
)
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
class TestLayerNormBF16OpByOpTest_case3(TestLayerNormBF16OpByOpTest):
    def initConfig(self):
        self.ori_atol = 1e-2
        self.ori_rtol = 1e-2

        self.max_relative_error = 1e-5

        self.dtype = np.uint16
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = True
        self.has_bias = False


404 405 406 407
@unittest.skipIf(
    paddle.is_compiled_with_rocm(),
    "ROCm doesn't support fp64 LayerNormOpByOp currently",
)
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429
class TestLayerNormOpByOpTestFP64_case4(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-6
        self.rev_comp_rtol = 1e-6
        self.fw_comp_atol = 1e-7
        self.fw_comp_rtol = 1e-7

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.cinn_atol = 1e-5
        self.cinn_rtol = 1e-5

        self.max_relative_error = 1e-5

        self.dtype = "float64"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = False
        self.has_bias = True


430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
class TestLayerNormBF16OpByOpTest_case4(TestLayerNormBF16OpByOpTest):
    def initConfig(self):
        self.ori_atol = 1e-2
        self.ori_rtol = 1e-2

        self.max_relative_error = 1e-5

        self.dtype = np.uint16
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = False
        self.has_bias = True


445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
class TestLayerNormOpByOpTestFP32(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-5
        self.rev_comp_rtol = 1e-5

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.max_relative_error = 7e-3

        self.dtype = "float32"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = True
        self.has_bias = True


class TestLayerNormOpByOpTestFP32_case2(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-5
        self.rev_comp_rtol = 1e-5

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.max_relative_error = 1e-5

        self.dtype = "float32"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = False
        self.has_bias = False


class TestLayerNormOpByOpTestFP32_case3(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-5
        self.rev_comp_rtol = 1e-5

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.max_relative_error = 3e-3

        self.dtype = "float32"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = True
        self.has_bias = False


class TestLayerNormOpByOpTestFP32_case4(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-5
        self.rev_comp_rtol = 1e-5

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.max_relative_error = 1e-3

        self.dtype = "float32"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = False
        self.has_bias = True


513
class TestLayerNormOp(unittest.TestCase):
514 515 516
    def setUp(self):
        self.use_cudnn = True

C
chengduoZH 已提交
517
    def __assert_close(self, tensor, np_array, msg, atol=1e-4):
C
chengduoZH 已提交
518
        self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
C
chengduoZH 已提交
519

520 521 522 523 524 525 526 527 528 529 530 531
    def check_forward_backward(
        self,
        shape,
        begin_norm_axis,
        has_scale=True,
        has_bias=True,
        y_grad_scale=1.0,
        use_mkldnn=False,
    ):
        def test_with_place(
            place, shape, begin_norm_axis, use_mkldnn=use_mkldnn
        ):
C
chengduoZH 已提交
532 533 534
            # attr
            epsilon = 0.00001
            x_shape = shape
535
            D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1)
C
chengduoZH 已提交
536
            scale_shape = [D]
C
chengduoZH 已提交
537

538 539
            np.random.seed(123)
            x = np.random.random_sample(x_shape).astype(np.float32)
540 541 542 543 544 545 546 547 548 549
            scale = (
                np.random.random_sample(scale_shape).astype(np.float32)
                if has_scale
                else None
            )
            bias = (
                np.random.random_sample(scale_shape).astype(np.float32)
                if has_bias
                else None
            )
550
            y_grad = (np.random.random_sample(x_shape) * y_grad_scale).astype(
551 552
                np.float32
            )
C
chengduoZH 已提交
553

554 555
            # reference forward & backward
            y, mean, variance = _reference_layer_norm_naive(
556 557
                x, scale, bias, epsilon, begin_norm_axis
            )
558
            x_grad, scale_grad, bias_grad = _reference_layer_norm_grad(
559 560
                x, y_grad, scale, bias, mean, variance, begin_norm_axis
            )
561 562 563

            var_dict = locals()
            var_dict['y@GRAD'] = y_grad
564 565 566 567 568
            var_names = ['x', 'mean', 'variance', 'y', 'y@GRAD']
            if has_scale:
                var_names += ['scale']
            if has_bias:
                var_names += ['bias']
569 570 571 572 573 574
            ground_truth = {name: var_dict[name] for name in var_names}

            program = fluid.Program()
            with fluid.program_guard(program):
                block = program.global_block()
                for name in ground_truth:
575 576 577 578 579
                    block.create_var(
                        name=name,
                        dtype='float32',
                        shape=ground_truth[name].shape,
                    )
580 581 582 583 584 585 586 587 588 589 590 591 592
                inputs = {"X": block.var('x')}
                fetch_list = [
                    'y',
                    'mean',
                    'variance',
                    'x@GRAD',
                ]
                if has_scale:
                    inputs["Scale"] = block.var('scale')
                    fetch_list += ['scale@GRAD']
                if has_bias:
                    inputs["Bias"] = block.var('bias')
                    fetch_list += ['bias@GRAD']
593 594
                layer_norm_op = block.append_op(
                    type="layer_norm",
595
                    inputs=inputs,
596 597 598
                    outputs={
                        "Y": block.var('y'),
                        "Mean": block.var('mean'),  # share the same memory
599 600 601
                        "Variance": block.var(
                            'variance'
                        ),  # share the same memory
602 603 604
                    },
                    attrs={
                        "epsilon": epsilon,
605
                        "begin_norm_axis": begin_norm_axis,
606 607 608
                        "use_mkldnn": use_mkldnn,
                    },
                )
609 610
                # generate backward op_desc
                grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
611 612
                    layer_norm_op.desc, set(), []
                )
613 614 615 616 617 618 619 620 621 622 623
                grad_op_desc = grad_op_desc_list[0]
                new_op_desc = block.desc.append_op()
                new_op_desc.copy_from(grad_op_desc)
                for var_name in grad_op_desc.output_arg_names():
                    block.desc.var(var_name.encode("ascii"))
                grad_op_desc.infer_var_type(block.desc)
                grad_op_desc.infer_shape(block.desc)
                for arg in grad_op_desc.output_arg_names():
                    grad_var = block.desc.find_var(arg.encode("ascii"))
                    grad_var.set_dtype(core.VarDesc.VarType.FP32)

624
                program._sync_with_cpp()
625
                exe = fluid.Executor(place)
626 627 628 629 630 631 632 633
                out = exe.run(
                    program,
                    feed={
                        name: var_dict[name]
                        for name in ['x', 'scale', 'bias', 'y@GRAD']
                    },
                    fetch_list=fetch_list,
                )
H
hong 已提交
634 635
                # print(y)
                # print(out[0])
L
Leo Chen 已提交
636
                self.__assert_close(y, out[0], "y")
637 638 639
                self.__assert_close(mean, out[1], "mean")
                self.__assert_close(variance, out[2], "variance", 1e-3)
                self.__assert_close(x_grad, out[3], "x_grad")
640
                if has_scale:
641 642 643 644 645 646
                    self.__assert_close(
                        scale_grad,
                        out[fetch_list.index('scale@GRAD')],
                        "scale_grad",
                        1e-3,
                    )
647
                if has_bias:
648 649 650 651 652
                    self.__assert_close(
                        bias_grad,
                        out[fetch_list.index('bias@GRAD')],
                        "bias_grad",
                    )
C
chengduoZH 已提交
653 654

        places = [core.CPUPlace()]
655 656 657 658 659
        if (
            core.is_compiled_with_cuda()
            and core.op_support_gpu("layer_norm")
            and self.use_cudnn
        ):
C
chengduoZH 已提交
660 661 662
            places.append(core.CUDAPlace(0))

        for place in places:
C
chengduoZH 已提交
663 664
            test_with_place(place, shape, begin_norm_axis)

665
    def test_check_forward_backward_with_scale_and_bias(self):
C
chengduoZH 已提交
666
        self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1)
667
        self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1)
668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
        self.check_forward_backward(
            shape=[2, 3, 4, 5],
            begin_norm_axis=1,
            has_scale=False,
            has_bias=True,
        )
        self.check_forward_backward(
            shape=[2, 3, 4, 5],
            begin_norm_axis=1,
            has_scale=True,
            has_bias=False,
        )
        self.check_forward_backward(
            shape=[2, 3, 4, 5],
            begin_norm_axis=1,
            has_scale=False,
            has_bias=False,
        )
C
chengduoZH 已提交
686
        self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3)
687 688 689
        self.check_forward_backward(
            shape=[92, 513, 129], begin_norm_axis=2, y_grad_scale=0.1
        )
690
        self.check_forward_backward(shape=[3, 34, 1134], begin_norm_axis=2)
691
        self.check_forward_backward(shape=[3, 2, 1133], begin_norm_axis=2)
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730
        self.check_forward_backward(
            shape=[92, 513, 1134], begin_norm_axis=2, y_grad_scale=0.1
        )
        self.check_forward_backward(
            shape=[92, 513, 1134],
            begin_norm_axis=2,
            has_scale=False,
            has_bias=True,
            y_grad_scale=0.1,
        )
        self.check_forward_backward(
            shape=[92, 513, 1134],
            begin_norm_axis=2,
            has_scale=True,
            has_bias=False,
            y_grad_scale=0.1,
        )
        self.check_forward_backward(
            shape=[92, 513, 1134],
            begin_norm_axis=2,
            has_scale=False,
            has_bias=False,
            y_grad_scale=0.1,
        )
        self.check_forward_backward(
            shape=[512, 1024], begin_norm_axis=1, has_scale=True, has_bias=True
        )
        self.check_forward_backward(
            shape=[1, 128, 256, 256],
            begin_norm_axis=3,
            has_scale=True,
            has_bias=True,
        )
        self.check_forward_backward(
            shape=[1, 256, 384],
            begin_norm_axis=2,
            has_scale=True,
            has_bias=True,
        )
C
chengduoZH 已提交
731 732


733 734
class TestLayerNormAPI(unittest.TestCase):
    def test_case(self):
G
GGBond8488 已提交
735
        x = paddle.static.data(name='x', shape=[64, 32, 256], dtype='float32')
736
        x = paddle.static.nn.layer_norm(
737 738 739 740 741 742 743 744
            x,
            scale=True,
            shift=True,
            begin_norm_axis=1,
            epsilon=1e-05,
            param_attr=None,
            bias_attr=None,
        )
745
        x = paddle.static.nn.layer_norm(
746 747 748 749 750 751 752 753
            x,
            scale=False,
            shift=False,
            begin_norm_axis=1,
            epsilon=1e-05,
            param_attr=None,
            bias_attr=None,
        )
754
        x = paddle.static.nn.layer_norm(
755 756 757 758 759 760 761 762
            x,
            scale=False,
            shift=False,
            begin_norm_axis=1,
            epsilon=1e-05,
            param_attr="scale",
            bias_attr="shift",
        )
763 764


765 766 767
class TestDygraphLayerNormAPIError(unittest.TestCase):
    def test_errors(self):
        with program_guard(Program(), Program()):
F
furnace 已提交
768 769
            paddle.enable_static()

W
wangzhen38 已提交
770
            layer_norm = paddle.nn.LayerNorm([32, 32])
771 772 773 774 775 776
            # the input of LayerNorm must be Variable.
            x1 = np.random.random((3, 32, 32)).astype('float32')
            self.assertRaises(TypeError, layer_norm, x1)

            # the input dtype of LayerNorm must be float32 or float64
            # float16 only can be set on GPU place
G
GGBond8488 已提交
777 778 779
            x2 = paddle.static.data(
                name='x2', shape=[-1, 3, 32, 32], dtype="int32"
            )
780 781 782
            self.assertRaises(TypeError, layer_norm, x2)


783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813
class TestFP16ScaleBiasLayerNorm(unittest.TestCase):
    def check_main(self, x_np, weight_np, bias_np, dtype):
        paddle.disable_static()

        weight_np = weight_np.astype(dtype)
        bias_np = bias_np.astype(dtype)

        x = paddle.to_tensor(x_np)
        weight = paddle.to_tensor(weight_np)
        bias = paddle.to_tensor(bias_np)
        x.stop_gradient = False
        weight.stop_gradient = False
        bias.stop_gradient = False
        y = F.layer_norm(x, x.shape[1:], weight, bias)
        x_g, w_g, b_g = paddle.grad(y, [x, weight, bias])
        y_np = y.numpy().astype('float32')
        x_g_np = x_g.numpy().astype('float32')
        w_g_np = w_g.numpy().astype('float16')
        b_g_np = b_g.numpy().astype('float32')

        paddle.enable_static()
        return y_np, x_g_np, w_g_np, b_g_np

    def test_main(self):
        if not paddle.is_compiled_with_cuda():
            return
        x_np = np.random.random([10, 20]).astype('float16')
        weight_np = np.random.random([20]).astype('float16')
        bias_np = np.random.random([20]).astype('float16')

        y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main(
814 815
            x_np, weight_np, bias_np, 'float16'
        )
816
        y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main(
817 818
            x_np, weight_np, bias_np, 'float32'
        )
819 820

        def assert_equal(x, y):
821
            np.testing.assert_array_equal(x, y)
822 823 824 825 826 827 828

        assert_equal(y_np_1, y_np_2)
        assert_equal(x_g_np_1, x_g_np_2)
        assert_equal(w_g_np_1, w_g_np_2)
        assert_equal(b_g_np_1, b_g_np_2)


829 830 831 832
@unittest.skipIf(
    not core.is_compiled_with_cuda() or paddle.is_compiled_with_rocm(),
    "BF16 is only supported on CUDA.",
)
833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859
class TestBF16ScaleBiasLayerNorm(unittest.TestCase):
    def check_main(self, x_np, weight_np, bias_np, dtype):
        paddle.disable_static()

        x = paddle.to_tensor(x_np)
        weight = paddle.to_tensor(weight_np)
        bias = paddle.to_tensor(bias_np)

        if dtype == "bfloat16":
            x = x.cast(paddle.fluid.core.VarDesc.VarType.BF16)

        x.stop_gradient = False
        weight.stop_gradient = False
        bias.stop_gradient = False

        y = F.layer_norm(x, x.shape[1:], weight, bias)
        x_g, w_g, b_g = paddle.grad(y, [x, weight, bias])

        y_np = y.cast('float32').numpy()
        x_g_np = x_g.cast('float32').numpy()
        w_g_np = w_g.cast('float32').numpy()
        b_g_np = b_g.cast('float32').numpy()

        paddle.enable_static()
        return y_np, x_g_np, w_g_np, b_g_np

    def test_main(self):
860 861 862 863 864
        if (
            (not core.is_compiled_with_cuda())
            or (core.cudnn_version() < 8100)
            or (paddle.device.cuda.get_device_capability()[0] < 8)
        ):
865 866 867 868 869 870
            return
        x_np = np.random.random([10, 20]).astype('float32')
        weight_np = np.random.random([20]).astype('float32')
        bias_np = np.random.random([20]).astype('float32')

        y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main(
871 872
            x_np, weight_np, bias_np, 'float32'
        )
873
        y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main(
874 875
            x_np, weight_np, bias_np, 'bfloat16'
        )
876 877

        def assert_equal(x, y):
878
            np.testing.assert_allclose(x, y, rtol=1e-05, atol=3e-2)
879 880 881 882 883 884 885

        assert_equal(y_np_1, y_np_2)
        assert_equal(x_g_np_1, x_g_np_2)
        assert_equal(w_g_np_1, w_g_np_2)
        assert_equal(b_g_np_1, b_g_np_2)


886 887 888 889 890 891 892 893 894
class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase):
    def test_main(self):
        self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())
        _keep_layer_norm_scale_bias_to_fp32(False)
        self.assertFalse(_keep_layer_norm_scale_bias_to_fp32())
        _keep_layer_norm_scale_bias_to_fp32(True)
        self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())


895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965
class TestFastMathLayerNormOp(unittest.TestCase):
    def check_layer_norm(
        self, dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias
    ):
        paddle.disable_static()
        epsilon = 0.00001

        x = paddle.to_tensor(x_np)
        if dtype == "bfloat16":
            x = x.cast(paddle.fluid.core.VarDesc.VarType.BF16)

        x.stop_gradient = True
        bias = paddle.to_tensor(bias_np) if has_scale else None
        scale = paddle.to_tensor(scale_np) if has_bias else None
        if bias is not None:
            bias.stop_gradient = True
        if scale is not None:
            scale.stop_gradient = True

        y = F.layer_norm(x, x.shape[norm_axis:], scale, bias)
        y_np = y.cast('float32').numpy()
        paddle.enable_static()
        return y_np

    def check_with_fast_math(
        self, dtype, shape, norm_axis, has_scale, has_bias
    ):
        def use_fast_math(enabled):
            paddle.set_flags({'FLAGS_use_fast_math': enabled})

        def __assert_close(x, y):
            np.testing.assert_allclose(x, y, rtol=1e-05, atol=1e-04)

        x_np = np.random.random(shape).astype('float32')
        bias_np = np.random.random(shape[norm_axis:]).astype('float32')
        scale_np = np.random.random(shape[norm_axis:]).astype('float32')

        use_fast_math(False)
        y_fast = self.check_layer_norm(
            dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias
        )
        use_fast_math(True)
        y_dev = self.check_layer_norm(
            dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias
        )
        __assert_close(y_fast, y_dev)

    def check_with_dtype(self, dtype):
        self.check_with_fast_math(
            dtype,
            shape=[17, 129],
            norm_axis=1,
            has_scale=False,
            has_bias=True,
        )
        self.check_with_fast_math(
            dtype,
            shape=[8, 512],
            norm_axis=1,
            has_scale=False,
            has_bias=False,
        )
        self.check_with_fast_math(
            dtype,
            shape=[2, 768],
            norm_axis=1,
            has_scale=False,
            has_bias=False,
        )

    def test_main(self):
966
        if not paddle.is_compiled_with_cuda() or paddle.is_compiled_with_rocm():
967 968 969 970 971
            return
        self.check_with_dtype(dtype="float32")
        self.check_with_dtype(dtype="bfloat16")


C
chengduoZH 已提交
972
if __name__ == '__main__':
H
hong 已提交
973
    paddle.enable_static()
C
chengduoZH 已提交
974
    unittest.main()