op_test.py 102.1 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

B
baojun 已提交
15
import os
16
import sys
17
import unittest
18
import warnings
19
import numpy as np
20
import random
21
import functools
22
import struct
M
minqiyang 已提交
23
from collections import defaultdict
24
from copy import copy
25

26
import paddle
27
import paddle.fluid as fluid
28
from paddle.fluid.framework import _dygraph_tracer
29
import paddle.fluid.core as core
30 31 32 33 34 35
from paddle.fluid.framework import (
    _in_legacy_dygraph,
    _enable_legacy_dygraph,
    _in_eager_without_dygraph_check,
    _disable_legacy_dygraph,
)
36
from paddle.fluid.framework import _test_eager_guard
37 38 39
from paddle.fluid.backward import append_backward
from paddle.fluid.op import Operator
from paddle.fluid.executor import Executor
40 41 42 43 44
from paddle.fluid.framework import (
    OpProtoHolder,
    Program,
    _current_expected_place,
)
45 46 47 48 49
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs

sys.path.append(os.path.abspath(os.path.dirname(__file__)))
from testsuite import (
50 51 52
    create_op,
    set_input,
    append_input_output,
53 54
    append_loss_ops,
)
55
from white_list import (
56 57 58 59 60
    op_accuracy_white_list,
    check_shape_white_list,
    compile_vs_runtime_white_list,
    no_check_set_white_list,
    op_threshold_white_list,
61 62
    no_grad_set_white_list,
)
63

64 65
# For switch new eager mode globally
g_is_in_eager = _in_eager_without_dygraph_check()
66 67 68 69 70 71
g_enable_legacy_dygraph = (
    _enable_legacy_dygraph if g_is_in_eager else lambda: None
)
g_disable_legacy_dygraph = (
    _disable_legacy_dygraph if g_is_in_eager else lambda: None
)
72

73

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
def check_out_dtype(api_fn, in_specs, expect_dtypes, target_index=0, **configs):
    """
    Determines whether dtype of output tensor is as expected.

    Args:
        api_fn(callable):  paddle api function
        in_specs(list[tuple]): list of shape and dtype information for constructing input tensor of api_fn, such as [(shape, dtype), (shape, dtype)].
        expected_dtype(list[str]): expected dtype of output tensor.
        target_index(int): indicate which one from in_specs to infer the dtype of output.
        config(dict): other arguments of paddle api function

    Example:
        check_out_dtype(fluid.layers.pad_constant_like, [([2,3,2,3], 'float64'), ([1, 3, 1,3], )], ['float32', 'float64', 'int64'], target_index=1, pad_value=0.)

    """
    paddle.enable_static()
    for i, expect_dtype in enumerate(expect_dtypes):
        with paddle.static.program_guard(paddle.static.Program()):
            input_t = []
            for index, spec in enumerate(in_specs):
                if len(spec) == 1:
                    shape = spec[0]
                    dtype = expect_dtype if target_index == index else 'float32'
                elif len(spec) == 2:
                    shape, dtype = spec
                else:
                    raise ValueError(
101 102 103 104
                        "Value of in_specs[{}] should contains two elements: [shape, dtype]".format(
                            index
                        )
                    )
105
                input_t.append(
106 107 108 109
                    paddle.static.data(
                        name='data_%s' % index, shape=shape, dtype=dtype
                    )
                )
110 111 112 113 114 115 116

            out = api_fn(*input_t, **configs)
            out_dtype = fluid.data_feeder.convert_dtype(out.dtype)

            if out_dtype != expect_dtype:
                raise ValueError(
                    "Expected out.dtype is {}, but got {} from {}.".format(
117 118 119
                        expect_dtype, out_dtype, api_fn.__name__
                    )
                )
120 121


122 123 124 125 126 127 128 129
def _set_use_system_allocator(value=None):
    USE_SYSTEM_ALLOCATOR_FLAG = "FLAGS_use_system_allocator"
    old_value = core.globals()[USE_SYSTEM_ALLOCATOR_FLAG]
    value = old_value if value is None else value
    core.globals()[USE_SYSTEM_ALLOCATOR_FLAG] = value
    return old_value


130
def randomize_probability(batch_size, class_num, dtype='float32'):
131 132 133
    prob = np.random.uniform(0.1, 1.0, size=(batch_size, class_num)).astype(
        dtype
    )
134
    prob_sum = prob.sum(axis=1)
135
    for i in range(len(prob)):
136 137 138 139
        prob[i] /= prob_sum[i]
    return prob


140 141 142 143 144 145 146 147 148 149
def get_numeric_gradient(
    place,
    scope,
    op,
    inputs,
    input_to_check,
    output_names,
    delta=0.005,
    in_place=False,
):
Y
Yu Yang 已提交
150
    # FIXME: change this method by compile time concepts
151
    set_input(scope, op, inputs, place)
152 153

    def product(dim):
154
        return functools.reduce(lambda a, b: a * b, dim, 1)
155 156

    tensor_to_check = scope.find_var(input_to_check).get_tensor()
Y
yuyang18 已提交
157 158
    tensor_size = product(tensor_to_check.shape())
    tensor_to_check_dtype = tensor_to_check._dtype()
159
    if tensor_to_check_dtype == core.VarDesc.VarType.FP32:
160
        tensor_to_check_dtype = np.float32
161
    elif tensor_to_check_dtype == core.VarDesc.VarType.FP64:
162
        tensor_to_check_dtype = np.float64
D
dzhwinter 已提交
163 164 165 166
    elif tensor_to_check_dtype == core.VarDesc.VarType.FP16:
        tensor_to_check_dtype = np.float16
        # set delta as np.float16, will automatic convert to float32, float64
        delta = np.array(delta).astype(np.float16)
167 168
    elif tensor_to_check_dtype == core.VarDesc.VarType.BF16:
        tensor_to_check_dtype = np.float32
L
Lijunhui 已提交
169 170 171
    elif tensor_to_check_dtype == core.VarDesc.VarType.COMPLEX64:
        tensor_to_check_dtype = np.complex64
    elif tensor_to_check_dtype == core.VarDesc.VarType.COMPLEX128:
172
        tensor_to_check_dtype = np.complex128
173
    else:
174 175 176 177 178 179
        raise ValueError(
            "Not supported data type "
            + str(tensor_to_check_dtype)
            + ", tensor name : "
            + str(input_to_check)
        )
180

C
chengduo 已提交
181 182 183 184
    def get_output():
        sum = []
        op.run(scope, place)
        for output_name in output_names:
185
            output_numpy = np.array(scope.find_var(output_name).get_tensor())
Y
Yiqun Liu 已提交
186 187 188
            # numpy.dtype does not have bfloat16, thus we use numpy.uint16 to
            # store bfloat16 data, and need to be converted to float to check
            # the floating precision.
189 190 191
            if tensor_to_check._dtype() == core.VarDesc.VarType.BF16:
                output_numpy = convert_uint16_to_float(output_numpy)
            sum.append(output_numpy.astype(tensor_to_check_dtype).mean())
C
chengduo 已提交
192 193
        return tensor_to_check_dtype(np.array(sum).sum() / len(output_names))

194
    gradient_flat = np.zeros(shape=(tensor_size,), dtype=tensor_to_check_dtype)
195 196

    def __get_elem__(tensor, i):
D
dzhwinter 已提交
197 198 199 200
        if tensor_to_check_dtype == np.float16:
            numpy_tensor = np.array(tensor).astype(np.float16)
            numpy_tensor = numpy_tensor.flatten()
            return numpy_tensor[i]
201 202 203
        elif tensor_to_check._dtype() == core.VarDesc.VarType.BF16:
            numpy_tensor = np.array(tensor).astype(np.uint16)
            numpy_tensor = numpy_tensor.flatten()
204 205
            return struct.unpack(
                '<f',
206 207
                struct.pack('<I', np.uint32(numpy_tensor[i]) << np.uint32(16)),
            )[0]
D
dzhwinter 已提交
208
        elif tensor_to_check_dtype == np.float32:
Y
yuyang18 已提交
209
            return tensor._get_float_element(i)
210
        elif tensor_to_check_dtype == np.float64:
Y
yuyang18 已提交
211
            return tensor._get_double_element(i)
212
        else:
213 214 215
            raise TypeError(
                "Unsupported test data type %s." % tensor_to_check_dtype
            )
216 217

    def __set_elem__(tensor, i, e):
D
dzhwinter 已提交
218 219 220 221 222
        if tensor_to_check_dtype == np.float16:
            numpy_tensor = np.array(tensor).astype(np.float16)
            shape = numpy_tensor.shape
            numpy_tensor = numpy_tensor.flatten()
            numpy_tensor[i] = e
223
            numpy_tensor = numpy_tensor.reshape(shape)
D
dzhwinter 已提交
224
            tensor.set(numpy_tensor, place)
225 226 227 228 229 230 231
        elif tensor_to_check._dtype() == core.VarDesc.VarType.BF16:
            numpy_tensor = np.array(tensor).astype(np.uint16)
            shape = numpy_tensor.shape
            numpy_tensor = numpy_tensor.flatten()
            numpy_tensor[i] = np.uint16(copy_bits_from_float_to_uint16(e))
            numpy_tensor = numpy_tensor.reshape(shape)
            tensor.set(numpy_tensor, place)
D
dzhwinter 已提交
232
        elif tensor_to_check_dtype == np.float32:
Y
yuyang18 已提交
233
            tensor._set_float_element(i, e)
234
        elif tensor_to_check_dtype == np.float64:
Y
yuyang18 已提交
235
            tensor._set_double_element(i, e)
236
        else:
237 238 239
            raise TypeError(
                "Unsupported test data type %s." % tensor_to_check_dtype
            )
240

241 242
    # we only compute gradient of one element each time.
    # we use a for loop to compute the gradient of every element.
243
    for i in range(tensor_size):
244
        if in_place:
245
            set_input(scope, op, inputs, place)
246 247

        # get one input element throw it's index i.
248
        origin = __get_elem__(tensor_to_check, i)
249 250
        # add delta to it, run op and then get the sum of the result tensor.
        x_pos = origin + delta
251
        __set_elem__(tensor_to_check, i, x_pos)
252 253 254
        y_pos = get_output()

        if in_place:
255
            set_input(scope, op, inputs, place)
256 257

        x_neg = origin - delta
258
        __set_elem__(tensor_to_check, i, x_neg)
259 260
        y_neg = get_output()

261
        __set_elem__(tensor_to_check, i, origin)
262 263
        gradient_flat[i] = (y_pos - y_neg) / delta / 2

Y
yuyang18 已提交
264
    return gradient_flat.reshape(tensor_to_check.shape())
265 266


267 268
def skip_check_grad_ci(reason=None):
    """Decorator to skip check_grad CI.
C
cc 已提交
269

270 271 272
    Check_grad is required for Op test cases. However, there are some special
    cases that do not need to do check_grad. This decorator is used to skip the
    check_grad of the above cases.
C
cc 已提交
273

274 275
    Note: the execution of unit test will not be skipped. It just avoids check_grad
    checking in tearDownClass method by setting a `no_need_check_grad` flag.
276

277 278 279
    Example:
        @skip_check_grad_ci(reason="For inference, check_grad is not required.")
        class TestInference(OpTest):
280 281 282 283 284 285 286 287 288 289 290
    """
    if not isinstance(reason, str):
        raise AssertionError("The reason for skipping check_grad is required.")

    def wrapper(cls):
        cls.no_need_check_grad = True
        return cls

    return wrapper


291 292 293
def skip_check_inplace_ci(reason=None):
    if not isinstance(reason, str):
        raise AssertionError(
294 295
            "The reason for skipping check_inplace is required."
        )
296 297 298 299 300 301 302 303

    def wrapper(cls):
        cls.no_need_check_inplace = True
        return cls

    return wrapper


304 305 306 307
def copy_bits_from_float_to_uint16(f):
    return struct.unpack('<I', struct.pack('<f', f))[0] >> 16


308 309 310 311
def convert_float_to_uint16(float_list, data_format="NCHW"):
    if data_format == "NHWC":
        float_list = np.transpose(float_list, [0, 3, 1, 2])

312 313 314
    new_output = []
    for x in np.nditer(float_list):
        new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
315
    new_output = np.reshape(new_output, float_list.shape).view(np.uint16)
316

317 318 319
    if data_format == "NHWC":
        new_output = np.transpose(new_output, [0, 2, 3, 1])
    return new_output
320 321


322 323
def convert_uint16_to_float(in_list):
    in_list = np.asarray(in_list)
324 325 326 327 328 329
    out = np.vectorize(
        lambda x: struct.unpack(
            '<f', struct.pack('<I', np.uint32(x) << np.uint32(16))
        )[0],
        otypes=[np.float32],
    )(in_list.flat)
330
    return np.reshape(out, in_list.shape)
331 332


333
class OpTest(unittest.TestCase):
334 335 336 337 338
    @classmethod
    def setUpClass(cls):
        '''Fix random seeds to remove randomness from tests'''
        cls._np_rand_state = np.random.get_state()
        cls._py_rand_state = random.getstate()
339
        cls.call_once = False
340
        cls.dtype = None
341
        cls.outputs = {}
342
        cls.input_shape_is_large = True
343 344 345 346

        np.random.seed(123)
        random.seed(124)

