未验证 提交 6d0ef342 编写于 作者: C Chen Zhiyang 提交者: GitHub

【New IR】New ir op test v1.2 (#56931)

* div passed v1.0

* IrChange->IrGuard & optimize static input dtype

* clean

* remove IrChange and optimize IrGuard
上级 3eafa1fc
...@@ -479,9 +479,9 @@ if is_compiled_with_cinn(): ...@@ -479,9 +479,9 @@ if is_compiled_with_cinn():
disable_static() disable_static()
from .new_ir_utils import IrChange # noqa: F401 from .new_ir_utils import IrGuard # noqa: F401
ir_change = IrChange() ir_change = IrGuard()
ir_change._switch_to_new_ir() ir_change._switch_to_new_ir()
__all__ = [ # noqa __all__ = [ # noqa
......
...@@ -15,10 +15,8 @@ ...@@ -15,10 +15,8 @@
import paddle import paddle
from .fluid.wrapped_decorator import signature_safe_contextmanager
class IrGuard:
class IrChange:
def __init__(self): def __init__(self):
old_flag = paddle.fluid.framework.get_flags("FLAGS_enable_new_ir_api") old_flag = paddle.fluid.framework.get_flags("FLAGS_enable_new_ir_api")
paddle.fluid.framework.set_flags({"FLAGS_enable_new_ir_api": False}) paddle.fluid.framework.set_flags({"FLAGS_enable_new_ir_api": False})
...@@ -33,6 +31,14 @@ class IrChange: ...@@ -33,6 +31,14 @@ class IrChange:
) )
paddle.fluid.framework.set_flags(old_flag) paddle.fluid.framework.set_flags(old_flag)
def __enter__(self):
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
self._switch_to_new_ir()
def __exit__(self, exc_type, exc_val, exc_tb):
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
self._switch_to_old_ir()
def _switch_to_new_ir(self): def _switch_to_new_ir(self):
if paddle.ir.core._use_new_ir_api(): if paddle.ir.core._use_new_ir_api():
paddle.framework.set_flags( paddle.framework.set_flags(
...@@ -64,15 +70,3 @@ class IrChange: ...@@ -64,15 +70,3 @@ class IrChange:
"IrChange._switch_to_old_ir only work when paddle.ir.core._use_new_ir_api() is false, \ "IrChange._switch_to_old_ir only work when paddle.ir.core._use_new_ir_api() is false, \
please set FLAGS_enable_new_ir_api = false" please set FLAGS_enable_new_ir_api = false"
) )
@signature_safe_contextmanager
def _newir_guard():
ir_change = IrChange()
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
ir_change._switch_to_new_ir()
try:
yield
finally:
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
ir_change._switch_to_old_ir()
...@@ -99,9 +99,9 @@ def data(name, shape, dtype=None, lod_level=0): ...@@ -99,9 +99,9 @@ def data(name, shape, dtype=None, lod_level=0):
""" """
if paddle.ir.core._use_new_ir_api():
if not dtype: if not dtype:
dtype = paddle.get_default_dtype() dtype = paddle.get_default_dtype()
if paddle.ir.core._use_new_ir_api():
ir_dtype = paddle.ir.core.convert_np_dtype_to_dtype_(dtype) ir_dtype = paddle.ir.core.convert_np_dtype_to_dtype_(dtype)
return paddle._ir_ops.data(name, shape, ir_dtype, core.Place()) return paddle._ir_ops.data(name, shape, ir_dtype, core.Place())
...@@ -115,7 +115,6 @@ def data(name, shape, dtype=None, lod_level=0): ...@@ -115,7 +115,6 @@ def data(name, shape, dtype=None, lod_level=0):
if shape[i] is None: if shape[i] is None:
shape[i] = -1 shape[i] = -1
if dtype:
out = helper.create_global_variable( out = helper.create_global_variable(
name=name, name=name,
shape=shape, shape=shape,
...@@ -127,18 +126,6 @@ def data(name, shape, dtype=None, lod_level=0): ...@@ -127,18 +126,6 @@ def data(name, shape, dtype=None, lod_level=0):
need_check_feed=True, need_check_feed=True,
) )
else:
out = helper.create_global_variable(
name=name,
shape=shape,
dtype=paddle.get_default_dtype(),
type=core.VarDesc.VarType.LOD_TENSOR,
stop_gradient=True,
lod_level=lod_level,
is_data=True,
need_check_feed=True,
)
is_new_ir_mode = os.environ.get("FLAGS_enable_new_ir_in_executor", None) is_new_ir_mode = os.environ.get("FLAGS_enable_new_ir_in_executor", None)
if evaluate_flag(is_new_ir_mode): if evaluate_flag(is_new_ir_mode):
helper = LayerHelper('data', **locals()) helper = LayerHelper('data', **locals())
......
...@@ -1261,7 +1261,7 @@ class OpTest(unittest.TestCase): ...@@ -1261,7 +1261,7 @@ class OpTest(unittest.TestCase):
static_inputs = defaultdict(list) static_inputs = defaultdict(list)
feed = {} feed = {}
for name, item in self.inputs.items(): for name, item in self.inputs.items():
if isinstance(item, list): if isinstance(item, (list, tuple)):
for tup in item: for tup in item:
dtype = ( dtype = (
"bfloat16" "bfloat16"
...@@ -1355,9 +1355,7 @@ class OpTest(unittest.TestCase): ...@@ -1355,9 +1355,7 @@ class OpTest(unittest.TestCase):
# executor run # executor run
executor = Executor(place) executor = Executor(place)
(outs,) = executor.run( (outs,) = executor.run(
ir_program, ir_program, feed=feed, fetch_list=[fetch_list]
feed=feed,
fetch_list=fetch_list,
) )
return outs return outs
...@@ -2473,7 +2471,7 @@ class OpTest(unittest.TestCase): ...@@ -2473,7 +2471,7 @@ class OpTest(unittest.TestCase):
or type(place) is paddle.fluid.libpaddle.CUDAPlace or type(place) is paddle.fluid.libpaddle.CUDAPlace
): ):
print("New IR checker begins...........") print("New IR checker begins...........")
with paddle.new_ir_utils._newir_guard(): with paddle.new_ir_utils.IrGuard():
new_ir_checker = NewIRChecker(self, self.outputs) new_ir_checker = NewIRChecker(self, self.outputs)
new_ir_checker.check() new_ir_checker.check()
...@@ -3020,6 +3018,7 @@ class OpTest(unittest.TestCase): ...@@ -3020,6 +3018,7 @@ class OpTest(unittest.TestCase):
"Gradient Check On %s" % str(place), "Gradient Check On %s" % str(place),
atol=atol, atol=atol,
) )
# get new ir gradient # get new ir gradient
if ( if (
self.op_type self.op_type
...@@ -3031,7 +3030,7 @@ class OpTest(unittest.TestCase): ...@@ -3031,7 +3030,7 @@ class OpTest(unittest.TestCase):
or type(place) is paddle.fluid.libpaddle.CUDAPlace or type(place) is paddle.fluid.libpaddle.CUDAPlace
): ):
print("New IR gradient begins...........") print("New IR gradient begins...........")
with paddle.new_ir_utils._newir_guard(): with paddle.new_ir_utils.IrGuard():
new_ir_grad = self._get_ir_gradient( new_ir_grad = self._get_ir_gradient(
inputs_to_check, inputs_to_check,
place, place,
...@@ -3042,7 +3041,7 @@ class OpTest(unittest.TestCase): ...@@ -3042,7 +3041,7 @@ class OpTest(unittest.TestCase):
print("New IR gradient ends...........") print("New IR gradient ends...........")
self._assert_is_close( self._assert_is_close(
numeric_grads, numeric_grads,
[new_ir_grad], new_ir_grad,
inputs_to_check, inputs_to_check,
max_relative_error, max_relative_error,
"Gradient Check On %s" % str(place), "Gradient Check On %s" % str(place),
...@@ -3448,23 +3447,38 @@ class OpTest(unittest.TestCase): ...@@ -3448,23 +3447,38 @@ class OpTest(unittest.TestCase):
args = OpTestUtils.assumption_assert_and_transform( args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig) args, len(inputs_sig)
) )
grad_outputs = []
if user_defined_grad_outputs is not None:
# user_defined_grad_outputs here are numpy arrays
if not isinstance(user_defined_grad_outputs, list):
user_defined_grad_outputs = [user_defined_grad_outputs]
for grad_out_value, idx in zip(
user_defined_grad_outputs,
range(len(user_defined_grad_outputs)),
):
grad_val = paddle.static.data(
name='val_grad_%s' % idx,
shape=grad_out_value.shape,
dtype=grad_out_value.dtype,
)
grad_outputs.append(grad_val)
feed.update({'val_grad_%s' % idx: grad_out_value})
# delete the inputs which no need to calculate grad
for no_grad_val in no_grad_set:
del static_inputs[no_grad_val]
ret_tuple = self.python_api(*args) ret_tuple = self.python_api(*args)
result = construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig) outputs = construct_output_dict_by_kernel_sig(
ret_tuple, outputs_sig
)
if hasattr(self, "python_out_sig_sub_name"): if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys(): for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])): for i in range(len(self.python_out_sig_sub_name[key])):
result[key][0][i].name = self.python_out_sig_sub_name[ outputs[key][0][i].name = self.python_out_sig_sub_name[
key key
][i] ][i]
fetch_list = getattr(self, "fetch_list", []) fetch_list = getattr(self, "fetch_list", [])
if len(fetch_list) == 0:
for var in result.items():
if isinstance(var[1], list):
for v in var[1]:
fetch_list.append(v)
else:
fetch_list.append(var[1])
outputs = result
outputs_valid = outputs outputs_valid = outputs
grad_inputs = inputs_to_check grad_inputs = inputs_to_check
if user_defined_grad_outputs is None: if user_defined_grad_outputs is None:
...@@ -3477,16 +3491,6 @@ class OpTest(unittest.TestCase): ...@@ -3477,16 +3491,6 @@ class OpTest(unittest.TestCase):
grad_outputs=None, grad_outputs=None,
) )
else: else:
# user_defined_grad_outputs here are numpy arrays
if not isinstance(user_defined_grad_outputs, list):
user_defined_grad_outputs = [user_defined_grad_outputs]
grad_outputs = []
for grad_out_value in user_defined_grad_outputs:
grad_outputs.append(paddle.to_tensor(grad_out_value))
# delete the inputs which no need to calculate grad
for no_grad_val in no_grad_set:
del static_inputs[no_grad_val]
grad_inputs = ir_grad( grad_inputs = ir_grad(
outputs=paddle.utils.flatten(outputs), outputs=paddle.utils.flatten(outputs),
inputs=paddle.utils.flatten(static_inputs), inputs=paddle.utils.flatten(static_inputs),
...@@ -3496,7 +3500,7 @@ class OpTest(unittest.TestCase): ...@@ -3496,7 +3500,7 @@ class OpTest(unittest.TestCase):
# executor run # executor run
executor = paddle.static.Executor() executor = paddle.static.Executor()
(outs,) = executor.run( outs = executor.run(
ir_program, ir_program,
feed=feed, feed=feed,
fetch_list=fetch_list, fetch_list=fetch_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册