test_layer_norm_op.py 25.0 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
C
chengduoZH 已提交
2 3 4 5 6 7 8 9 10 11 12 13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

C
chengduoZH 已提交
15
import unittest
16 17 18
from functools import reduce
from operator import mul

C
chengduoZH 已提交
19
import numpy as np
20
from eager_op_test import OpTest, _set_use_system_allocator
C
chengduoZH 已提交
21

22
import paddle
23
import paddle.nn.functional as F
24 25
from paddle import fluid
from paddle.fluid import Program, core, program_guard
26
from paddle.static.amp.fp16_utils import _keep_layer_norm_scale_bias_to_fp32
27 28

paddle.enable_static()
C
chengduoZH 已提交
29

30 31
np.random.seed(123)
paddle.seed(123)
C
chengduoZH 已提交
32

33 34
_set_use_system_allocator(True)

C
chengduoZH 已提交
35

C
chengduoZH 已提交
36
def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1):
C
chengduoZH 已提交
37 38
    x_shape = x.shape
    N = reduce(mul, x_shape[0:begin_norm_axis], 1)
39
    D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1)
C
chengduoZH 已提交
40
    x.shape = [N, D]
C
chengduoZH 已提交
41

C
chengduoZH 已提交
42 43
    mean = np.mean(x, axis=1)
    var = np.var(x, axis=1) + epsilon
44 45 46
    output = np.divide(
        (x - mean.reshape([N, 1])), (np.sqrt(var)).reshape([N, 1])
    )
47 48 49 50
    if scale is not None:
        output = scale.reshape([1, D]) * output
    if beta is not None:
        output = output + beta.reshape([1, D])
C
chengduoZH 已提交
51 52

    x.shape, output.shape = x_shape, x_shape
C
chengduoZH 已提交
53 54 55
    return output, mean, var


56 57 58
def _reference_layer_norm_grad(
    x, grad_y, scale, bias, mean, var, begin_norm_axis=1
):
C
chengduoZH 已提交
59
    x_shape = x.shape
C
chengduoZH 已提交
60
    N = reduce(mul, x_shape[0:begin_norm_axis], 1)
61
    D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1)
62 63 64 65

    if scale is not None:
        scale_shape = scale.shape
        scale.shape = [1, D]
C
chengduoZH 已提交
66 67
    x.shape, grad_y.shape = [N, D], [N, D]
    var.shape, mean.shape = [N, 1], [N, 1]
C
chengduoZH 已提交
68

C
chengduoZH 已提交
69
    # d_bias
70 71 72 73
    if bias is not None:
        d_bias = np.sum(grad_y, axis=0).reshape([1, D])
    else:
        d_bias = None
C
chengduoZH 已提交
74
    # d_scale
75
    if scale is not None:
76 77 78
        d_scale = np.sum(
            ((x - mean) * np.sqrt(1 / var)) * grad_y, axis=0
        ).reshape([1, D])
79 80
    else:
        d_scale = None
C
chengduoZH 已提交
81
    # dx
82 83
    if scale is not None:
        dx_end = scale * np.sqrt(1.0 / var) * grad_y
84 85 86
        d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * scale, axis=1).reshape(
            [N, 1]
        )  # the second part equals to zero.
87
        d_mean = 1.0 / D * d_mean_0
88 89 90 91 92
        d_std = np.sum(
            -(1.0 / var) * (x - mean) * grad_y * scale, axis=1
        ).reshape([N, 1]) * (
            1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean)
        )
93 94
    else:
        dx_end = 1.0 * np.sqrt(1.0 / var) * grad_y
95 96 97
        d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * 1.0, axis=1).reshape(
            [N, 1]
        )  # the second part equals to zero.
98
        d_mean = 1.0 / D * d_mean_0
99 100 101 102 103
        d_std = np.sum(
            -(1.0 / var) * (x - mean) * grad_y * 1.0, axis=1
        ).reshape([N, 1]) * (
            1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean)
        )
C
chengduoZH 已提交
104

C
chengduoZH 已提交
105
    grad_x = dx_end + d_mean + d_std
C
chengduoZH 已提交
106

C
chengduoZH 已提交
107
    grad_x.shape, x.shape, grad_y.shape = x_shape, x_shape, x_shape
108
    var.shape, mean.shape = [N], [N]
109 110 111

    if scale is not None:
        scale.shape = scale_shape
