op_test.py 102.3 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

B
baojun 已提交
15
import os
16
import sys
17
import unittest
18
import warnings
19
import numpy as np
20
import random
21
import functools
22
import struct
M
minqiyang 已提交
23
from collections import defaultdict
24
from copy import copy
25

26
import paddle
27
import paddle.fluid as fluid
28
from paddle.fluid.framework import _dygraph_tracer
29
import paddle.fluid.core as core
30 31 32 33 34 35
from paddle.fluid.framework import (
    _in_legacy_dygraph,
    _enable_legacy_dygraph,
    _in_eager_without_dygraph_check,
    _disable_legacy_dygraph,
)
36
from paddle.fluid.framework import _test_eager_guard
37 38 39
from paddle.fluid.backward import append_backward
from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor
40 41 42 43 44
from paddle.fluid.framework import (
    OpProtoHolder,
    Program,
    _current_expected_place,
)
45 46 47 48 49
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs

sys.path.append(os.path.abspath(os.path.dirname(__file__)))
from testsuite import (
50 51 52
    create_op,
    set_input,
    append_input_output,
53 54
    append_loss_ops,
)
55
from white_list import (
56 57 58 59 60
    op_accuracy_white_list,
    check_shape_white_list,
    compile_vs_runtime_white_list,
    no_check_set_white_list,
    op_threshold_white_list,
61 62
    no_grad_set_white_list,
)
63

64 65
# For switch new eager mode globally
g_is_in_eager = _in_eager_without_dygraph_check()
66 67 68 69 70 71
g_enable_legacy_dygraph = (
    _enable_legacy_dygraph if g_is_in_eager else lambda: None
)
g_disable_legacy_dygraph = (
    _disable_legacy_dygraph if g_is_in_eager else lambda: None
)
72

73

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
def check_out_dtype(api_fn, in_specs, expect_dtypes, target_index=0, **configs):
    """
    Determines whether dtype of output tensor is as expected.

    Args:
        api_fn(callable):  paddle api function
        in_specs(list[tuple]): list of shape and dtype information for constructing input tensor of api_fn, such as [(shape, dtype), (shape, dtype)].
        expected_dtype(list[str]): expected dtype of output tensor.
        target_index(int): indicate which one from in_specs to infer the dtype of output.
        config(dict): other arguments of paddle api function

    Example:
        check_out_dtype(fluid.layers.pad_constant_like, [([2,3,2,3], 'float64'), ([1, 3, 1,3], )], ['float32', 'float64', 'int64'], target_index=1, pad_value=0.)

    """
    paddle.enable_static()
    for i, expect_dtype in enumerate(expect_dtypes):
        with paddle.static.program_guard(paddle.static.Program()):
            input_t = []
            for index, spec in enumerate(in_specs):
                if len(spec) == 1:
                    shape = spec[0]
                    dtype = expect_dtype if target_index == index else 'float32'
                elif len(spec) == 2:
                    shape, dtype = spec
                else:
                    raise ValueError(
101 102 103 104
                        "Value of in_specs[{}] should contains two elements: [shape, dtype]".format(
                            index
                        )
                    )
105
                input_t.append(
106 107 108 109
                    paddle.static.data(
                        name='data_%s' % index, shape=shape, dtype=dtype
                    )
                )
110 111 112 113 114 115 116

            out = api_fn(*input_t, **configs)
            out_dtype = fluid.data_feeder.convert_dtype(out.dtype)

            if out_dtype != expect_dtype:
                raise ValueError(
                    "Expected out.dtype is {}, but got {} from {}.".format(
117 118 119
                        expect_dtype, out_dtype, api_fn.__name__
                    )
                )
120 121


122 123 124 125 126 127 128 129
def _set_use_system_allocator(value=None):
    USE_SYSTEM_ALLOCATOR_FLAG = "FLAGS_use_system_allocator"
    old_value = core.globals()[USE_SYSTEM_ALLOCATOR_FLAG]
    value = old_value if value is None else value
    core.globals()[USE_SYSTEM_ALLOCATOR_FLAG] = value
    return old_value


130
def randomize_probability(batch_size, class_num, dtype='float32'):
131 132 133
    prob = np.random.uniform(0.1, 1.0, size=(batch_size, class_num)).astype(
        dtype
    )
134
    prob_sum = prob.sum(axis=1)
135
    for i in range(len(prob)):
136 137 138 139
        prob[i] /= prob_sum[i]
    return prob


140 141 142 143 144 145 146 147 148 149
def get_numeric_gradient(
    place,
    scope,
    op,
    inputs,
    input_to_check,
    output_names,
    delta=0.005,
    in_place=False,
):
Y
Yu Yang 已提交
150
    # FIXME: change this method by compile time concepts
151
    set_input(scope, op, inputs, place)
152 153

    def product(dim):
154
        return functools.reduce(lambda a, b: a * b, dim, 1)
155 156

    tensor_to_check = scope.find_var(input_to_check).get_tensor()
Y
yuyang18 已提交
157 158
    tensor_size = product(tensor_to_check.shape())
    tensor_to_check_dtype = tensor_to_check._dtype()
159
    if tensor_to_check_dtype == core.VarDesc.VarType.FP32:
160
        tensor_to_check_dtype = np.float32
161
    elif tensor_to_check_dtype == core.VarDesc.VarType.FP64:
162
        tensor_to_check_dtype = np.float64
D
dzhwinter 已提交
163 164 165 166
    elif tensor_to_check_dtype == core.VarDesc.VarType.FP16:
        tensor_to_check_dtype = np.float16
        # set delta as np.float16, will automatic convert to float32, float64
        delta = np.array(delta).astype(np.float16)
167 168
    elif tensor_to_check_dtype == core.VarDesc.VarType.BF16:
        tensor_to_check_dtype = np.float32
L
Lijunhui 已提交
169 170 171
    elif tensor_to_check_dtype == core.VarDesc.VarType.COMPLEX64:
        tensor_to_check_dtype = np.complex64
    elif tensor_to_check_dtype == core.VarDesc.VarType.COMPLEX128:
172
        tensor_to_check_dtype = np.complex128
173
    else:
174 175 176 177 178 179
        raise ValueError(
            "Not supported data type "
            + str(tensor_to_check_dtype)
            + ", tensor name : "
            + str(input_to_check)
        )
180

C
chengduo 已提交
181 182 183 184
    def get_output():
        sum = []
        op.run(scope, place)
        for output_name in output_names:
185
            output_numpy = np.array(scope.find_var(output_name).get_tensor())
Y
Yiqun Liu 已提交
186 187 188
            # numpy.dtype does not have bfloat16, thus we use numpy.uint16 to
            # store bfloat16 data, and need to be converted to float to check
            # the floating precision.
189 190 191
            if tensor_to_check._dtype() == core.VarDesc.VarType.BF16:
                output_numpy = convert_uint16_to_float(output_numpy)
            sum.append(output_numpy.astype(tensor_to_check_dtype).mean())
C
chengduo 已提交
192 193
        return tensor_to_check_dtype(np.array(sum).sum() / len(output_names))

194
    gradient_flat = np.zeros(shape=(tensor_size,), dtype=tensor_to_check_dtype)
195 196

    def __get_elem__(tensor, i):
D
dzhwinter 已提交
197 198 199 200
        if tensor_to_check_dtype == np.float16:
            numpy_tensor = np.array(tensor).astype(np.float16)
            numpy_tensor = numpy_tensor.flatten()
            return numpy_tensor[i]
201 202 203
        elif tensor_to_check._dtype() == core.VarDesc.VarType.BF16:
            numpy_tensor = np.array(tensor).astype(np.uint16)
            numpy_tensor = numpy_tensor.flatten()
204 205
            return struct.unpack(
                '<f',
206 207
                struct.pack('<I', np.uint32(numpy_tensor[i]) << np.uint32(16)),
            )[0]
D
dzhwinter 已提交
208
        elif tensor_to_check_dtype == np.float32:
Y
yuyang18 已提交
209
            return tensor._get_float_element(i)
210
        elif tensor_to_check_dtype == np.float64:
Y
yuyang18 已提交
211
            return tensor._get_double_element(i)
212
        else:
213 214 215
            raise TypeError(
                "Unsupported test data type %s." % tensor_to_check_dtype
            )
216 217

    def __set_elem__(tensor, i, e):
D
dzhwinter 已提交
218 219 220 221 222
        if tensor_to_check_dtype == np.float16:
            numpy_tensor = np.array(tensor).astype(np.float16)
            shape = numpy_tensor.shape
            numpy_tensor = numpy_tensor.flatten()
            numpy_tensor[i] = e
223
            numpy_tensor = numpy_tensor.reshape(shape)
D
dzhwinter 已提交
224
            tensor.set(numpy_tensor, place)
225 226 227 228 229 230 231
        elif tensor_to_check._dtype() == core.VarDesc.VarType.BF16:
            numpy_tensor = np.array(tensor).astype(np.uint16)
            shape = numpy_tensor.shape
            numpy_tensor = numpy_tensor.flatten()
            numpy_tensor[i] = np.uint16(copy_bits_from_float_to_uint16(e))
            numpy_tensor = numpy_tensor.reshape(shape)
            tensor.set(numpy_tensor, place)
D
dzhwinter 已提交
232
        elif tensor_to_check_dtype == np.float32:
Y
yuyang18 已提交
233
            tensor._set_float_element(i, e)
234
        elif tensor_to_check_dtype == np.float64:
Y
yuyang18 已提交
235
            tensor._set_double_element(i, e)
236
        else:
237 238 239
            raise TypeError(
                "Unsupported test data type %s." % tensor_to_check_dtype
            )
240

241 242
    # we only compute gradient of one element each time.
    # we use a for loop to compute the gradient of every element.
243
    for i in range(tensor_size):
244
        if in_place:
245
            set_input(scope, op, inputs, place)
246 247

        # get one input element throw it's index i.
248
        origin = __get_elem__(tensor_to_check, i)
249 250
        # add delta to it, run op and then get the sum of the result tensor.
        x_pos = origin + delta
251
        __set_elem__(tensor_to_check, i, x_pos)
252 253 254
        y_pos = get_output()

        if in_place:
255
            set_input(scope, op, inputs, place)
256 257

        x_neg = origin - delta
258
        __set_elem__(tensor_to_check, i, x_neg)
259 260
        y_neg = get_output()

261
        __set_elem__(tensor_to_check, i, origin)
262 263
        gradient_flat[i] = (y_pos - y_neg) / delta / 2

Y
yuyang18 已提交
264
    return gradient_flat.reshape(tensor_to_check.shape())
265 266


267 268
def skip_check_grad_ci(reason=None):
    """Decorator to skip check_grad CI.
C
cc 已提交
269

270 271 272
    Check_grad is required for Op test cases. However, there are some special
    cases that do not need to do check_grad. This decorator is used to skip the
    check_grad of the above cases.
C
cc 已提交
273

274 275
    Note: the execution of unit test will not be skipped. It just avoids check_grad
    checking in tearDownClass method by setting a `no_need_check_grad` flag.
276

277 278 279
    Example:
        @skip_check_grad_ci(reason="For inference, check_grad is not required.")
        class TestInference(OpTest):
280 281 282 283 284 285 286 287 288 289 290
    """
    if not isinstance(reason, str):
        raise AssertionError("The reason for skipping check_grad is required.")

    def wrapper(cls):
        cls.no_need_check_grad = True
        return cls

    return wrapper


291 292 293
def skip_check_inplace_ci(reason=None):
    if not isinstance(reason, str):
        raise AssertionError(
294 295
            "The reason for skipping check_inplace is required."
        )
296 297 298 299 300 301 302 303

    def wrapper(cls):
        cls.no_need_check_inplace = True
        return cls

    return wrapper


304 305 306 307
def copy_bits_from_float_to_uint16(f):
    return struct.unpack('<I', struct.pack('<f', f))[0] >> 16


308 309 310 311
def convert_float_to_uint16(float_list, data_format="NCHW"):
    if data_format == "NHWC":
        float_list = np.transpose(float_list, [0, 3, 1, 2])

312 313 314
    new_output = []
    for x in np.nditer(float_list):
        new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
315
    new_output = np.reshape(new_output, float_list.shape).view(np.uint16)
316

317 318 319
    if data_format == "NHWC":
        new_output = np.transpose(new_output, [0, 2, 3, 1])
    return new_output
320 321


322 323
def convert_uint16_to_float(in_list):
    in_list = np.asarray(in_list)
324 325 326 327 328 329
    out = np.vectorize(
        lambda x: struct.unpack(
            '<f', struct.pack('<I', np.uint32(x) << np.uint32(16))
        )[0],
        otypes=[np.float32],
    )(in_list.flat)
330
    return np.reshape(out, in_list.shape)
331 332


333
class OpTest(unittest.TestCase):
334 335 336 337 338
    @classmethod
    def setUpClass(cls):
        '''Fix random seeds to remove randomness from tests'''
        cls._np_rand_state = np.random.get_state()
        cls._py_rand_state = random.getstate()
339
        cls.call_once = False
340
        cls.dtype = None
341
        cls.outputs = {}
342
        cls.input_shape_is_large = True
343 344 345 346

        np.random.seed(123)
        random.seed(124)