347 348 349 350
        if paddle.is_compiled_with_npu():
            cls._use_system_allocator = _set_use_system_allocator(False)
        else:
            cls._use_system_allocator = _set_use_system_allocator(True)
351

352 353
    @classmethod
    def tearDownClass(cls):
Y
yuyang18 已提交
354
        """Restore random seeds"""
355 356 357
        np.random.set_state(cls._np_rand_state)
        random.setstate(cls._py_rand_state)

358 359
        _set_use_system_allocator(cls._use_system_allocator)

360 361 362 363
        def is_empty_grad_op(op_type):
            all_op_kernels = core._get_all_register_op_kernels()
            grad_op = op_type + '_grad'
            if grad_op in all_op_kernels.keys():
J
juncaipeng 已提交
364
                if is_mkldnn_op_test():
365 366 367 368 369 370 371 372
                    grad_op_kernels = all_op_kernels[grad_op]
                    for grad_op_kernel in grad_op_kernels:
                        if 'MKLDNN' in grad_op_kernel:
                            return False
                else:
                    return False
            return True

373
        def is_xpu_op_test():
374
            return hasattr(cls, "use_xpu") and cls.use_xpu
375

J
juncaipeng 已提交
376
        def is_mkldnn_op_test():
377
            return hasattr(cls, "use_mkldnn") and cls.use_mkldnn
J
juncaipeng 已提交
378

379 380 381
        def is_rocm_op_test():
            return core.is_compiled_with_rocm()

382
        def is_npu_op_test():
383
            return hasattr(cls, "use_npu") and cls.use_npu
384

385
        def is_mlu_op_test():
386
            return hasattr(cls, "use_mlu") and cls.use_mlu
387

388
        def is_custom_device_op_test():
389
            return hasattr(cls, "use_custom_device") and cls.use_custom_device
390

391 392
        if not hasattr(cls, "op_type"):
            raise AssertionError(
393
                "This test do not have op_type in class attrs, "
394 395
                "please set self.__class__.op_type=the_real_op_type manually."
            )
396

J
juncaipeng 已提交
397
        # case in NO_FP64_CHECK_GRAD_CASES and op in NO_FP64_CHECK_GRAD_OP_LIST should be fixed
398 399 400 401 402 403 404 405 406 407 408 409
        if not hasattr(cls, "no_need_check_grad") and not is_empty_grad_op(
            cls.op_type
        ):
            if cls.dtype is None or (
                cls.dtype == np.float16
                and cls.op_type
                not in op_accuracy_white_list.NO_FP16_CHECK_GRAD_OP_LIST
                and not hasattr(cls, "exist_check_grad")
            ):
                raise AssertionError(
                    "This test of %s op needs check_grad." % cls.op_type
                )
J
juncaipeng 已提交
410

411
            # check for op test with fp64 precision, but not check mkldnn op test for now
412 413 414 415 416 417 418 419 420 421 422 423
            if (
                cls.dtype in [np.float32, np.float64]
                and cls.op_type
                not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST
                and not hasattr(cls, 'exist_fp64_check_grad')
                and not is_xpu_op_test()
                and not is_mkldnn_op_test()
                and not is_rocm_op_test()
                and not is_npu_op_test()
                and not is_mlu_op_test()
                and not is_custom_device_op_test()
            ):
J
juncaipeng 已提交
424
                raise AssertionError(
425 426 427 428 429 430 431 432 433
                    "This test of %s op needs check_grad with fp64 precision."
                    % cls.op_type
                )

            if (
                not cls.input_shape_is_large
                and cls.op_type
                not in check_shape_white_list.NEED_TO_FIX_OP_LIST
            ):
434
                raise AssertionError(
435 436 437 438
                    "Input's shape should be large than or equal to 100 for "
                    + cls.op_type
                    + " Op."
                )
439

440 441 442 443 444
    def try_call_once(self, data_type):
        if not self.call_once:
            self.call_once = True
            self.dtype = data_type

445
    def is_bfloat16_op(self):
Y
Yiqun Liu 已提交
446 447
        # self.dtype is the dtype of inputs, and is set in infer_dtype_from_inputs_outputs.
        # Make sure this function is called after calling infer_dtype_from_inputs_outputs.
448 449 450 451 452 453
        return (
            self.dtype == np.uint16
            or (
                hasattr(self, 'output_dtype') and self.output_dtype == np.uint16
            )
            or (
454
                hasattr(self, 'mkldnn_data_type')
455 456 457 458 459 460 461 462
                and getattr(self, 'mkldnn_data_type') == "bfloat16"
            )
            or (
                hasattr(self, 'attrs')
                and 'mkldnn_data_type' in self.attrs
                and self.attrs['mkldnn_data_type'] == 'bfloat16'
            )
        )
Y
Yiqun Liu 已提交
463 464

    def is_mkldnn_op(self):
465
        return (hasattr(self, "use_mkldnn") and self.use_mkldnn) or (
466 467
            hasattr(self, "attrs")
            and "use_mkldnn" in self.attrs
468
            and self.attrs["use_mkldnn"]
469
        )
Y
Yiqun Liu 已提交
470 471

    def is_xpu_op(self):
472
        return (hasattr(self, "use_xpu") and self.use_xpu) or (
473 474
            hasattr(self, "attrs")
            and "use_xpu" in self.attrs
475
            and self.attrs["use_xpu"]
476
        )
477

478
    # set the self.output_dtype .
479
    def infer_dtype_from_inputs_outputs(self, inputs, outputs):
J
juncaipeng 已提交
480 481 482 483
        def is_np_data(input):
            return isinstance(input, (np.ndarray, np.generic))

        def infer_dtype(numpy_dict, dtype_set):
484
            assert isinstance(
485 486
                numpy_dict, dict
            ), "self.inputs, self.outputs must be numpy_dict"
J
juncaipeng 已提交
487 488 489 490 491 492
            # the inputs are as follows:
            # case 1: inputs = {'X': x}
            # case 2: inputs = {'X': (x, x_lod)}
            # case 3: inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]}
            # case 4: inputs = {'X': [("x1", (x1, [x1_lod1])), ("x2", (x2, [x2_.lod2]))]}
            # TODO(juncaipeng) infer dtype from inputs maybe obtain wrong type.
493
            for _, var_value in numpy_dict.items():
J
juncaipeng 已提交
494 495 496 497 498 499 500
                if is_np_data(var_value):  # case 1
                    dtype_set.add(var_value.dtype)
                elif isinstance(var_value, (list, tuple)):  # case 2, 3, 4
                    for sub_val_value in var_value:
                        if is_np_data(sub_val_value):  # case 2
                            dtype_set.add(sub_val_value.dtype)
                        elif len(sub_val_value) > 1 and is_np_data(
501 502
                            sub_val_value[1]
                        ):  # case 3
J
juncaipeng 已提交
503
                            dtype_set.add(sub_val_value[1].dtype)
504 505 506 507 508
                        elif (
                            len(sub_val_value) > 1
                            and isinstance(sub_val_value[1], (list, tuple))
                            and is_np_data(sub_val_value[1][0])
                        ):  # case 4
J
juncaipeng 已提交
509 510 511 512
                            dtype_set.add(sub_val_value[1][0].dtype)

        # infer dtype from inputs, and dtype means the precision of the test
        # collect dtype of all inputs
Y
Yiqun Liu 已提交
513 514
        input_dtype_set = set()
        infer_dtype(inputs, input_dtype_set)
J
juncaipeng 已提交
515
        dtype_list = [
516 517 518 519 520 521 522 523 524
            np.dtype(np.float64),
            np.dtype(np.float32),
            np.dtype(np.float16),
            np.dtype(np.int64),
            np.dtype(np.int32),
            np.dtype(np.uint16),
            np.dtype(np.int16),
            np.dtype(np.int8),
            np.dtype(np.uint8),
525
            np.dtype(np.bool_),
J
juncaipeng 已提交
526 527 528
        ]
        # check the dtype in dtype_list in order, select the first dtype that in dtype_set
        for dtype in dtype_list:
Y
Yiqun Liu 已提交
529
            if dtype in input_dtype_set:
J
juncaipeng 已提交
530 531
                self.dtype = dtype
                break
Y
Yiqun Liu 已提交
532
        # save input dtype in class attr
533
        self.__class__.dtype = self.dtype
534

Y
Yiqun Liu 已提交
535 536 537 538 539 540 541 542
        # infer dtype of outputs
        output_dtype_set = set()
        infer_dtype(outputs, output_dtype_set)
        for dtype in dtype_list:
            if dtype in output_dtype_set:
                self.output_dtype = dtype
                break

Y
Yang Yang(Tony) 已提交
543 544 545 546 547 548
    def feed_var(self, input_vars, place):
        feed_map = {}
        for var_name in input_vars:
            if isinstance(input_vars[var_name], list):
                for name, np_value in self.inputs[var_name]:
                    tensor = core.LoDTensor()
549
                    if isinstance(np_value, tuple):
550
                        tensor.set(np_value[0], place)
551
                        tensor.set_recursive_sequence_lengths(np_value[1])
552
                    else:
553
                        tensor.set(np_value, place)
Y
Yang Yang(Tony) 已提交
554 555 556 557
                    feed_map[name] = tensor
            else:
                tensor = core.LoDTensor()
                if isinstance(self.inputs[var_name], tuple):
558
                    tensor.set(self.inputs[var_name][0], place)
559
                    tensor.set_recursive_sequence_lengths(
560 561
                        self.inputs[var_name][1]
                    )
Y
Yang Yang(Tony) 已提交
562
                else:
563
                    tensor.set(self.inputs[var_name], place)
Y
Yang Yang(Tony) 已提交
564
                feed_map[var_name] = tensor
565

Y
Yang Yang(Tony) 已提交
566 567
        return feed_map

568
    def _append_ops(self, block):
569 570 571
        self.__class__.op_type = (
            self.op_type
        )  # for ci check, please not delete it for now
Y
Yiqun Liu 已提交
572
        if self.is_mkldnn_op():
573
            self.__class__.use_mkldnn = True
C
cc 已提交
574

Y
Yiqun Liu 已提交
575
        if self.is_xpu_op():
576 577
            self.__class__.use_xpu = True

Y
Yang Yang(Tony) 已提交
578
        op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
579
        "infer datatype from inputs and outputs for this test case"
580 581 582 583 584 585
        if self.is_bfloat16_op():
            self.dtype = np.uint16
            self.__class__.dtype = self.dtype
            self.output_dtype = np.uint16
        else:
            self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
586 587 588 589 590 591
        inputs = append_input_output(
            block, op_proto, self.inputs, True, self.dtype
        )
        outputs = append_input_output(
            block, op_proto, self.outputs, False, self.dtype
        )
P
phlrain 已提交
592 593 594

        if hasattr(self, "cache_name_list"):
            for name in self.cache_name_list:
595 596 597 598 599 600
                inputs[name] = block.create_var(
                    name=name,
                    persistable=True,
                    type=core.VarDesc.VarType.RAW,
                    stop_gradient=True,
                )
P
phlrain 已提交
601

Y
Yang Yang(Tony) 已提交
602 603 604 605
        op = block.append_op(
            type=self.op_type,
            inputs=inputs,
            outputs=outputs,
606 607
            attrs=copy(self.attrs) if hasattr(self, "attrs") else dict(),
        )
C
cc 已提交
608
        # infer variable type and infer shape in compile-time
Q
QI JUN 已提交
609 610
        op.desc.infer_var_type(block.desc)
        op.desc.infer_shape(block.desc)
Y
Yang Yang(Tony) 已提交
611

612 613
        return op

614 615
    def _get_io_vars(self, block, numpy_inputs):
        inputs = {}
616
        for name, value in numpy_inputs.items():
617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635
            if isinstance(value, list):
                var_list = [
                    block.var(sub_name) for sub_name, sub_value in value
                ]
                inputs[name] = var_list
            else:
                inputs[name] = block.var(name)
        return inputs

    def _get_inputs(self, block):
        return self._get_io_vars(block, self.inputs)

    def _get_outputs(self, block):
        return self._get_io_vars(block, self.outputs)

    def calc_output(self, place):
        outs, _ = self._calc_output(place)
        return outs

M
minqiyang 已提交
636 637 638 639
    def _create_var_from_numpy(self, value):
        if isinstance(value, tuple):
            data = value[0]
            lod = value[1]
L
lujun 已提交
640
            v = fluid.dygraph.base.to_variable(value=data)
641
            v.value().get_tensor().set_recursive_sequence_lengths(lod)
M
minqiyang 已提交
642 643
            return v
        else:
L
lujun 已提交
644
            return fluid.dygraph.base.to_variable(value)
M
minqiyang 已提交
645

646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663
    def get_sequence_batch_size_1_input(self, lod=None, shape=None):
        """Get LoD input data whose batch size is 1.
        All sequence related OP unittests should call this function to contain the case of batch size = 1.
        Args:
            lod (list[list of int], optional): Length-based LoD, length of lod[0] should be 1. Default: [[13]].
            shape (list, optional): Shape of input, shape[0] should be equals to lod[0][0]. Default: [13, 23].
        Returns:
            tuple (ndarray, lod) : LoD input data whose batch size is 1.
        """
        if lod is None:
            lod = [[13]]
        if shape is None:
            shape = [13, 23]
        assert len(lod[0]) == 1
        assert lod[0][0] == shape[0]
        x = np.random.uniform(0.1, 1, shape).astype('float32')
        return (x, lod)