C
chengduoZH 已提交
112
    return grad_x, d_scale, d_bias
C
chengduoZH 已提交
113 114


115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
def layer_norm_wrapper(
    x, scale=None, bias=None, epsilon=1e-05, begin_norm_axis=1
):
    input_shape = list(x.shape)
    normalized_shape = input_shape[begin_norm_axis:]
    return paddle.nn.functional.layer_norm(
        x, normalized_shape, weight=scale, bias=bias, epsilon=epsilon
    )


class TestLayerNormOpByOpTest(OpTest):
    def setUp(self):
        self.python_api = layer_norm_wrapper
        self.public_python_api = layer_norm_wrapper
        self.op_type = "layer_norm"
        self.prim_op_type = "comp"
        self.python_out_sig = ["Y"]
        self.initConfig()
        self.initTestCase()

    def test_check_output(self):
        self.check_output(
            no_check_set=["Mean", "Variance"],
            atol=self.ori_atol,
            rtol=self.ori_rtol,
            check_prim=True,
        )

    def test_check_grad(self):
        self.check_grad(
            self.check_grad_input_list,
            ['Y'],
            max_relative_error=self.max_relative_error,
            check_prim=True,
        )

    def initConfig(self):
        self.rev_comp_atol = 1e-7
        self.rev_comp_rtol = 1e-7
        self.fw_comp_atol = 1e-6
        self.fw_comp_rtol = 1e-6

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.cinn_atol = 1e-5
        self.cinn_rtol = 1e-5

        self.max_relative_error = 1e-5

        self.dtype = "float64"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = True
        self.has_bias = True

    def initTestCase(self):
        np.random.seed(123)

        self.D = reduce(
            mul, self.x_shape[self.begin_norm_axis : len(self.x_shape)], 1
        )
        self.scale_shape = [self.D]
        x = np.random.random(self.x_shape).astype(self.dtype)
        scale = (
            np.random.random(self.scale_shape).astype(self.dtype)
            if self.has_scale
            else None
        )
        bias = (
            np.random.random(self.scale_shape).astype(self.dtype)
            if self.has_bias
            else None
        )
        self.inputs = {
            "X": x,
        }
        self.check_grad_input_list = ['X']

        if self.has_scale:
            self.inputs.update({"Scale": scale})
            self.check_grad_input_list.append('Scale')
        if self.has_bias:
            self.inputs.update({"Bias": bias})
            self.check_grad_input_list.append('Bias')

        self.attrs = {
            "epsilon": self.epsilon,
            "begin_norm_axis": self.begin_norm_axis,
        }
        y, mean, variance = _reference_layer_norm_naive(
            x, scale, bias, self.epsilon, self.begin_norm_axis
        )
        self.outputs = {
            "Y": y,
            "Mean": mean,
            "Variance": variance,
        }