347 348 349 350
        if paddle.is_compiled_with_npu():
            cls._use_system_allocator = _set_use_system_allocator(False)
        else:
            cls._use_system_allocator = _set_use_system_allocator(True)
351

352 353
    @classmethod
    def tearDownClass(cls):
Y
yuyang18 已提交
354
        """Restore random seeds"""
355 356 357
        np.random.set_state(cls._np_rand_state)
        random.setstate(cls._py_rand_state)

358 359
        _set_use_system_allocator(cls._use_system_allocator)

360 361 362 363
        def is_empty_grad_op(op_type):
            all_op_kernels = core._get_all_register_op_kernels()
            grad_op = op_type + '_grad'
            if grad_op in all_op_kernels.keys():
J
juncaipeng 已提交
364
                if is_mkldnn_op_test():
365 366 367 368 369 370 371 372
                    grad_op_kernels = all_op_kernels[grad_op]
                    for grad_op_kernel in grad_op_kernels:
                        if 'MKLDNN' in grad_op_kernel:
                            return False
                else:
                    return False
            return True

373 374 375
        def is_xpu_op_test():
            return hasattr(cls, "use_xpu") and cls.use_xpu == True

J
juncaipeng 已提交
376
        def is_mkldnn_op_test():
377
            return hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True
J
juncaipeng 已提交
378

379 380 381
        def is_rocm_op_test():
            return core.is_compiled_with_rocm()

382 383 384
        def is_npu_op_test():
            return hasattr(cls, "use_npu") and cls.use_npu == True

385 386 387
        def is_mlu_op_test():
            return hasattr(cls, "use_mlu") and cls.use_mlu == True

388
        def is_custom_device_op_test():
389 390 391 392
            return (
                hasattr(cls, "use_custom_device")
                and cls.use_custom_device == True
            )
393

394 395
        if not hasattr(cls, "op_type"):
            raise AssertionError(
396
                "This test do not have op_type in class attrs, "
397 398
                "please set self.__class__.op_type=the_real_op_type manually."
            )
399

J
juncaipeng 已提交
400
        # case in NO_FP64_CHECK_GRAD_CASES and op in NO_FP64_CHECK_GRAD_OP_LIST should be fixed
401 402 403 404 405 406 407 408 409 410 411 412
        if not hasattr(cls, "no_need_check_grad") and not is_empty_grad_op(
            cls.op_type
        ):
            if cls.dtype is None or (
                cls.dtype == np.float16
                and cls.op_type
                not in op_accuracy_white_list.NO_FP16_CHECK_GRAD_OP_LIST
                and not hasattr(cls, "exist_check_grad")
            ):
                raise AssertionError(
                    "This test of %s op needs check_grad." % cls.op_type
                )
J
juncaipeng 已提交
413

414
            # check for op test with fp64 precision, but not check mkldnn op test for now
415 416 417 418 419 420 421 422 423 424 425 426
            if (
                cls.dtype in [np.float32, np.float64]
                and cls.op_type
                not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST
                and not hasattr(cls, 'exist_fp64_check_grad')
                and not is_xpu_op_test()
                and not is_mkldnn_op_test()
                and not is_rocm_op_test()
                and not is_npu_op_test()
                and not is_mlu_op_test()
                and not is_custom_device_op_test()
            ):
J
juncaipeng 已提交
427
                raise AssertionError(
428 429 430 431 432 433 434 435 436
                    "This test of %s op needs check_grad with fp64 precision."
                    % cls.op_type
                )

            if (
                not cls.input_shape_is_large
                and cls.op_type
                not in check_shape_white_list.NEED_TO_FIX_OP_LIST
            ):
437
                raise AssertionError(
438 439 440 441
                    "Input's shape should be large than or equal to 100 for "
                    + cls.op_type
                    + " Op."
                )
442

443 444 445 446 447
    def try_call_once(self, data_type):
        if not self.call_once:
            self.call_once = True
            self.dtype = data_type

448
    def is_bfloat16_op(self):
Y
Yiqun Liu 已提交
449 450
        # self.dtype is the dtype of inputs, and is set in infer_dtype_from_inputs_outputs.
        # Make sure this function is called after calling infer_dtype_from_inputs_outputs.
451 452 453 454 455 456
        return (
            self.dtype == np.uint16
            or (
                hasattr(self, 'output_dtype') and self.output_dtype == np.uint16
            )
            or (
457
                hasattr(self, 'mkldnn_data_type')
458 459 460 461 462 463 464 465
                and getattr(self, 'mkldnn_data_type') == "bfloat16"
            )
            or (
                hasattr(self, 'attrs')
                and 'mkldnn_data_type' in self.attrs
                and self.attrs['mkldnn_data_type'] == 'bfloat16'
            )
        )
Y
Yiqun Liu 已提交
466 467 468

    def is_mkldnn_op(self):
        return (hasattr(self, "use_mkldnn") and self.use_mkldnn == True) or (
469 470 471 472
            hasattr(self, "attrs")
            and "use_mkldnn" in self.attrs
            and self.attrs["use_mkldnn"] == True
        )
Y
Yiqun Liu 已提交
473 474

    def is_xpu_op(self):
475 476 477 478 479
        return (hasattr(self, "use_xpu") and self.use_xpu == True) or (
            hasattr(self, "attrs")
            and "use_xpu" in self.attrs
            and self.attrs["use_xpu"] == True
        )
480

481
    # set the self.output_dtype .
482
    def infer_dtype_from_inputs_outputs(self, inputs, outputs):
J
juncaipeng 已提交
483 484 485 486
        def is_np_data(input):
            return isinstance(input, (np.ndarray, np.generic))

        def infer_dtype(numpy_dict, dtype_set):
487
            assert isinstance(
488 489
                numpy_dict, dict
            ), "self.inputs, self.outputs must be numpy_dict"
J
juncaipeng 已提交
490 491 492 493 494 495
            # the inputs are as follows:
            # case 1: inputs = {'X': x}
            # case 2: inputs = {'X': (x, x_lod)}
            # case 3: inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]}
            # case 4: inputs = {'X': [("x1", (x1, [x1_lod1])), ("x2", (x2, [x2_.lod2]))]}
            # TODO(juncaipeng) infer dtype from inputs maybe obtain wrong type.
496
            for _, var_value in numpy_dict.items():
J
juncaipeng 已提交
497 498 499 500 501 502 503
                if is_np_data(var_value):  # case 1
                    dtype_set.add(var_value.dtype)
                elif isinstance(var_value, (list, tuple)):  # case 2, 3, 4
                    for sub_val_value in var_value:
                        if is_np_data(sub_val_value):  # case 2
                            dtype_set.add(sub_val_value.dtype)
                        elif len(sub_val_value) > 1 and is_np_data(
504 505
                            sub_val_value[1]
                        ):  # case 3
J
juncaipeng 已提交
506
                            dtype_set.add(sub_val_value[1].dtype)
507 508 509 510 511
                        elif (
                            len(sub_val_value) > 1
                            and isinstance(sub_val_value[1], (list, tuple))
                            and is_np_data(sub_val_value[1][0])
                        ):  # case 4
J
juncaipeng 已提交
512 513 514 515
                            dtype_set.add(sub_val_value[1][0].dtype)

        # infer dtype from inputs, and dtype means the precision of the test
        # collect dtype of all inputs
Y
Yiqun Liu 已提交
516 517
        input_dtype_set = set()
        infer_dtype(inputs, input_dtype_set)
J
juncaipeng 已提交
518
        dtype_list = [
519 520 521 522 523 524 525 526 527
            np.dtype(np.float64),
            np.dtype(np.float32),
            np.dtype(np.float16),
            np.dtype(np.int64),
            np.dtype(np.int32),
            np.dtype(np.uint16),
            np.dtype(np.int16),
            np.dtype(np.int8),
            np.dtype(np.uint8),
528
            np.dtype(np.bool_),
J
juncaipeng 已提交
529 530 531
        ]
        # check the dtype in dtype_list in order, select the first dtype that in dtype_set
        for dtype in dtype_list:
Y
Yiqun Liu 已提交
532
            if dtype in input_dtype_set:
J
juncaipeng 已提交
533 534
                self.dtype = dtype
                break
Y
Yiqun Liu 已提交
535
        # save input dtype in class attr
536
        self.__class__.dtype = self.dtype
537

Y
Yiqun Liu 已提交
538 539 540 541 542 543 544 545
        # infer dtype of outputs
        output_dtype_set = set()
        infer_dtype(outputs, output_dtype_set)
        for dtype in dtype_list:
            if dtype in output_dtype_set:
                self.output_dtype = dtype
                break

Y
Yang Yang(Tony) 已提交
546 547 548 549 550 551
    def feed_var(self, input_vars, place):
        feed_map = {}
        for var_name in input_vars:
            if isinstance(input_vars[var_name], list):
                for name, np_value in self.inputs[var_name]:
                    tensor = core.LoDTensor()
552
                    if isinstance(np_value, tuple):
553
                        tensor.set(np_value[0], place)
554
                        tensor.set_recursive_sequence_lengths(np_value[1])
555
                    else:
556
                        tensor.set(np_value, place)
Y
Yang Yang(Tony) 已提交
557 558 559 560
                    feed_map[name] = tensor
            else:
                tensor = core.LoDTensor()
                if isinstance(self.inputs[var_name], tuple):
561
                    tensor.set(self.inputs[var_name][0], place)
562
                    tensor.set_recursive_sequence_lengths(
563 564
                        self.inputs[var_name][1]
                    )
Y
Yang Yang(Tony) 已提交
565
                else:
566
                    tensor.set(self.inputs[var_name], place)
Y
Yang Yang(Tony) 已提交
567
                feed_map[var_name] = tensor
568

Y
Yang Yang(Tony) 已提交
569 570
        return feed_map

571
    def _append_ops(self, block):
572 573 574
        self.__class__.op_type = (
            self.op_type
        )  # for ci check, please not delete it for now
Y
Yiqun Liu 已提交
575
        if self.is_mkldnn_op():
576
            self.__class__.use_mkldnn = True
C
cc 已提交
577

Y
Yiqun Liu 已提交
578
        if self.is_xpu_op():
579 580
            self.__class__.use_xpu = True

Y
Yang Yang(Tony) 已提交
581
        op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
582
        "infer datatype from inputs and outputs for this test case"
583 584 585 586 587 588
        if self.is_bfloat16_op():
            self.dtype = np.uint16
            self.__class__.dtype = self.dtype
            self.output_dtype = np.uint16
        else:
            self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
589 590 591 592 593 594
        inputs = append_input_output(
            block, op_proto, self.inputs, True, self.dtype
        )
        outputs = append_input_output(
            block, op_proto, self.outputs, False, self.dtype
        )
P
phlrain 已提交
595 596 597

        if hasattr(self, "cache_name_list"):
            for name in self.cache_name_list:
598 599 600 601 602 603
                inputs[name] = block.create_var(
                    name=name,
                    persistable=True,
                    type=core.VarDesc.VarType.RAW,
                    stop_gradient=True,
                )
P
phlrain 已提交
604

Y
Yang Yang(Tony) 已提交
605 606 607 608
        op = block.append_op(
            type=self.op_type,
            inputs=inputs,
            outputs=outputs,
609 610
            attrs=copy(self.attrs) if hasattr(self, "attrs") else dict(),
        )
C
cc 已提交
611
        # infer variable type and infer shape in compile-time
Q
QI JUN 已提交
612 613
        op.desc.infer_var_type(block.desc)
        op.desc.infer_shape(block.desc)
Y
Yang Yang(Tony) 已提交
614

615 616
        return op

617 618
    def _get_io_vars(self, block, numpy_inputs):
        inputs = {}
619
        for name, value in numpy_inputs.items():
620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638
            if isinstance(value, list):
                var_list = [
                    block.var(sub_name) for sub_name, sub_value in value
                ]
                inputs[name] = var_list
            else:
                inputs[name] = block.var(name)
        return inputs

    def _get_inputs(self, block):
        return self._get_io_vars(block, self.inputs)

    def _get_outputs(self, block):
        return self._get_io_vars(block, self.outputs)

    def calc_output(self, place):
        outs, _ = self._calc_output(place)
        return outs

M
minqiyang 已提交
639 640 641 642
    def _create_var_from_numpy(self, value):
        if isinstance(value, tuple):
            data = value[0]
            lod = value[1]
L
lujun 已提交
643
            v = fluid.dygraph.base.to_variable(value=data)
644
            v.value().get_tensor().set_recursive_sequence_lengths(lod)
M
minqiyang 已提交
645 646
            return v
        else:
L
lujun 已提交
647
            return fluid.dygraph.base.to_variable(value)
M
minqiyang 已提交
648

649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666
    def get_sequence_batch_size_1_input(self, lod=None, shape=None):
        """Get LoD input data whose batch size is 1.
        All sequence related OP unittests should call this function to contain the case of batch size = 1.
        Args:
            lod (list[list of int], optional): Length-based LoD, length of lod[0] should be 1. Default: [[13]].
            shape (list, optional): Shape of input, shape[0] should be equals to lod[0][0]. Default: [13, 23].
        Returns:
            tuple (ndarray, lod) : LoD input data whose batch size is 1.
        """
        if lod is None:
            lod = [[13]]
        if shape is None:
            shape = [13, 23]
        assert len(lod[0]) == 1
        assert lod[0][0] == shape[0]
        x = np.random.uniform(0.1, 1, shape).astype('float32')
        return (x, lod)