664 665 666 667 668 669 670 671
    def lod_has_single_zero(self, lod):
        for i in range(len(lod) - 2):
            if lod[i] != 0 and lod[i + 1] == 0 and lod[i + 2] != 0:
                return True
        return False

    def lod_has_continuous_zero(self, lod):
        for i in range(len(lod) - 3):
672 673 674 675 676 677
            if (
                lod[i] != 0
                and lod[i + 1] == 0
                and lod[i + 2] == 0
                and lod[i + 3] != 0
            ):
678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694
                return True
        return False

    def get_sequence_instance_size_0_input(self, lod=None, shape=None):
        """Get LoD input data whose instance size is 0.
        All sequence related OP unittests should call this function to contain the case of instance size is 0.
        Args:
            lod (list[list of int], optional): Length-based LoD, lod[0]'s size must at least eight, lod[0] must at least two zeros at the beginning and at least two zeros at the end, the middle position of lod[0] contains a single zero and multiple zero. Default: [[0, 0, 4, 0, 3, 0, 0, 5, 0, 0]].
            shape (list, optional): Shape of input, shape[0] should be equals to lod[0][0]. Default: [13, 23].
        Returns:
            tuple (ndarray, lod): LoD input data whose instance size is 0.
        """
        if lod is None:
            lod = [[0, 0, 4, 0, 3, 0, 0, 5, 0, 0]]
        if shape is None:
            shape = [12, 10]
        assert len(lod[0]) >= 8
695 696 697 698 699 700
        assert (
            lod[0][0] == 0
            and lod[0][1] == 0
            and lod[0][-1] == 0
            and lod[0][-2] == 0
        )
701 702 703 704 705 706 707
        assert self.lod_has_single_zero(lod[0]) is True
        assert self.lod_has_continuous_zero(lod[0]) is True
        assert sum(lod[0]) == shape[0]

        x = np.random.uniform(0.1, 1, shape).astype('float32')
        return (x, lod)

708 709 710
    def append_input_output_for_dygraph(
        self, op_proto, np_list, is_input, if_return_inputs_grad_dict, block
    ):
711 712 713 714 715 716 717 718 719 720 721
        def create_var(np_value, name, is_input, if_return_inputs_grad_dict):
            np_value_temp = np_value
            has_lod = False
            lod_temp = None
            if isinstance(np_value, tuple):
                np_value_temp = np_value[0]
                has_lod = True
                lod_temp = np_value[1]

            if is_input:
                v = self._create_var_from_numpy(np_value_temp)
722

723 724
                if if_return_inputs_grad_dict:
                    v.stop_gradient = False
J
Jiabin Yang 已提交
725
                    if not _in_legacy_dygraph():
726 727
                        v.retain_grads()

728
                if has_lod:
729
                    v.value().get_tensor().set_recursive_sequence_lengths(
730 731
                        lod_temp
                    )
732
            else:
733 734 735 736 737 738 739
                v = block.create_var(
                    name=name,
                    dtype=np_value_temp.dtype,
                    type=core.VarDesc.VarType.LOD_TENSOR,
                    persistable=False,
                    stop_gradient=False,
                )
740 741 742 743 744 745 746 747 748 749 750 751 752
            return v

        # prepare variable for input or output
        var_dict = defaultdict(list)
        if if_return_inputs_grad_dict:
            inputs_grad_dict = defaultdict()
        proto_list = op_proto.inputs if is_input else op_proto.outputs
        for var_proto in proto_list:
            name = var_proto.name
            if (name not in np_list) and var_proto.dispensable:
                continue
            if name not in np_list:
                assert var_proto.intermediate, "{} not found".format(name)
753 754 755
                v = block.create_var(
                    dtype='float32', type=core.VarDesc.VarType.LOD_TENSOR
                )
756 757 758 759 760 761
                var_dict[name].append(v)
                if if_return_inputs_grad_dict:
                    inputs_grad_dict[name] = v
                continue
            if var_proto.duplicable:
                assert isinstance(
762 763
                    np_list[name], list
                ), "Duplicable {} should be set as list".format(name)
764 765 766
                var_list = []
                slot_name = name
                for (name, np_value) in np_list[name]:
767 768 769
                    v = create_var(
                        np_value, name, is_input, if_return_inputs_grad_dict
                    )
770 771 772 773 774 775 776 777 778 779 780 781 782
                    var_list.append(v)
                    if if_return_inputs_grad_dict:
                        inputs_grad_dict[name] = v
                var_dict[slot_name] = var_list
            else:
                nplist_value_temp = None
                name_temp = None
                if isinstance(np_list[name], list):
                    nplist_value_temp = np_list[name][0]
                    name_temp = name
                else:
                    nplist_value_temp = np_list[name]
                    name_temp = unique_name.generate("%s_out" % (name))
783 784 785 786 787 788
                v = create_var(
                    nplist_value_temp,
                    name_temp,
                    is_input,
                    if_return_inputs_grad_dict,
                )
789 790 791 792 793 794 795 796 797
                var_dict[name].append(v)
                if if_return_inputs_grad_dict:
                    inputs_grad_dict[name] = v

        if if_return_inputs_grad_dict:
            return var_dict, inputs_grad_dict
        else:
            return var_dict

798
    def _check_api_outs_by_dygraph_outs(self, api_outs, dygraph_outs, place):
799 800 801 802
        """for quick verify, here we take a simplest strategy:
        1. we only check variable in api_outs.
        2. we simply check the numpy (tensor) .
        3. we set atol and rtol as 1e-5, because they are unrelated to dtype.
803 804 805 806
        """
        for name in api_outs:
            np_api = np.array(api_outs[name])
            np_dyg = np.array(dygraph_outs[name])
807 808 809 810 811
            np.testing.assert_allclose(
                np_api,
                np_dyg,
                rtol=1e-05,
                equal_nan=False,
812 813 814 815 816 817 818 819 820 821 822 823
                err_msg='Output ('
                + name
                + ') has diff at '
                + str(place)
                + '\nExpect '
                + str(np_dyg)
                + '\n'
                + 'But Got'
                + str(np_api)
                + ' in class '
                + self.__class__.__name__,
            )
824

825
    def _calc_python_api_output(self, place, egr_inps=None, egr_oups=None):
826
        """set egr_inps and egr_oups = None if you want to create it by yourself."""
827

828 829 830 831
        def prepare_python_api_arguments(
            api, op_proto_ins, op_proto_attrs, kernel_sig
        ):
            """map from `op proto inputs and attrs` to `api input list and api attrs dict`
832

833
            NOTE: the op_proto_attrs and op_proto_ins is a default dict. default value is []
834
            """
835 836 837 838 839 840 841

            class Empty:
                pass

            def is_empty(a):
                return isinstance(a, Empty)

842
            def get_default(idx, defaults):
843 844 845
                assert not isinstance(defaults[idx], Empty), (
                    "%d-th params of python api don't have default value." % idx
                )
846
                return defaults[idx]
847 848 849 850

            def to_defaults_list(params, defaults):
                return [defaults[p] for p in params if p in defaults]

851
            def parse_attri_value(name, op_inputs, op_attrs):
852 853 854 855
                """parse true value from inputs and attrs, if there is no name passed by OpTest, return Empty
                1. if the name in op_attrs, use the op_attrs[name]
                2. if the name in op_inputs, convert the op_inputs to [type of default value]
                3. if the name not in op_attrs ans op_inputs, return Empty. (this will use the default value from python api)
856 857 858 859
                """
                if name in op_proto_attrs:
                    return op_proto_attrs[name]
                elif name in op_inputs:
X
xiongkun 已提交
860 861
                    if len(op_inputs[name]) == 1:
                        # why don't use numpy().item() : if the Tensor is float64, we will change it to python.float32, where we loss accuracy: [allclose_op]
862
                        # why we reconstruct a tensor: because we want the tensor in cpu.
863 864 865
                        return paddle.to_tensor(
                            op_inputs[name][0].numpy(), place='cpu'
                        )
X
xiongkun 已提交
866 867 868
                    else:
                        # if this is a list (test_unsqueeze2_op): we just pass it into the python api.
                        return op_inputs[name]
869 870 871
                else:
                    return Empty()

872 873 874
            # NOTE(xiongkun): the logic of constructing parameters:
            # for example:
            #    python api: cumprod(x, dim, dtype=None, name=None)
875 876 877 878 879 880 881
            #    kernel sig: [["x"], ["dim"], ["out"]]"
            #
            # we will construct a lot of list with the same length : len == len(api_params), here is 4
            #    api_params = ["x", "dim", "dtype", "name"]
            #    api_defaults = [Empty, Empty, None, None]; empty means no defaults.
            #    inputs_and_attrs = ["x", "dim"] , the length may shorter or longer than api_params
            #    input_arguments = [RealValue in self.inputs and self.attrs]
882
            # then ,we will loop for the api_params, construct a result list:
883 884 885 886
            #    if the name in ['name', 'dtype', 'out', 'output'], we will use the default value
            #    else, we will consume a input_arguments. (because the name is not corresponding, so we only use the order)

            api_params, api_defaults = parse_arg_and_kwargs(api)
887
            api_defaults = to_defaults_list(api_params, api_defaults)
888 889 890 891
            api_defaults = [
                Empty() for i in range(len(api_params) - len(api_defaults))
            ] + api_defaults
            assert len(api_defaults) == len(
892 893
                api_params
            ), "Error happens. contack xiongkun03 to solve."
894
            inputs_sig, attrs_sig, outputs_sig = kernel_sig
895
            inputs_and_attrs = inputs_sig + attrs_sig
Z
zyfncg 已提交
896 897 898
            input_arguments = [
                op_proto_ins.get(name, Empty()) for name in inputs_sig
            ] + [
899
                parse_attri_value(name, op_proto_ins, op_proto_attrs)
900 901 902
                for name in attrs_sig
            ]
            results = []
903 904 905 906 907
            api_ignore_param_list = set(['name', 'dtype', 'out', 'output'])
            idx_of_op_proto_arguments = 0
            for idx, arg_name in enumerate(api_params):
                if arg_name in api_ignore_param_list:
                    results.append(get_default(idx, api_defaults))
908
                else:
909
                    if idx_of_op_proto_arguments < len(input_arguments):
910 911 912 913 914
                        tmp = input_arguments[idx_of_op_proto_arguments]
                        idx_of_op_proto_arguments += 1
                    else:
                        tmp = Empty()  # use the default value

915 916 917 918 919
                    if isinstance(tmp, Empty):
                        results.append(get_default(idx, api_defaults))
                    else:
                        results.append(tmp)
            assert len(results) == len(api_params)
920
            return results
921 922

        def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
X
xiongkun 已提交
923 924
            if hasattr(self, "python_out_sig"):
                output_sig = self.python_out_sig
925 926
            if not isinstance(ret_tuple, (tuple, list)):
                ret_tuple = [ret_tuple]
927 928 929 930 931
            if len(output_sig) == len(ret_tuple):
                # [assumption]: we assume {"Out": [Tensor]}
                return {a: [b] for a, b in zip(output_sig, ret_tuple)}
            else:
                # [assumption]: return multi-Tensor in a single output. such as paddle.split()
932 933 934
                assert (
                    len(output_sig) == 1
                ), "Don't support multi-output with multi-tensor output. (May be you can use set `python_out_sig`, see `test_squeeze2_op` as a example.)"
935
                return {output_sig[0]: ret_tuple}
936

937
        def assumption_assert_and_transform(args, inp_num):
938
            """
939
            transform inputs by the following rules:
940 941
                1. [Tensor] -> Tensor
                2. [Tensor, Tensor, ...] -> list of Tensors
Z
zyfncg 已提交
942 943
                3. None -> None
                4. Others: raise Error
944 945

            only support "X" is list of Tensor, currently don't support other structure like dict.
946
            """
947 948 949
            inp_args = [
                [inp] if inp is None else inp for inp in args[:inp_num]
            ]  # convert None -> [None]
Z
zyfncg 已提交
950
            for inp in inp_args:
951 952 953
                assert isinstance(
                    inp, list
                ), "currently only support `X` is [Tensor], don't support other structure."
954 955 956
            args = [
                inp[0] if len(inp) == 1 else inp for inp in inp_args
            ] + args[inp_num:]
957
            return args
958

959 960 961
        def _get_kernel_signature(
            eager_tensor_inputs, eager_tensor_outputs, attrs_outputs
        ):
962 963
            try:
                kernel_sig = _dygraph_tracer()._get_kernel_signature(
964 965 966 967 968
                    self.op_type,
                    eager_tensor_inputs,
                    eager_tensor_outputs,
                    attrs_outputs,
                )
969
            except RuntimeError as re:
970
                """we think the kernel_sig is missing."""
971
                kernel_sig = None
X
xiongkun 已提交
972 973
                print(
                    "[Warning: op_test.py] Kernel Signature is not found for %s, fall back to intermediate state."
974 975
                    % self.op_type
                )
976 977
            return kernel_sig

978
        def cal_python_api(python_api, args, kernel_sig):
979
            inputs_sig, attrs_sig, outputs_sig = kernel_sig
980 981
            args = assumption_assert_and_transform(args, len(inputs_sig))
            ret_tuple = python_api(*args)