class TestLayerNormOpByOpTestFP64_case2(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-6
        self.rev_comp_rtol = 1e-6
        self.fw_comp_atol = 1e-7
        self.fw_comp_rtol = 1e-7

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.cinn_atol = 1e-5
        self.cinn_rtol = 1e-5

        self.max_relative_error = 1e-5

        self.dtype = "float64"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = False
        self.has_bias = False


class TestLayerNormOpByOpTestFP64_case3(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-7
        self.rev_comp_rtol = 1e-7
        self.fw_comp_atol = 1e-7
        self.fw_comp_rtol = 1e-7

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.cinn_atol = 1e-5
        self.cinn_rtol = 1e-5

        self.max_relative_error = 1e-5

        self.dtype = "float64"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = True
        self.has_bias = False


class TestLayerNormOpByOpTestFP64_case4(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-6
        self.rev_comp_rtol = 1e-6
        self.fw_comp_atol = 1e-7
        self.fw_comp_rtol = 1e-7

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.cinn_atol = 1e-5
        self.cinn_rtol = 1e-5

        self.max_relative_error = 1e-5

        self.dtype = "float64"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = False
        self.has_bias = True


class TestLayerNormOpByOpTestFP32(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-5
        self.rev_comp_rtol = 1e-5

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.max_relative_error = 7e-3

        self.dtype = "float32"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = True
        self.has_bias = True


class TestLayerNormOpByOpTestFP32_case2(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-5
        self.rev_comp_rtol = 1e-5

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.max_relative_error = 1e-5

        self.dtype = "float32"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = False
        self.has_bias = False


class TestLayerNormOpByOpTestFP32_case3(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-5
        self.rev_comp_rtol = 1e-5

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.max_relative_error = 3e-3

        self.dtype = "float32"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = True
        self.has_bias = False


class TestLayerNormOpByOpTestFP32_case4(TestLayerNormOpByOpTest):
    def initConfig(self):
        self.rev_comp_atol = 1e-5
        self.rev_comp_rtol = 1e-5

        self.ori_atol = 1e-4
        self.ori_rtol = 1e-4
        self.max_relative_error = 1e-3

        self.dtype = "float32"
        self.x_shape = [2, 6, 6, 3]
        self.epsilon = 0.00001
        self.begin_norm_axis = 1
        self.has_scale = False
        self.has_bias = True


349
class TestLayerNormOp(unittest.TestCase):
350 351 352
    def setUp(self):
        self.use_cudnn = True

C
chengduoZH 已提交
353
    def __assert_close(self, tensor, np_array, msg, atol=1e-4):
C
chengduoZH 已提交
354
        self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
C
chengduoZH 已提交
355

356 357 358 359 360 361 362 363 364 365 366 367
    def check_forward_backward(
        self,
        shape,
        begin_norm_axis,
        has_scale=True,
        has_bias=True,
        y_grad_scale=1.0,
        use_mkldnn=False,
    ):
        def test_with_place(
            place, shape, begin_norm_axis, use_mkldnn=use_mkldnn
        ):
C
chengduoZH 已提交
368 369 370
            # attr
            epsilon = 0.00001
            x_shape = shape
371
            D = reduce(mul, x_shape[begin_norm_axis : len(x_shape)], 1)
C
chengduoZH 已提交
372
            scale_shape = [D]
C
chengduoZH 已提交
373

374 375
            np.random.seed(123)
            x = np.random.random_sample(x_shape).astype(np.float32)
376 377 378 379 380 381 382 383 384 385
            scale = (
                np.random.random_sample(scale_shape).astype(np.float32)
                if has_scale
                else None
            )
            bias = (
                np.random.random_sample(scale_shape).astype(np.float32)
                if has_bias
                else None
            )
386
            y_grad = (np.random.random_sample(x_shape) * y_grad_scale).astype(
387 388
                np.float32
            )
C
chengduoZH 已提交
389

390 391
            # reference forward & backward
            y, mean, variance = _reference_layer_norm_naive(
392 393
                x, scale, bias, epsilon, begin_norm_axis
            )
394
            x_grad, scale_grad, bias_grad = _reference_layer_norm_grad(
395 396
                x, y_grad, scale, bias, mean, variance, begin_norm_axis
            )
397 398 399

            var_dict = locals()
            var_dict['y@GRAD'] = y_grad
400 401 402 403 404
            var_names = ['x', 'mean', 'variance', 'y', 'y@GRAD']
            if has_scale:
                var_names += ['scale']
            if has_bias:
                var_names += ['bias']
405 406 407 408 409 410
            ground_truth = {name: var_dict[name] for name in var_names}

            program = fluid.Program()
            with fluid.program_guard(program):
                block = program.global_block()
                for name in ground_truth:
411 412 413 414 415
                    block.create_var(
                        name=name,
                        dtype='float32',
                        shape=ground_truth[name].shape,
                    )
416 417 418 419 420 421 422 423 424 425 426 427 428
                inputs = {"X": block.var('x')}
                fetch_list = [
                    'y',
                    'mean',
                    'variance',
                    'x@GRAD',
                ]
                if has_scale:
                    inputs["Scale"] = block.var('scale')
                    fetch_list += ['scale@GRAD']
                if has_bias:
                    inputs["Bias"] = block.var('bias')
                    fetch_list += ['bias@GRAD']
429 430
                layer_norm_op = block.append_op(
                    type="layer_norm",
431
                    inputs=inputs,
432 433 434
                    outputs={
                        "Y": block.var('y'),
                        "Mean": block.var('mean'),  # share the same memory
435 436 437
                        "Variance": block.var(
                            'variance'
                        ),  # share the same memory
438 439 440
                    },
                    attrs={
                        "epsilon": epsilon,
441
                        "begin_norm_axis": begin_norm_axis,
442 443 444
                        "use_mkldnn": use_mkldnn,
                    },
                )
445 446
                # generate backward op_desc
                grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
447 448
                    layer_norm_op.desc, set(), []
                )
449 450 451 452 453 454 455 456 457 458 459
                grad_op_desc = grad_op_desc_list[0]
                new_op_desc = block.desc.append_op()
                new_op_desc.copy_from(grad_op_desc)
                for var_name in grad_op_desc.output_arg_names():
                    block.desc.var(var_name.encode("ascii"))
                grad_op_desc.infer_var_type(block.desc)
                grad_op_desc.infer_shape(block.desc)
                for arg in grad_op_desc.output_arg_names():
                    grad_var = block.desc.find_var(arg.encode("ascii"))
                    grad_var.set_dtype(core.VarDesc.VarType.FP32)

460
                program._sync_with_cpp()
461
                exe = fluid.Executor(place)
462 463 464 465 466 467 468 469
                out = exe.run(
                    program,
                    feed={
                        name: var_dict[name]
                        for name in ['x', 'scale', 'bias', 'y@GRAD']
                    },
                    fetch_list=fetch_list,
                )
H
hong 已提交
470 471
                # print(y)
                # print(out[0])
L
Leo Chen 已提交
472
                self.__assert_close(y, out[0], "y")
473 474 475
                self.__assert_close(mean, out[1], "mean")
                self.__assert_close(variance, out[2], "variance", 1e-3)
                self.__assert_close(x_grad, out[3], "x_grad")
476
                if has_scale:
477 478 479 480 481 482
                    self.__assert_close(
                        scale_grad,
                        out[fetch_list.index('scale@GRAD')],
                        "scale_grad",
                        1e-3,
                    )
483
                if has_bias:
484 485 486 487 488
                    self.__assert_close(
                        bias_grad,
                        out[fetch_list.index('bias@GRAD')],
                        "bias_grad",
                    )
C
chengduoZH 已提交
489 490

        places = [core.CPUPlace()]
491 492 493 494 495
        if (
            core.is_compiled_with_cuda()
            and core.op_support_gpu("layer_norm")
            and self.use_cudnn
        ):
C
chengduoZH 已提交
496 497 498
            places.append(core.CUDAPlace(0))

        for place in places:
C
chengduoZH 已提交
499 500
            test_with_place(place, shape, begin_norm_axis)

501
    def test_check_forward_backward_with_scale_and_bias(self):
C
chengduoZH 已提交
502
        self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1)
503
        self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1)
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521
        self.check_forward_backward(
            shape=[2, 3, 4, 5],
            begin_norm_axis=1,
            has_scale=False,
            has_bias=True,
        )
        self.check_forward_backward(
            shape=[2, 3, 4, 5],
            begin_norm_axis=1,
            has_scale=True,
            has_bias=False,
        )
        self.check_forward_backward(
            shape=[2, 3, 4, 5],
            begin_norm_axis=1,
            has_scale=False,
            has_bias=False,
        )
C
chengduoZH 已提交
522
        self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3)
523 524 525
        self.check_forward_backward(
            shape=[92, 513, 129], begin_norm_axis=2, y_grad_scale=0.1
        )
526
        self.check_forward_backward(shape=[3, 34, 1134], begin_norm_axis=2)
527
        self.check_forward_backward(shape=[3, 2, 1133], begin_norm_axis=2)
528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566
        self.check_forward_backward(
            shape=[92, 513, 1134], begin_norm_axis=2, y_grad_scale=0.1
        )
        self.check_forward_backward(
            shape=[92, 513, 1134],
            begin_norm_axis=2,
            has_scale=False,
            has_bias=True,
            y_grad_scale=0.1,
        )
        self.check_forward_backward(
            shape=[92, 513, 1134],
            begin_norm_axis=2,
            has_scale=True,
            has_bias=False,
            y_grad_scale=0.1,
        )
        self.check_forward_backward(
            shape=[92, 513, 1134],
            begin_norm_axis=2,
            has_scale=False,
            has_bias=False,
            y_grad_scale=0.1,
        )
        self.check_forward_backward(
            shape=[512, 1024], begin_norm_axis=1, has_scale=True, has_bias=True
        )
        self.check_forward_backward(
            shape=[1, 128, 256, 256],
            begin_norm_axis=3,
            has_scale=True,
            has_bias=True,
        )
        self.check_forward_backward(
            shape=[1, 256, 384],
            begin_norm_axis=2,
            has_scale=True,
            has_bias=True,
        )
C
chengduoZH 已提交
567 568


569 570
class TestLayerNormAPI(unittest.TestCase):
    def test_case(self):
G
GGBond8488 已提交
571
        x = paddle.static.data(name='x', shape=[64, 32, 256], dtype='float32')
572
        x = paddle.static.nn.layer_norm(
573 574 575 576 577 578 579 580
            x,
            scale=True,
            shift=True,
            begin_norm_axis=1,
            epsilon=1e-05,
            param_attr=None,
            bias_attr=None,
        )
581
        x = paddle.static.nn.layer_norm(
582 583 584 585 586 587 588 589
            x,
            scale=False,
            shift=False,
            begin_norm_axis=1,
            epsilon=1e-05,
            param_attr=None,
            bias_attr=None,
        )
590
        x = paddle.static.nn.layer_norm(
591 592 593 594 595 596 597 598
            x,
            scale=False,
            shift=False,
            begin_norm_axis=1,
            epsilon=1e-05,
            param_attr="scale",
            bias_attr="shift",
        )
599 600


601 602 603
class TestDygraphLayerNormAPIError(unittest.TestCase):
    def test_errors(self):
        with program_guard(Program(), Program()):
F
furnace 已提交
604 605
            paddle.enable_static()

W
wangzhen38 已提交
606
            layer_norm = paddle.nn.LayerNorm([32, 32])
607 608 609 610 611 612
            # the input of LayerNorm must be Variable.
            x1 = np.random.random((3, 32, 32)).astype('float32')
            self.assertRaises(TypeError, layer_norm, x1)

            # the input dtype of LayerNorm must be float32 or float64
            # float16 only can be set on GPU place
G
GGBond8488 已提交
613 614 615
            x2 = paddle.static.data(
                name='x2', shape=[-1, 3, 32, 32], dtype="int32"
            )
616 617 618
            self.assertRaises(TypeError, layer_norm, x2)


619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649
class TestFP16ScaleBiasLayerNorm(unittest.TestCase):
    def check_main(self, x_np, weight_np, bias_np, dtype):
        paddle.disable_static()

        weight_np = weight_np.astype(dtype)
        bias_np = bias_np.astype(dtype)

        x = paddle.to_tensor(x_np)
        weight = paddle.to_tensor(weight_np)
        bias = paddle.to_tensor(bias_np)
        x.stop_gradient = False
        weight.stop_gradient = False
        bias.stop_gradient = False
        y = F.layer_norm(x, x.shape[1:], weight, bias)
        x_g, w_g, b_g = paddle.grad(y, [x, weight, bias])
        y_np = y.numpy().astype('float32')
        x_g_np = x_g.numpy().astype('float32')
        w_g_np = w_g.numpy().astype('float16')
        b_g_np = b_g.numpy().astype('float32')

        paddle.enable_static()
        return y_np, x_g_np, w_g_np, b_g_np

    def test_main(self):
        if not paddle.is_compiled_with_cuda():
            return
        x_np = np.random.random([10, 20]).astype('float16')
        weight_np = np.random.random([20]).astype('float16')
        bias_np = np.random.random([20]).astype('float16')

        y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main(
650 651
            x_np, weight_np, bias_np, 'float16'
        )
652
        y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main(
653 654
            x_np, weight_np, bias_np, 'float32'
        )
655 656

        def assert_equal(x, y):
657
            np.testing.assert_array_equal(x, y)
658 659 660 661 662 663 664

        assert_equal(y_np_1, y_np_2)
        assert_equal(x_g_np_1, x_g_np_2)
        assert_equal(w_g_np_1, w_g_np_2)
        assert_equal(b_g_np_1, b_g_np_2)


665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691
class TestBF16ScaleBiasLayerNorm(unittest.TestCase):
    def check_main(self, x_np, weight_np, bias_np, dtype):
        paddle.disable_static()

        x = paddle.to_tensor(x_np)
        weight = paddle.to_tensor(weight_np)
        bias = paddle.to_tensor(bias_np)

        if dtype == "bfloat16":
            x = x.cast(paddle.fluid.core.VarDesc.VarType.BF16)

        x.stop_gradient = False
        weight.stop_gradient = False
        bias.stop_gradient = False

        y = F.layer_norm(x, x.shape[1:], weight, bias)
        x_g, w_g, b_g = paddle.grad(y, [x, weight, bias])

        y_np = y.cast('float32').numpy()
        x_g_np = x_g.cast('float32').numpy()
        w_g_np = w_g.cast('float32').numpy()
        b_g_np = b_g.cast('float32').numpy()

        paddle.enable_static()
        return y_np, x_g_np, w_g_np, b_g_np

    def test_main(self):
692 693 694 695 696
        if (
            (not core.is_compiled_with_cuda())
            or (core.cudnn_version() < 8100)
            or (paddle.device.cuda.get_device_capability()[0] < 8)
        ):
697 698 699 700 701 702
            return
        x_np = np.random.random([10, 20]).astype('float32')
        weight_np = np.random.random([20]).astype('float32')
        bias_np = np.random.random([20]).astype('float32')

        y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main(
703 704
            x_np, weight_np, bias_np, 'float32'
        )
705
        y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main(
706 707
            x_np, weight_np, bias_np, 'bfloat16'
        )
708 709

        def assert_equal(x, y):
710
            np.testing.assert_allclose(x, y, rtol=1e-05, atol=3e-2)
711 712 713 714 715 716 717

        assert_equal(y_np_1, y_np_2)
        assert_equal(x_g_np_1, x_g_np_2)
        assert_equal(w_g_np_1, w_g_np_2)
        assert_equal(b_g_np_1, b_g_np_2)


718 719 720 721 722 723 724 725 726
class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase):
    def test_main(self):
        self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())
        _keep_layer_norm_scale_bias_to_fp32(False)
        self.assertFalse(_keep_layer_norm_scale_bias_to_fp32())
        _keep_layer_norm_scale_bias_to_fp32(True)
        self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())