667 668 669 670 671 672 673 674
    def lod_has_single_zero(self, lod):
        for i in range(len(lod) - 2):
            if lod[i] != 0 and lod[i + 1] == 0 and lod[i + 2] != 0:
                return True
        return False

    def lod_has_continuous_zero(self, lod):
        for i in range(len(lod) - 3):
675 676 677 678 679 680
            if (
                lod[i] != 0
                and lod[i + 1] == 0
                and lod[i + 2] == 0
                and lod[i + 3] != 0
            ):
681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697
                return True
        return False

    def get_sequence_instance_size_0_input(self, lod=None, shape=None):
        """Get LoD input data whose instance size is 0.
        All sequence related OP unittests should call this function to contain the case of instance size is 0.
        Args:
            lod (list[list of int], optional): Length-based LoD, lod[0]'s size must at least eight, lod[0] must at least two zeros at the beginning and at least two zeros at the end, the middle position of lod[0] contains a single zero and multiple zero. Default: [[0, 0, 4, 0, 3, 0, 0, 5, 0, 0]].
            shape (list, optional): Shape of input, shape[0] should be equals to lod[0][0]. Default: [13, 23].
        Returns:
            tuple (ndarray, lod): LoD input data whose instance size is 0.
        """
        if lod is None:
            lod = [[0, 0, 4, 0, 3, 0, 0, 5, 0, 0]]
        if shape is None:
            shape = [12, 10]
        assert len(lod[0]) >= 8
698 699 700 701 702 703
        assert (
            lod[0][0] == 0
            and lod[0][1] == 0
            and lod[0][-1] == 0
            and lod[0][-2] == 0
        )
704 705 706 707 708 709 710
        assert self.lod_has_single_zero(lod[0]) is True
        assert self.lod_has_continuous_zero(lod[0]) is True
        assert sum(lod[0]) == shape[0]

        x = np.random.uniform(0.1, 1, shape).astype('float32')
        return (x, lod)

711 712 713
    def append_input_output_for_dygraph(
        self, op_proto, np_list, is_input, if_return_inputs_grad_dict, block
    ):
714 715 716 717 718 719 720 721 722 723 724
        def create_var(np_value, name, is_input, if_return_inputs_grad_dict):
            np_value_temp = np_value
            has_lod = False
            lod_temp = None
            if isinstance(np_value, tuple):
                np_value_temp = np_value[0]
                has_lod = True
                lod_temp = np_value[1]

            if is_input:
                v = self._create_var_from_numpy(np_value_temp)
725

726 727
                if if_return_inputs_grad_dict:
                    v.stop_gradient = False
J
Jiabin Yang 已提交
728
                    if not _in_legacy_dygraph():
729 730
                        v.retain_grads()

731
                if has_lod:
732
                    v.value().get_tensor().set_recursive_sequence_lengths(
733 734
                        lod_temp
                    )
735
            else:
736 737 738 739 740 741 742
                v = block.create_var(
                    name=name,
                    dtype=np_value_temp.dtype,
                    type=core.VarDesc.VarType.LOD_TENSOR,
                    persistable=False,
                    stop_gradient=False,
                )
743 744 745 746 747 748 749 750 751 752 753 754 755
            return v

        # prepare variable for input or output
        var_dict = defaultdict(list)
        if if_return_inputs_grad_dict:
            inputs_grad_dict = defaultdict()
        proto_list = op_proto.inputs if is_input else op_proto.outputs
        for var_proto in proto_list:
            name = var_proto.name
            if (name not in np_list) and var_proto.dispensable:
                continue
            if name not in np_list:
                assert var_proto.intermediate, "{} not found".format(name)
756 757 758
                v = block.create_var(
                    dtype='float32', type=core.VarDesc.VarType.LOD_TENSOR
                )
759 760 761 762 763 764
                var_dict[name].append(v)
                if if_return_inputs_grad_dict:
                    inputs_grad_dict[name] = v
                continue
            if var_proto.duplicable:
                assert isinstance(
765 766
                    np_list[name], list
                ), "Duplicable {} should be set as list".format(name)
767 768 769
                var_list = []
                slot_name = name
                for (name, np_value) in np_list[name]:
770 771 772
                    v = create_var(
                        np_value, name, is_input, if_return_inputs_grad_dict
                    )
773 774 775 776 777 778 779 780 781 782 783 784 785
                    var_list.append(v)
                    if if_return_inputs_grad_dict:
                        inputs_grad_dict[name] = v
                var_dict[slot_name] = var_list
            else:
                nplist_value_temp = None
                name_temp = None
                if isinstance(np_list[name], list):
                    nplist_value_temp = np_list[name][0]
                    name_temp = name
                else:
                    nplist_value_temp = np_list[name]
                    name_temp = unique_name.generate("%s_out" % (name))
786 787 788 789 790 791
                v = create_var(
                    nplist_value_temp,
                    name_temp,
                    is_input,
                    if_return_inputs_grad_dict,
                )
792 793 794 795 796 797 798 799 800
                var_dict[name].append(v)
                if if_return_inputs_grad_dict:
                    inputs_grad_dict[name] = v

        if if_return_inputs_grad_dict:
            return var_dict, inputs_grad_dict
        else:
            return var_dict

801
    def _check_api_outs_by_dygraph_outs(self, api_outs, dygraph_outs, place):
802 803 804 805
        """for quick verify, here we take a simplest strategy:
        1. we only check variable in api_outs.
        2. we simply check the numpy (tensor) .
        3. we set atol and rtol as 1e-5, because they are unrelated to dtype.
806 807 808 809
        """
        for name in api_outs:
            np_api = np.array(api_outs[name])
            np_dyg = np.array(dygraph_outs[name])
810 811 812 813 814
            np.testing.assert_allclose(
                np_api,
                np_dyg,
                rtol=1e-05,
                equal_nan=False,
815 816 817 818 819 820 821 822 823 824 825 826
                err_msg='Output ('
                + name
                + ') has diff at '
                + str(place)
                + '\nExpect '
                + str(np_dyg)
                + '\n'
                + 'But Got'
                + str(np_api)
                + ' in class '
                + self.__class__.__name__,
            )
827

828
    def _calc_python_api_output(self, place, egr_inps=None, egr_oups=None):
829
        """set egr_inps and egr_oups = None if you want to create it by yourself."""
830

831 832 833 834
        def prepare_python_api_arguments(
            api, op_proto_ins, op_proto_attrs, kernel_sig
        ):
            """map from `op proto inputs and attrs` to `api input list and api attrs dict`
835

836
            NOTE: the op_proto_attrs and op_proto_ins is a default dict. default value is []
837
            """
838 839 840 841 842 843 844

            class Empty:
                pass

            def is_empty(a):
                return isinstance(a, Empty)

845
            def get_default(idx, defaults):
846 847 848
                assert not isinstance(defaults[idx], Empty), (
                    "%d-th params of python api don't have default value." % idx
                )
849
                return defaults[idx]
850 851 852 853

            def to_defaults_list(params, defaults):
                return [defaults[p] for p in params if p in defaults]

854
            def parse_attri_value(name, op_inputs, op_attrs):
855 856 857 858
                """parse true value from inputs and attrs, if there is no name passed by OpTest, return Empty
                1. if the name in op_attrs, use the op_attrs[name]
                2. if the name in op_inputs, convert the op_inputs to [type of default value]
                3. if the name not in op_attrs ans op_inputs, return Empty. (this will use the default value from python api)
859 860 861 862
                """
                if name in op_proto_attrs:
                    return op_proto_attrs[name]
                elif name in op_inputs:
X
xiongkun 已提交
863 864
                    if len(op_inputs[name]) == 1:
                        # why don't use numpy().item() : if the Tensor is float64, we will change it to python.float32, where we loss accuracy: [allclose_op]
865
                        # why we reconstruct a tensor: because we want the tensor in cpu.
866 867 868
                        return paddle.to_tensor(
                            op_inputs[name][0].numpy(), place='cpu'
                        )
X
xiongkun 已提交
869 870 871
                    else:
                        # if this is a list (test_unsqueeze2_op): we just pass it into the python api.
                        return op_inputs[name]
872 873 874
                else:
                    return Empty()

875 876 877
            # NOTE(xiongkun): the logic of constructing parameters:
            # for example:
            #    python api: cumprod(x, dim, dtype=None, name=None)
878 879 880 881 882 883 884
            #    kernel sig: [["x"], ["dim"], ["out"]]"
            #
            # we will construct a lot of list with the same length : len == len(api_params), here is 4
            #    api_params = ["x", "dim", "dtype", "name"]
            #    api_defaults = [Empty, Empty, None, None]; empty means no defaults.
            #    inputs_and_attrs = ["x", "dim"] , the length may shorter or longer than api_params
            #    input_arguments = [RealValue in self.inputs and self.attrs]
885
            # then ,we will loop for the api_params, construct a result list:
886 887 888 889
            #    if the name in ['name', 'dtype', 'out', 'output'], we will use the default value
            #    else, we will consume a input_arguments. (because the name is not corresponding, so we only use the order)

            api_params, api_defaults = parse_arg_and_kwargs(api)
890
            api_defaults = to_defaults_list(api_params, api_defaults)
891 892 893 894
            api_defaults = [
                Empty() for i in range(len(api_params) - len(api_defaults))
            ] + api_defaults
            assert len(api_defaults) == len(
895 896
                api_params
            ), "Error happens. contack xiongkun03 to solve."
897
            inputs_sig, attrs_sig, outputs_sig = kernel_sig
898
            inputs_and_attrs = inputs_sig + attrs_sig
Z
zyfncg 已提交
899 900 901
            input_arguments = [
                op_proto_ins.get(name, Empty()) for name in inputs_sig
            ] + [
902
                parse_attri_value(name, op_proto_ins, op_proto_attrs)
903 904 905
                for name in attrs_sig
            ]
            results = []
906 907 908 909 910
            api_ignore_param_list = set(['name', 'dtype', 'out', 'output'])
            idx_of_op_proto_arguments = 0
            for idx, arg_name in enumerate(api_params):
                if arg_name in api_ignore_param_list:
                    results.append(get_default(idx, api_defaults))
911
                else:
912
                    if idx_of_op_proto_arguments < len(input_arguments):
913 914 915 916 917
                        tmp = input_arguments[idx_of_op_proto_arguments]
                        idx_of_op_proto_arguments += 1
                    else:
                        tmp = Empty()  # use the default value

918 919 920 921 922
                    if isinstance(tmp, Empty):
                        results.append(get_default(idx, api_defaults))
                    else:
                        results.append(tmp)
            assert len(results) == len(api_params)
923
            return results
924 925

        def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
X
xiongkun 已提交
926 927
            if hasattr(self, "python_out_sig"):
                output_sig = self.python_out_sig
928 929
            if not isinstance(ret_tuple, (tuple, list)):
                ret_tuple = [ret_tuple]
930 931 932 933 934
            if len(output_sig) == len(ret_tuple):
                # [assumption]: we assume {"Out": [Tensor]}
                return {a: [b] for a, b in zip(output_sig, ret_tuple)}
            else:
                # [assumption]: return multi-Tensor in a single output. such as paddle.split()
935 936 937
                assert (
                    len(output_sig) == 1
                ), "Don't support multi-output with multi-tensor output. (May be you can use set `python_out_sig`, see `test_squeeze2_op` as a example.)"
938
                return {output_sig[0]: ret_tuple}
939

940
        def assumption_assert_and_transform(args, inp_num):
941
            """
942
            transform inputs by the following rules:
943 944
                1. [Tensor] -> Tensor
                2. [Tensor, Tensor, ...] -> list of Tensors
Z
zyfncg 已提交
945 946
                3. None -> None
                4. Others: raise Error
947 948

            only support "X" is list of Tensor, currently don't support other structure like dict.
949
            """
950 951 952
            inp_args = [
                [inp] if inp is None else inp for inp in args[:inp_num]
            ]  # convert None -> [None]
Z
zyfncg 已提交
953
            for inp in inp_args:
954 955 956
                assert isinstance(
                    inp, list
                ), "currently only support `X` is [Tensor], don't support other structure."
957 958 959
            args = [
                inp[0] if len(inp) == 1 else inp for inp in inp_args
            ] + args[inp_num:]
960
            return args
961

962 963 964
        def _get_kernel_signature(
            eager_tensor_inputs, eager_tensor_outputs, attrs_outputs
        ):
965 966
            try:
                kernel_sig = _dygraph_tracer()._get_kernel_signature(
967 968 969 970 971
                    self.op_type,
                    eager_tensor_inputs,
                    eager_tensor_outputs,
                    attrs_outputs,
                )
972
            except RuntimeError as re:
973
                """we think the kernel_sig is missing."""
974
                kernel_sig = None
X
xiongkun 已提交
975 976
                print(
                    "[Warning: op_test.py] Kernel Signature is not found for %s, fall back to intermediate state."
977 978
                    % self.op_type
                )
979 980
            return kernel_sig

981
        def cal_python_api(python_api, args, kernel_sig):
982
            inputs_sig, attrs_sig, outputs_sig = kernel_sig
