未验证 提交 6d0ef342 编写于 作者: C Chen Zhiyang 提交者: GitHub

【New IR】New ir op test v1.2 (#56931)

* div passed v1.0

* IrChange->IrGuard & optimize static input dtype

* clean

* remove IrChange and optimize IrGuard
上级 3eafa1fc
......@@ -479,9 +479,9 @@ if is_compiled_with_cinn():
disable_static()
from .new_ir_utils import IrChange # noqa: F401
from .new_ir_utils import IrGuard # noqa: F401
ir_change = IrChange()
ir_change = IrGuard()
ir_change._switch_to_new_ir()
__all__ = [ # noqa
......
......@@ -15,10 +15,8 @@
import paddle
from .fluid.wrapped_decorator import signature_safe_contextmanager
class IrChange:
class IrGuard:
def __init__(self):
old_flag = paddle.fluid.framework.get_flags("FLAGS_enable_new_ir_api")
paddle.fluid.framework.set_flags({"FLAGS_enable_new_ir_api": False})
......@@ -33,6 +31,14 @@ class IrChange:
)
paddle.fluid.framework.set_flags(old_flag)
def __enter__(self):
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
self._switch_to_new_ir()
def __exit__(self, exc_type, exc_val, exc_tb):
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
self._switch_to_old_ir()
def _switch_to_new_ir(self):
if paddle.ir.core._use_new_ir_api():
paddle.framework.set_flags(
......@@ -64,15 +70,3 @@ class IrChange:
"IrChange._switch_to_old_ir only work when paddle.ir.core._use_new_ir_api() is false, \
please set FLAGS_enable_new_ir_api = false"
)
@signature_safe_contextmanager
def _newir_guard():
ir_change = IrChange()
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
ir_change._switch_to_new_ir()
try:
yield
finally:
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
ir_change._switch_to_old_ir()
......@@ -99,9 +99,9 @@ def data(name, shape, dtype=None, lod_level=0):
"""
if paddle.ir.core._use_new_ir_api():
if not dtype:
dtype = paddle.get_default_dtype()
if paddle.ir.core._use_new_ir_api():
ir_dtype = paddle.ir.core.convert_np_dtype_to_dtype_(dtype)
return paddle._ir_ops.data(name, shape, ir_dtype, core.Place())
......@@ -115,7 +115,6 @@ def data(name, shape, dtype=None, lod_level=0):
if shape[i] is None:
shape[i] = -1
if dtype:
out = helper.create_global_variable(
name=name,
shape=shape,
......@@ -127,18 +126,6 @@ def data(name, shape, dtype=None, lod_level=0):
need_check_feed=True,
)
else:
out = helper.create_global_variable(
name=name,
shape=shape,
dtype=paddle.get_default_dtype(),
type=core.VarDesc.VarType.LOD_TENSOR,
stop_gradient=True,
lod_level=lod_level,
is_data=True,
need_check_feed=True,
)
is_new_ir_mode = os.environ.get("FLAGS_enable_new_ir_in_executor", None)
if evaluate_flag(is_new_ir_mode):
helper = LayerHelper('data', **locals())
......
......@@ -1261,7 +1261,7 @@ class OpTest(unittest.TestCase):
static_inputs = defaultdict(list)
feed = {}
for name, item in self.inputs.items():
if isinstance(item, list):
if isinstance(item, (list, tuple)):
for tup in item:
dtype = (
"bfloat16"
......@@ -1355,9 +1355,7 @@ class OpTest(unittest.TestCase):
# executor run
executor = Executor(place)
(outs,) = executor.run(
ir_program,
feed=feed,
fetch_list=fetch_list,
ir_program, feed=feed, fetch_list=[fetch_list]
)
return outs
......@@ -2473,7 +2471,7 @@ class OpTest(unittest.TestCase):
or type(place) is paddle.fluid.libpaddle.CUDAPlace
):
print("New IR checker begins...........")
with paddle.new_ir_utils._newir_guard():
with paddle.new_ir_utils.IrGuard():
new_ir_checker = NewIRChecker(self, self.outputs)
new_ir_checker.check()
......@@ -3020,6 +3018,7 @@ class OpTest(unittest.TestCase):
"Gradient Check On %s" % str(place),
atol=atol,
)
# get new ir gradient
if (
self.op_type
......@@ -3031,7 +3030,7 @@ class OpTest(unittest.TestCase):
or type(place) is paddle.fluid.libpaddle.CUDAPlace
):
print("New IR gradient begins...........")
with paddle.new_ir_utils._newir_guard():
with paddle.new_ir_utils.IrGuard():
new_ir_grad = self._get_ir_gradient(
inputs_to_check,
place,
......@@ -3042,7 +3041,7 @@ class OpTest(unittest.TestCase):
print("New IR gradient ends...........")
self._assert_is_close(
numeric_grads,
[new_ir_grad],
new_ir_grad,
inputs_to_check,
max_relative_error,
"Gradient Check On %s" % str(place),
......@@ -3448,23 +3447,38 @@ class OpTest(unittest.TestCase):
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
grad_outputs = []
if user_defined_grad_outputs is not None:
# user_defined_grad_outputs here are numpy arrays
if not isinstance(user_defined_grad_outputs, list):
user_defined_grad_outputs = [user_defined_grad_outputs]
for grad_out_value, idx in zip(
user_defined_grad_outputs,
range(len(user_defined_grad_outputs)),
):
grad_val = paddle.static.data(
name='val_grad_%s' % idx,
shape=grad_out_value.shape,
dtype=grad_out_value.dtype,
)
grad_outputs.append(grad_val)
feed.update({'val_grad_%s' % idx: grad_out_value})
# delete the inputs which no need to calculate grad
for no_grad_val in no_grad_set:
del static_inputs[no_grad_val]
ret_tuple = self.python_api(*args)
result = construct_output_dict_by_kernel_sig(ret_tuple, outputs_sig)
outputs = construct_output_dict_by_kernel_sig(
ret_tuple, outputs_sig
)
if hasattr(self, "python_out_sig_sub_name"):
for key in self.python_out_sig_sub_name.keys():
for i in range(len(self.python_out_sig_sub_name[key])):
result[key][0][i].name = self.python_out_sig_sub_name[
outputs[key][0][i].name = self.python_out_sig_sub_name[
key
][i]
fetch_list = getattr(self, "fetch_list", [])
if len(fetch_list) == 0:
for var in result.items():
if isinstance(var[1], list):
for v in var[1]:
fetch_list.append(v)
else:
fetch_list.append(var[1])
outputs = result
outputs_valid = outputs
grad_inputs = inputs_to_check
if user_defined_grad_outputs is None:
......@@ -3477,16 +3491,6 @@ class OpTest(unittest.TestCase):
grad_outputs=None,
)
else:
# user_defined_grad_outputs here are numpy arrays
if not isinstance(user_defined_grad_outputs, list):
user_defined_grad_outputs = [user_defined_grad_outputs]
grad_outputs = []
for grad_out_value in user_defined_grad_outputs:
grad_outputs.append(paddle.to_tensor(grad_out_value))
# delete the inputs which no need to calculate grad
for no_grad_val in no_grad_set:
del static_inputs[no_grad_val]
grad_inputs = ir_grad(
outputs=paddle.utils.flatten(outputs),
inputs=paddle.utils.flatten(static_inputs),
......@@ -3496,7 +3500,7 @@ class OpTest(unittest.TestCase):
# executor run
executor = paddle.static.Executor()
(outs,) = executor.run(
outs = executor.run(
ir_program,
feed=feed,
fetch_list=fetch_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册