未验证 提交 ef1aa8ce 编写于 作者: C Charles-hit 提交者: GitHub

fix prim test (#51385)

上级 66b99dc8
...@@ -338,6 +338,9 @@ class OpTest(unittest.TestCase): ...@@ -338,6 +338,9 @@ class OpTest(unittest.TestCase):
_set_use_system_allocator(cls._use_system_allocator) _set_use_system_allocator(cls._use_system_allocator)
if hasattr(cls, 'check_prim') and os.getenv('FLAGS_prim_test_log'):
print("check prim end!")
def is_empty_grad_op(op_type): def is_empty_grad_op(op_type):
all_op_kernels = core._get_all_register_op_kernels() all_op_kernels = core._get_all_register_op_kernels()
grad_op = op_type + '_grad' grad_op = op_type + '_grad'
......
...@@ -352,6 +352,9 @@ class OpTest(unittest.TestCase): ...@@ -352,6 +352,9 @@ class OpTest(unittest.TestCase):
_set_use_system_allocator(cls._use_system_allocator) _set_use_system_allocator(cls._use_system_allocator)
if hasattr(cls, 'check_prim') and os.getenv('FLAGS_prim_test_log'):
print("check prim end!")
def is_empty_grad_op(op_type): def is_empty_grad_op(op_type):
all_op_kernels = core._get_all_register_op_kernels() all_op_kernels = core._get_all_register_op_kernels()
grad_op = op_type + '_grad' grad_op = op_type + '_grad'
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import struct import struct
from collections import defaultdict from collections import defaultdict
...@@ -277,6 +278,8 @@ class PrimForwardChecker: ...@@ -277,6 +278,8 @@ class PrimForwardChecker:
if hasattr(self.op_test, 'enable_cinn') if hasattr(self.op_test, 'enable_cinn')
else True else True
) )
if os.getenv('FLAGS_enable_cinn'):
self.enable_cinn = True
self.enable_check_eager_comp = ( self.enable_check_eager_comp = (
self.op_test.enable_check_eager_comp self.op_test.enable_check_eager_comp
if hasattr(self.op_test, 'enable_check_eager_comp') if hasattr(self.op_test, 'enable_check_eager_comp')
...@@ -398,8 +401,8 @@ class PrimForwardChecker: ...@@ -398,8 +401,8 @@ class PrimForwardChecker:
eager_tensor_inputs, eager_tensor_inputs,
attrs_outputs, attrs_outputs,
_, _,
) = self.get_eager_input_attr_and_inputdict() ) = self.get_eager_input_attr_and_inputdict(stop_gradient=True)
eager_tensor_outputs = self.get_eager_empty_output() eager_tensor_outputs = self.get_eager_empty_output(stop_gradient=True)
kernel_sig = OpTestUtils._get_kernel_signature( kernel_sig = OpTestUtils._get_kernel_signature(
self.op_type, self.op_type,
eager_tensor_inputs, eager_tensor_inputs,
...@@ -418,7 +421,7 @@ class PrimForwardChecker: ...@@ -418,7 +421,7 @@ class PrimForwardChecker:
eager_tensor_inputs, eager_tensor_inputs,
attrs_outputs, attrs_outputs,
_, _,
) = self.get_eager_input_attr_and_inputdict() ) = self.get_eager_input_attr_and_inputdict(stop_gradient=True)
args = OpTestUtils.prepare_python_api_arguments( args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
) )
...@@ -432,7 +435,7 @@ class PrimForwardChecker: ...@@ -432,7 +435,7 @@ class PrimForwardChecker:
ret = map_structure(lambda x: convert_uint16_to_float(x), ret) ret = map_structure(lambda x: convert_uint16_to_float(x), ret)
return ret return ret
def get_eager_input_attr_and_inputdict(self): def get_eager_input_attr_and_inputdict(self, stop_gradient):
attrs_outputs = {} attrs_outputs = {}
for attrs_name in self.attrs: for attrs_name in self.attrs:
if self.attrs[attrs_name] is not None: if self.attrs[attrs_name] is not None:
...@@ -450,7 +453,7 @@ class PrimForwardChecker: ...@@ -450,7 +453,7 @@ class PrimForwardChecker:
x = paddle.to_tensor( x = paddle.to_tensor(
data=tup[1], data=tup[1],
place=self.place, place=self.place,
stop_gradient=False, stop_gradient=stop_gradient,
dtype=dtype, dtype=dtype,
) )
eager_inputs[name].append(x) eager_inputs[name].append(x)
...@@ -464,14 +467,14 @@ class PrimForwardChecker: ...@@ -464,14 +467,14 @@ class PrimForwardChecker:
x = paddle.to_tensor( x = paddle.to_tensor(
data=item, data=item,
place=self.place, place=self.place,
stop_gradient=False, stop_gradient=stop_gradient,
dtype=dtype, dtype=dtype,
) )
eager_inputs[name].append(x) eager_inputs[name].append(x)
input_dict.update({name: x}) input_dict.update({name: x})
return eager_inputs, attrs_outputs, input_dict return eager_inputs, attrs_outputs, input_dict
def get_eager_empty_output(self): def get_eager_empty_output(self, stop_gradient):
eager_outputs = defaultdict(list) eager_outputs = defaultdict(list)
for name, item in self.outputs.items(): for name, item in self.outputs.items():
if isinstance(item, list): if isinstance(item, list):
...@@ -484,7 +487,7 @@ class PrimForwardChecker: ...@@ -484,7 +487,7 @@ class PrimForwardChecker:
x = paddle.to_tensor( x = paddle.to_tensor(
data=[], data=[],
place=self.place, place=self.place,
stop_gradient=False, stop_gradient=stop_gradient,
dtype=dtype, dtype=dtype,
) )
eager_outputs[name].append(x) eager_outputs[name].append(x)
...@@ -495,12 +498,15 @@ class PrimForwardChecker: ...@@ -495,12 +498,15 @@ class PrimForwardChecker:
else item.dtype else item.dtype
) )
x = paddle.to_tensor( x = paddle.to_tensor(
data=[], place=self.place, stop_gradient=False, dtype=dtype data=[],
place=self.place,
stop_gradient=stop_gradient,
dtype=dtype,
) )
eager_outputs[name].append(x) eager_outputs[name].append(x)
return eager_outputs return eager_outputs
def get_static_input_attr_inputdict_and_feed(self): def get_static_input_attr_inputdict_and_feed(self, stop_gradient):
attrs_outputs = {} attrs_outputs = {}
for attrs_name in self.attrs: for attrs_name in self.attrs:
if self.attrs[attrs_name] is not None: if self.attrs[attrs_name] is not None:
...@@ -519,7 +525,7 @@ class PrimForwardChecker: ...@@ -519,7 +525,7 @@ class PrimForwardChecker:
x = paddle.static.data( x = paddle.static.data(
name=str(tup[0]), shape=tup[1].shape, dtype=dtype name=str(tup[0]), shape=tup[1].shape, dtype=dtype
) )
x.stop_gradient = False x.stop_gradient = stop_gradient
static_inputs[name].append(x) static_inputs[name].append(x)
feed.update({str(tup[0]): tup[1]}) feed.update({str(tup[0]): tup[1]})
input_dict.update({str(tup[0]): x}) input_dict.update({str(tup[0]): x})
...@@ -530,7 +536,7 @@ class PrimForwardChecker: ...@@ -530,7 +536,7 @@ class PrimForwardChecker:
else item.dtype else item.dtype
) )
x = paddle.static.data(name=name, shape=item.shape, dtype=dtype) x = paddle.static.data(name=name, shape=item.shape, dtype=dtype)
x.stop_gradient = False x.stop_gradient = stop_gradient
static_inputs[name].append(x) static_inputs[name].append(x)
feed.update({name: item}) feed.update({name: item})
input_dict.update({name: x}) input_dict.update({name: x})
...@@ -555,7 +561,9 @@ class PrimForwardChecker: ...@@ -555,7 +561,9 @@ class PrimForwardChecker:
attrs, attrs,
input_dict, input_dict,
feed, feed,
) = self.get_static_input_attr_inputdict_and_feed() ) = self.get_static_input_attr_inputdict_and_feed(
stop_gradient=True
)
args = OpTestUtils.prepare_python_api_arguments( args = OpTestUtils.prepare_python_api_arguments(
self.python_api, static_inputs, attrs, self.kernel_sig self.python_api, static_inputs, attrs, self.kernel_sig
) )
...@@ -621,7 +629,7 @@ class PrimForwardChecker: ...@@ -621,7 +629,7 @@ class PrimForwardChecker:
eager_tensor_inputs, eager_tensor_inputs,
attrs_outputs, attrs_outputs,
_, _,
) = self.get_eager_input_attr_and_inputdict() ) = self.get_eager_input_attr_and_inputdict(stop_gradient=True)
args = OpTestUtils.prepare_python_api_arguments( args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
) )
...@@ -698,7 +706,7 @@ class PrimForwardChecker: ...@@ -698,7 +706,7 @@ class PrimForwardChecker:
eager_tensor_inputs, eager_tensor_inputs,
attrs_outputs, attrs_outputs,
_, _,
) = self.get_eager_input_attr_and_inputdict() ) = self.get_eager_input_attr_and_inputdict(stop_gradient=True)
args = OpTestUtils.prepare_python_api_arguments( args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
) )
...@@ -794,7 +802,9 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -794,7 +802,9 @@ class PrimGradChecker(PrimForwardChecker):
output_dict = {} output_dict = {}
for i in range(len(api_outputs)): for i in range(len(api_outputs)):
output_name = outputs_sig[i] output_name = outputs_sig[i]
if isinstance(np_outputs[output_name], list): if output_name in np_outputs and isinstance(
np_outputs[output_name], list
):
for j, tup in enumerate(np_outputs[output_name]): for j, tup in enumerate(np_outputs[output_name]):
output_dict.update({tup[0]: api_outputs[i][j]}) output_dict.update({tup[0]: api_outputs[i][j]})
else: else:
...@@ -854,11 +864,13 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -854,11 +864,13 @@ class PrimGradChecker(PrimForwardChecker):
eager_tensor_inputs, eager_tensor_inputs,
attrs_outputs, attrs_outputs,
inputs_dict, inputs_dict,
) = self.get_eager_input_attr_and_inputdict() ) = self.get_eager_input_attr_and_inputdict(stop_gradient=False)
args = OpTestUtils.prepare_python_api_arguments( args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
) )
inputs_sig, _, outputs_sig = self.kernel_sig inputs_sig, _, outputs_sig = self.kernel_sig
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
args = OpTestUtils.assumption_assert_and_transform( args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig) args, len(inputs_sig)
) )
...@@ -954,11 +966,15 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -954,11 +966,15 @@ class PrimGradChecker(PrimForwardChecker):
attrs, attrs,
inputs_dict, inputs_dict,
feed, feed,
) = self.get_static_input_attr_inputdict_and_feed() ) = self.get_static_input_attr_inputdict_and_feed(
stop_gradient=False
)
args = OpTestUtils.prepare_python_api_arguments( args = OpTestUtils.prepare_python_api_arguments(
self.python_api, static_inputs, attrs, self.kernel_sig self.python_api, static_inputs, attrs, self.kernel_sig
) )
inputs_sig, _, outputs_sig = self.kernel_sig inputs_sig, _, outputs_sig = self.kernel_sig
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
args = OpTestUtils.assumption_assert_and_transform( args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig) args, len(inputs_sig)
) )
...@@ -1055,7 +1071,7 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1055,7 +1071,7 @@ class PrimGradChecker(PrimForwardChecker):
eager_tensor_inputs, eager_tensor_inputs,
attrs_outputs, attrs_outputs,
inputs_dict, inputs_dict,
) = self.get_eager_input_attr_and_inputdict() ) = self.get_eager_input_attr_and_inputdict(stop_gradient=False)
args = OpTestUtils.prepare_python_api_arguments( args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
) )
...@@ -1066,6 +1082,8 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1066,6 +1082,8 @@ class PrimGradChecker(PrimForwardChecker):
net = PrimNet(self.python_api) net = PrimNet(self.python_api)
net = apply_to_static(net, False) net = apply_to_static(net, False)
out = _as_list(net(args)) out = _as_list(net(args))
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
outputs_dict = self.get_output_dict(self.outputs, out, outputs_sig) outputs_dict = self.get_output_dict(self.outputs, out, outputs_sig)
ys = [] ys = []
if isinstance(self.output_names, list): if isinstance(self.output_names, list):
...@@ -1163,7 +1181,7 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1163,7 +1181,7 @@ class PrimGradChecker(PrimForwardChecker):
eager_tensor_inputs, eager_tensor_inputs,
attrs_outputs, attrs_outputs,
inputs_dict, inputs_dict,
) = self.get_eager_input_attr_and_inputdict() ) = self.get_eager_input_attr_and_inputdict(stop_gradient=False)
args = OpTestUtils.prepare_python_api_arguments( args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
) )
...@@ -1176,6 +1194,8 @@ class PrimGradChecker(PrimForwardChecker): ...@@ -1176,6 +1194,8 @@ class PrimGradChecker(PrimForwardChecker):
net, core.is_compiled_with_cinn() and self.enable_cinn net, core.is_compiled_with_cinn() and self.enable_cinn
) )
out = _as_list(net(args)) out = _as_list(net(args))
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
outputs_dict = self.get_output_dict(self.outputs, out, outputs_sig) outputs_dict = self.get_output_dict(self.outputs, out, outputs_sig)
ys = [] ys = []
if isinstance(self.output_names, list): if isinstance(self.output_names, list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册