983 984
            args = assumption_assert_and_transform(args, len(inputs_sig))
            ret_tuple = python_api(*args)
985 986 987 988 989 990
            return construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig)

        with fluid.dygraph.base.guard(place=place):
            block = fluid.default_main_program().global_block()
            op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
            # prepare input variable
991 992 993 994 995 996 997
            eager_tensor_inputs = (
                egr_inps
                if egr_inps
                else self.append_input_output_for_dygraph(
                    op_proto, self.inputs, True, False, block
                )
            )
998
            # prepare output variable
999 1000 1001 1002 1003 1004 1005
            eager_tensor_outputs = (
                egr_oups
                if egr_oups
                else self.append_input_output_for_dygraph(
                    op_proto, self.outputs, False, False, block
                )
            )
1006

1007
            # prepare attributes
1008 1009 1010 1011 1012 1013
            attrs_outputs = {}
            if hasattr(self, "attrs"):
                for attrs_name in self.attrs:
                    if self.attrs[attrs_name] is not None:
                        attrs_outputs[attrs_name] = self.attrs[attrs_name]

1014 1015 1016
            kernel_sig = _get_kernel_signature(
                eager_tensor_inputs, eager_tensor_outputs, attrs_outputs
            )
1017 1018
            if not kernel_sig:
                return None
1019 1020 1021 1022 1023 1024 1025
            assert hasattr(self, "python_api"), (
                "Detect there is KernelSignature for `%s` op, please set the `self.python_api` if you set check_eager = True"
                % self.op_type
            )
            args = prepare_python_api_arguments(
                self.python_api, eager_tensor_inputs, attrs_outputs, kernel_sig
            )
1026
            """ we directly return the cal_python_api value because the value is already tensor.
1027
            """
1028
            return cal_python_api(self.python_api, args, kernel_sig)
1029

L
lujun 已提交
1030
    def _calc_dygraph_output(self, place, parallel=False, no_check_set=None):
1031 1032 1033
        self.__class__.op_type = (
            self.op_type
        )  # for ci check, please not delete it for now
L
lujun 已提交
1034
        with fluid.dygraph.base.guard(place=place):
M
minqiyang 已提交
1035 1036
            block = fluid.default_main_program().global_block()

1037
            op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
M
minqiyang 已提交
1038

1039
            # prepare input variable
1040
            inputs = self.append_input_output_for_dygraph(
1041 1042
                op_proto, self.inputs, True, False, block
            )
M
minqiyang 已提交
1043
            # prepare output variable
1044
            outputs = self.append_input_output_for_dygraph(
1045 1046
                op_proto, self.outputs, False, False, block
            )
1047

1048
            # prepare attributes
1049 1050 1051 1052 1053
            attrs_outputs = {}
            if hasattr(self, "attrs"):
                for attrs_name in self.attrs:
                    if self.attrs[attrs_name] is not None:
                        attrs_outputs[attrs_name] = self.attrs[attrs_name]
1054

M
minqiyang 已提交
1055 1056 1057 1058
            block.append_op(
                type=self.op_type,
                inputs=inputs,
                outputs=outputs,
1059 1060
                attrs=attrs_outputs if hasattr(self, "attrs") else None,
            )
M
minqiyang 已提交
1061
            return outputs
1062

1063 1064 1065 1066 1067 1068 1069 1070 1071
    def _calc_output(
        self,
        place,
        parallel=False,
        no_check_set=None,
        loss=None,
        enable_inplace=None,
        for_inplace_test=None,
    ):
1072 1073
        program = Program()
        block = program.global_block()
1074
        op = self._append_ops(block)
1075 1076 1077 1078 1079

        inputs = self._get_inputs(block)
        outputs = self._get_outputs(block)
        feed_map = self.feed_var(inputs, place)

1080
        if for_inplace_test:
C
cc 已提交
1081 1082
            # Some variables' tensors hold no buffer (tensor's _holder is NULL), like XShape in reshape2 op,
            # and the shapes of those variables contain 0 (eg. Xshape.shape = [0, 2, 5]).
1083 1084
            # Set persistable for those variables in order to get them from global_scope for inplace grad test directly other than feed them,
            # since feed op calls check_memory_size() which fails when tensor's holder_ is NULL.
1085 1086
            for out_name in op.output_arg_names:
                var = block.var(out_name)
1087 1088
                if 0 in var.shape:
                    var.persistable = True
1089
        original_program = program
1090 1091
        if parallel:
            use_cuda = False
1092
            if isinstance(place, fluid.CUDAPlace):
1093
                use_cuda = True
1094
            compiled_prog = fluid.CompiledProgram(program).with_data_parallel(
1095 1096
                loss_name=loss.name if loss else None, places=place
            )
1097
            program = compiled_prog
1098 1099 1100 1101
        fetch_list = getattr(self, "fetch_list", [])
        # if the fetch_list is customized by user, we use it directly.
        # if not, fill the fetch_list by the user configured outputs in test.
        if len(fetch_list) == 0:
1102
            for var_name, var in outputs.items():
1103 1104
                if no_check_set is not None and var_name in no_check_set:
                    continue
Y
Yang Yang(Tony) 已提交
1105 1106
                if isinstance(var, list):
                    for v in var:
1107
                        fetch_list.append(v.name)
Y
Yang Yang(Tony) 已提交
1108
                else:
1109
                    fetch_list.append(var.name)
1110 1111 1112 1113
        # if the fetch_list still empty, fill the fetch_list by the operator output.
        if len(fetch_list) == 0:
            for out_name, out_dup in Operator.get_op_outputs(self.op_type):
                fetch_list.append(str(out_name))
1114 1115 1116 1117 1118 1119

        if enable_inplace is not None:
            build_strategy = fluid.BuildStrategy()
            build_strategy.enable_inplace = enable_inplace

            compiled_prog = fluid.CompiledProgram(program).with_data_parallel(
1120 1121
                build_strategy=build_strategy, places=place
            )
1122 1123
            program = compiled_prog

1124
        executor = Executor(place)
1125 1126 1127
        outs = executor.run(
            program, feed=feed_map, fetch_list=fetch_list, return_numpy=False
        )
1128 1129
        self.op = op
        self.program = original_program
1130 1131 1132 1133
        if for_inplace_test:
            return outs, fetch_list, feed_map, original_program, op.desc
        else:
            return outs, fetch_list
Y
Yang Yang(Tony) 已提交
1134

1135 1136 1137
    def _compare_expect_and_actual_outputs(
        self, place, fetch_list, expect_outs, actual_outs, inplace_atol=None
    ):
1138 1139 1140
        """Compare expect outs and actual outs of an tested op.

        Args:
C
cc 已提交
1141
            place (CPUPlace | CUDAPlace): The place where the op runs.
1142 1143 1144 1145 1146 1147 1148 1149 1150 1151
            fetch_list (list): The outputs of tested op.
            expect_outs (list): The expect outs of tested op.
            actual_outs (list): The actual outs of tested op.
            inplace_atol (float): The tolerable error, only set when tested op doesn't ensure computational consistency, like group_norm op.

        Returns:
            None.
        """
        # compare expect_outs and actual_outs
        for i, name in enumerate(fetch_list):
C
cc 已提交
1152
            # Note(zhiqiu): inplace_atol should be only set when op doesn't ensure
L
Leo Chen 已提交
1153 1154 1155
            # computational consistency.
            # When inplace_atol is not None, the inplace check uses numpy.allclose
            # to check inplace result instead of numpy.array_equal.
1156 1157
            expect_out = np.array(expect_outs[i])
            actual_out = np.array(actual_outs[i])
1158
            if inplace_atol is not None:
1159 1160 1161 1162 1163
                np.testing.assert_allclose(
                    expect_out,
                    actual_out,
                    rtol=1e-05,
                    atol=inplace_atol,
1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176
                    err_msg='Output ('
                    + name
                    + ') has diff at '
                    + str(place)
                    + ' when using and not using inplace'
                    + '\nExpect '
                    + str(expect_out)
                    + '\n'
                    + 'But Got'
                    + str(actual_out)
                    + ' in class '
                    + self.__class__.__name__,
                )
1177
            else:
1178 1179 1180
                np.testing.assert_array_equal(
                    expect_out,
                    actual_out,
1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198
                    err_msg='Output ('
                    + name
                    + ') has diff at '
                    + str(place)
                    + ' when using and not using inplace'
                    + '\nExpect '
                    + str(expect_out)
                    + '\n'
                    + 'But Got'
                    + str(actual_out)
                    + ' in class '
                    + self.__class__.__name__
                    + '\n',
                )

    def _construct_grad_program_from_forward(
        self, fwd_program, grad_op_desc, op_grad_to_var
    ):
1199 1200 1201 1202 1203
        """Generate grad_program which contains the grad_op.

        Args:
            fwd_program (tuple): The program that contains grad_op_desc's corresponding forward op.
            grad_op_desc (OpDesc): The OpDesc of grad op.
C
cc 已提交
1204
            op_grad_to_var (dict): The relation of variables in grad op and its forward op.
1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215

        Returns:
            grad_program (program): The program which contains the grad_op.
        """
        grad_program = Program()
        grad_block = grad_program.global_block()
        new_op_desc = grad_block.desc.append_op()
        new_op_desc.copy_from(grad_op_desc)
        grad_program._sync_with_cpp()

        # Create grad vars based on fwd vars (shape and dtype)
1216 1217 1218
        for arg in (
            grad_op_desc.input_arg_names() + grad_op_desc.output_arg_names()
        ):
1219 1220 1221 1222 1223
            fwd_var_name = op_grad_to_var.get(arg, None)
            if fwd_var_name is None:
                fwd_var_name = arg
            fwd_var = fwd_program.global_block().vars.get(fwd_var_name)
            assert fwd_var is not None, "{} cannot be found".format(
1224 1225 1226 1227 1228 1229 1230 1231 1232
                fwd_var_name
            )
            grad_var = grad_block.create_var(
                name=arg,
                dtype=fwd_var.dtype,
                shape=fwd_var.shape,
                type=fwd_var.type,
                persistable=False,
            )
1233

C
cc 已提交
1234 1235
            # Some variables' tensors hold no buffer (tensor's _holder is NULL), like XShape in reshape2 op,
            # and the shapes of those variables contain 0 (eg. Xshape.shape = [0, 2, 5]).
1236 1237 1238 1239 1240 1241 1242
            # Set persistable for those variables in order to get them from global_scope for inplace grad test directly other than feed them,
            # since feed op calls check_memory_size() which fails when tensor's holder_ is NULL.
            if 0 in grad_var.shape:
                grad_var.persistable = True
        grad_program._sync_with_cpp()
        return grad_program

1243 1244 1245
    def _construct_grad_feed_map_from_forward(
        self, place, fwd_res, grad_op_desc, op_grad_to_var
    ):
1246 1247 1248 1249 1250 1251
        """Generate grad_feed_map for grad_program.

        since we don`t really check gradient accuracy, but check the consistency when using and not using inplace,
        we use fwd outs (also inputs sometimes) to construct grad inputs.

        Args:
C
cc 已提交
1252
            place (CPUPlace | CUDAPlace): The place where the op runs.
1253 1254 1255
            fwd_res (tuple): The outputs of its forward op, in the same form as returns of _calc_outputs() when for_inplace_test is True.
                i.e., tuple(fwd_outs, fwd_fetch_list, fwd_feed_map, fwd_program, fwd_op_desc)
            grad_op_desc (OpDesc): The OpDesc of grad op.
C
cc 已提交
1256
            op_grad_to_var (dict): The relation of variables in grad op and its fwd_op.
1257 1258 1259 1260

        Returns:
            grad_feed_map (dict): The feed_map of grad_op.
        """
1261 1262 1263 1264 1265 1266 1267
        (
            fwd_outs,
            fwd_fetch_list,
            fwd_feed_map,
            fwd_program,
            fwd_op_desc,
        ) = fwd_res
1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286
        p = core.Place()
        p.set_place(place)
        grad_feed_map = {}
        for arg in grad_op_desc.input_arg_names():
            if arg in fwd_feed_map.keys():
                grad_feed_map[arg] = fwd_feed_map[arg]._copy(p)
            else:
                fwd_var_name = op_grad_to_var.get(arg, None)
                if fwd_var_name is None:
                    fwd_var_name = arg

                for i, out_name in enumerate(fwd_fetch_list):
                    if out_name == fwd_var_name:
                        # don't feed variables whose tensors hold no buffer (shape contains 0 like shape = [0,2,5] and holder_ is NULL), like XShape in reshape2 op.
                        # get them from global_scope directly since we have set them persistable in fwd execution
                        if 0 in fwd_program.global_block().var(out_name).shape:
                            continue
                        else:
                            grad_feed_map[arg] = fwd_outs[i]._copy(p)
1287