982 983 984 985 986 987
            return construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig)

        with fluid.dygraph.base.guard(place=place):
            block = fluid.default_main_program().global_block()
            op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
            # prepare input variable
988 989 990 991 992 993 994
            eager_tensor_inputs = (
                egr_inps
                if egr_inps
                else self.append_input_output_for_dygraph(
                    op_proto, self.inputs, True, False, block
                )
            )
995
            # prepare output variable
996 997 998 999 1000 1001 1002
            eager_tensor_outputs = (
                egr_oups
                if egr_oups
                else self.append_input_output_for_dygraph(
                    op_proto, self.outputs, False, False, block
                )
            )
1003

1004
            # prepare attributes
1005 1006 1007 1008 1009 1010
            attrs_outputs = {}
            if hasattr(self, "attrs"):
                for attrs_name in self.attrs:
                    if self.attrs[attrs_name] is not None:
                        attrs_outputs[attrs_name] = self.attrs[attrs_name]

1011 1012 1013
            kernel_sig = _get_kernel_signature(
                eager_tensor_inputs, eager_tensor_outputs, attrs_outputs
            )
1014 1015
            if not kernel_sig:
                return None
1016 1017 1018 1019 1020 1021 1022
            assert hasattr(self, "python_api"), (
                "Detect there is KernelSignature for `%s` op, please set the `self.python_api` if you set check_eager = True"
                % self.op_type
            )
            args = prepare_python_api_arguments(
                self.python_api, eager_tensor_inputs, attrs_outputs, kernel_sig
            )
1023
            """ we directly return the cal_python_api value because the value is already tensor.
1024
            """
1025
            return cal_python_api(self.python_api, args, kernel_sig)
1026

L
lujun 已提交
1027
    def _calc_dygraph_output(self, place, parallel=False, no_check_set=None):
1028 1029 1030
        self.__class__.op_type = (
            self.op_type
        )  # for ci check, please not delete it for now
L
lujun 已提交
1031
        with fluid.dygraph.base.guard(place=place):
M
minqiyang 已提交
1032 1033
            block = fluid.default_main_program().global_block()

1034
            op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)
M
minqiyang 已提交
1035

1036
            # prepare input variable
1037
            inputs = self.append_input_output_for_dygraph(
1038 1039
                op_proto, self.inputs, True, False, block
            )
M
minqiyang 已提交
1040
            # prepare output variable
1041
            outputs = self.append_input_output_for_dygraph(
1042 1043
                op_proto, self.outputs, False, False, block
            )
1044

1045
            # prepare attributes
1046 1047 1048 1049 1050
            attrs_outputs = {}
            if hasattr(self, "attrs"):
                for attrs_name in self.attrs:
                    if self.attrs[attrs_name] is not None:
                        attrs_outputs[attrs_name] = self.attrs[attrs_name]
1051

M
minqiyang 已提交
1052 1053 1054 1055
            block.append_op(
                type=self.op_type,
                inputs=inputs,
                outputs=outputs,
1056 1057
                attrs=attrs_outputs if hasattr(self, "attrs") else None,
            )
M
minqiyang 已提交
1058
            return outputs
1059

1060 1061 1062 1063 1064 1065 1066 1067 1068
    def _calc_output(
        self,
        place,
        parallel=False,
        no_check_set=None,
        loss=None,
        enable_inplace=None,
        for_inplace_test=None,
    ):
1069 1070
        program = Program()
        block = program.global_block()
1071
        op = self._append_ops(block)
1072 1073 1074 1075 1076

        inputs = self._get_inputs(block)
        outputs = self._get_outputs(block)
        feed_map = self.feed_var(inputs, place)

1077
        if for_inplace_test:
C
cc 已提交
1078 1079
            # Some variables' tensors hold no buffer (tensor's _holder is NULL), like XShape in reshape2 op,
            # and the shapes of those variables contain 0 (eg. Xshape.shape = [0, 2, 5]).
1080 1081
            # Set persistable for those variables in order to get them from global_scope for inplace grad test directly other than feed them,
            # since feed op calls check_memory_size() which fails when tensor's holder_ is NULL.
1082 1083
            for out_name in op.output_arg_names:
                var = block.var(out_name)
1084 1085
                if 0 in var.shape:
                    var.persistable = True
1086
        original_program = program
1087 1088
        if parallel:
            use_cuda = False
1089
            if isinstance(place, fluid.CUDAPlace):
1090
                use_cuda = True
1091
            compiled_prog = fluid.CompiledProgram(program).with_data_parallel(
1092 1093
                loss_name=loss.name if loss else None, places=place
            )
1094
            program = compiled_prog
1095 1096 1097 1098
        fetch_list = getattr(self, "fetch_list", [])
        # if the fetch_list is customized by user, we use it directly.
        # if not, fill the fetch_list by the user configured outputs in test.
        if len(fetch_list) == 0:
1099
            for var_name, var in outputs.items():
1100 1101
                if no_check_set is not None and var_name in no_check_set:
                    continue
Y
Yang Yang(Tony) 已提交
1102 1103
                if isinstance(var, list):
                    for v in var:
1104
                        fetch_list.append(v.name)
Y
Yang Yang(Tony) 已提交
1105
                else:
1106
                    fetch_list.append(var.name)
1107 1108 1109 1110
        # if the fetch_list still empty, fill the fetch_list by the operator output.
        if len(fetch_list) == 0:
            for out_name, out_dup in Operator.get_op_outputs(self.op_type):
                fetch_list.append(str(out_name))
1111 1112 1113 1114 1115 1116

        if enable_inplace is not None:
            build_strategy = fluid.BuildStrategy()
            build_strategy.enable_inplace = enable_inplace

            compiled_prog = fluid.CompiledProgram(program).with_data_parallel(
1117 1118
                build_strategy=build_strategy, places=place
            )
1119 1120
            program = compiled_prog

1121
        executor = Executor(place)
1122 1123 1124
        outs = executor.run(
            program, feed=feed_map, fetch_list=fetch_list, return_numpy=False
        )
1125 1126
        self.op = op
        self.program = original_program
1127 1128 1129 1130
        if for_inplace_test:
            return outs, fetch_list, feed_map, original_program, op.desc
        else:
            return outs, fetch_list
Y
Yang Yang(Tony) 已提交
1131

1132 1133 1134
    def _compare_expect_and_actual_outputs(
        self, place, fetch_list, expect_outs, actual_outs, inplace_atol=None
    ):
1135 1136 1137
        """Compare expect outs and actual outs of an tested op.

        Args:
C
cc 已提交
1138
            place (CPUPlace | CUDAPlace): The place where the op runs.
1139 1140 1141 1142 1143 1144 1145 1146 1147 1148
            fetch_list (list): The outputs of tested op.
            expect_outs (list): The expect outs of tested op.
            actual_outs (list): The actual outs of tested op.
            inplace_atol (float): The tolerable error, only set when tested op doesn't ensure computational consistency, like group_norm op.

        Returns:
            None.
        """
        # compare expect_outs and actual_outs
        for i, name in enumerate(fetch_list):
C
cc 已提交
1149
            # Note(zhiqiu): inplace_atol should be only set when op doesn't ensure
L
Leo Chen 已提交
1150 1151 1152
            # computational consistency.
            # When inplace_atol is not None, the inplace check uses numpy.allclose
            # to check inplace result instead of numpy.array_equal.
1153 1154
            expect_out = np.array(expect_outs[i])
            actual_out = np.array(actual_outs[i])
1155
            if inplace_atol is not None:
1156 1157 1158 1159 1160
                np.testing.assert_allclose(
                    expect_out,
                    actual_out,
                    rtol=1e-05,
                    atol=inplace_atol,
1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173
                    err_msg='Output ('
                    + name
                    + ') has diff at '
                    + str(place)
                    + ' when using and not using inplace'
                    + '\nExpect '
                    + str(expect_out)
                    + '\n'
                    + 'But Got'
                    + str(actual_out)
                    + ' in class '
                    + self.__class__.__name__,
                )
1174
            else:
1175 1176 1177
                np.testing.assert_array_equal(
                    expect_out,
                    actual_out,
1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195
                    err_msg='Output ('
                    + name
                    + ') has diff at '
                    + str(place)
                    + ' when using and not using inplace'
                    + '\nExpect '
                    + str(expect_out)
                    + '\n'
                    + 'But Got'
                    + str(actual_out)
                    + ' in class '
                    + self.__class__.__name__
                    + '\n',
                )

    def _construct_grad_program_from_forward(
        self, fwd_program, grad_op_desc, op_grad_to_var
    ):
1196 1197 1198 1199 1200
        """Generate grad_program which contains the grad_op.

        Args:
            fwd_program (tuple): The program that contains grad_op_desc's corresponding forward op.
            grad_op_desc (OpDesc): The OpDesc of grad op.
C
cc 已提交
1201
            op_grad_to_var (dict): The relation of variables in grad op and its forward op.
1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212

        Returns:
            grad_program (program): The program which contains the grad_op.
        """
        grad_program = Program()
        grad_block = grad_program.global_block()
        new_op_desc = grad_block.desc.append_op()
        new_op_desc.copy_from(grad_op_desc)
        grad_program._sync_with_cpp()

        # Create grad vars based on fwd vars (shape and dtype)
1213 1214 1215
        for arg in (
            grad_op_desc.input_arg_names() + grad_op_desc.output_arg_names()
        ):
1216 1217 1218 1219 1220
            fwd_var_name = op_grad_to_var.get(arg, None)
            if fwd_var_name is None:
                fwd_var_name = arg
            fwd_var = fwd_program.global_block().vars.get(fwd_var_name)
            assert fwd_var is not None, "{} cannot be found".format(
1221 1222 1223 1224 1225 1226 1227 1228 1229
                fwd_var_name
            )
            grad_var = grad_block.create_var(
                name=arg,
                dtype=fwd_var.dtype,
                shape=fwd_var.shape,
                type=fwd_var.type,
                persistable=False,
            )
1230

C
cc 已提交
1231 1232
            # Some variables' tensors hold no buffer (tensor's _holder is NULL), like XShape in reshape2 op,
            # and the shapes of those variables contain 0 (eg. Xshape.shape = [0, 2, 5]).
1233 1234 1235 1236 1237 1238 1239
            # Set persistable for those variables in order to get them from global_scope for inplace grad test directly other than feed them,
            # since feed op calls check_memory_size() which fails when tensor's holder_ is NULL.
            if 0 in grad_var.shape:
                grad_var.persistable = True
        grad_program._sync_with_cpp()
        return grad_program

1240 1241 1242
    def _construct_grad_feed_map_from_forward(
        self, place, fwd_res, grad_op_desc, op_grad_to_var
    ):
1243 1244 1245 1246 1247 1248
        """Generate grad_feed_map for grad_program.

        since we don`t really check gradient accuracy, but check the consistency when using and not using inplace,
        we use fwd outs (also inputs sometimes) to construct grad inputs.

        Args:
C
cc 已提交
1249
            place (CPUPlace | CUDAPlace): The place where the op runs.
1250 1251 1252
            fwd_res (tuple): The outputs of its forward op, in the same form as returns of _calc_outputs() when for_inplace_test is True.
                i.e., tuple(fwd_outs, fwd_fetch_list, fwd_feed_map, fwd_program, fwd_op_desc)
            grad_op_desc (OpDesc): The OpDesc of grad op.
C
cc 已提交
1253
            op_grad_to_var (dict): The relation of variables in grad op and its fwd_op.
1254 1255 1256 1257

        Returns:
            grad_feed_map (dict): The feed_map of grad_op.
        """
1258 1259 1260 1261 1262 1263 1264
        (
            fwd_outs,
            fwd_fetch_list,
            fwd_feed_map,
            fwd_program,
            fwd_op_desc,
        ) = fwd_res
1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283
        p = core.Place()
        p.set_place(place)
        grad_feed_map = {}
        for arg in grad_op_desc.input_arg_names():
            if arg in fwd_feed_map.keys():
                grad_feed_map[arg] = fwd_feed_map[arg]._copy(p)
            else:
                fwd_var_name = op_grad_to_var.get(arg, None)
                if fwd_var_name is None:
                    fwd_var_name = arg

                for i, out_name in enumerate(fwd_fetch_list):
                    if out_name == fwd_var_name:
                        # don't feed variables whose tensors hold no buffer (shape contains 0 like shape = [0,2,5] and holder_ is NULL), like XShape in reshape2 op.
                        # get them from global_scope directly since we have set them persistable in fwd execution
                        if 0 in fwd_program.global_block().var(out_name).shape:
                            continue
                        else:
                            grad_feed_map[arg] = fwd_outs[i]._copy(p)
1284

