未验证 提交 ef1aa8ce 编写于 作者: C Charles-hit 提交者: GitHub

fix prim test (#51385)

上级 66b99dc8
......@@ -338,6 +338,9 @@ class OpTest(unittest.TestCase):
_set_use_system_allocator(cls._use_system_allocator)
if hasattr(cls, 'check_prim') and os.getenv('FLAGS_prim_test_log'):
print("check prim end!")
def is_empty_grad_op(op_type):
all_op_kernels = core._get_all_register_op_kernels()
grad_op = op_type + '_grad'
......
......@@ -352,6 +352,9 @@ class OpTest(unittest.TestCase):
_set_use_system_allocator(cls._use_system_allocator)
if hasattr(cls, 'check_prim') and os.getenv('FLAGS_prim_test_log'):
print("check prim end!")
def is_empty_grad_op(op_type):
all_op_kernels = core._get_all_register_op_kernels()
grad_op = op_type + '_grad'
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import struct
from collections import defaultdict
......@@ -277,6 +278,8 @@ class PrimForwardChecker:
if hasattr(self.op_test, 'enable_cinn')
else True
)
if os.getenv('FLAGS_enable_cinn'):
self.enable_cinn = True
self.enable_check_eager_comp = (
self.op_test.enable_check_eager_comp
if hasattr(self.op_test, 'enable_check_eager_comp')
......@@ -398,8 +401,8 @@ class PrimForwardChecker:
eager_tensor_inputs,
attrs_outputs,
_,
) = self.get_eager_input_attr_and_inputdict()
eager_tensor_outputs = self.get_eager_empty_output()
) = self.get_eager_input_attr_and_inputdict(stop_gradient=True)
eager_tensor_outputs = self.get_eager_empty_output(stop_gradient=True)
kernel_sig = OpTestUtils._get_kernel_signature(
self.op_type,
eager_tensor_inputs,
......@@ -418,7 +421,7 @@ class PrimForwardChecker:
eager_tensor_inputs,
attrs_outputs,
_,
) = self.get_eager_input_attr_and_inputdict()
) = self.get_eager_input_attr_and_inputdict(stop_gradient=True)
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
)
......@@ -432,7 +435,7 @@ class PrimForwardChecker:
ret = map_structure(lambda x: convert_uint16_to_float(x), ret)
return ret
def get_eager_input_attr_and_inputdict(self):
def get_eager_input_attr_and_inputdict(self, stop_gradient):
attrs_outputs = {}
for attrs_name in self.attrs:
if self.attrs[attrs_name] is not None:
......@@ -450,7 +453,7 @@ class PrimForwardChecker:
x = paddle.to_tensor(
data=tup[1],
place=self.place,
stop_gradient=False,
stop_gradient=stop_gradient,
dtype=dtype,
)
eager_inputs[name].append(x)
......@@ -464,14 +467,14 @@ class PrimForwardChecker:
x = paddle.to_tensor(
data=item,
place=self.place,
stop_gradient=False,
stop_gradient=stop_gradient,
dtype=dtype,
)
eager_inputs[name].append(x)
input_dict.update({name: x})
return eager_inputs, attrs_outputs, input_dict
def get_eager_empty_output(self):
def get_eager_empty_output(self, stop_gradient):
eager_outputs = defaultdict(list)
for name, item in self.outputs.items():
if isinstance(item, list):
......@@ -484,7 +487,7 @@ class PrimForwardChecker:
x = paddle.to_tensor(
data=[],
place=self.place,
stop_gradient=False,
stop_gradient=stop_gradient,
dtype=dtype,
)
eager_outputs[name].append(x)
......@@ -495,12 +498,15 @@ class PrimForwardChecker:
else item.dtype
)
x = paddle.to_tensor(
data=[], place=self.place, stop_gradient=False, dtype=dtype
data=[],
place=self.place,
stop_gradient=stop_gradient,
dtype=dtype,
)
eager_outputs[name].append(x)
return eager_outputs
def get_static_input_attr_inputdict_and_feed(self):
def get_static_input_attr_inputdict_and_feed(self, stop_gradient):
attrs_outputs = {}
for attrs_name in self.attrs:
if self.attrs[attrs_name] is not None:
......@@ -519,7 +525,7 @@ class PrimForwardChecker:
x = paddle.static.data(
name=str(tup[0]), shape=tup[1].shape, dtype=dtype
)
x.stop_gradient = False
x.stop_gradient = stop_gradient
static_inputs[name].append(x)
feed.update({str(tup[0]): tup[1]})
input_dict.update({str(tup[0]): x})
......@@ -530,7 +536,7 @@ class PrimForwardChecker:
else item.dtype
)
x = paddle.static.data(name=name, shape=item.shape, dtype=dtype)
x.stop_gradient = False
x.stop_gradient = stop_gradient
static_inputs[name].append(x)
feed.update({name: item})
input_dict.update({name: x})
......@@ -555,7 +561,9 @@ class PrimForwardChecker:
attrs,
input_dict,
feed,
) = self.get_static_input_attr_inputdict_and_feed()
) = self.get_static_input_attr_inputdict_and_feed(
stop_gradient=True
)
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, static_inputs, attrs, self.kernel_sig
)
......@@ -621,7 +629,7 @@ class PrimForwardChecker:
eager_tensor_inputs,
attrs_outputs,
_,
) = self.get_eager_input_attr_and_inputdict()
) = self.get_eager_input_attr_and_inputdict(stop_gradient=True)
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
)
......@@ -698,7 +706,7 @@ class PrimForwardChecker:
eager_tensor_inputs,
attrs_outputs,
_,
) = self.get_eager_input_attr_and_inputdict()
) = self.get_eager_input_attr_and_inputdict(stop_gradient=True)
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
)
......@@ -794,7 +802,9 @@ class PrimGradChecker(PrimForwardChecker):
output_dict = {}
for i in range(len(api_outputs)):
output_name = outputs_sig[i]
if isinstance(np_outputs[output_name], list):
if output_name in np_outputs and isinstance(
np_outputs[output_name], list
):
for j, tup in enumerate(np_outputs[output_name]):
output_dict.update({tup[0]: api_outputs[i][j]})
else:
......@@ -854,11 +864,13 @@ class PrimGradChecker(PrimForwardChecker):
eager_tensor_inputs,
attrs_outputs,
inputs_dict,
) = self.get_eager_input_attr_and_inputdict()
) = self.get_eager_input_attr_and_inputdict(stop_gradient=False)
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
)
inputs_sig, _, outputs_sig = self.kernel_sig
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
......@@ -954,11 +966,15 @@ class PrimGradChecker(PrimForwardChecker):
attrs,
inputs_dict,
feed,
) = self.get_static_input_attr_inputdict_and_feed()
) = self.get_static_input_attr_inputdict_and_feed(
stop_gradient=False
)
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, static_inputs, attrs, self.kernel_sig
)
inputs_sig, _, outputs_sig = self.kernel_sig
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
args = OpTestUtils.assumption_assert_and_transform(
args, len(inputs_sig)
)
......@@ -1055,7 +1071,7 @@ class PrimGradChecker(PrimForwardChecker):
eager_tensor_inputs,
attrs_outputs,
inputs_dict,
) = self.get_eager_input_attr_and_inputdict()
) = self.get_eager_input_attr_and_inputdict(stop_gradient=False)
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
)
......@@ -1066,6 +1082,8 @@ class PrimGradChecker(PrimForwardChecker):
net = PrimNet(self.python_api)
net = apply_to_static(net, False)
out = _as_list(net(args))
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
outputs_dict = self.get_output_dict(self.outputs, out, outputs_sig)
ys = []
if isinstance(self.output_names, list):
......@@ -1163,7 +1181,7 @@ class PrimGradChecker(PrimForwardChecker):
eager_tensor_inputs,
attrs_outputs,
inputs_dict,
) = self.get_eager_input_attr_and_inputdict()
) = self.get_eager_input_attr_and_inputdict(stop_gradient=False)
args = OpTestUtils.prepare_python_api_arguments(
self.python_api, eager_tensor_inputs, attrs_outputs, self.kernel_sig
)
......@@ -1176,6 +1194,8 @@ class PrimGradChecker(PrimForwardChecker):
net, core.is_compiled_with_cinn() and self.enable_cinn
)
out = _as_list(net(args))
if hasattr(self.op_test, "python_out_sig"):
outputs_sig = self.op_test.python_out_sig
outputs_dict = self.get_output_dict(self.outputs, out, outputs_sig)
ys = []
if isinstance(self.output_names, list):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册