1288 1289 1290 1291 1292 1293 1294
        return grad_feed_map

    def _get_need_run_ops(self, op_desc, fwd_op_desc=None):
        """Postorder traversal of the 'grad' tree to get all ops that need to run during inplace test.
        An op needs to run druing inplace check if,
        (1) it has infer_inplace,
        (2) it has infer_inplace in its grad descendants. (since we need its outputs as to construct its grad's inputs)
C
cc 已提交
1295

1296
        Args:
C
cc 已提交
1297 1298
            op_desc (OpDesc): The op_desc of current op.
            fwd_op_desc (OpDesc): The op_desc of current op's forward op, None if current op has no forward op.
1299
                Eg. relu's fwd_op is None, relu_grad's fwd_op is relu, relu_grad_grad's fwd_op is relu_grad, etc.
C
cc 已提交
1300

1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314
        Returns:
            need_run_ops (list[(op_desc, fwd_op_desc)]): The ops that need to run during inplace test.
        """
        need_run_ops = []
        visited_ops = []

        def _dfs_grad_op(op_desc, fwd_op_desc=None):
            visited_ops.append(op_desc.type())
            has_infer_inplace = fluid.core.has_infer_inplace(op_desc.type())
            has_grad_op_maker = fluid.core.has_grad_op_maker(op_desc.type())
            has_infer_inplace_in_grad_descendants = False
            if not has_grad_op_maker:
                has_infer_inplace_in_descendants = False
            else:
C
cc 已提交
1315
                # get grad_op_desc
1316
                grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
1317 1318
                    op_desc, set(), []
                )
1319 1320 1321 1322
                if not grad_op_desc_list:
                    has_infer_inplace_in_grad_descendants = False
                else:
                    for i, grad_op_desc in enumerate(grad_op_desc_list):
1323 1324 1325 1326
                        if (
                            grad_op_desc.type() not in visited_ops
                            and _dfs_grad_op(grad_op_desc, fwd_op_desc=op_desc)
                        ):
1327 1328 1329 1330 1331 1332 1333 1334 1335 1336
                            has_infer_inplace_in_grad_descendants = True
            if has_infer_inplace or has_infer_inplace_in_grad_descendants:
                need_run_ops.append((op_desc, fwd_op_desc))
                return True
            else:
                return False

        _dfs_grad_op(op_desc, fwd_op_desc=fwd_op_desc)
        return need_run_ops

1337 1338 1339
    def _check_forward_inplace(
        self, place, no_check_set=None, inplace_atol=None
    ):
1340
        """Check the inplace correctness of given op (self.op_type).
1341
        Run the op twice with same inputs, one enable inplace and another disable, compare their outputs.
C
cc 已提交
1342

1343
        Args:
C
cc 已提交
1344
            place (CPUPlace | CUDAPlace): The place where the op runs.
1345 1346 1347 1348
            no_check_set (list): The names of outputs that needn't check, like XShape of reshape op.
            inplace_atol (float): The tolerable error, only set when op doesn't ensure computational consistency, like group_norm op.

        Returns:
C
cc 已提交
1349 1350
            expect_res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given op.
                We return this to construct grad_program and grad_feed_map for grad inplace check.
1351 1352
        """
        # _calc_output() returns in the form tuple(outs, fetch_list, feed_map, program, op_desc) when for_inplace_test=True.
1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364
        expect_res = self._calc_output(
            place,
            no_check_set=no_check_set,
            enable_inplace=False,
            for_inplace_test=True,
        )
        actual_res = self._calc_output(
            place,
            no_check_set=no_check_set,
            enable_inplace=True,
            for_inplace_test=True,
        )
1365
        # compare expect_outs and actual_outs
1366 1367 1368 1369 1370 1371 1372
        self._compare_expect_and_actual_outputs(
            place,
            expect_res[1],
            expect_res[0],
            actual_res[0],
            inplace_atol=inplace_atol,
        )
1373 1374
        return expect_res

1375 1376 1377
    def _calc_grad_output(
        self, place, fwd_res, grad_op_desc, enable_inplace=None
    ):
1378 1379 1380 1381 1382 1383
        """Calculate grad_output for given grad_op_desc.

        since we don`t really check gradient accuracy, but check the consistency when using and not using inplace,
        we use fwd outs (also inputs sometimes) to construct grad inputs.

        Args:
C
cc 已提交
1384
            place (CPUPlace | CUDAPlace): The place where the op runs.
1385 1386 1387 1388 1389 1390 1391 1392
            fwd_res (tuple): The outputs of its forward op, in the same form as returns of _calc_outputs() when for_inplace_test is True.
                i.e., tuple(fwd_outs, fwd_fetch_list, fwd_feed_map, fwd_program, fwd_op_desc).
            grad_op_desc (OpDesc): The OpDesc of grad op.
            enable_inplace (bool): Enable inplace or not.

        Returns:
            res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc.
        """
1393 1394 1395 1396 1397 1398 1399
        (
            fwd_outs,
            fwd_fetch_list,
            fwd_feed_map,
            fwd_program,
            fwd_op_desc,
        ) = fwd_res
1400
        grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
1401 1402
            fwd_op_desc, set(), []
        )
1403
        grad_program = self._construct_grad_program_from_forward(
1404 1405
            fwd_program, grad_op_desc, op_grad_to_var
        )
1406
        grad_feed_map = self._construct_grad_feed_map_from_forward(
1407 1408
            place, fwd_res, grad_op_desc, op_grad_to_var
        )
1409 1410 1411 1412 1413 1414 1415
        grad_fetch_list = grad_op_desc.output_arg_names()
        exe = Executor(place)
        program = grad_program
        if enable_inplace is not None:
            build_strategy = fluid.BuildStrategy()
            build_strategy.enable_inplace = enable_inplace
            compiled_program = fluid.CompiledProgram(
1416 1417 1418 1419
                grad_program
            ).with_data_parallel(
                loss_name="", build_strategy=build_strategy, places=place
            )
1420
            program = compiled_program
1421

1422 1423 1424 1425 1426 1427
        outs = exe.run(
            program,
            feed=grad_feed_map,
            fetch_list=grad_fetch_list,
            return_numpy=False,
        )
1428 1429
        return outs, grad_fetch_list, grad_feed_map, grad_program, grad_op_desc

1430 1431 1432
    def _check_grad_inplace(
        self, place, fwd_res, grad_op_desc, inplace_atol=None
    ):
1433
        """Check the inplace correctness of given grad_op_desc.
1434 1435 1436 1437 1438 1439

        Run the grad op twice with same inputs, one enable inplace and another disable, compare their outputs.
        It works like _check_forward_inplace, but the way to construct program and feed_map differs.
        So we define a new function for grad, grad_grad, etc.

        Args:
C
cc 已提交
1440
            place (CPUPlace | CUDAPlace): The place where the op runs.
1441 1442 1443 1444 1445 1446
            fwd_res (tuple): The outputs of its forward op, in the same form as returns of _calc_outputs() when for_inplace_test is True.
                i.e., tuple(fwd_outs, fwd_fetch_list, fwd_feed_map, fwd_program, fwd_op_desc).
            grad_op_desc (OpDesc): The OpDesc of grad op.
            inplace_atol (float): The tolerable error, only set when op doesn't ensure computational consistency, like group_norm op.

        Returns:
C
cc 已提交
1447 1448
            expect_res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given op.
                We return this to construct grad_program and grad_feed_map for grad inplace check.
1449
        """
1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463
        expect_res = self._calc_grad_output(
            place, fwd_res, grad_op_desc, enable_inplace=False
        )
        actual_res = self._calc_grad_output(
            place, fwd_res, grad_op_desc, enable_inplace=True
        )

        self._compare_expect_and_actual_outputs(
            place,
            expect_res[1],
            expect_res[0],
            actual_res[0],
            inplace_atol=inplace_atol,
        )
1464
        return expect_res
1465

1466 1467 1468
    def check_inplace_output_with_place(
        self, place, no_check_set=None, inplace_atol=None
    ):
1469 1470 1471 1472 1473 1474
        """Chech the inplace correctness of given op, its grad op, its grad_grad op, etc.

        (1) Get all ops need to run. (see conditions in _get_need_run_ops())
        (2) Run op in need_run_ops, and do inplace check if it has infer_inplace.

        Args:
C
cc 已提交
1475
            place (CPUPlace | CUDAPlace): The place where the op runs.
1476 1477 1478 1479 1480 1481
            no_check_set (list): The names of outputs that needn't check, like XShape of reshape op.
            inplace_atol (float): The tolerable error, only set when op doesn't ensure computational consistency, like group_norm op.

        Returns:
            None
        """
1482 1483 1484
        if getattr(self, "no_need_check_inplace", False):
            return

1485 1486 1487
        has_infer_inplace = fluid.core.has_infer_inplace(self.op_type)
        has_grad_op_maker = fluid.core.has_grad_op_maker(self.op_type)

1488 1489 1490
        fwd_res = self._calc_output(
            place, no_check_set=no_check_set, for_inplace_test=True
        )
1491 1492 1493 1494
        op_desc = fwd_res[4]
        need_run_ops = self._get_need_run_ops(op_desc)

        res = {}
1495 1496
        if hasattr(self, 'attrs') and bool(self.attrs.get('use_xpu', False)):
            return
1497 1498 1499 1500 1501 1502 1503 1504
        for op_desc, father_op_desc in reversed(need_run_ops):
            # The first one is the forward op
            has_infer_inplace = fluid.core.has_infer_inplace(op_desc.type())
            if op_desc.type() == self.op_type:
                if has_infer_inplace:
                    res[op_desc] = self._check_forward_inplace(
                        place,
                        no_check_set=no_check_set,
1505 1506
                        inplace_atol=inplace_atol,
                    )
1507
                else:
1508 1509 1510
                    res[op_desc] = self._calc_output(
                        place, no_check_set=no_check_set, for_inplace_test=True
                    )
1511
            else:
1512 1513
                # TODO(zhiqiu): enhance inplace_grad test for ops (sum and activation) using mkldnn
                # skip op that use_mkldnn currently
1514
                flags_use_mkldnn = fluid.core.globals()["FLAGS_use_mkldnn"]
1515
                attrs_use_mkldnn = hasattr(self, 'attrs') and bool(
1516 1517
                    self.attrs.get('use_mkldnn', False)
                )
1518 1519 1520 1521 1522 1523 1524 1525
                if flags_use_mkldnn or attrs_use_mkldnn:
                    warnings.warn(
                        "check inplace_grad for ops using mkldnn is not supported"
                    )
                    continue
                if has_infer_inplace:
                    fwd_res = res[father_op_desc]
                    res[op_desc] = self._check_grad_inplace(
1526 1527
                        place, fwd_res, op_desc, inplace_atol=inplace_atol
                    )
1528
                else:
1529
                    res[op_desc] = self._calc_grad_output(
1530 1531
                        place, fwd_res, op_desc
                    )
1532

1533 1534 1535 1536 1537 1538 1539 1540 1541 1542
    def check_output_with_place(
        self,
        place,
        atol=0,
        no_check_set=None,
        equal_nan=False,
        check_dygraph=True,
        inplace_atol=None,
        check_eager=False,
    ):
1543

1544 1545 1546 1547
        # disable legacy dygraph check when check_eager is True
        if check_eager == True:
            check_dygraph = False

1548 1549 1550 1551 1552 1553 1554 1555
        def find_imperative_actual(target_name, dygraph_outs, place):
            for name in dygraph_outs:
                if name == target_name:
                    return dygraph_outs[name][0]
                var_list = dygraph_outs[name]
                for i, var in enumerate(var_list):
                    if var.name == target_name:
                        return dygraph_outs[name][i]
1556
            self.assertTrue(
1557 1558 1559
                False,
                "Found failed {} {}".format(dygraph_outs.keys(), target_name),
            )
1560 1561 1562

        def find_actual(target_name, fetch_list):
            found = [
1563 1564
                i
                for i, var_name in enumerate(fetch_list)
1565 1566 1567
                if var_name == target_name
            ]
            self.assertTrue(
1568 1569
                len(found) == 1, "Found {} {}".format(len(found), target_name)
            )
1570 1571 1572
            return found[0]

        class Checker(object):
1573 1574
            """base class for check with self.outputs.
            currently don't support check between checkers.
1575 1576 1577
            """

            def __init__(self, op_test, expect_dict):
1578 1579
                """expect_dict is the self.outputs
                support : {str: [numpy]} and {str: [(str, numpy), (str, numpy)]}
1580 1581 1582 1583 1584 1585
                """
                self.expects = expect_dict
                self.checker_name = "checker"
                self.op_test = op_test  # stop the op_test object.
                self.op_type = op_test.op_type

1586 1587 1588
            def init(self):
                pass

1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605
            def convert_uint16_to_float(self, actual_np, expect_np):
                raise NotImplementedError("base class, not implement!")

            def calculate_output(self):
                """
                judge whether convert current output and expect to uint16.
                return True | False
                """

            def _is_skip_name(self, name):
                if name not in self.expects:
                    return True
                if no_check_set is not None and name in no_check_set:
                    return True
                return False

            def find_actual_value(self, name):
1606
                """return: (actual_tensor(var_base), actual_numpy)"""
1607 1608 1609 1610 1611 1612 1613 1614 1615
                raise NotImplementedError("base class, not implement!")

            def _compare_numpy(self, name, actual_np, expect_np):
                self.op_test.assertTrue(
                    np.allclose(
                        actual_np,
                        expect_np,
                        atol=atol,
                        rtol=self.rtol if hasattr(self, 'rtol') else 1e-5,
1616 1617 1618 1619 1620 1621 1622 1623 1624
                        equal_nan=equal_nan,
                    ),
                    "Output ("
                    + name
                    + ") has diff at "
                    + str(place)
                    + " in "
                    + self.checker_name,
                )
