未验证 提交 457defe7 编写于 作者: C Charles-hit 提交者: GitHub

[OpTest] support prim test in OpTest (#50509)

* support prim test in OpTest

* fix cmake

* fix op test

* fix test_input_spec

* disable cinn in reduce_sum unit test

* add bfloat16 dtype for sum

* polish code

* add clear jit program function

* convert grad out from tensor to numpy

* remove unnecessary code

* add only_prim flag

* fix flag

* fix op test

* fix optest comp inplace error

* fix op test

* fix op test with guard

* add initialization of check_comp flag

* fix comp inplace error in op test

* rename check_comp with check_prim and add bfloat16 dtype convert

* rename comp_op_type to prim_op_type

* rename comp to prim

* remove useless code

* skip ci check for only prim

* add no_grad_vars and grad_outputs in prim test

* fix var_dict

* fix op test for only_prim

* fix dy2static bugs

* polish some code
上级 2135020a
......@@ -1204,6 +1204,14 @@ if($ENV{USE_STANDALONE_EXECUTOR})
PROPERTIES ENVIRONMENT FLAGS_USE_STANDALONE_EXECUTOR=0)
endif()
set(TEST_CINN_OPS test_softmax_op test_expand_v2_op test_reduce_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
if(WITH_CINN)
set_tests_properties(${TEST_CINN_OPS} PROPERTIES LABELS "RUN_TYPE=CINN")
endif()
endforeach()
if(WITH_CINN AND WITH_TESTING)
set_tests_properties(
test_resnet50_with_cinn
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
TOLERANCE = {
np.dtype('float64'): {
"jit_comp": {"rtol": 1e-15, "atol": 1e-15},
"fw_comp": {"rtol": 1e-15, "atol": 1e-15},
"rev_comp": {"rtol": 1e-15, "atol": 1e-15},
"cinn": {"rtol": 1e-14, "atol": 1e-14},
},
np.dtype('float32'): {
"jit_comp": {"rtol": 1e-6, "atol": 1e-6},
"fw_comp": {"rtol": 1e-6, "atol": 1e-6},
"rev_comp": {"rtol": 1e-6, "atol": 1e-6},
"cinn": {"rtol": 1e-5, "atol": 1e-5},
},
np.dtype('float16'): {
"jit_comp": {"rtol": 1e-3, "atol": 1e-3},
"fw_comp": {"rtol": 1e-3, "atol": 1e-3},
"rev_comp": {"rtol": 1e-3, "atol": 1e-3},
"cinn": {"rtol": 1e-2, "atol": 1e-2},
},
np.dtype('uint16'): {
"jit_comp": {"rtol": 1e-2, "atol": 1e-2},
"fw_comp": {"rtol": 1e-2, "atol": 1e-2},
"rev_comp": {"rtol": 1e-2, "atol": 1e-2},
"cinn": {"rtol": 1e-1, "atol": 1e-1},
},
}
......@@ -34,13 +34,12 @@ from paddle.fluid.framework import (
OpProtoHolder,
Program,
_current_expected_place,
_dygraph_tracer,
in_dygraph_mode,
)
from paddle.fluid.op import Operator
from paddle.jit.dy2static.utils import parse_arg_and_kwargs
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
from prim_op_test import OpTestUtils, PrimForwardChecker, PrimGradChecker
from testsuite import append_input_output, append_loss_ops, create_op, set_input
from white_list import (
check_shape_white_list,
......@@ -321,6 +320,7 @@ class OpTest(unittest.TestCase):
cls.dtype = None
cls.outputs = {}
cls.input_shape_is_large = True
cls.check_prim = False
np.random.seed(123)
random.seed(124)
......@@ -401,6 +401,7 @@ class OpTest(unittest.TestCase):
and not is_npu_op_test()
and not is_mlu_op_test()
and not is_custom_device_op_test()
and not cls.check_prim
):
raise AssertionError(
"This test of %s op needs check_grad with fp64 precision."
......@@ -579,7 +580,6 @@ class OpTest(unittest.TestCase):
type=core.VarDesc.VarType.RAW,
stop_gradient=True,
)
op = block.append_op(
type=self.op_type,
inputs=inputs,
......@@ -806,100 +806,6 @@ class OpTest(unittest.TestCase):
def _calc_python_api_output(self, place, egr_inps=None, egr_oups=None):
"""set egr_inps and egr_oups = None if you want to create it by yourself."""
def prepare_python_api_arguments(
api, op_proto_ins, op_proto_attrs, kernel_sig
):
"""map from `op proto inputs and attrs` to `api input list and api attrs dict`
NOTE: the op_proto_attrs and op_proto_ins is a default dict. default value is []
"""
class Empty:
pass
def is_empty(a):
return isinstance(a, Empty)
def get_default(idx, defaults):
assert not isinstance(defaults[idx], Empty), (
"%d-th params of python api don't have default value." % idx
)
return defaults[idx]
def to_defaults_list(params, defaults):
return [defaults[p] for p in params if p in defaults]
def parse_attri_value(name, op_inputs, op_attrs):
"""parse true value from inputs and attrs, if there is no name passed by OpTest, return Empty
1. if the name in op_attrs, use the op_attrs[name]
2. if the name in op_inputs, convert the op_inputs to [type of default value]
3. if the name not in op_attrs ans op_inputs, return Empty. (this will use the default value from python api)
"""
if name in op_proto_attrs:
return op_proto_attrs[name]
elif name in op_inputs:
if len(op_inputs[name]) == 1:
# why don't use numpy().item() : if the Tensor is float64, we will change it to python.float32, where we loss accuracy: [allclose_op]
# why we reconstruct a tensor: because we want the tensor in cpu.
return paddle.to_tensor(
op_inputs[name][0].numpy(), place='cpu'
)
else:
# if this is a list (test_unsqueeze2_op): we just pass it into the python api.
return op_inputs[name]
else:
return Empty()
# NOTE(xiongkun): the logic of constructing parameters:
# for example:
# python api: cumprod(x, dim, dtype=None, name=None)
# kernel sig: [["x"], ["dim"], ["out"]]"
#
# we will construct a lot of list with the same length : len == len(api_params), here is 4
# api_params = ["x", "dim", "dtype", "name"]
# api_defaults = [Empty, Empty, None, None]; empty means no defaults.
# inputs_and_attrs = ["x", "dim"] , the length may shorter or longer than api_params
# input_arguments = [RealValue in self.inputs and self.attrs]
# then ,we will loop for the api_params, construct a result list:
# if the name in ['name', 'dtype', 'out', 'output'], we will use the default value
# else, we will consume a input_arguments. (because the name is not corresponding, so we only use the order)
api_params, api_defaults = parse_arg_and_kwargs(api)
api_defaults = to_defaults_list(api_params, api_defaults)
api_defaults = [
Empty() for i in range(len(api_params) - len(api_defaults))
] + api_defaults
assert len(api_defaults) == len(
api_params
), "Error happens. contack xiongkun03 to solve."
inputs_sig, attrs_sig, outputs_sig = kernel_sig
inputs_and_attrs = inputs_sig + attrs_sig
input_arguments = [
op_proto_ins.get(name, Empty()) for name in inputs_sig
] + [
parse_attri_value(name, op_proto_ins, op_proto_attrs)
for name in attrs_sig
]
results = []
api_ignore_param_list = set(['name', 'dtype', 'out', 'output'])
idx_of_op_proto_arguments = 0
for idx, arg_name in enumerate(api_params):
if arg_name in api_ignore_param_list:
results.append(get_default(idx, api_defaults))
else:
if idx_of_op_proto_arguments < len(input_arguments):
tmp = input_arguments[idx_of_op_proto_arguments]
idx_of_op_proto_arguments += 1
else:
tmp = Empty() # use the default value
if isinstance(tmp, Empty):
results.append(get_default(idx, api_defaults))
else:
results.append(tmp)
assert len(results) == len(api_params)
return results
def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
if hasattr(self, "python_out_sig"):
output_sig = self.python_out_sig
......@@ -915,50 +821,11 @@ class OpTest(unittest.TestCase):
), "Don't support multi-output with multi-tensor output. (May be you can use set `python_out_sig`, see `test_squeeze2_op` as a example.)"
return {output_sig[0]: ret_tuple}
def assumption_assert_and_transform(args, inp_num):
"""
transform inputs by the following rules:
1. [Tensor] -> Tensor
2. [Tensor, Tensor, ...] -> list of Tensors
3. None -> None
4. Others: raise Error
only support "X" is list of Tensor, currently don't support other structure like dict.
"""
inp_args = [
[inp] if inp is None else inp for inp in args[:inp_num]
] # convert None -> [None]
for inp in inp_args:
assert isinstance(
inp, list
), "currently only support `X` is [Tensor], don't support other structure."
args = [
inp[0] if len(inp) == 1 else inp for inp in inp_args
] + args[inp_num:]
return args
def _get_kernel_signature(
dygraph_tensor_inputs, dygraph_tensor_outputs, attrs_outputs
):
try:
kernel_sig = _dygraph_tracer()._get_kernel_signature(
self.op_type,
dygraph_tensor_inputs,
dygraph_tensor_outputs,
attrs_outputs,
)
except RuntimeError as re:
"""we think the kernel_sig is missing."""
kernel_sig = None
print(
"[Warning: op_test.py] Kernel Signature is not found for %s, fall back to intermediate state."
% self.op_type
)
return kernel_sig
def cal_python_api(python_api, args, kernel_sig):
inputs_sig, attrs_sig, outputs_sig = kernel_sig
args = assumption_assert_and_transform(args, len(inputs_sig))
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret_tuple = python_api(*args)
return construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig)
......@@ -989,8 +856,11 @@ class OpTest(unittest.TestCase):
if self.attrs[attrs_name] is not None:
attrs_outputs[attrs_name] = self.attrs[attrs_name]
kernel_sig = _get_kernel_signature(
dygraph_tensor_inputs, dygraph_tensor_outputs, attrs_outputs
kernel_sig = OpTestUtils._get_kernel_signature(
self.op_type,
dygraph_tensor_inputs,
dygraph_tensor_outputs,
attrs_outputs,
)
if not kernel_sig:
return None
......@@ -998,7 +868,7 @@ class OpTest(unittest.TestCase):
"Detect there is KernelSignature for `%s` op, please set the `self.python_api` if you set check_dygraph = True"
% self.op_type
)
args = prepare_python_api_arguments(
args = OpTestUtils.prepare_python_api_arguments(
self.python_api,
dygraph_tensor_inputs,
attrs_outputs,
......@@ -1050,6 +920,7 @@ class OpTest(unittest.TestCase):
enable_inplace=None,
for_inplace_test=None,
):
with paddle.fluid.framework._dygraph_guard(None):
program = Program()
block = program.global_block()
op = self._append_ops(block)
......@@ -1072,7 +943,9 @@ class OpTest(unittest.TestCase):
use_cuda = False
if isinstance(place, fluid.CUDAPlace):
use_cuda = True
compiled_prog = fluid.CompiledProgram(program).with_data_parallel(
compiled_prog = fluid.CompiledProgram(
program
).with_data_parallel(
loss_name=loss.name if loss else None, places=place
)
program = compiled_prog
......@@ -1097,14 +970,19 @@ class OpTest(unittest.TestCase):
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = enable_inplace
compiled_prog = fluid.CompiledProgram(program).with_data_parallel(
compiled_prog = fluid.CompiledProgram(
program
).with_data_parallel(
build_strategy=build_strategy, places=place
)
program = compiled_prog
executor = Executor(place)
outs = executor.run(
program, feed=feed_map, fetch_list=fetch_list, return_numpy=False
program,
feed=feed_map,
fetch_list=fetch_list,
return_numpy=False,
)
self.op = op
self.program = original_program
......@@ -1371,6 +1249,7 @@ class OpTest(unittest.TestCase):
Returns:
res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc.
"""
with paddle.fluid.framework._dygraph_guard(None):
(
fwd_outs,
fwd_fetch_list,
......@@ -1465,7 +1344,6 @@ class OpTest(unittest.TestCase):
has_infer_inplace = fluid.core.has_infer_inplace(self.op_type)
has_grad_op_maker = fluid.core.has_grad_op_maker(self.op_type)
fwd_res = self._calc_output(
place, no_check_set=no_check_set, for_inplace_test=True
)
......@@ -1518,8 +1396,11 @@ class OpTest(unittest.TestCase):
no_check_set=None,
equal_nan=False,
check_dygraph=True,
check_prim=False,
inplace_atol=None,
):
core._set_prim_all_enabled(False)
def find_imperative_actual(target_name, dygraph_outs, place):
for name in dygraph_outs:
if name == target_name:
......@@ -1785,6 +1666,15 @@ class OpTest(unittest.TestCase):
return True
return super()._is_skip_name(name)
if check_prim:
prim_checker = PrimForwardChecker(self, place)
prim_checker.check()
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True)
self.__class__.op_type = self.op_type
if prim_checker.is_only_check_prim():
self.only_prim = True
return
# set some flags by the combination of arguments.
self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs)
if (
......@@ -1930,6 +1820,7 @@ class OpTest(unittest.TestCase):
no_check_set=None,
equal_nan=False,
check_dygraph=True,
check_prim=False,
inplace_atol=None,
):
......@@ -1948,8 +1839,11 @@ class OpTest(unittest.TestCase):
no_check_set,
equal_nan,
check_dygraph=check_dygraph,
check_prim=check_prim,
inplace_atol=inplace_atol,
)
if hasattr(self, 'only_prim') and self.only_prim:
continue
if check_dygraph:
outs, dygraph_dygraph_outs, fetch_list = res
else:
......@@ -2063,8 +1957,8 @@ class OpTest(unittest.TestCase):
user_defined_grads=None,
user_defined_grad_outputs=None,
check_dygraph=True,
check_prim=False,
):
self._check_grad_helper()
places = self._get_places()
for place in places:
......@@ -2079,6 +1973,7 @@ class OpTest(unittest.TestCase):
user_defined_grads,
user_defined_grad_outputs,
check_dygraph=check_dygraph,
check_prim=check_prim,
)
def check_grad_with_place(
......@@ -2093,9 +1988,26 @@ class OpTest(unittest.TestCase):
user_defined_grads=None,
user_defined_grad_outputs=None,
check_dygraph=True,
check_prim=False,
numeric_place=None,
):
core._set_prim_all_enabled(False)
if check_prim:
prim_grad_checker = PrimGradChecker(
self,
place,
inputs_to_check,
output_names,
no_grad_set,
user_defined_grad_outputs,
)
prim_grad_checker.check()
# Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True)
self._check_grad_helper()
if prim_grad_checker.is_only_check_prim():
self.only_prim = True
return
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
......@@ -2448,6 +2360,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None,
parallel=False,
):
with paddle.fluid.framework._dygraph_guard(None):
prog = Program()
scope = core.Scope()
block = prog.global_block()
......@@ -2504,7 +2417,9 @@ class OpTest(unittest.TestCase):
targets = [
outputs[name] for name in outputs if name in output_names
]
inputs = [inputs[name] for name in input_to_check if name in inputs]
inputs = [
inputs[name] for name in input_to_check if name in inputs
]
grad_inputs = paddle.static.gradients(
targets, inputs, grad_outputs, no_grad_set
)
......@@ -2519,14 +2434,19 @@ class OpTest(unittest.TestCase):
)
prog = compiled_prog
executor = fluid.Executor(place)
return list(
res = list(
map(
np.array,
executor.run(
prog, feed_dict, fetch_list, scope=scope, return_numpy=False
prog,
feed_dict,
fetch_list,
scope=scope,
return_numpy=False,
),
)
)
return res
class OpTestTool:
......
......@@ -35,16 +35,15 @@ from paddle.fluid.framework import (
Program,
_current_expected_place,
_disable_legacy_dygraph,
_dygraph_tracer,
_enable_legacy_dygraph,
_in_eager_without_dygraph_check,
_test_eager_guard,
in_dygraph_mode,
)
from paddle.fluid.op import Operator
from paddle.jit.dy2static.utils import parse_arg_and_kwargs
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
from prim_op_test import OpTestUtils, PrimForwardChecker, PrimGradChecker
from testsuite import append_input_output, append_loss_ops, create_op, set_input
from white_list import (
check_shape_white_list,
......@@ -334,6 +333,7 @@ class OpTest(unittest.TestCase):
cls.dtype = None
cls.outputs = {}
cls.input_shape_is_large = True
cls.check_prim = False
np.random.seed(123)
random.seed(124)
......@@ -414,6 +414,7 @@ class OpTest(unittest.TestCase):
and not is_npu_op_test()
and not is_mlu_op_test()
and not is_custom_device_op_test()
and not cls.check_prim
):
raise AssertionError(
"This test of %s op needs check_grad with fp64 precision."
......@@ -819,100 +820,6 @@ class OpTest(unittest.TestCase):
def _calc_python_api_output(self, place, egr_inps=None, egr_oups=None):
"""set egr_inps and egr_oups = None if you want to create it by yourself."""
def prepare_python_api_arguments(
api, op_proto_ins, op_proto_attrs, kernel_sig
):
"""map from `op proto inputs and attrs` to `api input list and api attrs dict`
NOTE: the op_proto_attrs and op_proto_ins is a default dict. default value is []
"""
class Empty:
pass
def is_empty(a):
return isinstance(a, Empty)
def get_default(idx, defaults):
assert not isinstance(defaults[idx], Empty), (
"%d-th params of python api don't have default value." % idx
)
return defaults[idx]
def to_defaults_list(params, defaults):
return [defaults[p] for p in params if p in defaults]
def parse_attri_value(name, op_inputs, op_attrs):
"""parse true value from inputs and attrs, if there is no name passed by OpTest, return Empty
1. if the name in op_attrs, use the op_attrs[name]
2. if the name in op_inputs, convert the op_inputs to [type of default value]
3. if the name not in op_attrs ans op_inputs, return Empty. (this will use the default value from python api)
"""
if name in op_proto_attrs:
return op_proto_attrs[name]
elif name in op_inputs:
if len(op_inputs[name]) == 1:
# why don't use numpy().item() : if the Tensor is float64, we will change it to python.float32, where we loss accuracy: [allclose_op]
# why we reconstruct a tensor: because we want the tensor in cpu.
return paddle.to_tensor(
op_inputs[name][0].numpy(), place='cpu'
)
else:
# if this is a list (test_unsqueeze2_op): we just pass it into the python api.
return op_inputs[name]
else:
return Empty()
# NOTE(xiongkun): the logic of constructing parameters:
# for example:
# python api: cumprod(x, dim, dtype=None, name=None)
# kernel sig: [["x"], ["dim"], ["out"]]"
#
# we will construct a lot of list with the same length : len == len(api_params), here is 4
# api_params = ["x", "dim", "dtype", "name"]
# api_defaults = [Empty, Empty, None, None]; empty means no defaults.
# inputs_and_attrs = ["x", "dim"] , the length may shorter or longer than api_params
# input_arguments = [RealValue in self.inputs and self.attrs]
# then ,we will loop for the api_params, construct a result list:
# if the name in ['name', 'dtype', 'out', 'output'], we will use the default value
# else, we will consume a input_arguments. (because the name is not corresponding, so we only use the order)
api_params, api_defaults = parse_arg_and_kwargs(api)
api_defaults = to_defaults_list(api_params, api_defaults)
api_defaults = [
Empty() for i in range(len(api_params) - len(api_defaults))
] + api_defaults
assert len(api_defaults) == len(
api_params
), "Error happens. contack xiongkun03 to solve."
inputs_sig, attrs_sig, outputs_sig = kernel_sig
inputs_and_attrs = inputs_sig + attrs_sig
input_arguments = [
op_proto_ins.get(name, Empty()) for name in inputs_sig
] + [
parse_attri_value(name, op_proto_ins, op_proto_attrs)
for name in attrs_sig
]
results = []
api_ignore_param_list = set(['name', 'dtype', 'out', 'output'])
idx_of_op_proto_arguments = 0
for idx, arg_name in enumerate(api_params):
if arg_name in api_ignore_param_list:
results.append(get_default(idx, api_defaults))
else:
if idx_of_op_proto_arguments < len(input_arguments):
tmp = input_arguments[idx_of_op_proto_arguments]
idx_of_op_proto_arguments += 1
else:
tmp = Empty() # use the default value
if isinstance(tmp, Empty):
results.append(get_default(idx, api_defaults))
else:
results.append(tmp)
assert len(results) == len(api_params)
return results
def construct_output_dict_by_kernel_sig(ret_tuple, output_sig):
if hasattr(self, "python_out_sig"):
output_sig = self.python_out_sig
......@@ -928,50 +835,11 @@ class OpTest(unittest.TestCase):
), "Don't support multi-output with multi-tensor output. (May be you can use set `python_out_sig`, see `test_squeeze2_op` as a example.)"
return {output_sig[0]: ret_tuple}
def assumption_assert_and_transform(args, inp_num):
"""
transform inputs by the following rules:
1. [Tensor] -> Tensor
2. [Tensor, Tensor, ...] -> list of Tensors
3. None -> None
4. Others: raise Error
only support "X" is list of Tensor, currently don't support other structure like dict.
"""
inp_args = [
[inp] if inp is None else inp for inp in args[:inp_num]
] # convert None -> [None]
for inp in inp_args:
assert isinstance(
inp, list
), "currently only support `X` is [Tensor], don't support other structure."
args = [
inp[0] if len(inp) == 1 else inp for inp in inp_args
] + args[inp_num:]
return args
def _get_kernel_signature(
eager_tensor_inputs, eager_tensor_outputs, attrs_outputs
):
try:
kernel_sig = _dygraph_tracer()._get_kernel_signature(
self.op_type,
eager_tensor_inputs,
eager_tensor_outputs,
attrs_outputs,
)
except RuntimeError as re:
"""we think the kernel_sig is missing."""
kernel_sig = None
print(
"[Warning: op_test.py] Kernel Signature is not found for %s, fall back to intermediate state."
% self.op_type
)
return kernel_sig
def cal_python_api(python_api, args, kernel_sig):
inputs_sig, attrs_sig, outputs_sig = kernel_sig
args = assumption_assert_and_transform(args, len(inputs_sig))
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret_tuple = python_api(*args)
return construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig)
......@@ -994,7 +862,6 @@ class OpTest(unittest.TestCase):
op_proto, self.outputs, False, False, block
)
)
# prepare attributes
attrs_outputs = {}
if hasattr(self, "attrs"):
......@@ -1002,8 +869,11 @@ class OpTest(unittest.TestCase):
if self.attrs[attrs_name] is not None:
attrs_outputs[attrs_name] = self.attrs[attrs_name]
kernel_sig = _get_kernel_signature(
eager_tensor_inputs, eager_tensor_outputs, attrs_outputs
kernel_sig = OpTestUtils._get_kernel_signature(
self.op_type,
eager_tensor_inputs,
eager_tensor_outputs,
attrs_outputs,
)
if not kernel_sig:
return None
......@@ -1011,7 +881,7 @@ class OpTest(unittest.TestCase):
"Detect there is KernelSignature for `%s` op, please set the `self.python_api` if you set check_eager = True"
% self.op_type
)
args = prepare_python_api_arguments(
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, kernel_sig
)
""" we directly return the cal_python_api value because the value is already tensor.
......@@ -1060,6 +930,7 @@ class OpTest(unittest.TestCase):
enable_inplace=None,
for_inplace_test=None,
):
with paddle.fluid.framework._dygraph_guard(None):
program = Program()
block = program.global_block()
op = self._append_ops(block)
......@@ -1082,7 +953,9 @@ class OpTest(unittest.TestCase):
use_cuda = False
if isinstance(place, fluid.CUDAPlace):
use_cuda = True
compiled_prog = fluid.CompiledProgram(program).with_data_parallel(
compiled_prog = fluid.CompiledProgram(
program
).with_data_parallel(
loss_name=loss.name if loss else None, places=place
)
program = compiled_prog
......@@ -1107,14 +980,19 @@ class OpTest(unittest.TestCase):
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = enable_inplace
compiled_prog = fluid.CompiledProgram(program).with_data_parallel(
compiled_prog = fluid.CompiledProgram(
program
).with_data_parallel(
build_strategy=build_strategy, places=place
)
program = compiled_prog
executor = Executor(place)
outs = executor.run(
program, feed=feed_map, fetch_list=fetch_list, return_numpy=False
program,
feed=feed_map,
fetch_list=fetch_list,
return_numpy=False,
)
self.op = op
self.program = original_program
......@@ -1381,6 +1259,7 @@ class OpTest(unittest.TestCase):
Returns:
res (tuple(outs, fetch_list, feed_map, program, op_desc)): The results of given grad_op_desc.
"""
with paddle.fluid.framework._dygraph_guard(None):
(
fwd_outs,
fwd_fetch_list,
......@@ -1530,8 +1409,18 @@ class OpTest(unittest.TestCase):
check_dygraph=True,
inplace_atol=None,
check_eager=False,
check_prim=False,
):
core._set_prim_all_enabled(False)
if check_prim:
prim_checker = PrimForwardChecker(self, place)
prim_checker.check()
# Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True)
self.__class__.op_type = self.op_type
if prim_checker.is_only_check_prim():
self.only_prim = True
return
# disable legacy dygraph check when check_eager is True
if check_eager:
check_dygraph = False
......@@ -1990,6 +1879,7 @@ class OpTest(unittest.TestCase):
check_dygraph=True,
inplace_atol=None,
check_eager=False,
check_prim=False,
):
# disable legacy dygraph check when check_eager is True
......@@ -2013,7 +1903,10 @@ class OpTest(unittest.TestCase):
check_dygraph,
inplace_atol,
check_eager=check_eager,
check_prim=check_prim,
)
if hasattr(self, 'only_prim') and self.only_prim:
continue
if check_eager:
assert not check_dygraph
outs, eager_dygraph_outs, fetch_list = res
......@@ -2131,8 +2024,8 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None,
check_dygraph=True,
check_eager=False,
check_prim=False,
):
# disable legacy dygraph check when check_eager is True
if check_eager:
check_dygraph = False
......@@ -2152,6 +2045,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs,
check_dygraph,
check_eager=check_eager,
check_prim=check_prim,
)
def check_grad_with_place(
......@@ -2168,8 +2062,25 @@ class OpTest(unittest.TestCase):
check_dygraph=True,
numeric_place=None,
check_eager=False,
check_prim=False,
):
core._set_prim_all_enabled(False)
if check_prim:
prim_grad_checker = PrimGradChecker(
self,
place,
inputs_to_check,
output_names,
no_grad_set,
user_defined_grad_outputs,
)
prim_grad_checker.check()
# Support operators which not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32
setattr(self.__class__, 'check_prim', True)
self._check_grad_helper()
if prim_grad_checker.is_only_check_prim():
self.only_prim = True
return
# disable legacy dygraph check when check_eager is True
if check_eager:
check_dygraph = False
......@@ -2561,6 +2472,7 @@ class OpTest(unittest.TestCase):
user_defined_grad_outputs=None,
parallel=False,
):
with paddle.fluid.framework._dygraph_guard(None):
prog = Program()
scope = core.Scope()
block = prog.global_block()
......@@ -2617,7 +2529,9 @@ class OpTest(unittest.TestCase):
targets = [
outputs[name] for name in outputs if name in output_names
]
inputs = [inputs[name] for name in input_to_check if name in inputs]
inputs = [
inputs[name] for name in input_to_check if name in inputs
]
grad_inputs = paddle.static.gradients(
targets, inputs, grad_outputs, no_grad_set
)
......@@ -2632,14 +2546,19 @@ class OpTest(unittest.TestCase):
)
prog = compiled_prog
executor = fluid.Executor(place)
return list(
res = list(
map(
np.array,
executor.run(
prog, feed_dict, fetch_list, scope=scope, return_numpy=False
prog,
feed_dict,
fetch_list,
scope=scope,
return_numpy=False,
),
)
)
return res
class OpTestTool:
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import struct
from collections import defaultdict
import config
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import _dygraph_tracer, in_dygraph_mode
from paddle.fluid.layers.utils import map_structure
from paddle.jit.dy2static.utils import parse_arg_and_kwargs
def flatten(nest_list):
out = []
for i in nest_list:
if isinstance(i, list or tuple):
tmp_list = flatten(i)
for j in tmp_list:
out.append(j)
else:
out.append(i)
return out
def _as_list(x):
if x is None:
return []
return list(x) if isinstance(x, list or tuple) else [x]
def convert_uint16_to_float(in_list):
in_list = np.asarray(in_list)
out = np.vectorize(
lambda x: struct.unpack(
'<f', struct.pack('<I', np.uint32(x) << np.uint32(16))
)[0],
otypes=[np.float32],
)(in_list.flat)
return np.reshape(out, in_list.shape)
# TODO(wanghao107): OpTestUtils will be moved to op_test.py
class OpTestUtils:
@classmethod
def _get_kernel_signature(
cls, op_type, eager_tensor_inputs, eager_tensor_outputs, attrs_outputs
):
try:
kernel_sig = _dygraph_tracer()._get_kernel_signature(
op_type,
eager_tensor_inputs,
eager_tensor_outputs,
attrs_outputs,
)
except RuntimeError as re:
"""we think the kernel_sig is missing."""
kernel_sig = None
print(
"[Warning: op_test.py] Kernel Signature is not found for %s, fall back to intermediate state."
% op_type
)
return kernel_sig
@classmethod
def prepare_python_api_arguments(
cls, api, op_proto_ins, op_proto_attrs, kernel_sig
):
"""map from `op proto inputs and attrs` to `api input list and api attrs dict`
NOTE: the op_proto_attrs and op_proto_ins is a default dict. default value is []
"""
class Empty:
pass
def is_empty(a):
return isinstance(a, Empty)
def get_default(idx, defaults):
assert not isinstance(defaults[idx], Empty), (
"%d-th params of python api don't have default value." % idx
)
return defaults[idx]
def to_defaults_list(params, defaults):
return [defaults[p] for p in params if p in defaults]
def parse_attri_value(name, op_inputs, op_attrs):
"""parse true value from inputs and attrs, if there is no name passed by OpTest, return Empty
1. if the name in op_attrs, use the op_attrs[name]
2. if the name in op_inputs, convert the op_inputs to [type of default value]
3. if the name not in op_attrs ans op_inputs, return Empty. (this will use the default value from python api)
"""
if name in op_proto_attrs:
return op_proto_attrs[name]
elif name in op_inputs:
if len(op_inputs[name]) == 1:
# why don't use numpy().item() : if the Tensor is float64, we will change it to python.float32, where we loss accuracy: [allclose_op]
# why we reconstruct a tensor: because we want the tensor in cpu.
if in_dygraph_mode():
return paddle.to_tensor(
op_inputs[name][0].numpy(), place='cpu'
)
else:
return op_inputs[name][0]
else:
# if this is a list (test_unsqueeze2_op): we just pass it into the python api.
return op_inputs[name]
else:
return Empty()
# NOTE(xiongkun): the logic of constructing parameters:
# for example:
# python api: cumprod(x, dim, dtype=None, name=None)
# kernel sig: [["x"], ["dim"], ["out"]]"
#
# we will construct a lot of list with the same length : len == len(api_params), here is 4
# api_params = ["x", "dim", "dtype", "name"]
# api_defaults = [Empty, Empty, None, None]; empty means no defaults.
# inputs_and_attrs = ["x", "dim"] , the length may shorter or longer than api_params
# input_arguments = [RealValue in self.inputs and self.attrs]
# then ,we will loop for the api_params, construct a result list:
# if the name in ['name', 'dtype', 'out', 'output'], we will use the default value
# else, we will consume a input_arguments. (because the name is not corresponding, so we only use the order)
api_params, api_defaults = parse_arg_and_kwargs(api)
api_defaults = to_defaults_list(api_params, api_defaults)
api_defaults = [
Empty() for i in range(len(api_params) - len(api_defaults))
] + api_defaults
assert len(api_defaults) == len(
api_params
), "Error happens. contack xiongkun03 to solve."
inputs_sig, attrs_sig, outputs_sig = kernel_sig
inputs_and_attrs = inputs_sig + attrs_sig
input_arguments = [
op_proto_ins.get(name, Empty()) for name in inputs_sig
] + [
parse_attri_value(name, op_proto_ins, op_proto_attrs)
for name in attrs_sig
]
results = []
api_ignore_param_list = set(['name', 'dtype', 'out', 'output'])
idx_of_op_proto_arguments = 0
for idx, arg_name in enumerate(api_params):
if arg_name in api_ignore_param_list:
results.append(get_default(idx, api_defaults))
else:
if idx_of_op_proto_arguments < len(input_arguments):
tmp = input_arguments[idx_of_op_proto_arguments]
idx_of_op_proto_arguments += 1
else:
tmp = Empty() # use the default value
if isinstance(tmp, Empty):
results.append(get_default(idx, api_defaults))
else:
results.append(tmp)
assert len(results) == len(api_params)
return results
@classmethod
def assumption_assert_and_transform(cls, args, inp_num):
"""
transform inputs by the following rules:
1. [Tensor] -> Tensor
2. [Tensor, Tensor, ...] -> list of Tensors
3. None -> None
4. Others: raise Error
only support "X" is list of Tensor, currently don't support other structure like dict.
"""
inp_args = [
[inp] if inp is None else inp for inp in args[:inp_num]
] # convert None -> [None]
for inp in inp_args:
assert isinstance(
inp, list
), "currently only support `X` is [Tensor], don't support other structure."
args = [inp[0] if len(inp) == 1 else inp for inp in inp_args] + args[
inp_num:
]
return args
@classmethod
def is_bfloat16_type(cls, np_type):
if np_type == np.dtype('uint16'):
return True
return False
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)
class PrimNet(paddle.nn.Layer):
def __init__(self, python_api):
super(PrimNet, self).__init__()
self.python_api = python_api
def forward(self, args):
out = self.python_api(*args)
return out
class PrimForwardChecker:
def __init__(self, op_test, place):
self.checker_name = "PrimForwardChecker"
self.place = place
self.op_test = op_test
self.save_eager_or_static_status()
self.init()
self.init_checker()
def init(self):
pass
def save_eager_or_static_status(self):
self.eager_mode = True if in_dygraph_mode() else False
def recover_eager_or_static_status(self):
if self.eager_mode:
paddle.disable_static()
else:
paddle.enable_static()
def init_checker(self):
assert hasattr(
self.op_test, 'prim_op_type'
), "if you want to test comp op, please set prim_op_type in setUp function."
assert self.op_test.prim_op_type in [
"comp",
"prim",
], "prim_op_type must be comp or prim in setUp function."
assert hasattr(
self.op_test, 'dtype'
), "Please set dtype in setUp function."
self.op_type = self.op_test.op_type
self.prim_op_type = self.op_test.prim_op_type
self.python_api = self.op_test.python_api
self.dtype = np.dtype(self.op_test.dtype)
self.inputs = self.op_test.inputs
self.attrs = (
self.op_test.attrs if hasattr(self.op_test, 'attrs') else {}
)
self.outputs = self.op_test.outputs
self.init_checker_threshold()
self.enable_fw_comp = (
self.op_test.enable_fw_comp
if hasattr(self.op_test, 'enable_fw_comp')
else True
)
self.enable_rev_comp = (
self.op_test.enable_rev_comp
if hasattr(self.op_test, 'enable_rev_comp')
else True
)
self.enable_cinn = (
self.op_test.enable_cinn
if hasattr(self.op_test, 'enable_cinn')
else True
)
self.enable_check_eager_comp = (
self.op_test.enable_check_eager_comp
if hasattr(self.op_test, 'enable_check_eager_comp')
else True
)
self.enable_check_static_comp = (
self.op_test.enable_check_static_comp
if hasattr(self.op_test, 'enable_check_static_comp')
else True
)
self.enable_check_jit_comp = (
self.op_test.enable_check_jit_comp
if hasattr(self.op_test, 'enable_check_jit_comp')
else True
)
self.enable_check_jit_comp_with_cinn = (
self.op_test.enable_check_jit_comp_with_cinn
if hasattr(self.op_test, 'enable_check_jit_comp_with_cinn')
else True
)
self.only_prim = (
self.op_test.only_prim
if hasattr(self.op_test, 'only_prim')
else False
)
self.kernel_sig = self.get_kernel_sig()
def init_checker_threshold(self):
if hasattr(self.op_test, 'jit_comp_rtol'):
self.jit_comp_rtol = self.op_test.jit_comp_rtol
else:
self.jit_comp_rtol = (
config.TOLERANCE[self.dtype]['jit_comp']['rtol']
if self.dtype in config.TOLERANCE
else 0
)
if hasattr(self.op_test, 'jit_comp_atol'):
self.jit_comp_atol = self.op_test.jit_comp_atol
else:
self.jit_comp_atol = (
config.TOLERANCE[self.dtype]['jit_comp']['atol']
if self.dtype in config.TOLERANCE
else 0
)
if hasattr(self.op_test, 'fw_comp_rtol'):
self.fw_comp_rtol = self.op_test.fw_comp_rtol
else:
self.fw_comp_rtol = (
config.TOLERANCE[self.dtype]['fw_comp']['rtol']
if self.dtype in config.TOLERANCE
else 0
)
if hasattr(self.op_test, 'fw_comp_atol'):
self.fw_comp_atol = self.op_test.fw_comp_atol
else:
self.fw_comp_atol = (
config.TOLERANCE[self.dtype]['fw_comp']['atol']
if self.dtype in config.TOLERANCE
else 0
)
if hasattr(self.op_test, 'rev_comp_rtol'):
self.rev_comp_rtol = self.op_test.rev_comp_rtol
else:
self.rev_comp_rtol = (
config.TOLERANCE[self.dtype]['rev_comp']['rtol']
if self.dtype in config.TOLERANCE
else 0
)
if hasattr(self.op_test, 'rev_comp_atol'):
self.rev_comp_atol = self.op_test.rev_comp_atol
else:
self.rev_comp_atol = (
config.TOLERANCE[self.dtype]['rev_comp']['atol']
if self.dtype in config.TOLERANCE
else 0
)
if hasattr(self.op_test, 'cinn_rtol'):
self.cinn_rtol = self.op_test.cinn_rtol
else:
self.cinn_rtol = (
config.TOLERANCE[self.dtype]['cinn']['rtol']
if self.dtype in config.TOLERANCE
else 0
)
if hasattr(self.op_test, 'cinn_atol'):
self.cinn_atol = self.op_test.cinn_atol
else:
self.cinn_atol = (
config.TOLERANCE[self.dtype]['cinn']['atol']
if self.dtype in config.TOLERANCE
else 0
)
def check(self):
self.eager_desire = self.get_eager_desire()
if self.enable_check_static_comp:
self.check_static_comp()
if self.enable_check_jit_comp:
self.check_jit_comp()
if self.enable_check_jit_comp_with_cinn:
self.check_jit_comp_with_cinn()
self.recover_eager_or_static_status()
def get_kernel_sig(self):
paddle.disable_static()
if type(self.place) is paddle.fluid.libpaddle.CPUPlace:
paddle.device.set_device("cpu")
if type(self.place) is paddle.fluid.libpaddle.CUDAPlace:
paddle.device.set_device("gpu:0")
(
eager_tensor_inputs,
attrs_outputs,
_,
) = self.get_eager_input_attr_and_inputdict()
eager_tensor_outputs = self.get_eager_empty_output()
kernel_sig = OpTestUtils._get_kernel_signature(
self.op_type,
eager_tensor_inputs,
eager_tensor_outputs,
attrs_outputs,
)
return kernel_sig
def is_only_check_prim(self):
return self.only_prim
def get_eager_desire(self):
paddle.disable_static()
if type(self.place) is paddle.fluid.libpaddle.CPUPlace:
paddle.device.set_device("cpu")
if type(self.place) is paddle.fluid.libpaddle.CUDAPlace:
paddle.device.set_device("gpu:0")
(
eager_tensor_inputs,
attrs_outputs,
_,
) = self.get_eager_input_attr_and_inputdict()
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
)
inputs_sig, _, _ = self.kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret = flatten(_as_list(self.python_api(*args)))
ret = map_structure(lambda x: x.numpy(), ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
ret = map_structure(lambda x: convert_uint16_to_float(x), ret)
return ret
def get_eager_input_attr_and_inputdict(self):
attrs_outputs = {}
for attrs_name in self.attrs:
if self.attrs[attrs_name] is not None:
attrs_outputs[attrs_name] = self.attrs[attrs_name]
input_dict = {}
eager_inputs = defaultdict(list)
for name, item in self.inputs.items():
if isinstance(item, list):
for tup in item:
dtype = (
"bfloat16"
if OpTestUtils.is_bfloat16_type(tup[1].dtype)
else tup[1].dtype
)
x = paddle.to_tensor(
data=tup[1],
place=self.place,
stop_gradient=False,
dtype=dtype,
)
eager_inputs[name].append(x)
input_dict.update({str(tup[0]): x})
else:
dtype = (
"bfloat16"
if OpTestUtils.is_bfloat16_type(item.dtype)
else item.dtype
)
x = paddle.to_tensor(
data=item,
place=self.place,
stop_gradient=False,
dtype=dtype,
)
eager_inputs[name].append(x)
input_dict.update({name: x})
return eager_inputs, attrs_outputs, input_dict
def get_eager_empty_output(self):
eager_outputs = defaultdict(list)
for name, item in self.outputs.items():
if isinstance(item, list):
for tup in item:
dtype = (
"bfloat16"
if OpTestUtils.is_bfloat16_type(tup[1].dtype)
else tup[1].dtype
)
x = paddle.to_tensor(
data=[],
place=self.place,
stop_gradient=False,
dtype=dtype,
)
eager_outputs[name].append(x)
else:
dtype = (
"bfloat16"
if OpTestUtils.is_bfloat16_type(item.dtype)
else item.dtype
)
x = paddle.to_tensor(
data=[], place=self.place, stop_gradient=False, dtype=dtype
)
eager_outputs[name].append(x)
return eager_outputs
def get_static_input_attr_inputdict_and_feed(self):
attrs_outputs = {}
for attrs_name in self.attrs:
if self.attrs[attrs_name] is not None:
attrs_outputs[attrs_name] = self.attrs[attrs_name]
input_dict = {}
static_inputs = defaultdict(list)
feed = {}
for name, item in self.inputs.items():
if isinstance(item, list):
for tup in item:
dtype = (
"bfloat16"
if OpTestUtils.is_bfloat16_type(tup[1].dtype)
else tup[1].dtype
)
x = paddle.static.data(
name=str(tup[0]), shape=tup[1].shape, dtype=dtype
)
x.stop_gradient = False
static_inputs[name].append(x)
feed.update({str(tup[0]): tup[1]})
input_dict.update({str(tup[0]): x})
else:
dtype = (
"bfloat16"
if OpTestUtils.is_bfloat16_type(item.dtype)
else item.dtype
)
x = paddle.static.data(name=name, shape=item.shape, dtype=dtype)
x.stop_gradient = False
static_inputs[name].append(x)
feed.update({name: item})
input_dict.update({name: x})
return static_inputs, attrs_outputs, input_dict, feed
def check_eager_comp(self):
pass
def check_static_comp(self):
# forward comp only for comp op
if self.prim_op_type == "prim":
return
paddle.enable_static()
core._set_prim_forward_enabled(self.enable_fw_comp)
startup_program, main_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, startup_program):
(
static_inputs,
attrs,
input_dict,
feed,
) = self.get_static_input_attr_inputdict_and_feed()
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, static_inputs, attrs, self.kernel_sig
)
inputs_sig, _, _ = self.kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret = flatten(_as_list(self.python_api(*args)))
paddle.incubate.autograd.to_prim(main_program.blocks)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
ret = exe.run(main_program, feed=feed, fetch_list=ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
ret = map_structure(lambda x: convert_uint16_to_float(x), ret)
# check static forward
if len(ret) != len(self.eager_desire):
msg = (
"The static comp forward api out tensor nums is different with eager forward api out tensor nums on %s."
'when enable_fw_comp is %s, static comp forward api out tensor nums = %s, eager forward api out tensor nums = %s. \n'
% (
str(self.place),
self.enable_fw_comp,
len(ret),
len(self.eager_desire),
)
)
raise RuntimeError(msg)
for i in range(len(ret)):
if not np.allclose(
ret[i],
self.eager_desire[i],
rtol=self.fw_comp_rtol,
atol=self.fw_comp_atol,
):
msg = (
'Check static comp forward api out failed. Mismatch between static comp '
'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n'
'static comp forward api out tensor:%s\n eager forward api out tensor:%s\n'
% (
str(self.place),
self.enable_fw_comp,
i,
ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
paddle.disable_static()
core._set_prim_forward_enabled(False)
def check_jit_comp(self):
if self.prim_op_type == "prim":
return
paddle.disable_static()
if type(self.place) == paddle.fluid.libpaddle.CPUPlace:
paddle.device.set_device("cpu")
if type(self.place) == paddle.fluid.libpaddle.CUDAPlace:
paddle.device.set_device("gpu:0")
atol = self.fw_comp_atol if self.enable_fw_comp else self.jit_comp_atol
rtol = self.fw_comp_rtol if self.enable_fw_comp else self.jit_comp_rtol
core._set_prim_forward_enabled(self.enable_fw_comp)
(
eager_tensor_inputs,
attrs_outputs,
_,
) = self.get_eager_input_attr_and_inputdict()
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
)
inputs_sig, _, _ = self.kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
net = PrimNet(self.python_api)
net = apply_to_static(net, False)
ret = flatten(_as_list(net(args)))
ret = map_structure(lambda x: x.numpy(), ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
ret = map_structure(lambda x: convert_uint16_to_float(x), ret)
# check jit comp forward
if len(ret) != len(self.eager_desire):
msg = (
"The jit comp forward api out tensor nums is different with eager forward api out tensor nums on %s."
'when enable_fw_comp is %s, jit comp forward api out tensor nums = %s, eager forward api out tensor nums = %s. \n'
% (
str(self.place),
self.enable_fw_comp,
len(ret),
len(self.eager_desire),
)
)
raise RuntimeError(msg)
for i in range(len(ret)):
if not np.allclose(
ret[i], self.eager_desire[i], rtol=rtol, atol=atol
):
msg = (
'Check jit comp forward api out failed. Mismatch between jit comp '
'and eager on %s, when enable_fw_comp is %s,the forward api out tensor\'s index is : %d \n'
'jit comp forward api out tensor:%s\n eager forward api out tensor:%s\n'
% (
str(self.place),
self.enable_fw_comp,
i,
ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
core._set_prim_forward_enabled(False)
net.forward.program_cache.clear()
def check_jit_comp_with_cinn(self):
if self.prim_op_type == "prim":
return
# cinn doesn't suppoort cpu place
if (
type(self.place) == paddle.fluid.libpaddle.CPUPlace
and self.enable_cinn
and core.is_compiled_with_cinn()
):
return
paddle.disable_static()
atol = (
self.cinn_atol
if self.enable_cinn and core.is_compiled_with_cinn()
else self.fw_comp_atol
)
rtol = (
self.cinn_rtol
if self.enable_cinn and core.is_compiled_with_cinn()
else self.fw_comp_rtol
)
core._set_prim_forward_enabled(self.enable_fw_comp)
if type(self.place) is paddle.fluid.libpaddle.CPUPlace:
paddle.device.set_device("cpu")
if type(self.place) is paddle.fluid.libpaddle.CUDAPlace:
paddle.device.set_device("gpu:0")
(
eager_tensor_inputs,
attrs_outputs,
_,
) = self.get_eager_input_attr_and_inputdict()
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
)
inputs_sig, _, _ = self.kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
net = PrimNet(self.python_api)
net = apply_to_static(
net, core.is_compiled_with_cinn() and self.enable_cinn
)
ret = flatten(_as_list(net(args)))
ret = map_structure(lambda x: x.numpy(), ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
ret = map_structure(lambda x: convert_uint16_to_float(x), ret)
# check jit comp forward
if len(ret) != len(self.eager_desire):
msg = (
"The jit comp with cinn forward api out tensor nums is different with eager forward api out tensor nums on %s."
'when enable_fw_comp is %s, enable_cinn is %s, jit comp forward api out tensor nums = %s, eager forward api out tensor nums = %s. \n'
% (
str(self.place),
self.enable_fw_comp,
core.is_compiled_with_cinn() and self.enable_cinn,
len(ret),
len(self.eager_desire),
)
)
raise RuntimeError(msg)
for i in range(len(ret)):
if not np.allclose(
ret[i], self.eager_desire[i], rtol=rtol, atol=atol
):
msg = (
'Check jit comp with cinn forward api out failed. Mismatch between jit comp and eager on %s, '
'when enable_fw_comp is %s, enable_cinn is %s, the forward api out tensor\'s index is : %d \n'
'jit comp forward api out tensor:%s\n eager forward api out tensor:%s\n'
% (
str(self.place),
self.enable_fw_comp,
core.is_compiled_with_cinn() and self.enable_cinn,
i,
ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
core._set_prim_forward_enabled(False)
net.forward.program_cache.clear()
class PrimGradChecker(PrimForwardChecker):
def __init__(
self,
op_test,
place,
inputs_to_check,
output_names,
no_grad_set,
grad_outputs,
):
PrimForwardChecker.__init__(self, op_test, place)
self.inputs_to_check = inputs_to_check
self.output_names = output_names
self.no_grad_set = no_grad_set
self.grad_outputs = grad_outputs
def init(self):
self.checker_name = "PrimGradChecker"
def check(self):
self.eager_desire = self.get_eager_desire()
if self.enable_check_eager_comp:
self.check_eager_comp()
if self.enable_check_static_comp:
self.check_static_comp()
if self.enable_check_jit_comp:
self.check_jit_comp()
if self.enable_check_jit_comp_with_cinn:
self.check_jit_comp_with_cinn()
self.recover_eager_or_static_status()
def get_output_dict(self, np_outputs, api_outputs, outputs_sig):
assert len(api_outputs) == len(outputs_sig), (
"forward api outputs length must be the same as KernelSignature outputs,but recive %s and %s"
) % (len(api_outputs), len(outputs_sig))
output_dict = {}
for i, output_name in enumerate(outputs_sig):
if isinstance(np_outputs[output_name], list):
for j, tup in enumerate(np_outputs[output_name]):
output_dict.update({tup[0]: api_outputs[i][j]})
else:
output_dict.update({output_name: api_outputs[i]})
return output_dict
def gen_eager_grad_outputs(self):
if self.grad_outputs is None:
return None
eager_vs = []
for np_v in self.grad_outputs:
eager_vs.append(
paddle.to_tensor(
data=np_v,
place=self.place,
dtype="bfloat16"
if OpTestUtils.is_bfloat16_type(np_v.dtype)
else np_v.dtype,
)
)
return eager_vs
def gen_static_grad_outputs_and_feed(self):
if self.grad_outputs is None:
return None, {}
static_vs = []
feed = {}
for i, np_v in enumerate(self.grad_outputs):
static_vs.append(
paddle.static.data(
name='v_' + str(i),
shape=np_v.shape,
dtype="bfloat16"
if OpTestUtils.is_bfloat16_type(np_v.dtype)
else np_v.dtype,
)
)
feed.update({'v_' + str(i): np_v})
return static_vs, feed
def gen_no_grad_set(self, var_dict):
if self.no_grad_set is None:
return None
no_grad_set = set()
for name in self.no_grad_set:
if name in var_dict:
no_grad_set.add(var_dict[name])
return no_grad_set
def get_eager_desire(self):
paddle.disable_static()
if type(self.place) is paddle.fluid.libpaddle.CPUPlace:
paddle.device.set_device("cpu")
if type(self.place) is paddle.fluid.libpaddle.CUDAPlace:
paddle.device.set_device("gpu:0")
(
eager_tensor_inputs,
attrs_outputs,
inputs_dict,
) = self.get_eager_input_attr_and_inputdict()
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
)
inputs_sig, _, outputs_sig = self.kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
ret = _as_list(self.python_api(*args))
outputs_dict = self.get_output_dict(self.outputs, ret, outputs_sig)
ys = []
if isinstance(self.output_names, list):
for output_name in self.output_names:
ys.append(outputs_dict[output_name])
else:
ys.append(outputs_dict[self.output_names])
xs = []
if isinstance(self.inputs_to_check, list):
for input_name in self.inputs_to_check:
xs.append(inputs_dict[input_name])
else:
xs.append(inputs_dict[self.inputs_to_check])
vs = self.gen_eager_grad_outputs()
no_grad_vars = self.gen_no_grad_set(
var_dict={**inputs_dict, **outputs_dict}
)
ret = paddle.grad(
ys, xs, vs, allow_unused=True, no_grad_vars=no_grad_vars
)
ret = map_structure(lambda x: x.numpy(), ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
ret = map_structure(lambda x: convert_uint16_to_float(x), ret)
return ret
def check_eager_comp(self):
if self.prim_op_type == "comp":
return
paddle.disable_static()
if type(self.place) is paddle.fluid.libpaddle.CPUPlace:
paddle.device.set_device("cpu")
if type(self.place) is paddle.fluid.libpaddle.CUDAPlace:
paddle.device.set_device("gpu:0")
atol = self.rev_comp_atol
rtol = self.rev_comp_rtol
core._set_prim_backward_enabled(self.enable_rev_comp)
actual_ret = self.get_eager_desire()
# check static forward
if len(actual_ret) != len(self.eager_desire):
msg = (
"The eager comp grad out tensor nums is different with eager grad out tensor nums on %s."
'when enable_rev_comp is %s, eager comp grad api out tensor nums = %s, eager grad out tensor nums = %s. \n'
% (
str(self.place),
self.enable_rev_comp,
len(actual_ret),
len(self.eager_desire),
)
)
raise RuntimeError(msg)
for i in range(len(actual_ret)):
if not np.allclose(
actual_ret[i],
self.eager_desire[i],
rtol=atol,
atol=rtol,
):
msg = (
'Check eager comp grad out failed. Mismatch between eager comp '
'and eager on %s, when enable_rev_comp is %s,the eager comp grad out tensor\'s index is : %d \n'
'eager comp grad out tensor:%s\n eager grad out tensor:%s\n'
% (
str(self.place),
self.enable_rev_comp,
i,
actual_ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
def check_static_comp(self):
paddle.enable_static()
if self.prim_op_type == "prim":
core._set_prim_backward_enabled(self.enable_rev_comp)
else:
core._set_prim_forward_enabled(self.enable_fw_comp)
core._set_prim_backward_enabled(self.enable_rev_comp)
atol = self.rev_comp_atol if self.enable_rev_comp else self.fw_comp_atol
rtol = self.rev_comp_rtol if self.enable_rev_comp else self.fw_comp_rtol
startup_program, main_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, startup_program):
(
static_inputs,
attrs,
inputs_dict,
feed,
) = self.get_static_input_attr_inputdict_and_feed()
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, static_inputs, attrs, self.kernel_sig
)
inputs_sig, _, outputs_sig = self.kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
fw_outs = _as_list(self.python_api(*args))
outputs_dict = self.get_output_dict(
self.outputs, fw_outs, outputs_sig
)
paddle.incubate.autograd.to_prim(main_program.blocks)
ys = []
if isinstance(self.output_names, list):
for output_name in self.output_names:
ys.append(outputs_dict[output_name])
else:
ys.append(outputs_dict[self.output_names])
xs = []
if isinstance(self.inputs_to_check, list):
for input_name in self.inputs_to_check:
xs.append(inputs_dict[input_name])
else:
xs.append(inputs_dict[self.inputs_to_check])
vs, vs_feed = self.gen_static_grad_outputs_and_feed()
feed.update(vs_feed)
no_grad_vars = self.gen_no_grad_set(
var_dict={**inputs_dict, **outputs_dict}
)
ret = paddle.static.gradients(ys, xs, vs, no_grad_set=no_grad_vars)
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
actual_ret = exe.run(main_program, feed=feed, fetch_list=ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
actual_ret = map_structure(
lambda x: convert_uint16_to_float(x), actual_ret
)
# check static grad out
if len(actual_ret) != len(self.eager_desire):
msg = (
"The static comp grad out tensor nums is different with eager grad out tensor nums on %s."
'when enable_fw_comp is %s,enable_rev_comp is %s, static comp grad out tensor nums = %s, eager grad out tensor nums = %s. \n'
% (
str(self.place),
self.enable_fw_comp,
self.enable_rev_comp,
len(actual_ret),
len(self.eager_desire),
)
)
raise RuntimeError(msg)
for i in range(len(actual_ret)):
if not np.allclose(
actual_ret[i], self.eager_desire[i], rtol=rtol, atol=atol
):
msg = (
'Check static comp grad out failed. Mismatch between static comp '
'and eager on %s, when enable_fw_comp is %s,enable_rev_comp is %s,the forward api out tensor\'s index is : %d \n'
'static comp grad out tensor:%s\n eager grad out tensor:%s\n'
% (
str(self.place),
self.enable_fw_comp,
self.enable_rev_comp,
i,
actual_ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(False)
paddle.disable_static()
def check_jit_comp(self):
paddle.disable_static()
if type(self.place) is paddle.fluid.libpaddle.CPUPlace:
paddle.device.set_device("cpu")
if type(self.place) is paddle.fluid.libpaddle.CUDAPlace:
paddle.device.set_device("gpu:0")
if self.prim_op_type == "prim":
core._set_prim_backward_enabled(self.enable_rev_comp)
else:
core._set_prim_forward_enabled(self.enable_fw_comp)
core._set_prim_backward_enabled(self.enable_rev_comp)
atol = (
self.fw_comp_atol
if self.enable_fw_comp and not self.enable_rev_comp
else self.jit_comp_atol
)
rtol = (
self.fw_comp_rtol
if self.enable_fw_comp and not self.enable_rev_comp
else self.jit_comp_rtol
)
atol = self.rev_comp_atol if self.enable_rev_comp else atol
rtol = self.rev_comp_rtol if self.enable_rev_comp else rtol
(
eager_tensor_inputs,
attrs_outputs,
inputs_dict,
) = self.get_eager_input_attr_and_inputdict()
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
)
inputs_sig, _, outputs_sig = self.kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
net = PrimNet(self.python_api)
net = apply_to_static(net, False)
out = _as_list(net(args))
outputs_dict = self.get_output_dict(self.outputs, out, outputs_sig)
ys = []
if isinstance(self.output_names, list):
for output_name in self.output_names:
ys.append(outputs_dict[output_name])
else:
ys.append(outputs_dict[self.output_names])
xs = []
if isinstance(self.inputs_to_check, list):
for input_name in self.inputs_to_check:
xs.append(inputs_dict[input_name])
else:
xs.append(inputs_dict[self.inputs_to_check])
vs = self.gen_eager_grad_outputs()
no_grad_vars = self.gen_no_grad_set(
var_dict={**inputs_dict, **outputs_dict}
)
ret = paddle.grad(
ys, xs, vs, allow_unused=True, no_grad_vars=no_grad_vars
)
ret = map_structure(lambda x: x.numpy(), ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
ret = map_structure(lambda x: convert_uint16_to_float(x), ret)
# check jit comp grad out
if len(ret) != len(self.eager_desire):
msg = (
"The jit comp grad out tensor nums is different with eager grad out tensor nums on %s."
'when enable_fw_comp is %s, enable_rev_comp is %s, jit comp grad out tensor nums = %s, eager grad out tensor nums = %s. \n'
% (
str(self.place),
self.enable_fw_comp,
self.enable_rev_comp,
len(ret),
len(self.eager_desire),
)
)
raise RuntimeError(msg)
for i in range(len(ret)):
if not np.allclose(
ret[i], self.eager_desire[i], rtol=rtol, atol=atol
):
msg = (
'Check jit comp grad out failed. Mismatch between jit comp '
'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s,the grad out tensor\'s index is : %d \n'
'jit comp grad out tensor:%s\n eager grad out out tensor:%s\n'
% (
str(self.place),
self.enable_fw_comp,
self.enable_rev_comp,
i,
ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(False)
net.forward.program_cache.clear()
def check_jit_comp_with_cinn(self):
# cinn doesen't support cpu place
if (
type(self.place) is paddle.fluid.libpaddle.CPUPlace
and self.enable_cinn
and core.is_compiled_with_cinn()
):
return
paddle.disable_static()
if type(self.place) is paddle.fluid.libpaddle.CPUPlace:
paddle.device.set_device("cpu")
if type(self.place) is paddle.fluid.libpaddle.CUDAPlace:
paddle.device.set_device("gpu:0")
if self.prim_op_type == "prim":
core._set_prim_backward_enabled(self.enable_rev_comp)
else:
core._set_prim_forward_enabled(self.enable_fw_comp)
core._set_prim_backward_enabled(self.enable_rev_comp)
if self.enable_cinn and core.is_compiled_with_cinn():
atol = self.cinn_atol
rtol = self.cinn_rtol
else:
atol = (
self.fw_comp_atol
if self.enable_fw_comp and not self.enable_rev_comp
else self.jit_comp_atol
)
rtol = (
self.fw_comp_rtol
if self.enable_fw_comp and not self.enable_rev_comp
else self.jit_comp_rtol
)
atol = self.rev_comp_atol if self.enable_rev_comp else atol
rtol = self.rev_comp_rtol if self.enable_rev_comp else rtol
(
eager_tensor_inputs,
attrs_outputs,
inputs_dict,
) = self.get_eager_input_attr_and_inputdict()
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
)
inputs_sig, _, outputs_sig = self.kernel_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
net = PrimNet(self.python_api)
net = apply_to_static(
net, core.is_compiled_with_cinn() and self.enable_cinn
)
out = _as_list(net(args))
outputs_dict = self.get_output_dict(self.outputs, out, outputs_sig)
ys = []
if isinstance(self.output_names, list):
for output_name in self.output_names:
ys.append(outputs_dict[output_name])
else:
ys.append(outputs_dict[self.output_names])
xs = []
if isinstance(self.inputs_to_check, list):
for input_name in self.inputs_to_check:
xs.append(inputs_dict[input_name])
else:
xs.append(inputs_dict[self.inputs_to_check])
vs = self.gen_eager_grad_outputs()
no_grad_vars = self.gen_no_grad_set(
var_dict={**inputs_dict, **outputs_dict}
)
ret = paddle.grad(
ys, xs, vs, allow_unused=True, no_grad_vars=no_grad_vars
)
ret = map_structure(lambda x: x.numpy(), ret)
if OpTestUtils.is_bfloat16_type(self.dtype):
ret = map_structure(lambda x: convert_uint16_to_float(x), ret)
# check jit comp grad out
if len(ret) != len(self.eager_desire):
msg = (
"The jit comp with cinn grad out tensor nums is different with eager grad out tensor nums on %s."
'when enable_fw_comp is %s, enable_rev_comp is %s, enable_cinn is %s, jit comp grad out tensor nums = %s, eager grad out tensor nums = %s. \n'
% (
str(self.place),
self.enable_fw_comp,
self.enable_rev_comp,
self.enable_cinn and core.is_compiled_with_cinn(),
len(ret),
len(self.eager_desire),
)
)
raise RuntimeError(msg)
for i in range(len(ret)):
if not np.allclose(
ret[i], self.eager_desire[i], rtol=rtol, atol=atol
):
msg = (
'Check jit comp with cinn grad out failed. Mismatch between jit comp with cinn '
'and eager on %s, when enable_fw_comp is %s, enable_rev_comp is %s, enable_cinn is %s,'
'the grad out tensor\'s index is : %d ,jit comp with cinn grad out tensor:%s\n eager grad out out tensor:%s\n'
% (
str(self.place),
self.enable_fw_comp,
self.enable_rev_comp,
self.enable_cinn and core.is_compiled_with_cinn(),
i,
ret[i],
self.eager_desire[i],
)
)
raise RuntimeError(msg)
core._set_prim_forward_enabled(False)
core._set_prim_backward_enabled(False)
net.forward.program_cache.clear()
......@@ -28,13 +28,14 @@ from paddle.fluid import Program, core, program_guard
class TestExpandV2OpRank1(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "prim"
self.init_data()
self.python_api = paddle.expand
self.inputs = {'X': np.random.random(self.ori_shape).astype("float64")}
self.attrs = {'shape': self.shape}
output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output}
self.enable_cinn = False
def init_data(self):
self.ori_shape = [100]
......@@ -42,10 +43,10 @@ class TestExpandV2OpRank1(OpTest):
self.expand_times = [1]
def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestExpandV2OpRank2_DimExpanding(TestExpandV2OpRank1):
......@@ -80,6 +81,7 @@ class TestExpandV2OpRank4(TestExpandV2OpRank1):
class TestExpandV2OpRank1_tensor_attr(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "prim"
self.python_api = paddle.expand
self.init_data()
expand_shapes_tensor = []
......@@ -103,10 +105,10 @@ class TestExpandV2OpRank1_tensor_attr(OpTest):
self.infer_expand_shape = [-1]
def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestExpandV2OpRank2_Corner_tensor_attr(TestExpandV2OpRank1_tensor_attr):
......@@ -121,6 +123,7 @@ class TestExpandV2OpRank2_Corner_tensor_attr(TestExpandV2OpRank1_tensor_attr):
class TestExpandV2OpRank1_tensor(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "prim"
self.python_api = paddle.expand
self.init_data()
......@@ -148,6 +151,7 @@ class TestExpandV2OpRank1_tensor(OpTest):
class TestExpandV2OpInteger(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "prim"
self.python_api = paddle.expand
self.inputs = {
'X': np.random.randint(10, size=(2, 4, 5)).astype("int32")
......@@ -164,6 +168,7 @@ class TestExpandV2OpInteger(OpTest):
class TestExpandV2OpBoolean(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "prim"
self.python_api = paddle.expand
self.inputs = {'X': np.random.randint(2, size=(2, 4, 5)).astype("bool")}
self.attrs = {'shape': [2, 4, 5]}
......@@ -178,6 +183,7 @@ class TestExpandV2OpBoolean(OpTest):
class TestExpandV2OpInt64_t(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "prim"
self.python_api = paddle.expand
self.inputs = {
'X': np.random.randint(10, size=(2, 4, 5)).astype("int64")
......
......@@ -76,10 +76,6 @@ class TestInputSpec(unittest.TestCase):
with self.assertRaises(TypeError):
tensor_spec = InputSpec(4, dtype='int8')
# 3. len(shape) should be greater than 0.
with self.assertRaises(ValueError):
tensor_spec = InputSpec([], dtype='int8')
def test_batch_and_unbatch(self):
tensor_spec = InputSpec([10])
# insert batch_size
......@@ -90,15 +86,11 @@ class TestInputSpec(unittest.TestCase):
unbatch_spec = batch_tensor_spec.unbatch()
self.assertEqual(unbatch_spec.shape, (10,))
# 1. `unbatch` requires len(shape) > 1
with self.assertRaises(ValueError):
unbatch_spec.unbatch()
# 2. `batch` requires len(batch_size) == 1
# 1. `batch` requires len(batch_size) == 1
with self.assertRaises(ValueError):
tensor_spec.batch([16, 12])
# 3. `batch` requires type(batch_size) == int
# 2. `batch` requires type(batch_size) == int
with self.assertRaises(TypeError):
tensor_spec.batch('16')
......
......@@ -28,36 +28,25 @@ class TestSumOp(OpTest):
def setUp(self):
self.python_api = paddle.sum
self.op_type = "reduce_sum"
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
self.attrs = {'dim': [0]}
# reduce doesn't support float64 in cinn
self.enable_cinn = False
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
class TestSumOp_ZeroDim(OpTest):
def setUp(self):
self.python_api = paddle.sum
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random([]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=None)}
self.attrs = {'dim': [], 'reduce_all': True}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
class TestSumOp_fp16(OpTest):
class TestSumOpFp32(OpTest):
def setUp(self):
self.python_api = paddle.sum
self.op_type = "reduce_sum"
self.prim_op_type = "prim"
self.inputs = {
'X': np.random.uniform(0, 0.1, (5, 6, 10)).astype("float16")
}
......@@ -66,6 +55,8 @@ class TestSumOp_fp16(OpTest):
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
self.gradient = self.calc_gradient()
# error occurred in cinn
self.enable_cinn = False
def test_check_output(self):
self.check_output(check_eager=True)
......@@ -77,10 +68,33 @@ class TestSumOp_fp16(OpTest):
def test_check_grad(self):
self.check_grad(
['X'], 'Out', user_defined_grads=self.gradient, check_eager=True
['X'],
'Out',
user_defined_grads=self.gradient,
check_eager=True,
check_prim=True,
)
class TestSumOp_ZeroDim(OpTest):
def setUp(self):
self.python_api = paddle.sum
self.op_type = "reduce_sum"
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random([]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=None)}
self.attrs = {'dim': [], 'reduce_all': True}
# reduce doesn't support float64 in cinn.
# 0-D tensor doesn't support in cinn
self.enable_cinn = False
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
......@@ -89,6 +103,7 @@ class TestSumOp_bf16(OpTest):
np.random.seed(100)
self.python_api = paddle.sum
self.op_type = "reduce_sum"
self.prim_op_type = "prim"
self.dtype = np.uint16
self.x = np.random.uniform(0, 0.1, (2, 5, 10)).astype(np.float32)
self.attrs = {'dim': [0, 1, 2]}
......@@ -98,6 +113,7 @@ class TestSumOp_bf16(OpTest):
self.inputs = {'X': convert_float_to_uint16(self.x)}
self.outputs = {'Out': convert_float_to_uint16(self.out)}
self.gradient = self.calc_gradient()
self.enable_cinn = False
def test_check_output(self):
place = core.CUDAPlace(0)
......@@ -111,6 +127,7 @@ class TestSumOp_bf16(OpTest):
'Out',
user_defined_grads=self.gradient,
check_eager=True,
check_prim=True,
)
def calc_gradient(self):
......@@ -123,6 +140,7 @@ class TestSumOp_fp16_withInt(OpTest):
def setUp(self):
self.python_api = paddle.sum
self.op_type = "reduce_sum"
self.prim_op_type = "prim"
self.inputs = {
# ref to https://en.wikipedia.org/wiki/Half-precision_floating-point_format
# Precision limitations on integer values between 0 and 2048 can be exactly represented
......@@ -133,6 +151,7 @@ class TestSumOp_fp16_withInt(OpTest):
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
self.gradient = self.calc_gradient()
self.enable_cinn = False
def test_check_output(self):
self.check_output(check_eager=True)
......@@ -144,7 +163,11 @@ class TestSumOp_fp16_withInt(OpTest):
def test_check_grad(self):
self.check_grad(
['X'], 'Out', user_defined_grads=self.gradient, check_eager=True
['X'],
'Out',
user_defined_grads=self.gradient,
check_eager=True,
check_prim=True,
)
......@@ -152,34 +175,40 @@ class TestSumOp5D(OpTest):
def setUp(self):
self.python_api = paddle.sum
self.op_type = "reduce_sum"
self.prim_op_type = "prim"
self.inputs = {
'X': np.random.random((1, 2, 5, 6, 10)).astype("float64")
}
self.attrs = {'dim': [0]}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
# error occurred in cinn
self.enable_cinn = False
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)
class TestSumOp6D(OpTest):
def setUp(self):
self.python_api = paddle.sum
self.op_type = "reduce_sum"
self.prim_op_type = "prim"
self.inputs = {
'X': np.random.random((1, 1, 2, 5, 6, 10)).astype("float64")
}
self.attrs = {'dim': [0]}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
# error occurred in cinn
self.enable_cinn = False
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)
class TestSumOp8D(OpTest):
......@@ -193,7 +222,7 @@ class TestSumOp8D(OpTest):
self.outputs = {'Out': self.inputs['X'].sum(axis=(0, 3))}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True)
......@@ -633,72 +662,100 @@ class TestAnyOpError(unittest.TestCase):
class Test1DReduce(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random(120).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
# reduce doesn't support float64 in cinn.
self.enable_cinn = False
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class Test2DReduce0(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.attrs = {'dim': [0]}
self.inputs = {'X': np.random.random((20, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
# reduce doesn't support float64 in cinn.
self.enable_cinn = False
class Test2DReduce1(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.attrs = {'dim': [1]}
self.inputs = {'X': np.random.random((20, 10)).astype("float64")}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
# reduce doesn't support float64 in cinn.
self.enable_cinn = False
class Test3DReduce0(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.attrs = {'dim': [1]}
self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
# reduce doesn't support float64 in cinn.
self.enable_cinn = False
class Test3DReduce1(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.attrs = {'dim': [2]}
self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
# reduce doesn't support float64 in cinn.
self.enable_cinn = False
class Test3DReduce2(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.attrs = {'dim': [-2]}
self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
# reduce doesn't support float64 in cinn.
self.enable_cinn = False
class Test3DReduce3(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.attrs = {'dim': [1, 2]}
self.inputs = {'X': np.random.random((5, 6, 7)).astype("float64")}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
# reduce doesn't support float64 in cinn.
self.enable_cinn = False
class Test8DReduce0(Test1DReduce):
......@@ -712,10 +769,18 @@ class Test8DReduce0(Test1DReduce):
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestKeepDimReduce(Test1DReduce):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.attrs = {'dim': [1], 'keep_dim': True}
self.outputs = {
......@@ -723,6 +788,8 @@ class TestKeepDimReduce(Test1DReduce):
axis=tuple(self.attrs['dim']), keepdims=self.attrs['keep_dim']
)
}
# reduce doesn't support float64 in cinn.
self.enable_cinn = False
class TestKeepDim8DReduce(Test1DReduce):
......@@ -738,6 +805,12 @@ class TestKeepDim8DReduce(Test1DReduce):
)
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
@skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function,"
......@@ -782,6 +855,8 @@ class TestReduceMinOpMultiAxises(OpTest):
class TestKeepDimReduceSumMultiAxises(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float64")}
self.attrs = {'dim': [-2, -1], 'keep_dim': True}
self.outputs = {
......@@ -794,12 +869,15 @@ class TestKeepDimReduceSumMultiAxises(OpTest):
self.check_output()
def test_check_grad(self):
# rev_comp error
self.check_grad(['X'], 'Out')
class TestReduceSumWithDimOne(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random((100, 1, 1)).astype("float64")}
self.attrs = {'dim': [1, 2], 'keep_dim': True}
self.outputs = {
......@@ -807,17 +885,21 @@ class TestReduceSumWithDimOne(OpTest):
axis=tuple(self.attrs['dim']), keepdims=True
)
}
# reduce doesn't support float64 in cinn
self.enable_cinn = False
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestReduceSumWithNumelOne(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random((100, 1)).astype("float64")}
self.attrs = {'dim': [1], 'keep_dim': False}
self.outputs = {
......@@ -825,45 +907,74 @@ class TestReduceSumWithNumelOne(OpTest):
axis=tuple(self.attrs['dim']), keepdims=False
)
}
# reduce doesn't support float64 in cinn
self.enable_cinn = False
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=False)
class TestReduceAll(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random((100, 1, 1)).astype("float64")}
self.attrs = {'reduce_all': True, 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum()}
# reduce doesn't support float64 in cinn
self.enable_cinn = False
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestReduceAllFp32(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random((100, 1, 1)).astype("float32")}
self.attrs = {'reduce_all': True, 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum()}
# reduce doesn't support float64 in cinn
self.enable_cinn = False
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
class Test1DReduceWithAxes1(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random(100).astype("float64")}
self.attrs = {'dim': [0], 'keep_dim': False}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}
self.enable_cinn = False
def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestReduceWithDtype(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = paddle.sum
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random((6, 2, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum().astype('float64')}
self.attrs = {'reduce_all': True}
......@@ -873,17 +984,26 @@ class TestReduceWithDtype(OpTest):
'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)),
}
)
self.enable_cinn = False
def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
def reduce_sum_wrapper(
x, axis=None, dtype_rename=None, keepdim=False, name=None
):
return paddle.sum(x, axis, "float64", keepdim, name)
class TestReduceWithDtype1(TestReduceWithDtype):
def setUp(self):
self.op_type = "reduce_sum"
self.python_api = reduce_sum_wrapper
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random((6, 2, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=1)}
self.attrs = {'dim': [1]}
......@@ -893,11 +1013,20 @@ class TestReduceWithDtype1(TestReduceWithDtype):
'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)),
}
)
self.enable_cinn = False
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
class TestReduceWithDtype2(TestReduceWithDtype):
def setUp(self):
self.op_type = "reduce_sum"
self.prim_op_type = "prim"
self.python_api = reduce_sum_wrapper
self.inputs = {'X': np.random.random((6, 2, 10)).astype("float64")}
self.outputs = {'Out': self.inputs['X'].sum(axis=1, keepdims=True)}
self.attrs = {'dim': [1], 'keep_dim': True}
......@@ -907,6 +1036,13 @@ class TestReduceWithDtype2(TestReduceWithDtype):
'out_dtype': int(convert_np_dtype_to_dtype_(np.float64)),
}
)
self.enable_cinn = False
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
class TestReduceSumOpError(unittest.TestCase):
......
......@@ -43,12 +43,6 @@ def ref_softmax(x, axis=None, dtype=None):
return np.apply_along_axis(stable_softmax, axis, x_t)
def softmax_wrapper(
x, axis=-1, dtype=None, name=None, use_cudnn=False, use_mkldnn=False
):
return paddle.nn.functional.softmax(x, axis=axis, dtype=dtype)
class TestSoftmaxOp(OpTest):
def get_x_shape(self):
return [10, 10]
......@@ -58,7 +52,8 @@ class TestSoftmaxOp(OpTest):
def setUp(self):
self.op_type = "softmax"
self.python_api = softmax_wrapper
self.prim_op_type = "comp"
self.python_api = F.softmax
self.use_cudnn = False
self.use_mkldnn = False
# explicilty use float32 for ROCm, as MIOpen does not yet support float64
......@@ -78,6 +73,7 @@ class TestSoftmaxOp(OpTest):
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
}
self.enable_cinn = False
def init_kernel_type(self):
pass
......@@ -86,11 +82,9 @@ class TestSoftmaxOp(OpTest):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(
place, atol=1e-5, check_dygraph=(not self.use_mkldnn)
)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output(check_dygraph=(not self.use_mkldnn))
self.check_output(check_prim=True)
def test_check_grad(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
......@@ -110,13 +104,20 @@ class TestSoftmaxOp(OpTest):
"Out",
max_relative_error=0.01,
check_dygraph=(not self.use_mkldnn),
check_prim=True,
)
class TestSoftmaxOpfp32(TestSoftmaxOp):
def init_kernel_type(self):
self.dtype = np.float32
class TestSoftmaxOp_ZeroDim1(TestSoftmaxOp):
def setUp(self):
self.op_type = "softmax"
self.python_api = softmax_wrapper
self.prim_op_type = "comp"
self.python_api = F.softmax
self.use_cudnn = False
self.use_mkldnn = False
# explicilty use float32 for ROCm, as MIOpen does not yet support float64
......@@ -133,6 +134,15 @@ class TestSoftmaxOp_ZeroDim1(TestSoftmaxOp):
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
}
self.enable_cinn = False
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output(check_prim=True)
@unittest.skipIf(
......@@ -141,7 +151,7 @@ class TestSoftmaxOp_ZeroDim1(TestSoftmaxOp):
class TestSoftmaxOp_ZeroDim2(TestSoftmaxOp):
def setUp(self):
self.op_type = "softmax"
self.python_api = softmax_wrapper
self.python_api = F.softmax
self.use_cudnn = True
self.use_mkldnn = False
# explicilty use float32 for ROCm, as MIOpen does not yet support float64
......@@ -158,6 +168,15 @@ class TestSoftmaxOp_ZeroDim2(TestSoftmaxOp):
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
}
self.enable_cinn = False
def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output(check_prim=True)
class TestSoftmaxOp2(TestSoftmaxOp):
......@@ -375,7 +394,7 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp):
class TestSoftmaxBF16Op(OpTest):
def setUp(self):
self.op_type = "softmax"
self.python_api = softmax_wrapper
self.python_api = F.softmax
self.use_cudnn = self.init_cudnn()
self.use_mkldnn = False
self.dtype = np.uint16
......
......@@ -1243,6 +1243,9 @@ class ProgramCache:
def concrete_programs(self):
return [cp for key, (cp, _) in self._caches.items()]
def clear(self):
self._caches = collections.OrderedDict()
class ProgramTranslator:
"""
......
......@@ -298,12 +298,6 @@ class InputSpec:
type(shape).__name__
)
)
if len(shape) == 0:
raise ValueError(
"`shape` in InputSpec should contain at least 1 element, but received {}.".format(
shape
)
)
for i, ele in enumerate(shape):
if ele is not None:
......
......@@ -1265,6 +1265,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
'x',
[
'bool',
'uint16',
'float16',
'float32',
'float64',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册