1285 1286 1287 1288 1289 1290 1291
        return grad_feed_map

    def _get_need_run_ops(self, op_desc, fwd_op_desc=None):
        """Postorder traversal of the 'grad' tree to get all ops that need to run during inplace test.
        An op needs to run druing inplace check if,
        (1) it has infer_inplace,
        (2) it has infer_inplace in its grad descendants. (since we need its outputs as to construct its grad's inputs)
C
cc 已提交
1292

1293
        Args:
C
cc 已提交
1294 1295
            op_desc (OpDesc): The op_desc of current op.
            fwd_op_desc (OpDesc): The op_desc of current op's forward op, None if current op has no forward op.
1296
                Eg. relu's fwd_op is None, relu_grad's fwd_op is relu, relu_grad_grad's fwd_op is relu_grad, etc.
C
cc 已提交
1297

1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311
        Returns:
            need_run_ops (list[(op_desc, fwd_op_desc)]): The ops that need to run during inplace test.
        """
        need_run_ops = []
        visited_ops = []

        def _dfs_grad_op(op_desc, fwd_op_desc=None):
            visited_ops.append(op_desc.type())
            has_infer_inplace = fluid.core.has_infer_inplace(op_desc.type())
            has_grad_op_maker = fluid.core.has_grad_op_maker(op_desc.type())
            has_infer_inplace_in_grad_descendants = False
            if not has_grad_op_maker:
                has_infer_inplace_in_descendants = False
            else:
C
cc 已提交
1312
                # get grad_op_desc
1313
                grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
1314 1315
                    op_desc, set(), []
                )
1316 1317 1318 1319
                if not grad_op_desc_list:
                    has_infer_inplace_in_grad_descendants = False
                else:
                    for i, grad_op_desc in enumerate(grad_op_desc_list):
1320 1321 1322 1323
                        if (
                            grad_op_desc.type() not in visited_ops
                            and _dfs_grad_op(grad_op_desc, fwd_op_desc=op_desc)
                        ):
1324 1325 1326 1327 1328 1329 1330 1331 1332 1333
                            has_infer_inplace_in_grad_descendants = True
            if has_infer_inplace or has_infer_inplace_in_grad_descendants:
                need_run_ops.append((op_desc, fwd_op_desc))
                return True
            else:
                return False

        _dfs_grad_op(op_desc, fwd_op_desc=fwd_op_desc)
        return need_run_ops

1334 1335 1336
    def _check_forward_inplace(
        self, place, no_check_set=None, inplace_atol=None
    ):
1337
        """Check the inplace correctness of given op (self.op_type).
1338
        Run the op twice with same inputs, one enable inplace and another disable, compare their outputs.
C
cc 已提交
1339

1340
        Args:
C
cc 已提交
1341
            place (CPUPlace | CUDAPlace): The place where the op runs.
1342 1343 1344 1345
            no_check_set (list): The names of outputs that needn't check, like XShape of reshape op.
            inplace_atol (float): The tolerable error, only set when op doesn't ensure computational consistency, like group_norm op.

        Returns:
C
cc 已提交
1346 1347
            expect_res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given op.
                We return this to construct grad_program and grad_feed_map for grad inplace check.
1348 1349
        """
        # _calc_output() returns in the form tuple(outs, fetch_list, feed_map, program, op_desc) when for_inplace_test=True.
1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361
        expect_res = self._calc_output(
            place,
            no_check_set=no_check_set,
            enable_inplace=False,
            for_inplace_test=True,
        )
        actual_res = self._calc_output(
            place,
            no_check_set=no_check_set,
            enable_inplace=True,
            for_inplace_test=True,
        )
1362
        # compare expect_outs and actual_outs
1363 1364 1365 1366 1367 1368 1369
        self._compare_expect_and_actual_outputs(
            place,
            expect_res[1],
            expect_res[0],
            actual_res[0],
            inplace_atol=inplace_atol,
        )
1370 1371
        return expect_res

1372 1373 1374
    def _calc_grad_output(
        self, place, fwd_res, grad_op_desc, enable_inplace=None
    ):
1375 1376 1377 1378 1379 1380
        """Calculate grad_output for given grad_op_desc.

        since we don`t really check gradient accuracy, but check the consistency when using and not using inplace,
        we use fwd outs (also inputs sometimes) to construct grad inputs.

        Args:
C
cc 已提交
1381
            place (CPUPlace | CUDAPlace): The place where the op runs.
1382 1383 1384 1385 1386 1387 1388 1389
            fwd_res (tuple): The outputs of its forward op, in the same form as returns of _calc_outputs() when for_inplace_test is True.
                i.e., tuple(fwd_outs, fwd_fetch_list, fwd_feed_map, fwd_program, fwd_op_desc).
            grad_op_desc (OpDesc): The OpDesc of grad op.
            enable_inplace (bool): Enable inplace or not.

        Returns:
            res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc.
        """
1390 1391 1392 1393 1394 1395 1396
        (
            fwd_outs,
            fwd_fetch_list,
            fwd_feed_map,
            fwd_program,
            fwd_op_desc,
        ) = fwd_res
1397
        grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
1398 1399
            fwd_op_desc, set(), []
        )
1400
        grad_program = self._construct_grad_program_from_forward(
1401 1402
            fwd_program, grad_op_desc, op_grad_to_var
        )
1403
        grad_feed_map = self._construct_grad_feed_map_from_forward(
1404 1405
            place, fwd_res, grad_op_desc, op_grad_to_var
        )
1406 1407 1408 1409 1410 1411 1412
        grad_fetch_list = grad_op_desc.output_arg_names()
        exe = Executor(place)
        program = grad_program
        if enable_inplace is not None:
            build_strategy = fluid.BuildStrategy()
            build_strategy.enable_inplace = enable_inplace
            compiled_program = fluid.CompiledProgram(
1413 1414 1415 1416
                grad_program
            ).with_data_parallel(
                loss_name="", build_strategy=build_strategy, places=place
            )
1417
            program = compiled_program
1418

1419 1420 1421 1422 1423 1424
        outs = exe.run(
            program,
            feed=grad_feed_map,
            fetch_list=grad_fetch_list,
            return_numpy=False,
        )
1425 1426
        return outs, grad_fetch_list, grad_feed_map, grad_program, grad_op_desc

1427 1428 1429
    def _check_grad_inplace(
        self, place, fwd_res, grad_op_desc, inplace_atol=None
    ):
1430
        """Check the inplace correctness of given grad_op_desc.
1431 1432 1433 1434 1435 1436

        Run the grad op twice with same inputs, one enable inplace and another disable, compare their outputs.
        It works like _check_forward_inplace, but the way to construct program and feed_map differs.
        So we define a new function for grad, grad_grad, etc.

        Args:
C
cc 已提交
1437
            place (CPUPlace | CUDAPlace): The place where the op runs.
1438 1439 1440 1441 1442 1443
            fwd_res (tuple): The outputs of its forward op, in the same form as returns of _calc_outputs() when for_inplace_test is True.
                i.e., tuple(fwd_outs, fwd_fetch_list, fwd_feed_map, fwd_program, fwd_op_desc).
            grad_op_desc (OpDesc): The OpDesc of grad op.
            inplace_atol (float): The tolerable error, only set when op doesn't ensure computational consistency, like group_norm op.

        Returns:
C
cc 已提交
1444 1445
            expect_res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given op.
                We return this to construct grad_program and grad_feed_map for grad inplace check.
1446
        """
1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460
        expect_res = self._calc_grad_output(
            place, fwd_res, grad_op_desc, enable_inplace=False
        )
        actual_res = self._calc_grad_output(
            place, fwd_res, grad_op_desc, enable_inplace=True
        )

        self._compare_expect_and_actual_outputs(
            place,
            expect_res[1],
            expect_res[0],
            actual_res[0],
            inplace_atol=inplace_atol,
        )
1461
        return expect_res
1462

1463 1464 1465
    def check_inplace_output_with_place(
        self, place, no_check_set=None, inplace_atol=None
    ):
1466 1467 1468 1469 1470 1471
        """Chech the inplace correctness of given op, its grad op, its grad_grad op, etc.

        (1) Get all ops need to run. (see conditions in _get_need_run_ops())
        (2) Run op in need_run_ops, and do inplace check if it has infer_inplace.

        Args:
C
cc 已提交
1472
            place (CPUPlace | CUDAPlace): The place where the op runs.
1473 1474 1475 1476 1477 1478
            no_check_set (list): The names of outputs that needn't check, like XShape of reshape op.
            inplace_atol (float): The tolerable error, only set when op doesn't ensure computational consistency, like group_norm op.

        Returns:
            None
        """
1479 1480 1481
        if getattr(self, "no_need_check_inplace", False):
            return

1482 1483 1484
        has_infer_inplace = fluid.core.has_infer_inplace(self.op_type)
        has_grad_op_maker = fluid.core.has_grad_op_maker(self.op_type)

1485 1486 1487
        fwd_res = self._calc_output(
            place, no_check_set=no_check_set, for_inplace_test=True
        )
1488 1489 1490 1491
        op_desc = fwd_res[4]
        need_run_ops = self._get_need_run_ops(op_desc)

        res = {}
1492 1493
        if hasattr(self, 'attrs') and bool(self.attrs.get('use_xpu', False)):
            return
1494 1495 1496 1497 1498 1499 1500 1501
        for op_desc, father_op_desc in reversed(need_run_ops):
            # The first one is the forward op
            has_infer_inplace = fluid.core.has_infer_inplace(op_desc.type())
            if op_desc.type() == self.op_type:
                if has_infer_inplace:
                    res[op_desc] = self._check_forward_inplace(
                        place,
                        no_check_set=no_check_set,
1502 1503
                        inplace_atol=inplace_atol,
                    )
1504
                else:
1505 1506 1507
                    res[op_desc] = self._calc_output(
                        place, no_check_set=no_check_set, for_inplace_test=True
                    )
1508
            else:
1509 1510
                # TODO(zhiqiu): enhance inplace_grad test for ops (sum and activation) using mkldnn
                # skip op that use_mkldnn currently
1511
                flags_use_mkldnn = fluid.core.globals()["FLAGS_use_mkldnn"]
1512
                attrs_use_mkldnn = hasattr(self, 'attrs') and bool(
1513 1514
                    self.attrs.get('use_mkldnn', False)
                )
1515 1516 1517 1518 1519 1520 1521 1522
                if flags_use_mkldnn or attrs_use_mkldnn:
                    warnings.warn(
                        "check inplace_grad for ops using mkldnn is not supported"
                    )
                    continue
                if has_infer_inplace:
                    fwd_res = res[father_op_desc]
                    res[op_desc] = self._check_grad_inplace(
1523 1524
                        place, fwd_res, op_desc, inplace_atol=inplace_atol
                    )
1525
                else:
1526
                    res[op_desc] = self._calc_grad_output(
1527 1528
                        place, fwd_res, op_desc
                    )
1529

1530 1531 1532 1533 1534 1535 1536 1537 1538 1539
    def check_output_with_place(
        self,
        place,
        atol=0,
        no_check_set=None,
        equal_nan=False,
        check_dygraph=True,
        inplace_atol=None,
        check_eager=False,
    ):
1540

1541
        # disable legacy dygraph check when check_eager is True
1542
        if check_eager:
1543 1544
            check_dygraph = False

1545 1546 1547 1548 1549 1550 1551 1552
        def find_imperative_actual(target_name, dygraph_outs, place):
            for name in dygraph_outs:
                if name == target_name:
                    return dygraph_outs[name][0]
                var_list = dygraph_outs[name]
                for i, var in enumerate(var_list):
                    if var.name == target_name:
                        return dygraph_outs[name][i]
1553
            self.assertTrue(
1554 1555 1556
                False,
                "Found failed {} {}".format(dygraph_outs.keys(), target_name),
            )
1557 1558 1559

        def find_actual(target_name, fetch_list):
            found = [
1560 1561
                i
                for i, var_name in enumerate(fetch_list)
1562 1563 1564
                if var_name == target_name
            ]
            self.assertTrue(
1565 1566
                len(found) == 1, "Found {} {}".format(len(found), target_name)
            )
1567 1568 1569
            return found[0]

        class Checker(object):
1570 1571
            """base class for check with self.outputs.
            currently don't support check between checkers.
1572 1573 1574
            """

            def __init__(self, op_test, expect_dict):
1575 1576
                """expect_dict is the self.outputs
                support : {str: [numpy]} and {str: [(str, numpy), (str, numpy)]}
1577 1578 1579 1580 1581 1582
                """
                self.expects = expect_dict
                self.checker_name = "checker"
                self.op_test = op_test  # stop the op_test object.
                self.op_type = op_test.op_type

1583 1584 1585
            def init(self):
                pass

1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602
            def convert_uint16_to_float(self, actual_np, expect_np):
                raise NotImplementedError("base class, not implement!")

            def calculate_output(self):
                """
                judge whether convert current output and expect to uint16.
                return True | False
                """

            def _is_skip_name(self, name):
                if name not in self.expects:
                    return True
                if no_check_set is not None and name in no_check_set:
                    return True
                return False

            def find_actual_value(self, name):
1603
                """return: (actual_tensor(var_base), actual_numpy)"""
1604 1605 1606 1607 1608 1609 1610 1611 1612
                raise NotImplementedError("base class, not implement!")

            def _compare_numpy(self, name, actual_np, expect_np):
                self.op_test.assertTrue(
                    np.allclose(
                        actual_np,
                        expect_np,
                        atol=atol,
                        rtol=self.rtol if hasattr(self, 'rtol') else 1e-5,
1613 1614 1615 1616 1617 1618 1619 1620 1621
                        equal_nan=equal_nan,
                    ),
                    "Output ("
                    + name
                    + ") has diff at "
                    + str(place)
                    + " in "
                    + self.checker_name,
                )
1622 1623

            def _compare_list(self, name, actual, expect):
1624
                """if expect is a tuple, we need to compare list."""