727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803
class TestFastMathLayerNormOp(unittest.TestCase):
    def check_layer_norm(
        self, dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias
    ):
        paddle.disable_static()
        epsilon = 0.00001

        x = paddle.to_tensor(x_np)
        if dtype == "bfloat16":
            x = x.cast(paddle.fluid.core.VarDesc.VarType.BF16)

        x.stop_gradient = True
        bias = paddle.to_tensor(bias_np) if has_scale else None
        scale = paddle.to_tensor(scale_np) if has_bias else None
        if bias is not None:
            bias.stop_gradient = True
        if scale is not None:
            scale.stop_gradient = True

        y = F.layer_norm(x, x.shape[norm_axis:], scale, bias)
        y_np = y.cast('float32').numpy()
        paddle.enable_static()
        return y_np

    def check_with_fast_math(
        self, dtype, shape, norm_axis, has_scale, has_bias
    ):
        def use_fast_math(enabled):
            paddle.set_flags({'FLAGS_use_fast_math': enabled})

        def __assert_close(x, y):
            np.testing.assert_allclose(x, y, rtol=1e-05, atol=1e-04)

        x_np = np.random.random(shape).astype('float32')
        bias_np = np.random.random(shape[norm_axis:]).astype('float32')
        scale_np = np.random.random(shape[norm_axis:]).astype('float32')

        use_fast_math(False)
        y_fast = self.check_layer_norm(
            dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias
        )
        use_fast_math(True)
        y_dev = self.check_layer_norm(
            dtype, x_np, scale_np, bias_np, norm_axis, has_scale, has_bias
        )
        __assert_close(y_fast, y_dev)

    def check_with_dtype(self, dtype):
        self.check_with_fast_math(
            dtype,
            shape=[17, 129],
            norm_axis=1,
            has_scale=False,
            has_bias=True,
        )
        self.check_with_fast_math(
            dtype,
            shape=[8, 512],
            norm_axis=1,
            has_scale=False,
            has_bias=False,
        )
        self.check_with_fast_math(
            dtype,
            shape=[2, 768],
            norm_axis=1,
            has_scale=False,
            has_bias=False,
        )

    def test_main(self):
        if not paddle.is_compiled_with_cuda():
            return
        self.check_with_dtype(dtype="float32")
        self.check_with_dtype(dtype="bfloat16")


C
chengduoZH 已提交
804
if __name__ == '__main__':
H
hong 已提交
805
    paddle.enable_static()
C
chengduoZH 已提交
806
    unittest.main()