1625 1626

            def _compare_list(self, name, actual, expect):
1627
                """if expect is a tuple, we need to compare list."""
1628 1629 1630 1631
                raise NotImplementedError("base class, not implement!")

            def compare_single_output_with_expect(self, name, expect):
                actual, actual_np = self.find_actual_value(name)
1632
                expect_np = expect[0] if isinstance(expect, tuple) else expect
1633
                actual_np, expect_np = self.convert_uint16_to_float_ifneed(
1634 1635
                    actual_np, expect_np
                )
1636 1637 1638
                # NOTE(zhiqiu): np.allclose([], [1.]) returns True
                # see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
                if expect_np.size == 0:
1639
                    self.op_test.assertTrue(actual_np.size == 0)
1640 1641 1642 1643 1644 1645
                self._compare_numpy(name, actual_np, expect_np)
                if isinstance(expect, tuple):
                    self._compare_list(name, actual, expect)

            def compare_outputs_with_expects(self):
                for out_name, out_dup in Operator.get_op_outputs(self.op_type):
1646 1647
                    if self._is_skip_name(out_name):
                        continue
1648 1649 1650 1651
                    if out_dup:
                        # if self.output = {'name': [(subname, Tensor), (subname, Tensor)]}
                        sub_out = self.expects[out_name]
                        if not isinstance(sub_out, list):
1652 1653 1654
                            raise AssertionError(
                                "sub_out type %s is not list", type(sub_out)
                            )
1655 1656
                        for item in sub_out:
                            sub_out_name, expect = item[0], item[1]
1657
                            self.compare_single_output_with_expect(
1658 1659
                                sub_out_name, expect
                            )
1660 1661 1662 1663 1664 1665 1666 1667 1668 1669
                    else:
                        expect = self.expects[out_name]
                        self.compare_single_output_with_expect(out_name, expect)

            def check(self):
                """
                return None means ok, raise Error means failed.

                the main enter point of Checker class
                """
1670
                self.init()
1671 1672 1673 1674
                self.calculate_output()
                self.compare_outputs_with_expects()

        class StaticChecker(Checker):
1675 1676 1677
            def init(self):
                self.checker_name = "static checker"

1678 1679
            def calculate_output(self):
                outs, fetch_list = self.op_test._calc_output(
1680 1681
                    place, no_check_set=no_check_set
                )
1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696
                self.outputs = outs
                self.fetch_list = fetch_list

            def find_actual_value(self, name):
                idx = find_actual(name, self.fetch_list)
                actual = self.outputs[idx]
                actual_t = np.array(actual)
                return actual, actual_t

            def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
                """
                judge whether convert current output and expect to uint16.
                return True | False
                """
                if actual_np.dtype == np.uint16 and expect_np.dtype in [
1697 1698
                    np.float32,
                    np.float64,
1699 1700
                ]:
                    actual_np = convert_uint16_to_float(actual_np)
1701
                    self.rtol = 1.0e-2
1702
                else:
1703 1704 1705 1706 1707
                    self.rtol = 1.0e-5
                if (
                    expect_np.dtype == np.uint16
                    and actual_np.dtype == np.uint16
                ):
1708 1709 1710 1711 1712 1713 1714
                    nonlocal atol
                    expect_np = convert_uint16_to_float(expect_np)
                    actual_np = convert_uint16_to_float(actual_np)
                    atol = max(atol, 0.03)
                return actual_np, expect_np

            def _compare_list(self, name, actual, expect):
1715
                """if expect is a tuple, we need to compare list."""
1716
                self.op_test.assertListEqual(
1717 1718 1719 1720
                    actual.recursive_sequence_lengths(),
                    expect[1],
                    "Output (" + name + ") has different lod at " + str(place),
                )
1721 1722

        class DygraphChecker(Checker):
1723 1724 1725
            def init(self):
                self.checker_name = "dygraph checker"

1726 1727
            def calculate_output(self):
                self.outputs = self.op_test._calc_dygraph_output(
1728 1729
                    place, no_check_set=no_check_set
                )
1730 1731 1732 1733

            def find_actual_value(self, name):
                with fluid.dygraph.base.guard(place=place):
                    imperative_actual = find_imperative_actual(
1734 1735
                        name, self.outputs, place
                    )
1736
                    imperative_actual_t = np.array(
1737 1738
                        imperative_actual.value().get_tensor()
                    )
1739 1740 1741
                    return imperative_actual, imperative_actual_t

            def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
1742
                if actual_np.dtype == np.uint16 and expect_np.dtype in [
1743 1744
                    np.float32,
                    np.float64,
1745
                ]:
1746
                    self.rtol = 1.0e-2
1747
                else:
1748
                    self.rtol = 1.0e-5
1749 1750 1751 1752
                if self.op_test.is_bfloat16_op():
                    if actual_np.dtype == np.uint16:
                        actual_np = convert_uint16_to_float(actual_np)
                    if expect_np.dtype == np.uint16:
X
xiongkun 已提交
1753
                        expect_np = convert_uint16_to_float(expect_np)
1754 1755 1756
                return actual_np, expect_np

            def _compare_list(self, name, actual, expect):
1757
                """if expect is a tuple, we need to compare list."""
1758 1759
                with fluid.dygraph.base.guard(place=place):
                    self.op_test.assertListEqual(
1760 1761 1762 1763 1764 1765 1766 1767 1768 1769
                        actual.value()
                        .get_tensor()
                        .recursive_sequence_lengths(),
                        expect[1],
                        "Output ("
                        + name
                        + ") has different lod at "
                        + str(place)
                        + " in dygraph mode",
                    )
1770 1771

            def _compare_numpy(self, name, actual_np, expect_np):
1772 1773 1774 1775 1776 1777
                if (
                    functools.reduce(lambda x, y: x * y, actual_np.shape, 1)
                    == 0
                    and functools.reduce(lambda x, y: x * y, expect_np.shape, 1)
                    == 0
                ):
1778 1779 1780 1781 1782 1783 1784 1785
                    pass
                else:
                    self.op_test.assertTrue(
                        np.allclose(
                            actual_np,
                            expect_np,
                            atol=atol,
                            rtol=self.rtol if hasattr(self, 'rtol') else 1e-5,
1786 1787 1788 1789 1790 1791 1792 1793 1794
                            equal_nan=equal_nan,
                        ),
                        "Output ("
                        + name
                        + ") has diff at "
                        + str(place)
                        + " in "
                        + self.checker_name,
                    )
1795 1796

        class EagerChecker(DygraphChecker):
1797 1798 1799
            def init(self):
                self.checker_name = "eager checker"

1800 1801 1802
            def calculate_output(self):
                # we only check end2end api when check_eager=True
                with _test_eager_guard():
1803
                    self.is_python_api_test = True
1804
                    eager_dygraph_outs = self.op_test._calc_python_api_output(
1805 1806
                        place
                    )
1807
                    if eager_dygraph_outs is None:
X
xiongkun 已提交
1808
                        self.is_python_api_test = False
1809
                        # missing KernelSignature, fall back to eager middle output.
1810
                        eager_dygraph_outs = self.op_test._calc_dygraph_output(
1811 1812
                            place, no_check_set=no_check_set
                        )
1813 1814 1815 1816 1817 1818 1819 1820
                self.outputs = eager_dygraph_outs

            def _compare_numpy(self, name, actual_np, expect_np):
                with _test_eager_guard():
                    super()._compare_numpy(name, actual_np, expect_np)

            def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
                with _test_eager_guard():
1821
                    return super().convert_uint16_to_float_ifneed(
1822 1823
                        actual_np, expect_np
                    )
1824 1825 1826 1827 1828 1829

            def find_actual_value(self, name):
                with _test_eager_guard():
                    return super().find_actual_value(name)

            def _compare_list(self, name, actual, expect):
1830
                """if expect is a tuple, we need to compare list."""
1831 1832 1833
                with _test_eager_guard():
                    super()._compare_list(name, actual, expect)

X
xiongkun 已提交
1834 1835
            def _is_skip_name(self, name):
                # if in final state and kernel signature don't have name, then skip it.
1836 1837 1838 1839 1840
                if (
                    self.is_python_api_test
                    and hasattr(self.op_test, "python_out_sig")
                    and name not in self.op_test.python_out_sig
                ):
X
xiongkun 已提交
1841 1842
                    return True
                return super()._is_skip_name(name)
1843

1844
        # set some flags by the combination of arguments.
X
xiongkun 已提交
1845
        self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
1846 1847 1848 1849 1850
        if (
            self.dtype == np.float64
            and self.op_type
            not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST
        ):
1851 1852
            atol = 0

1853
        if self.is_bfloat16_op():
Y
Yiqun Liu 已提交
1854 1855
            if self.is_mkldnn_op():
                check_dygraph = False
1856
                check_eager = False
Y
Yiqun Liu 已提交
1857
                if hasattr(self, 'force_fp32_output') and getattr(
1858 1859
                    self, 'force_fp32_output'
                ):
Y
Yiqun Liu 已提交
1860 1861 1862
                    atol = 1e-2
                else:
                    atol = 2
1863
            else:
1864
                atol = 1e-1
1865

1866
        if no_check_set is not None:
1867 1868 1869 1870
            if (
                self.op_type
                not in no_check_set_white_list.no_check_set_white_list
            ):
1871
                raise AssertionError(
1872 1873
                    "no_check_set of op %s must be set to None." % self.op_type
                )
1874 1875 1876
        static_checker = StaticChecker(self, self.outputs)
        static_checker.check()
        outs, fetch_list = static_checker.outputs, static_checker.fetch_list
L
lujun 已提交
1877
        if check_dygraph:
1878 1879
            # always enable legacy dygraph
            g_enable_legacy_dygraph()
1880 1881 1882
            dygraph_checker = DygraphChecker(self, self.outputs)
            dygraph_checker.check()
            dygraph_outs = dygraph_checker.outputs
1883 1884
            # yield the original state
            g_disable_legacy_dygraph()
1885
        if check_eager:
1886 1887 1888
            eager_checker = EagerChecker(self, self.outputs)
            eager_checker.check()
            eager_dygraph_outs = eager_checker.outputs
1889

C
cc 已提交
1890
        # Note(zhiqiu): inplace_atol should be only set when op doesn't ensure
L
Leo Chen 已提交
1891 1892
        # computational consistency.
        # For example, group_norm uses AtomicAdd on CUDAPlace, which do not ensure
C
cc 已提交
1893
        # computation order when multiple threads write the same address. So the
L
Leo Chen 已提交
1894 1895 1896
        # result of group_norm is non-deterministic when datatype is float.
        # When inplace_atol is not None, the inplace check uses numpy.allclose
        # to check inplace result instead of numpy.array_equal.
1897 1898
        if inplace_atol is not None:
            warnings.warn(
L
Leo Chen 已提交
1899 1900
                "inplace_atol should only be set when op doesn't ensure computational consistency, please check it!"
            )
1901
        # Check inplace for given op, its grad op, its grad_grad op, etc.
C
cc 已提交
1902
        # No effect on original OpTest
1903
        # Currently not support ParallelExecutor on XPUPlace.
1904 1905 1906 1907 1908 1909 1910 1911 1912
        if (
            not paddle.is_compiled_with_xpu()
            and not paddle.is_compiled_with_npu()
            and not paddle.is_compiled_with_mlu()
            and not isinstance(place, core.CustomPlace)
        ):
            self.check_inplace_output_with_place(
                place, no_check_set=no_check_set, inplace_atol=inplace_atol
            )
1913

1914
        if check_eager:
1915 1916
            assert check_dygraph == False
            return outs, eager_dygraph_outs, fetch_list
1917
        elif check_dygraph:
1918 1919 1920 1921 1922 1923 1924
            return outs, dygraph_outs, fetch_list
        else:
            return outs, fetch_list

    def check_compile_vs_runtime(self, fetch_list, fetch_outs):
        def find_fetch_index(target_name, fetch_list):
            found = [
1925 1926
                i
                for i, var_name in enumerate(fetch_list)
1927 1928 1929 1930 1931 1932 1933
                if var_name == target_name
            ]
            if len(found) == 0:
                return -1
            else:
                self.assertTrue(
                    len(found) == 1,
1934 1935
                    "Found {} {}".format(len(found), target_name),
                )
1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960
                return found[0]

        for name in self.op.desc.output_names():
            var_names = self.op.desc.output(name)
            for var_name in var_names:
                i = find_fetch_index(var_name, fetch_list)
                if i == -1:
                    # The output is dispensiable or intermediate.
                    break
                out = fetch_outs[i]
                if isinstance(out, core.LoDTensor):
                    lod_level_runtime = len(out.lod())
                else:
                    if isinstance(out, core.LoDTensorArray):
                        warnings.warn(
                            "The check of LoDTensorArray's lod_level is not implemented now!"
                        )
                    lod_level_runtime = 0

                var = self.program.global_block().var(var_name)
                if var.type == core.VarDesc.VarType.LOD_TENSOR:
                    lod_level_compile = var.lod_level
                else:
                    lod_level_compile = 0
                self.assertEqual(
1961 1962 1963 1964 1965 1966 1967 1968 1969 1970
                    lod_level_compile,
                    lod_level_runtime,
                    "The lod_level of Output ("
                    + name
                    + ") is different between compile-time and runtime ("
                    + str(lod_level_compile)
                    + " vs "
                    + str(lod_level_runtime)
                    + ")",
                )