1625 1626 1627 1628
                raise NotImplementedError("base class, not implement!")

            def compare_single_output_with_expect(self, name, expect):
                actual, actual_np = self.find_actual_value(name)
1629
                expect_np = expect[0] if isinstance(expect, tuple) else expect
1630
                actual_np, expect_np = self.convert_uint16_to_float_ifneed(
1631 1632
                    actual_np, expect_np
                )
1633 1634 1635
                # NOTE(zhiqiu): np.allclose([], [1.]) returns True
                # see details: https://stackoverflow.com/questions/38331703/why-does-numpys-broadcasting-sometimes-allow-comparing-arrays-of-different-leng
                if expect_np.size == 0:
1636
                    self.op_test.assertTrue(actual_np.size == 0)
1637 1638 1639 1640 1641 1642
                self._compare_numpy(name, actual_np, expect_np)
                if isinstance(expect, tuple):
                    self._compare_list(name, actual, expect)

            def compare_outputs_with_expects(self):
                for out_name, out_dup in Operator.get_op_outputs(self.op_type):
1643 1644
                    if self._is_skip_name(out_name):
                        continue
1645 1646 1647 1648
                    if out_dup:
                        # if self.output = {'name': [(subname, Tensor), (subname, Tensor)]}
                        sub_out = self.expects[out_name]
                        if not isinstance(sub_out, list):
1649 1650 1651
                            raise AssertionError(
                                "sub_out type %s is not list", type(sub_out)
                            )
1652 1653
                        for item in sub_out:
                            sub_out_name, expect = item[0], item[1]
1654
                            self.compare_single_output_with_expect(
1655 1656
                                sub_out_name, expect
                            )
1657 1658 1659 1660 1661 1662 1663 1664 1665 1666
                    else:
                        expect = self.expects[out_name]
                        self.compare_single_output_with_expect(out_name, expect)

            def check(self):
                """
                return None means ok, raise Error means failed.

                the main enter point of Checker class
                """
1667
                self.init()
1668 1669 1670 1671
                self.calculate_output()
                self.compare_outputs_with_expects()

        class StaticChecker(Checker):
1672 1673 1674
            def init(self):
                self.checker_name = "static checker"

1675 1676
            def calculate_output(self):
                outs, fetch_list = self.op_test._calc_output(
1677 1678
                    place, no_check_set=no_check_set
                )
1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693
                self.outputs = outs
                self.fetch_list = fetch_list

            def find_actual_value(self, name):
                idx = find_actual(name, self.fetch_list)
                actual = self.outputs[idx]
                actual_t = np.array(actual)
                return actual, actual_t

            def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
                """
                judge whether convert current output and expect to uint16.
                return True | False
                """
                if actual_np.dtype == np.uint16 and expect_np.dtype in [
1694 1695
                    np.float32,
                    np.float64,
1696 1697
                ]:
                    actual_np = convert_uint16_to_float(actual_np)
1698
                    self.rtol = 1.0e-2
1699
                else:
1700 1701 1702 1703 1704
                    self.rtol = 1.0e-5
                if (
                    expect_np.dtype == np.uint16
                    and actual_np.dtype == np.uint16
                ):
1705 1706 1707 1708 1709 1710 1711
                    nonlocal atol
                    expect_np = convert_uint16_to_float(expect_np)
                    actual_np = convert_uint16_to_float(actual_np)
                    atol = max(atol, 0.03)
                return actual_np, expect_np

            def _compare_list(self, name, actual, expect):
1712
                """if expect is a tuple, we need to compare list."""
1713
                self.op_test.assertListEqual(
1714 1715 1716 1717
                    actual.recursive_sequence_lengths(),
                    expect[1],
                    "Output (" + name + ") has different lod at " + str(place),
                )
1718 1719

        class DygraphChecker(Checker):
1720 1721 1722
            def init(self):
                self.checker_name = "dygraph checker"

1723 1724
            def calculate_output(self):
                self.outputs = self.op_test._calc_dygraph_output(
1725 1726
                    place, no_check_set=no_check_set
                )
1727 1728 1729 1730

            def find_actual_value(self, name):
                with fluid.dygraph.base.guard(place=place):
                    imperative_actual = find_imperative_actual(
1731 1732
                        name, self.outputs, place
                    )
1733
                    imperative_actual_t = np.array(
1734 1735
                        imperative_actual.value().get_tensor()
                    )
1736 1737 1738
                    return imperative_actual, imperative_actual_t

            def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
1739
                if actual_np.dtype == np.uint16 and expect_np.dtype in [
1740 1741
                    np.float32,
                    np.float64,
1742
                ]:
1743
                    self.rtol = 1.0e-2
1744
                else:
1745
                    self.rtol = 1.0e-5
1746 1747 1748 1749
                if self.op_test.is_bfloat16_op():
                    if actual_np.dtype == np.uint16:
                        actual_np = convert_uint16_to_float(actual_np)
                    if expect_np.dtype == np.uint16:
X
xiongkun 已提交
1750
                        expect_np = convert_uint16_to_float(expect_np)
1751 1752 1753
                return actual_np, expect_np

            def _compare_list(self, name, actual, expect):
1754
                """if expect is a tuple, we need to compare list."""
1755 1756
                with fluid.dygraph.base.guard(place=place):
                    self.op_test.assertListEqual(
1757 1758 1759 1760 1761 1762 1763 1764 1765 1766
                        actual.value()
                        .get_tensor()
                        .recursive_sequence_lengths(),
                        expect[1],
                        "Output ("
                        + name
                        + ") has different lod at "
                        + str(place)
                        + " in dygraph mode",
                    )
1767 1768

            def _compare_numpy(self, name, actual_np, expect_np):
1769 1770 1771 1772 1773 1774
                if (
                    functools.reduce(lambda x, y: x * y, actual_np.shape, 1)
                    == 0
                    and functools.reduce(lambda x, y: x * y, expect_np.shape, 1)
                    == 0
                ):
1775 1776 1777 1778 1779 1780 1781 1782
                    pass
                else:
                    self.op_test.assertTrue(
                        np.allclose(
                            actual_np,
                            expect_np,
                            atol=atol,
                            rtol=self.rtol if hasattr(self, 'rtol') else 1e-5,
1783 1784 1785 1786 1787 1788 1789 1790 1791
                            equal_nan=equal_nan,
                        ),
                        "Output ("
                        + name
                        + ") has diff at "
                        + str(place)
                        + " in "
                        + self.checker_name,
                    )
1792 1793

        class EagerChecker(DygraphChecker):
1794 1795 1796
            def init(self):
                self.checker_name = "eager checker"

1797 1798 1799
            def calculate_output(self):
                # we only check end2end api when check_eager=True
                with _test_eager_guard():
1800
                    self.is_python_api_test = True
1801
                    eager_dygraph_outs = self.op_test._calc_python_api_output(
1802 1803
                        place
                    )
1804
                    if eager_dygraph_outs is None:
X
xiongkun 已提交
1805
                        self.is_python_api_test = False
1806
                        # missing KernelSignature, fall back to eager middle output.
1807
                        eager_dygraph_outs = self.op_test._calc_dygraph_output(
1808 1809
                            place, no_check_set=no_check_set
                        )
1810 1811 1812 1813 1814 1815 1816 1817
                self.outputs = eager_dygraph_outs

            def _compare_numpy(self, name, actual_np, expect_np):
                with _test_eager_guard():
                    super()._compare_numpy(name, actual_np, expect_np)

            def convert_uint16_to_float_ifneed(self, actual_np, expect_np):
                with _test_eager_guard():
1818
                    return super().convert_uint16_to_float_ifneed(
1819 1820
                        actual_np, expect_np
                    )
1821 1822 1823 1824 1825 1826

            def find_actual_value(self, name):
                with _test_eager_guard():
                    return super().find_actual_value(name)

            def _compare_list(self, name, actual, expect):
1827
                """if expect is a tuple, we need to compare list."""
1828 1829 1830
                with _test_eager_guard():
                    super()._compare_list(name, actual, expect)

X
xiongkun 已提交
1831 1832
            def _is_skip_name(self, name):
                # if in final state and kernel signature don't have name, then skip it.
1833 1834 1835 1836 1837
                if (
                    self.is_python_api_test
                    and hasattr(self.op_test, "python_out_sig")
                    and name not in self.op_test.python_out_sig
                ):
X
xiongkun 已提交
1838 1839
                    return True
                return super()._is_skip_name(name)
1840

1841
        # set some flags by the combination of arguments.
X
xiongkun 已提交
1842
        self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
1843 1844 1845 1846 1847
        if (
            self.dtype == np.float64
            and self.op_type
            not in op_threshold_white_list.NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST
        ):
1848 1849
            atol = 0

1850
        if self.is_bfloat16_op():
Y
Yiqun Liu 已提交
1851 1852
            if self.is_mkldnn_op():
                check_dygraph = False
1853
                check_eager = False
Y
Yiqun Liu 已提交
1854
                if hasattr(self, 'force_fp32_output') and getattr(
1855 1856
                    self, 'force_fp32_output'
                ):
Y
Yiqun Liu 已提交
1857 1858 1859
                    atol = 1e-2
                else:
                    atol = 2
1860
            else:
1861
                atol = 1e-1
1862

1863
        if no_check_set is not None:
1864 1865 1866 1867
            if (
                self.op_type
                not in no_check_set_white_list.no_check_set_white_list
            ):
1868
                raise AssertionError(
1869 1870
                    "no_check_set of op %s must be set to None." % self.op_type
                )
1871 1872 1873
        static_checker = StaticChecker(self, self.outputs)
        static_checker.check()
        outs, fetch_list = static_checker.outputs, static_checker.fetch_list
L
lujun 已提交
1874
        if check_dygraph:
1875 1876
            # always enable legacy dygraph
            g_enable_legacy_dygraph()
1877 1878 1879
            dygraph_checker = DygraphChecker(self, self.outputs)
            dygraph_checker.check()
            dygraph_outs = dygraph_checker.outputs
1880 1881
            # yield the original state
            g_disable_legacy_dygraph()
1882
        if check_eager:
1883 1884 1885
            eager_checker = EagerChecker(self, self.outputs)
            eager_checker.check()
            eager_dygraph_outs = eager_checker.outputs
1886

C
cc 已提交
1887
        # Note(zhiqiu): inplace_atol should be only set when op doesn't ensure
L
Leo Chen 已提交
1888 1889
        # computational consistency.
        # For example, group_norm uses AtomicAdd on CUDAPlace, which do not ensure
C
cc 已提交
1890
        # computation order when multiple threads write the same address. So the
L
Leo Chen 已提交
1891 1892 1893
        # result of group_norm is non-deterministic when datatype is float.
        # When inplace_atol is not None, the inplace check uses numpy.allclose
        # to check inplace result instead of numpy.array_equal.
1894 1895
        if inplace_atol is not None:
            warnings.warn(
L
Leo Chen 已提交
1896 1897
                "inplace_atol should only be set when op doesn't ensure computational consistency, please check it!"
            )
1898
        # Check inplace for given op, its grad op, its grad_grad op, etc.
C
cc 已提交
1899
        # No effect on original OpTest
1900
        # Currently not support ParallelExecutor on XPUPlace.
1901 1902 1903 1904 1905 1906 1907 1908 1909
        if (
            not paddle.is_compiled_with_xpu()
            and not paddle.is_compiled_with_npu()
            and not paddle.is_compiled_with_mlu()
            and not isinstance(place, core.CustomPlace)
        ):
            self.check_inplace_output_with_place(
                place, no_check_set=no_check_set, inplace_atol=inplace_atol
            )
1910

1911
        if check_eager:
1912
            assert not check_dygraph
1913
            return outs, eager_dygraph_outs, fetch_list
1914
        elif check_dygraph:
1915 1916 1917 1918 1919 1920 1921
            return outs, dygraph_outs, fetch_list
        else:
            return outs, fetch_list

    def check_compile_vs_runtime(self, fetch_list, fetch_outs):
        def find_fetch_index(target_name, fetch_list):
            found = [
1922 1923
                i
                for i, var_name in enumerate(fetch_list)
1924 1925 1926 1927 1928 1929 1930
                if var_name == target_name
            ]
            if len(found) == 0:
                return -1
            else:
                self.assertTrue(
                    len(found) == 1,
1931 1932
                    "Found {} {}".format(len(found), target_name),
                )
1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957
                return found[0]

        for name in self.op.desc.output_names():
            var_names = self.op.desc.output(name)
            for var_name in var_names:
                i = find_fetch_index(var_name, fetch_list)
                if i == -1:
                    # The output is dispensiable or intermediate.
                    break
                out = fetch_outs[i]
                if isinstance(out, core.LoDTensor):
                    lod_level_runtime = len(out.lod())
                else:
                    if isinstance(out, core.LoDTensorArray):
                        warnings.warn(
                            "The check of LoDTensorArray's lod_level is not implemented now!"
                        )
                    lod_level_runtime = 0

                var = self.program.global_block().var(var_name)
                if var.type == core.VarDesc.VarType.LOD_TENSOR:
                    lod_level_compile = var.lod_level
                else:
                    lod_level_compile = 0
                self.assertEqual(
1958 1959 1960 1961 1962 1963 1964 1965 1966 1967
                    lod_level_compile,
                    lod_level_runtime,
                    "The lod_level of Output ("
                    + name
                    + ") is different between compile-time and runtime ("
                    + str(lod_level_compile)
                    + " vs "
                    + str(lod_level_runtime)
                    + ")",
                )
