diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 666935934235f3bdcea94326adc5fedda490db3d..cc59217e480cccd62a4a9f4d7e0d5bcfffbe40bb 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -479,9 +479,9 @@ if is_compiled_with_cinn(): disable_static() -from .new_ir_utils import IrChange # noqa: F401 +from .new_ir_utils import IrGuard # noqa: F401 -ir_change = IrChange() +ir_change = IrGuard() ir_change._switch_to_new_ir() __all__ = [ # noqa diff --git a/python/paddle/new_ir_utils.py b/python/paddle/new_ir_utils.py index 83c9b5f826d8d5ccbe06d5536415db7061740f4f..88a5415e3ec3e14a4a86b8f2233f22a80353962a 100644 --- a/python/paddle/new_ir_utils.py +++ b/python/paddle/new_ir_utils.py @@ -15,10 +15,8 @@ import paddle -from .fluid.wrapped_decorator import signature_safe_contextmanager - -class IrChange: +class IrGuard: def __init__(self): old_flag = paddle.fluid.framework.get_flags("FLAGS_enable_new_ir_api") paddle.fluid.framework.set_flags({"FLAGS_enable_new_ir_api": False}) @@ -33,6 +31,14 @@ class IrChange: ) paddle.fluid.framework.set_flags(old_flag) + def __enter__(self): + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) + self._switch_to_new_ir() + + def __exit__(self, exc_type, exc_val, exc_tb): + paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) + self._switch_to_old_ir() + def _switch_to_new_ir(self): if paddle.ir.core._use_new_ir_api(): paddle.framework.set_flags( @@ -64,15 +70,3 @@ class IrChange: "IrChange._switch_to_old_ir only work when paddle.ir.core._use_new_ir_api() is false, \ please set FLAGS_enable_new_ir_api = false" ) - - -@signature_safe_contextmanager -def _newir_guard(): - ir_change = IrChange() - paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True}) - ir_change._switch_to_new_ir() - try: - yield - finally: - paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) - ir_change._switch_to_old_ir() diff --git a/python/paddle/static/input.py b/python/paddle/static/input.py index 30a853336c976e783b764136143f4adb4be431fb..1382ff591b7f939075202491e86af3375c948f1e 100644 --- a/python/paddle/static/input.py +++ b/python/paddle/static/input.py @@ -99,9 +99,9 @@ def data(name, shape, dtype=None, lod_level=0): """ + if not dtype: + dtype = paddle.get_default_dtype() if paddle.ir.core._use_new_ir_api(): - if not dtype: - dtype = paddle.get_default_dtype() ir_dtype = paddle.ir.core.convert_np_dtype_to_dtype_(dtype) return paddle._ir_ops.data(name, shape, ir_dtype, core.Place()) @@ -115,29 +115,16 @@ def data(name, shape, dtype=None, lod_level=0): if shape[i] is None: shape[i] = -1 - if dtype: - out = helper.create_global_variable( - name=name, - shape=shape, - dtype=dtype, - type=core.VarDesc.VarType.LOD_TENSOR, - stop_gradient=True, - lod_level=lod_level, - is_data=True, - need_check_feed=True, - ) - - else: - out = helper.create_global_variable( - name=name, - shape=shape, - dtype=paddle.get_default_dtype(), - type=core.VarDesc.VarType.LOD_TENSOR, - stop_gradient=True, - lod_level=lod_level, - is_data=True, - need_check_feed=True, - ) + out = helper.create_global_variable( + name=name, + shape=shape, + dtype=dtype, + type=core.VarDesc.VarType.LOD_TENSOR, + stop_gradient=True, + lod_level=lod_level, + is_data=True, + need_check_feed=True, + ) is_new_ir_mode = os.environ.get("FLAGS_enable_new_ir_in_executor", None) if evaluate_flag(is_new_ir_mode): diff --git a/test/legacy_test/eager_op_test.py b/test/legacy_test/eager_op_test.py index 817d39b7d879d9bfe5468b2e355ef3f245bf9735..66a48acc6c5f6c8137fa340967c68ba096e9a79a 100644 --- a/test/legacy_test/eager_op_test.py +++ b/test/legacy_test/eager_op_test.py @@ -1261,7 +1261,7 @@ class OpTest(unittest.TestCase): static_inputs = defaultdict(list) feed = {} for name, item in self.inputs.items(): - if isinstance(item, list): + if isinstance(item, (list, tuple)): for tup in item: dtype = ( "bfloat16" @@ -1355,9 +1355,7 @@ class OpTest(unittest.TestCase): # executor run executor = Executor(place) (outs,) = executor.run( - ir_program, - feed=feed, - fetch_list=fetch_list, + ir_program, feed=feed, fetch_list=[fetch_list] ) return outs @@ -2473,7 +2471,7 @@ class OpTest(unittest.TestCase): or type(place) is paddle.fluid.libpaddle.CUDAPlace ): print("New IR checker begins...........") - with paddle.new_ir_utils._newir_guard(): + with paddle.new_ir_utils.IrGuard(): new_ir_checker = NewIRChecker(self, self.outputs) new_ir_checker.check() @@ -3020,6 +3018,7 @@ class OpTest(unittest.TestCase): "Gradient Check On %s" % str(place), atol=atol, ) + # get new ir gradient if ( self.op_type @@ -3031,7 +3030,7 @@ class OpTest(unittest.TestCase): or type(place) is paddle.fluid.libpaddle.CUDAPlace ): print("New IR gradient begins...........") - with paddle.new_ir_utils._newir_guard(): + with paddle.new_ir_utils.IrGuard(): new_ir_grad = self._get_ir_gradient( inputs_to_check, place, @@ -3042,7 +3041,7 @@ class OpTest(unittest.TestCase): print("New IR gradient ends...........") self._assert_is_close( numeric_grads, - [new_ir_grad], + new_ir_grad, inputs_to_check, max_relative_error, "Gradient Check On %s" % str(place), @@ -3448,23 +3447,38 @@ class OpTest(unittest.TestCase): args = OpTestUtils.assumption_assert_and_transform( args, len(inputs_sig) ) + grad_outputs = [] + if user_defined_grad_outputs is not None: + # user_defined_grad_outputs here are numpy arrays + if not isinstance(user_defined_grad_outputs, list): + user_defined_grad_outputs = [user_defined_grad_outputs] + for grad_out_value, idx in zip( + user_defined_grad_outputs, + range(len(user_defined_grad_outputs)), + ): + grad_val = paddle.static.data( + name='val_grad_%s' % idx, + shape=grad_out_value.shape, + dtype=grad_out_value.dtype, + ) + grad_outputs.append(grad_val) + feed.update({'val_grad_%s' % idx: grad_out_value}) + # delete the inputs which no need to calculate grad + for no_grad_val in no_grad_set: + del static_inputs[no_grad_val] + ret_tuple = self.python_api(*args) - result = construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig) + outputs = construct_output_dict_by_kernel_sig( + ret_tuple, outputs_sig + ) if hasattr(self, "python_out_sig_sub_name"): for key in self.python_out_sig_sub_name.keys(): for i in range(len(self.python_out_sig_sub_name[key])): - result[key][0][i].name = self.python_out_sig_sub_name[ + outputs[key][0][i].name = self.python_out_sig_sub_name[ key ][i] fetch_list = getattr(self, "fetch_list", []) - if len(fetch_list) == 0: - for var in result.items(): - if isinstance(var[1], list): - for v in var[1]: - fetch_list.append(v) - else: - fetch_list.append(var[1]) - outputs = result + outputs_valid = outputs grad_inputs = inputs_to_check if user_defined_grad_outputs is None: @@ -3477,16 +3491,6 @@ class OpTest(unittest.TestCase): grad_outputs=None, ) else: - # user_defined_grad_outputs here are numpy arrays - if not isinstance(user_defined_grad_outputs, list): - user_defined_grad_outputs = [user_defined_grad_outputs] - grad_outputs = [] - for grad_out_value in user_defined_grad_outputs: - grad_outputs.append(paddle.to_tensor(grad_out_value)) - # delete the inputs which no need to calculate grad - for no_grad_val in no_grad_set: - del static_inputs[no_grad_val] - grad_inputs = ir_grad( outputs=paddle.utils.flatten(outputs), inputs=paddle.utils.flatten(static_inputs), @@ -3496,7 +3500,7 @@ class OpTest(unittest.TestCase): # executor run executor = paddle.static.Executor() - (outs,) = executor.run( + outs = executor.run( ir_program, feed=feed, fetch_list=fetch_list,