1971

1972
    def _get_places(self):
D
dzhwinter 已提交
1973 1974
        if self.dtype == np.float16:
            if core.is_compiled_with_cuda() and core.op_support_gpu(
1975 1976
                self.op_type
            ):
D
dzhwinter 已提交
1977 1978 1979
                place = core.CUDAPlace(0)
                if core.is_float16_supported(place):
                    return [place]
W
Wu Yi 已提交
1980 1981
                else:
                    return []
D
dzhwinter 已提交
1982 1983
            else:
                return []
1984
        places = [fluid.CPUPlace()]
1985
        cpu_only = self._cpu_only if hasattr(self, '_cpu_only') else False
1986 1987 1988 1989 1990
        if (
            core.is_compiled_with_cuda()
            and core.op_support_gpu(self.op_type)
            and not cpu_only
        ):
D
dzhwinter 已提交
1991
            places.append(core.CUDAPlace(0))
1992 1993
        return places

1994 1995 1996 1997 1998 1999 2000 2001 2002
    def check_output(
        self,
        atol=1e-5,
        no_check_set=None,
        equal_nan=False,
        check_dygraph=True,
        inplace_atol=None,
        check_eager=False,
    ):
2003 2004 2005 2006 2007

        # disable legacy dygraph check when check_eager is True
        if check_eager == True:
            check_dygraph = False

2008
        self.__class__.op_type = self.op_type
Y
Yiqun Liu 已提交
2009
        if self.is_mkldnn_op():
2010
            self.__class__.use_mkldnn = True
C
cc 已提交
2011

Y
Yiqun Liu 已提交
2012
        if self.is_xpu_op():
2013 2014
            self.__class__.use_xpu = True

2015
        places = self._get_places()
Q
qijun 已提交
2016
        for place in places:
2017 2018 2019 2020 2021 2022 2023 2024 2025
            res = self.check_output_with_place(
                place,
                atol,
                no_check_set,
                equal_nan,
                check_dygraph,
                inplace_atol,
                check_eager=check_eager,
            )
2026
            if check_eager:
2027 2028
                assert check_dygraph == False
                outs, eager_dygraph_outs, fetch_list = res
2029
            elif check_dygraph:
2030 2031 2032
                outs, dygraph_outs, fetch_list = res
            else:
                outs, fetch_list = res
2033 2034 2035 2036
            if (
                self.op_type
                not in compile_vs_runtime_white_list.COMPILE_RUN_OP_WHITE_LIST
            ):
2037
                self.check_compile_vs_runtime(fetch_list, outs)
Q
qijun 已提交
2038

P
pangyoki 已提交
2039
    def check_output_customized(self, checker, custom_place=None):
2040
        places = self._get_places()
P
pangyoki 已提交
2041 2042
        if custom_place:
            places.append(custom_place)
2043 2044 2045
        for place in places:
            outs = self.calc_output(place)
            outs = [np.array(out) for out in outs]
2046
            outs.sort(key=len)
2047 2048
            checker(outs)

2049 2050 2051 2052 2053 2054
    def check_output_with_place_customized(self, checker, place):
        outs = self.calc_output(place)
        outs = [np.array(out) for out in outs]
        outs.sort(key=len)
        checker(outs)

2055 2056 2057 2058 2059 2060 2061 2062
    def _assert_is_close(
        self,
        numeric_grads,
        analytic_grads,
        names,
        max_relative_error,
        msg_prefix,
    ):
2063
        for a, b, name in zip(numeric_grads, analytic_grads, names):
2064 2065 2066 2067 2068 2069
            # It asserts np.abs(a - b) / np.abs(a) < max_relative_error, in which
            # max_relative_error is 1e-7. According to the value of np.abs(a), we
            # change np.abs(a) to achieve dynamic threshold. For example, if
            # the value of np.abs(a) is between 1e-10 and 1e-8, we set np.abs(a)*=1e4.
            # Therefore, it asserts np.abs(a - b) / (np.abs(a)*1e4) < max_relative_error,
            # which is the same as np.abs(a - b) / np.abs(a) < max_relative_error*1e4.
2070
            abs_a = np.abs(a)
2071
            if abs_a.ndim > 0:
2072 2073 2074 2075 2076
                if (
                    self.dtype == np.float64
                    and self.op_type
                    not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
                ):
2077 2078 2079 2080 2081 2082 2083 2084
                    abs_a[abs_a < 1e-10] = 1e-3
                    abs_a[np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)] *= 1e4
                    abs_a[np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)] *= 1e2
                elif self.is_bfloat16_op():
                    abs_a[abs_a < 1e-2] = 1
                else:
                    abs_a[abs_a < 1e-3] = 1
            elif abs_a.ndim == 0:
2085 2086 2087 2088 2089
                if (
                    self.dtype == np.float64
                    and self.op_type
                    not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
                ):
2090 2091 2092 2093 2094 2095 2096 2097 2098 2099
                    if abs_a < 1e-10:
                        abs_a = 1e-3
                    elif abs_a > 1e-10 and abs_a <= 1e-8:
                        abs_a = abs_a * 1e4
                    elif abs_a > 1e-8 and abs_a <= 1e-6:
                        abs_a = abs_a * 1e2
                elif self.is_bfloat16_op():
                    abs_a = 1 if abs_a < 1e-2 else abs_a
                else:
                    abs_a = 1 if abs_a < 1e-3 else abs_a
2100 2101 2102 2103 2104 2105

            diff_mat = np.abs(a - b) / abs_a
            max_diff = np.max(diff_mat)

            def err_msg():
                offset = np.argmax(diff_mat > max_relative_error)
2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120
                return (
                    "Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff %e over limit %e, "
                    "the first error element is %d, expected %e, but got %e."
                ) % (
                    self.op_type,
                    msg_prefix,
                    name,
                    str(a.shape),
                    self.dtype,
                    max_diff,
                    max_relative_error,
                    offset,
                    a.flatten()[offset],
                    b.flatten()[offset],
                )
2121 2122 2123

            self.assertLessEqual(max_diff, max_relative_error, err_msg())

2124 2125 2126 2127 2128 2129 2130
    def _check_grad_helper(self):
        self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
        self.__class__.op_type = self.op_type
        self.__class__.exist_check_grad = True
        if self.dtype == np.float64:
            self.__class__.exist_fp64_check_grad = True

2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143
    def check_grad(
        self,
        inputs_to_check,
        output_names,
        no_grad_set=None,
        numeric_grad_delta=0.005,
        in_place=False,
        max_relative_error=0.005,
        user_defined_grads=None,
        user_defined_grad_outputs=None,
        check_dygraph=True,
        check_eager=False,
    ):
2144 2145 2146 2147 2148

        # disable legacy dygraph check when check_eager is True
        if check_eager == True:
            check_dygraph = False

2149
        self._check_grad_helper()
2150
        places = self._get_places()
2151
        for place in places:
2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180
            self.check_grad_with_place(
                place,
                inputs_to_check,
                output_names,
                no_grad_set,
                numeric_grad_delta,
                in_place,
                max_relative_error,
                user_defined_grads,
                user_defined_grad_outputs,
                check_dygraph,
                check_eager=check_eager,
            )

    def check_grad_with_place(
        self,
        place,
        inputs_to_check,
        output_names,
        no_grad_set=None,
        numeric_grad_delta=0.005,
        in_place=False,
        max_relative_error=0.005,
        user_defined_grads=None,
        user_defined_grad_outputs=None,
        check_dygraph=True,
        numeric_place=None,
        check_eager=False,
    ):
2181 2182 2183 2184 2185

        # disable legacy dygraph check when check_eager is True
        if check_eager == True:
            check_dygraph = False

2186
        self.scope = core.Scope()
Q
qijun 已提交
2187
        op_inputs = self.inputs if hasattr(self, "inputs") else dict()
2188
        op_outputs = self.outputs if hasattr(self, "outputs") else dict()
Q
qijun 已提交
2189
        op_attrs = self.attrs if hasattr(self, "attrs") else dict()
P
phlrain 已提交
2190

Y
Yiqun Liu 已提交
2191 2192
        self._check_grad_helper()
        if self.is_bfloat16_op() and self.is_mkldnn_op():
2193
            check_dygraph = False
2194
            check_eager = False
2195

2196 2197 2198 2199 2200
        if (
            self.dtype == np.float64
            and self.op_type
            not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
        ):
2201 2202
            numeric_grad_delta = 1e-5
            max_relative_error = 1e-7
2203

P
phlrain 已提交
2204 2205 2206
        cache_list = None
        if hasattr(self, "cache_name_list"):
            cache_list = self.cache_name_list
2207 2208 2209 2210 2211 2212 2213

        # oneDNN numeric gradient should use CPU kernel
        use_onednn = False
        if "use_mkldnn" in op_attrs and op_attrs["use_mkldnn"] == True:
            op_attrs["use_mkldnn"] = False
            use_onednn = True

2214 2215 2216 2217 2218 2219 2220 2221
        self.op = create_op(
            self.scope,
            self.op_type,
            op_inputs,
            op_outputs,
            op_attrs,
            cache_list=cache_list,
        )
Y
Yu Yang 已提交
2222

2223 2224 2225
        if use_onednn:
            op_attrs["use_mkldnn"] = True

2226 2227
        if no_grad_set is None:
            no_grad_set = set()
2228
        else:
2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240
            if (
                (self.op_type not in no_grad_set_white_list.NEED_TO_FIX_OP_LIST)
                and (
                    self.op_type not in no_grad_set_white_list.NOT_CHECK_OP_LIST
                )
                and (not self.is_bfloat16_op())
            ):
                raise AssertionError(
                    "no_grad_set must be None, op_type is "
                    + self.op_type
                    + " Op."
                )
2241

2242 2243 2244
        for input_to_check in inputs_to_check:
            set_input(self.scope, self.op, self.inputs, place)
            tensor_to_check = self.scope.find_var(input_to_check).get_tensor()
2245 2246 2247
            tensor_size = functools.reduce(
                lambda a, b: a * b, tensor_to_check.shape(), 1
            )
2248 2249 2250
            tensor_ndim = len(tensor_to_check.shape())
            # for 0D Tensor, it's additional case for OP, so not raise error
            if tensor_ndim > 0 and tensor_size < 100:
2251 2252
                self.__class__.input_shape_is_large = False

Y
Yancey 已提交
2253 2254 2255
        if not type(output_names) is list:
            output_names = [output_names]

2256 2257 2258
        if numeric_place is None:
            numeric_place = place

Q
Qiao Longfei 已提交
2259
        numeric_grads = user_defined_grads or [
2260 2261 2262 2263 2264 2265 2266 2267 2268 2269
            get_numeric_gradient(
                numeric_place,
                self.scope,
                self.op,
                self.inputs,
                input_to_check,
                output_names,
                delta=numeric_grad_delta,
                in_place=in_place,
            )
2270
            for input_to_check in inputs_to_check
2271
        ]
2272 2273 2274 2275 2276 2277 2278
        analytic_grads = self._get_gradient(
            inputs_to_check,
            place,
            output_names,
            no_grad_set,
            user_defined_grad_outputs,
        )
2279 2280
        # comparison of bf16 results will happen as fp32
        # loop over list of grads and convert bf16 to fp32
2281
        fp32_analytic_grads = []
2282 2283 2284
        for grad in analytic_grads:
            if grad.dtype == np.uint16:
                grad = convert_uint16_to_float(grad)
2285 2286 2287
                max_relative_error = (
                    0.04 if max_relative_error < 0.04 else max_relative_error
                )
2288 2289 2290 2291 2292 2293 2294
            fp32_analytic_grads.append(grad)
        analytic_grads = fp32_analytic_grads

        fp32_numeric_grads = []
        for grad in numeric_grads:
            if grad.dtype == np.uint16:
                grad = convert_uint16_to_float(grad)
2295 2296 2297
                max_relative_error = (
                    0.04 if max_relative_error < 0.04 else max_relative_error
                )
2298 2299
            fp32_numeric_grads.append(grad)
        numeric_grads = fp32_numeric_grads
2300

2301 2302 2303 2304 2305 2306 2307
        self._assert_is_close(
            numeric_grads,
            analytic_grads,
            inputs_to_check,
            max_relative_error,
            "Gradient Check On %s" % str(place),
        )
Q
qijun 已提交
2308

2309
        if check_dygraph:
2310 2311 2312
            # ensure switch into legacy dygraph
            g_enable_legacy_dygraph()

2313 2314 2315 2316 2317 2318 2319 2320
            dygraph_grad = self._get_dygraph_grad(
                inputs_to_check,
                place,
                output_names,
                user_defined_grad_outputs,
                no_grad_set,
                False,
            )
2321 2322 2323 2324
            fp32_grads = []
            for grad in dygraph_grad:
                if grad.dtype == np.uint16:
                    grad = convert_uint16_to_float(grad)
2325 2326 2327 2328 2329
                    max_relative_error = (
                        0.03
                        if max_relative_error < 0.03
                        else max_relative_error
                    )
2330 2331
                fp32_grads.append(grad)
            dygraph_grad = fp32_grads