1968

1969
    def _get_places(self):
D
dzhwinter 已提交
1970 1971
        if self.dtype == np.float16:
            if core.is_compiled_with_cuda() and core.op_support_gpu(
1972 1973
                self.op_type
            ):
D
dzhwinter 已提交
1974 1975 1976
                place = core.CUDAPlace(0)
                if core.is_float16_supported(place):
                    return [place]
W
Wu Yi 已提交
1977 1978
                else:
                    return []
D
dzhwinter 已提交
1979 1980
            else:
                return []
1981
        places = [fluid.CPUPlace()]
1982
        cpu_only = self._cpu_only if hasattr(self, '_cpu_only') else False
1983 1984 1985 1986 1987
        if (
            core.is_compiled_with_cuda()
            and core.op_support_gpu(self.op_type)
            and not cpu_only
        ):
D
dzhwinter 已提交
1988
            places.append(core.CUDAPlace(0))
1989 1990
        return places

1991 1992 1993 1994 1995 1996 1997 1998 1999
    def check_output(
        self,
        atol=1e-5,
        no_check_set=None,
        equal_nan=False,
        check_dygraph=True,
        inplace_atol=None,
        check_eager=False,
    ):
2000 2001

        # disable legacy dygraph check when check_eager is True
2002
        if check_eager:
2003 2004
            check_dygraph = False

2005
        self.__class__.op_type = self.op_type
Y
Yiqun Liu 已提交
2006
        if self.is_mkldnn_op():
2007
            self.__class__.use_mkldnn = True
C
cc 已提交
2008

Y
Yiqun Liu 已提交
2009
        if self.is_xpu_op():
2010 2011
            self.__class__.use_xpu = True

2012
        places = self._get_places()
Q
qijun 已提交
2013
        for place in places:
2014 2015 2016 2017 2018 2019 2020 2021 2022
            res = self.check_output_with_place(
                place,
                atol,
                no_check_set,
                equal_nan,
                check_dygraph,
                inplace_atol,
                check_eager=check_eager,
            )
2023
            if check_eager:
2024
                assert not check_dygraph
2025
                outs, eager_dygraph_outs, fetch_list = res
2026
            elif check_dygraph:
2027 2028 2029
                outs, dygraph_outs, fetch_list = res
            else:
                outs, fetch_list = res
2030 2031 2032 2033
            if (
                self.op_type
                not in compile_vs_runtime_white_list.COMPILE_RUN_OP_WHITE_LIST
            ):
2034
                self.check_compile_vs_runtime(fetch_list, outs)
Q
qijun 已提交
2035

P
pangyoki 已提交
2036
    def check_output_customized(self, checker, custom_place=None):
2037
        places = self._get_places()
P
pangyoki 已提交
2038 2039
        if custom_place:
            places.append(custom_place)
2040 2041 2042
        for place in places:
            outs = self.calc_output(place)
            outs = [np.array(out) for out in outs]
2043
            outs.sort(key=len)
2044 2045
            checker(outs)

2046 2047 2048 2049 2050 2051
    def check_output_with_place_customized(self, checker, place):
        outs = self.calc_output(place)
        outs = [np.array(out) for out in outs]
        outs.sort(key=len)
        checker(outs)

2052 2053 2054 2055 2056 2057 2058 2059
    def _assert_is_close(
        self,
        numeric_grads,
        analytic_grads,
        names,
        max_relative_error,
        msg_prefix,
    ):
2060
        for a, b, name in zip(numeric_grads, analytic_grads, names):
2061 2062 2063 2064 2065 2066
            # It asserts np.abs(a - b) / np.abs(a) < max_relative_error, in which
            # max_relative_error is 1e-7. According to the value of np.abs(a), we
            # change np.abs(a) to achieve dynamic threshold. For example, if
            # the value of np.abs(a) is between 1e-10 and 1e-8, we set np.abs(a)*=1e4.
            # Therefore, it asserts np.abs(a - b) / (np.abs(a)*1e4) < max_relative_error,
            # which is the same as np.abs(a - b) / np.abs(a) < max_relative_error*1e4.
2067
            abs_a = np.abs(a)
2068
            if abs_a.ndim > 0:
2069 2070 2071 2072 2073
                if (
                    self.dtype == np.float64
                    and self.op_type
                    not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
                ):
2074 2075 2076 2077 2078 2079 2080 2081
                    abs_a[abs_a < 1e-10] = 1e-3
                    abs_a[np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)] *= 1e4
                    abs_a[np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)] *= 1e2
                elif self.is_bfloat16_op():
                    abs_a[abs_a < 1e-2] = 1
                else:
                    abs_a[abs_a < 1e-3] = 1
            elif abs_a.ndim == 0:
2082 2083 2084 2085 2086
                if (
                    self.dtype == np.float64
                    and self.op_type
                    not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
                ):
2087 2088 2089 2090 2091 2092 2093 2094 2095 2096
                    if abs_a < 1e-10:
                        abs_a = 1e-3
                    elif abs_a > 1e-10 and abs_a <= 1e-8:
                        abs_a = abs_a * 1e4
                    elif abs_a > 1e-8 and abs_a <= 1e-6:
                        abs_a = abs_a * 1e2
                elif self.is_bfloat16_op():
                    abs_a = 1 if abs_a < 1e-2 else abs_a
                else:
                    abs_a = 1 if abs_a < 1e-3 else abs_a
2097 2098 2099 2100 2101 2102

            diff_mat = np.abs(a - b) / abs_a
            max_diff = np.max(diff_mat)

            def err_msg():
                offset = np.argmax(diff_mat > max_relative_error)
2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117
                return (
                    "Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff %e over limit %e, "
                    "the first error element is %d, expected %e, but got %e."
                ) % (
                    self.op_type,
                    msg_prefix,
                    name,
                    str(a.shape),
                    self.dtype,
                    max_diff,
                    max_relative_error,
                    offset,
                    a.flatten()[offset],
                    b.flatten()[offset],
                )
2118 2119 2120

            self.assertLessEqual(max_diff, max_relative_error, err_msg())

2121 2122 2123 2124 2125 2126 2127
    def _check_grad_helper(self):
        self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
        self.__class__.op_type = self.op_type
        self.__class__.exist_check_grad = True
        if self.dtype == np.float64:
            self.__class__.exist_fp64_check_grad = True

2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140
    def check_grad(
        self,
        inputs_to_check,
        output_names,
        no_grad_set=None,
        numeric_grad_delta=0.005,
        in_place=False,
        max_relative_error=0.005,
        user_defined_grads=None,
        user_defined_grad_outputs=None,
        check_dygraph=True,
        check_eager=False,
    ):
2141 2142

        # disable legacy dygraph check when check_eager is True
2143
        if check_eager:
2144 2145
            check_dygraph = False

2146
        self._check_grad_helper()
2147
        places = self._get_places()
2148
        for place in places:
2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177
            self.check_grad_with_place(
                place,
                inputs_to_check,
                output_names,
                no_grad_set,
                numeric_grad_delta,
                in_place,
                max_relative_error,
                user_defined_grads,
                user_defined_grad_outputs,
                check_dygraph,
                check_eager=check_eager,
            )

    def check_grad_with_place(
        self,
        place,
        inputs_to_check,
        output_names,
        no_grad_set=None,
        numeric_grad_delta=0.005,
        in_place=False,
        max_relative_error=0.005,
        user_defined_grads=None,
        user_defined_grad_outputs=None,
        check_dygraph=True,
        numeric_place=None,
        check_eager=False,
    ):
2178 2179

        # disable legacy dygraph check when check_eager is True
2180
        if check_eager:
2181 2182
            check_dygraph = False

2183
        self.scope = core.Scope()
Q
qijun 已提交
2184
        op_inputs = self.inputs if hasattr(self, "inputs") else dict()
2185
        op_outputs = self.outputs if hasattr(self, "outputs") else dict()
Q
qijun 已提交
2186
        op_attrs = self.attrs if hasattr(self, "attrs") else dict()
P
phlrain 已提交
2187

Y
Yiqun Liu 已提交
2188 2189
        self._check_grad_helper()
        if self.is_bfloat16_op() and self.is_mkldnn_op():
2190
            check_dygraph = False
2191
            check_eager = False
2192

2193 2194 2195 2196 2197
        if (
            self.dtype == np.float64
            and self.op_type
            not in op_threshold_white_list.NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST
        ):
2198 2199
            numeric_grad_delta = 1e-5
            max_relative_error = 1e-7
2200

P
phlrain 已提交
2201 2202 2203
        cache_list = None
        if hasattr(self, "cache_name_list"):
            cache_list = self.cache_name_list
2204 2205 2206

        # oneDNN numeric gradient should use CPU kernel
        use_onednn = False
2207
        if "use_mkldnn" in op_attrs and op_attrs["use_mkldnn"]:
2208 2209 2210
            op_attrs["use_mkldnn"] = False
            use_onednn = True

2211 2212 2213 2214 2215 2216 2217 2218
        self.op = create_op(
            self.scope,
            self.op_type,
            op_inputs,
            op_outputs,
            op_attrs,
            cache_list=cache_list,
        )
Y
Yu Yang 已提交
2219

2220 2221 2222
        if use_onednn:
            op_attrs["use_mkldnn"] = True

2223 2224
        if no_grad_set is None:
            no_grad_set = set()
2225
        else:
2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237
            if (
                (self.op_type not in no_grad_set_white_list.NEED_TO_FIX_OP_LIST)
                and (
                    self.op_type not in no_grad_set_white_list.NOT_CHECK_OP_LIST
                )
                and (not self.is_bfloat16_op())
            ):
                raise AssertionError(
                    "no_grad_set must be None, op_type is "
                    + self.op_type
                    + " Op."
                )
2238

2239 2240 2241
        for input_to_check in inputs_to_check:
            set_input(self.scope, self.op, self.inputs, place)
            tensor_to_check = self.scope.find_var(input_to_check).get_tensor()
2242 2243 2244
            tensor_size = functools.reduce(
                lambda a, b: a * b, tensor_to_check.shape(), 1
            )
2245 2246 2247
            tensor_ndim = len(tensor_to_check.shape())
            # for 0D Tensor, it's additional case for OP, so not raise error
            if tensor_ndim > 0 and tensor_size < 100:
2248 2249
                self.__class__.input_shape_is_large = False

Y
Yancey 已提交
2250 2251 2252
        if not type(output_names) is list:
            output_names = [output_names]

2253 2254 2255
        if numeric_place is None:
            numeric_place = place

Q
Qiao Longfei 已提交
2256
        numeric_grads = user_defined_grads or [
2257 2258 2259 2260 2261 2262 2263 2264 2265 2266
            get_numeric_gradient(
                numeric_place,
                self.scope,
                self.op,
                self.inputs,
                input_to_check,
                output_names,
                delta=numeric_grad_delta,
                in_place=in_place,
            )
2267
            for input_to_check in inputs_to_check
2268
        ]
2269 2270 2271 2272 2273 2274 2275
        analytic_grads = self._get_gradient(
            inputs_to_check,
            place,
            output_names,
            no_grad_set,
            user_defined_grad_outputs,
        )
2276 2277
        # comparison of bf16 results will happen as fp32
        # loop over list of grads and convert bf16 to fp32
2278
        fp32_analytic_grads = []
2279 2280 2281
        for grad in analytic_grads:
            if grad.dtype == np.uint16:
                grad = convert_uint16_to_float(grad)
2282 2283 2284
                max_relative_error = (
                    0.04 if max_relative_error < 0.04 else max_relative_error
                )
2285 2286 2287 2288 2289 2290 2291
            fp32_analytic_grads.append(grad)
        analytic_grads = fp32_analytic_grads

        fp32_numeric_grads = []
        for grad in numeric_grads:
            if grad.dtype == np.uint16:
                grad = convert_uint16_to_float(grad)
2292 2293 2294
                max_relative_error = (
                    0.04 if max_relative_error < 0.04 else max_relative_error
                )
2295 2296
            fp32_numeric_grads.append(grad)
        numeric_grads = fp32_numeric_grads
2297

2298 2299 2300 2301 2302 2303 2304
        self._assert_is_close(
            numeric_grads,
            analytic_grads,
            inputs_to_check,
            max_relative_error,
            "Gradient Check On %s" % str(place),
        )
Q
qijun 已提交
2305

2306
        if check_dygraph:
2307 2308 2309
            # ensure switch into legacy dygraph
            g_enable_legacy_dygraph()

2310 2311 2312 2313 2314 2315 2316 2317
            dygraph_grad = self._get_dygraph_grad(
                inputs_to_check,
                place,
                output_names,
                user_defined_grad_outputs,
                no_grad_set,
                False,
            )
2318 2319 2320 2321
            fp32_grads = []
            for grad in dygraph_grad:
                if grad.dtype == np.uint16:
                    grad = convert_uint16_to_float(grad)
2322 2323 2324 2325 2326
                    max_relative_error = (
                        0.03
                        if max_relative_error < 0.03
                        else max_relative_error
                    )
2327 2328
                fp32_grads.append(grad)
            dygraph_grad = fp32_grads
2329 2330 2331 2332 2333 2334 2335
            self._assert_is_close(
                numeric_grads,
                dygraph_grad,
                inputs_to_check,
                max_relative_error,
                "Gradient Check On %s" % str(place),
            )
2336 2337
            # ensure switch back eager dygraph
            g_disable_legacy_dygraph()
2338

2339
        if check_eager:
J
Jiabin Yang 已提交
2340 2341 2342
            with fluid.dygraph.base.guard(place):
                with _test_eager_guard():
                    eager_dygraph_grad = self._get_dygraph_grad(
2343 2344 2345 2346 2347 2348 2349
                        inputs_to_check,
                        place,
                        output_names,
                        user_defined_grad_outputs,
                        no_grad_set,
                        check_eager,
                    )
J
Jiabin Yang 已提交
2350 2351 2352 2353
                    fp32_grads = []
                    for grad in eager_dygraph_grad:
                        if grad.dtype == np.uint16:
                            grad = convert_uint16_to_float(grad)
2354 2355 2356 2357 2358
                            max_relative_error = (
                                0.03
                                if max_relative_error < 0.03
                                else max_relative_error
                            )
J
Jiabin Yang 已提交
2359 2360
                        fp32_grads.append(grad)
                    eager_dygraph_grad = fp32_grads
2361 2362 2363 2364 2365 2366 2367
                    self._assert_is_close(
                        numeric_grads,
                        eager_dygraph_grad,
                        inputs_to_check,
                        max_relative_error,
                        "Gradient Check On %s" % str(place),
                    )
2368

2369 2370 2371 2372 2373 2374 2375 2376 2377
    def _find_var_in_dygraph(self, output_vars, name):
        if name in output_vars:
            return output_vars[name]
        else:
            for output_vars_index in output_vars:
                for output_vars_selected in output_vars[output_vars_index]:
                    if output_vars_selected.name == name:
                        return output_vars_selected

2378 2379 2380 2381 2382 2383 2384 2385 2386
    def _get_dygraph_grad(
        self,
        inputs_to_check,
        place,
        output_names,
        user_defined_grad_outputs=None,
        no_grad_set=None,
        check_eager=False,
    ):
2387 2388 2389 2390 2391 2392 2393
        with fluid.dygraph.base.guard(place=place):
            block = fluid.default_main_program().global_block()

            op_proto = OpProtoHolder.instance().get_op_proto(self.op_type)

            # prepare input variable
            inputs, inputs_grad_dict = self.append_input_output_for_dygraph(
2394 2395
                op_proto, self.inputs, True, True, block
            )
2396 2397 2398

            # prepare output variable
            outputs = self.append_input_output_for_dygraph(
2399 2400
                op_proto, self.outputs, False, False, block
            )
2401

2402
            # prepare attributes
2403 2404 2405 2406 2407
            attrs_outputs = {}
            if hasattr(self, "attrs"):
                for attrs_name in self.attrs:
                    if self.attrs[attrs_name] is not None:
                        attrs_outputs[attrs_name] = self.attrs[attrs_name]
2408

2409
            if check_eager:
2410
                eager_outputs = self._calc_python_api_output(
2411 2412
                    place, inputs, outputs
                )
2413
            # if outputs is None, kernel sig is empty or other error is happens.
X
xiongkun 已提交
2414
            if not check_eager or eager_outputs is None:
2415 2416 2417 2418
                block.append_op(
                    type=self.op_type,
                    inputs=inputs,
                    outputs=outputs,
2419 2420
                    attrs=attrs_outputs if hasattr(self, "attrs") else None,
                )
X
xiongkun 已提交
2421 2422
            else:
                outputs = eager_outputs
2423

2424
            if self.dtype == np.uint16:
2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439
                cast_inputs = self._find_var_in_dygraph(
                    outputs, output_names[0]
                )
                cast_outputs = block.create_var(
                    dtype="float32", shape=cast_inputs[0].shape
                )
                cast_op = block.append_op(
                    inputs={"X": cast_inputs},
                    outputs={"Out": cast_outputs},
                    type="cast",
                    attrs={
                        "in_dtype": core.VarDesc.VarType.BF16,
                        "out_dtype": core.VarDesc.VarType.FP32,
                    },
                )
2440 2441
                outputs = {output_names[0]: cast_outputs}

2442 2443 2444
            outputs_valid = {}
            for output_name in output_names:
                outputs_valid[output_name] = self._find_var_in_dygraph(
2445 2446
                    outputs, output_name
                )
2447

2448 2449 2450 2451 2452 2453 2454
            if user_defined_grad_outputs is None:
                if len(outputs_valid) == 1:
                    loss = block.create_var(
                        dtype=self.dtype,
                        type=core.VarDesc.VarType.LOD_TENSOR,
                        persistable=False,
                        stop_gradient=False,
2455 2456
                        shape=[1],
                    )
2457 2458 2459 2460 2461
                    for outputs_valid_key in outputs_valid:
                        block.append_op(
                            type="mean",
                            inputs={"X": outputs_valid[outputs_valid_key]},
                            outputs={"Out": [loss]},
2462 2463
                            attrs=None,
                        )
2464 2465 2466 2467 2468 2469 2470
                else:
                    avg_sum = []
                    for cur_loss in outputs_valid:
                        cur_avg_loss = block.create_var(
                            dtype=self.dtype,
                            type=core.VarDesc.VarType.LOD_TENSOR,
                            persistable=False,
2471 2472 2473 2474 2475 2476 2477 2478
                            stop_gradient=False,
                        )
                        block.append_op(
                            type="mean",
                            inputs={"X": outputs_valid[cur_loss]},
                            outputs={"Out": [cur_avg_loss]},
                            attrs=None,
                        )
2479 2480 2481 2482 2483 2484
                        avg_sum.append(cur_avg_loss)
                    loss_sum = block.create_var(
                        dtype=self.dtype,
                        type=core.VarDesc.VarType.LOD_TENSOR,
                        persistable=False,
                        stop_gradient=False,
2485 2486 2487 2488 2489 2490 2491 2492
                        shape=[1],
                    )
                    block.append_op(
                        type='sum',
                        inputs={"X": avg_sum},
                        outputs={"Out": loss_sum},
                        attrs=None,
                    )
2493
                    loss = block.create_var(
2494 2495 2496
                        dtype=self.dtype,
                        type=core.VarDesc.VarType.LOD_TENSOR,
                        persistable=False,
2497
                        stop_gradient=False,
2498 2499 2500 2501 2502 2503 2504 2505
                        shape=[1],
                    )
                    block.append_op(
                        type='scale',
                        inputs={"X": loss_sum},
                        outputs={"Out": loss},
                        attrs={'scale': 1.0 / float(len(avg_sum))},
                    )
2506
                loss.backward()
2507

2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519
                fetch_list_grad = []
                for inputs_to_check_name in inputs_to_check:
                    a = inputs_grad_dict[inputs_to_check_name].gradient()
                    fetch_list_grad.append(a)
                return fetch_list_grad
            else:
                # user_defined_grad_outputs here are numpy arrays
                if not isinstance(user_defined_grad_outputs, list):
                    user_defined_grad_outputs = [user_defined_grad_outputs]
                grad_outputs = []
                for grad_out_value in user_defined_grad_outputs:
                    grad_outputs.append(paddle.to_tensor(grad_out_value))
2520
                # delete the inputs which no need to calculate grad
C
chentianyu03 已提交
2521
                for no_grad_val in no_grad_set:
2522
                    del inputs[no_grad_val]
C
chentianyu03 已提交
2523

J
Jiabin Yang 已提交
2524
                if not _in_legacy_dygraph():
2525 2526 2527
                    core.eager.run_backward(
                        fluid.layers.utils.flatten(outputs), grad_outputs, False
                    )
2528 2529 2530 2531 2532 2533 2534 2535 2536
                    grad_inputs = []
                    for inputs_list in inputs.values():
                        for inp in inputs_list:
                            grad_inputs.append(inp.grad.numpy())
                    return grad_inputs
                else:
                    grad_inputs = paddle.grad(
                        outputs=fluid.layers.utils.flatten(outputs),
                        inputs=fluid.layers.utils.flatten(inputs),
2537 2538
                        grad_outputs=grad_outputs,
                    )
2539
                    return [grad.numpy() for grad in grad_inputs]
2540

Y
Yu Yang 已提交
2541 2542 2543 2544 2545
    @staticmethod
    def _numpy_to_lod_tensor(np_value, lod, place):
        tensor = core.LoDTensor()
        tensor.set(np_value, place)
        if lod is not None:
2546
            tensor.set_recursive_sequence_lengths(lod)
Y
Yu Yang 已提交
2547 2548
        return tensor

K
Kexin Zhao 已提交
2549
    @staticmethod
K
Kexin Zhao 已提交
2550 2551
    def np_dtype_to_fluid_dtype(input):
        return input
K
Kexin Zhao 已提交
2552

D
dzhwinter 已提交
2553 2554 2555 2556 2557 2558 2559 2560
    @staticmethod
    def fluid_dtype_to_np_dtype(self, dtype):
        return dtype

    @staticmethod
    def np_value_to_fluid_value(input):
        return input

2561 2562 2563 2564 2565 2566 2567 2568 2569
    def _get_gradient(
        self,
        input_to_check,
        place,
        output_names,
        no_grad_set,
        user_defined_grad_outputs=None,
        parallel=False,
    ):
Y
Yu Yang 已提交
2570
        prog = Program()
2571
        scope = core.Scope()
Y
Yu Yang 已提交
2572
        block = prog.global_block()
2573
        self._append_ops(block)
Y
Yu Yang 已提交
2574

2575
        inputs = self._get_inputs(block)
2576
        outputs = self._get_outputs(block)
2577
        feed_dict = self.feed_var(inputs, place)
Y
Yu Yang 已提交
2578

2579
        if user_defined_grad_outputs is None:
2580 2581
            if self.dtype == np.uint16:
                cast_inputs = list(map(block.var, output_names))
2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593
                cast_outputs = block.create_var(
                    dtype="float32", shape=cast_inputs[0].shape
                )
                cast_op = block.append_op(
                    inputs={"X": cast_inputs},
                    outputs={"Out": cast_outputs},
                    type="cast",
                    attrs={
                        "in_dtype": core.VarDesc.VarType.BF16,
                        "out_dtype": core.VarDesc.VarType.FP32,
                    },
                )
2594 2595 2596
                cast_op.desc.infer_var_type(block.desc)
                cast_op.desc.infer_shape(block.desc)
                output_names = [cast_outputs.name]
2597
            loss = append_loss_ops(block, output_names)
2598 2599 2600 2601 2602
            param_grad_list = append_backward(
                loss=loss,
                parameter_list=input_to_check,
                no_grad_set=no_grad_set,
            )
2603 2604
            fetch_list = [g for p, g in param_grad_list]
        else:
2605 2606 2607
            assert (
                parallel is False
            ), "unsupported parallel mode when giving custom grad outputs."
2608 2609 2610 2611 2612 2613
            # user_defined_grad_outputs here are numpy arrays
            if not isinstance(user_defined_grad_outputs, list):
                user_defined_grad_outputs = [user_defined_grad_outputs]
            grad_outputs = []
            for grad_out_value in user_defined_grad_outputs:
                # `presistable` is used to avoid executor create new var in local scope
2614 2615 2616 2617 2618
                var = block.create_var(
                    shape=grad_out_value.shape,
                    dtype=grad_out_value.dtype,
                    persistable=True,
                )
2619 2620 2621 2622 2623 2624 2625
                true_var = scope.var(var.name)
                tensor = true_var.get_tensor()
                tensor.set(grad_out_value, place)
                grad_outputs.append(var)
            targets = [
                outputs[name] for name in outputs if name in output_names
            ]
2626
            inputs = [inputs[name] for name in input_to_check if name in inputs]
2627 2628 2629
            grad_inputs = paddle.static.gradients(
                targets, inputs, grad_outputs, no_grad_set
            )
2630 2631
            fetch_list = grad_inputs

2632 2633
        if parallel:
            use_cuda = False
2634
            if isinstance(place, fluid.CUDAPlace):
2635
                use_cuda = True
2636
            compiled_prog = fluid.CompiledProgram(prog).with_data_parallel(
2637 2638
                loss_name=loss.name, places=place
            )
2639 2640
            prog = compiled_prog
        executor = fluid.Executor(place)
2641
        return list(
2642 2643
            map(
                np.array,
2644 2645 2646 2647 2648
                executor.run(
                    prog, feed_dict, fetch_list, scope=scope, return_numpy=False
                ),
            )
        )
A
arlesniak 已提交
2649 2650 2651 2652 2653 2654 2655 2656 2657 2658


class OpTestTool:
    @classmethod
    def skip_if(cls, condition: object, reason: str):
        return unittest.skipIf(condition, reason)

    @classmethod
    def skip_if_not_cpu_bf16(cls):
        return OpTestTool.skip_if(
2659 2660 2661 2662 2663 2664
            not (
                isinstance(_current_expected_place(), core.CPUPlace)
                and core.supports_bfloat16()
            ),
            "Place does not support BF16 evaluation",
        )
2665 2666 2667 2668 2669

    @classmethod
    def skip_if_not_cpu(cls):
        return OpTestTool.skip_if(
            not isinstance(_current_expected_place(), core.CPUPlace),
2670 2671
            "OneDNN supports only CPU for now",
        )