2332 2333 2334 2335 2336 2337 2338
            self._assert_is_close(
                numeric_grads,
                dygraph_grad,
                inputs_to_check,
                max_relative_error,
                "Gradient Check On %s" % str(place),
            )
2339 2340
            # ensure switch back eager dygraph
            g_disable_legacy_dygraph()
2341

2342
        if check_eager:
J
Jiabin Yang 已提交
2343 2344 2345
            with fluid.dygraph.base.guard(place):
                with _test_eager_guard():
                    eager_dygraph_grad = self._get_dygraph_grad(
2346 2347 2348 2349 2350 2351 2352
                        inputs_to_check,
                        place,
                        output_names,
                        user_defined_grad_outputs,
                        no_grad_set,
                        check_eager,
                    )
J
Jiabin Yang 已提交
2353 2354 2355 2356
                    fp32_grads = []
                    for grad in eager_dygraph_grad:
                        if grad.dtype == np.uint16:
                            grad = convert_uint16_to_float(grad)
2357 2358 2359 2360 2361
                            max_relative_error = (
                                0.03
                                if max_relative_error < 0.03
                                else max_relative_error
                            )
J
Jiabin Yang 已提交
2362 2363
                        fp32_grads.append(grad)
                    eager_dygraph_grad = fp32_grads
2364 2365 2366 2367 2368 2369 2370
                    self._assert_is_close(
                        numeric_grads,
                        eager_dygraph_grad,
                        inputs_to_check,
                        max_relative_error,
                        "Gradient Check On %s" % str(place),
                    )
2371

2372 2373 2374 2375 2376 2377 2378 2379 2380
    def _find_var_in_dygraph(self, output_vars, name):
        if name in output_vars:
            return output_vars[name]
        else:
            for output_vars_index in output_vars:
                for output_vars_selected in output_vars[output_vars_index]:
                    if output_vars_selected.name == name:
                        return output_vars_selected

2381 2382 2383 2384 2385 2386 2387 2388 2389
    def _get_dygraph_grad(
        self,
        inputs_to_check,
        place,
        output_names,
        user_defined_grad_outputs=None,
        no_grad_set=None,
        check_eager=False,
    ):
2390 2391 2392 2393 2394 2395 2396
        with fluid.dygraph.base.guard(place=place):
            block = fluid.default_main_program().global_block()

            op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)

            # prepare input variable
            inputs, inputs_grad_dict = self.append_input_output_for_dygraph(
2397 2398
                op_proto, self.inputs, True, True, block
            )
2399 2400 2401

            # prepare output variable
            outputs = self.append_input_output_for_dygraph(
2402 2403
                op_proto, self.outputs, False, False, block
            )
2404

2405
            # prepare attributes
2406 2407 2408 2409 2410
            attrs_outputs = {}
            if hasattr(self, "attrs"):
                for attrs_name in self.attrs:
                    if self.attrs[attrs_name] is not None:
                        attrs_outputs[attrs_name] = self.attrs[attrs_name]
2411

2412
            if check_eager:
2413
                eager_outputs = self._calc_python_api_output(
2414 2415
                    place, inputs, outputs
                )
2416
            # if outputs is None, kernel sig is empty or other error is happens.
X
xiongkun 已提交
2417
            if not check_eager or eager_outputs is None:
2418 2419 2420 2421
                block.append_op(
                    type=self.op_type,
                    inputs=inputs,
                    outputs=outputs,
2422 2423
                    attrs=attrs_outputs if hasattr(self, "attrs") else None,
                )
X
xiongkun 已提交
2424 2425
            else:
                outputs = eager_outputs
2426

2427
            if self.dtype == np.uint16:
2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442
                cast_inputs = self._find_var_in_dygraph(
                    outputs, output_names[0]
                )
                cast_outputs = block.create_var(
                    dtype="float32", shape=cast_inputs[0].shape
                )
                cast_op = block.append_op(
                    inputs={"X": cast_inputs},
                    outputs={"Out": cast_outputs},
                    type="cast",
                    attrs={
                        "in_dtype": core.VarDesc.VarType.BF16,
                        "out_dtype": core.VarDesc.VarType.FP32,
                    },
                )
2443 2444
                outputs = {output_names[0]: cast_outputs}

2445 2446 2447
            outputs_valid = {}
            for output_name in output_names:
                outputs_valid[output_name] = self._find_var_in_dygraph(
2448 2449
                    outputs, output_name
                )
2450

2451 2452 2453 2454 2455 2456 2457
            if user_defined_grad_outputs is None:
                if len(outputs_valid) == 1:
                    loss = block.create_var(
                        dtype=self.dtype,
                        type=core.VarDesc.VarType.LOD_TENSOR,
                        persistable=False,
                        stop_gradient=False,
2458 2459
                        shape=[1],
                    )
2460 2461 2462 2463 2464
                    for outputs_valid_key in outputs_valid:
                        block.append_op(
                            type="mean",
                            inputs={"X": outputs_valid[outputs_valid_key]},
                            outputs={"Out": [loss]},
2465 2466
                            attrs=None,
                        )
2467 2468 2469 2470 2471 2472 2473
                else:
                    avg_sum = []
                    for cur_loss in outputs_valid:
                        cur_avg_loss = block.create_var(
                            dtype=self.dtype,
                            type=core.VarDesc.VarType.LOD_TENSOR,
                            persistable=False,
2474 2475 2476 2477 2478 2479 2480 2481
                            stop_gradient=False,
                        )
                        block.append_op(
                            type="mean",
                            inputs={"X": outputs_valid[cur_loss]},
                            outputs={"Out": [cur_avg_loss]},
                            attrs=None,
                        )
2482 2483 2484 2485 2486 2487
                        avg_sum.append(cur_avg_loss)
                    loss_sum = block.create_var(
                        dtype=self.dtype,
                        type=core.VarDesc.VarType.LOD_TENSOR,
                        persistable=False,
                        stop_gradient=False,
2488 2489 2490 2491 2492 2493 2494 2495
                        shape=[1],
                    )
                    block.append_op(
                        type='sum',
                        inputs={"X": avg_sum},
                        outputs={"Out": loss_sum},
                        attrs=None,
                    )
2496
                    loss = block.create_var(
2497 2498 2499
                        dtype=self.dtype,
                        type=core.VarDesc.VarType.LOD_TENSOR,
                        persistable=False,
2500
                        stop_gradient=False,
2501 2502 2503 2504 2505 2506 2507 2508
                        shape=[1],
                    )
                    block.append_op(
                        type='scale',
                        inputs={"X": loss_sum},
                        outputs={"Out": loss},
                        attrs={'scale': 1.0 / float(len(avg_sum))},
                    )
2509
                loss.backward()
2510

2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522
                fetch_list_grad = []
                for inputs_to_check_name in inputs_to_check:
                    a = inputs_grad_dict[inputs_to_check_name].gradient()
                    fetch_list_grad.append(a)
                return fetch_list_grad
            else:
                # user_defined_grad_outputs here are numpy arrays
                if not isinstance(user_defined_grad_outputs, list):
                    user_defined_grad_outputs = [user_defined_grad_outputs]
                grad_outputs = []
                for grad_out_value in user_defined_grad_outputs:
                    grad_outputs.append(paddle.to_tensor(grad_out_value))
2523
                # delete the inputs which no need to calculate grad
C
chentianyu03 已提交
2524
                for no_grad_val in no_grad_set:
2525
                    del inputs[no_grad_val]
C
chentianyu03 已提交
2526

J
Jiabin Yang 已提交
2527
                if not _in_legacy_dygraph():
2528 2529 2530
                    core.eager.run_backward(
                        fluid.layers.utils.flatten(outputs), grad_outputs, False
                    )
2531 2532 2533 2534 2535 2536 2537 2538 2539
                    grad_inputs = []
                    for inputs_list in inputs.values():
                        for inp in inputs_list:
                            grad_inputs.append(inp.grad.numpy())
                    return grad_inputs
                else:
                    grad_inputs = paddle.grad(
                        outputs=fluid.layers.utils.flatten(outputs),
                        inputs=fluid.layers.utils.flatten(inputs),
2540 2541
                        grad_outputs=grad_outputs,
                    )
2542
                    return [grad.numpy() for grad in grad_inputs]
2543

Y
Yu Yang 已提交
2544 2545 2546 2547 2548
    @staticmethod
    def _numpy_to_lod_tensor(np_value, lod, place):
        tensor = core.LoDTensor()
        tensor.set(np_value, place)
        if lod is not None:
2549
            tensor.set_recursive_sequence_lengths(lod)
Y
Yu Yang 已提交
2550 2551
        return tensor

K
Kexin Zhao 已提交
2552
    @staticmethod
K
Kexin Zhao 已提交
2553 2554
    def np_dtype_to_fluid_dtype(input):
        return input
K
Kexin Zhao 已提交
2555

D
dzhwinter 已提交
2556 2557 2558 2559 2560 2561 2562 2563
    @staticmethod
    def fluid_dtype_to_np_dtype(self, dtype):
        return dtype

    @staticmethod
    def np_value_to_fluid_value(input):
        return input

2564 2565 2566 2567 2568 2569 2570 2571 2572
    def _get_gradient(
        self,
        input_to_check,
        place,
        output_names,
        no_grad_set,
        user_defined_grad_outputs=None,
        parallel=False,
    ):
Y
Yu Yang 已提交
2573
        prog = Program()
2574
        scope = core.Scope()
Y
Yu Yang 已提交
2575
        block = prog.global_block()
2576
        self._append_ops(block)
Y
Yu Yang 已提交
2577

2578
        inputs = self._get_inputs(block)
2579
        outputs = self._get_outputs(block)
2580
        feed_dict = self.feed_var(inputs, place)
Y
Yu Yang 已提交
2581

2582
        if user_defined_grad_outputs is None:
2583 2584
            if self.dtype == np.uint16:
                cast_inputs = list(map(block.var, output_names))
2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596
                cast_outputs = block.create_var(
                    dtype="float32", shape=cast_inputs[0].shape
                )
                cast_op = block.append_op(
                    inputs={"X": cast_inputs},
                    outputs={"Out": cast_outputs},
                    type="cast",
                    attrs={
                        "in_dtype": core.VarDesc.VarType.BF16,
                        "out_dtype": core.VarDesc.VarType.FP32,
                    },
                )
2597 2598 2599
                cast_op.desc.infer_var_type(block.desc)
                cast_op.desc.infer_shape(block.desc)
                output_names = [cast_outputs.name]
2600
            loss = append_loss_ops(block, output_names)
2601 2602 2603 2604 2605
            param_grad_list = append_backward(
                loss=loss,
                parameter_list=input_to_check,
                no_grad_set=no_grad_set,
            )
2606 2607
            fetch_list = [g for p, g in param_grad_list]
        else:
2608 2609 2610
            assert (
                parallel is False
            ), "unsupported parallel mode when giving custom grad outputs."
2611 2612 2613 2614 2615 2616
            # user_defined_grad_outputs here are numpy arrays
            if not isinstance(user_defined_grad_outputs, list):
                user_defined_grad_outputs = [user_defined_grad_outputs]
            grad_outputs = []
            for grad_out_value in user_defined_grad_outputs:
                # `presistable` is used to avoid executor create new var in local scope
2617 2618 2619 2620 2621
                var = block.create_var(
                    shape=grad_out_value.shape,
                    dtype=grad_out_value.dtype,
                    persistable=True,
                )
2622 2623 2624 2625 2626 2627 2628
                true_var = scope.var(var.name)
                tensor = true_var.get_tensor()
                tensor.set(grad_out_value, place)
                grad_outputs.append(var)
            targets = [
                outputs[name] for name in outputs if name in output_names
            ]
2629
            inputs = [inputs[name] for name in input_to_check if name in inputs]
2630 2631 2632
            grad_inputs = paddle.static.gradients(
                targets, inputs, grad_outputs, no_grad_set
            )
2633 2634
            fetch_list = grad_inputs

2635 2636
        if parallel:
            use_cuda = False
2637
            if isinstance(place, fluid.CUDAPlace):
2638
                use_cuda = True
2639
            compiled_prog = fluid.CompiledProgram(prog).with_data_parallel(
2640 2641
                loss_name=loss.name, places=place
            )
2642 2643
            prog = compiled_prog
        executor = fluid.Executor(place)
2644
        return list(
2645 2646
            map(
                np.array,
2647 2648 2649 2650 2651
                executor.run(
                    prog, feed_dict, fetch_list, scope=scope, return_numpy=False
                ),
            )
        )
A
arlesniak 已提交
2652 2653 2654 2655 2656 2657 2658 2659 2660 2661


class OpTestTool:
    @classmethod
    def skip_if(cls, condition: object, reason: str):
        return unittest.skipIf(condition, reason)

    @classmethod
    def skip_if_not_cpu_bf16(cls):
        return OpTestTool.skip_if(
2662 2663 2664 2665 2666 2667
            not (
                isinstance(_current_expected_place(), core.CPUPlace)
                and core.supports_bfloat16()
            ),
            "Place does not support BF16 evaluation",
        )
2668 2669 2670 2671 2672

    @classmethod
    def skip_if_not_cpu(cls):
        return OpTestTool.skip_if(
            not isinstance(_current_expected_place(), core.CPUPlace),
2673 2674
            "OneDNN supports only CPU for now",
        )