提交 9a98d11e 编写于 作者: L lijianshe02 提交者: Tao Luo

cherry-pick error info check of Print_op for release1.6 (#21349)

* add input type and input data type check for Print_op test=develop (#21250)

* add input type and input data type check for Print_op test=develop

* cherry-pick error info check of Print_op for release1.6 test=develop

* cherry-pick error info check of Print_op for release1.6 test=develop
上级 c75b162a
......@@ -20,6 +20,7 @@ import os
import six
from six.moves import zip, range, xrange
import multiprocessing
import warnings
from .framework import Variable, default_main_program, _current_expected_place
from .framework import _cpu_num, _cuda_ids
......@@ -64,6 +65,39 @@ def convert_dtype(dtype):
"int32, int64, uint8]")
def check_type_and_dtype(input,
input_name,
expected_type,
expected_dtype,
op_name,
extra_message=''):
check_type(input, input_name, expected_type, op_name, extra_message)
check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
def check_type(input, input_name, expected_type, op_name, extra_message=''):
if not isinstance(input, expected_type):
raise TypeError(
"The type of '%s' in %s must be %s, but received %s. %s" %
(input_name, op_name, expected_type, type(input), extra_message))
def check_dtype(input_dtype,
input_name,
expected_dtype,
op_name,
extra_message=''):
if convert_dtype(input_dtype) in ['float16']:
warnings.warn(
"The data type of '%s' in %s only support float16 in GPU now. %s" %
(input_name, op_name, extra_message))
if convert_dtype(input_dtype) not in expected_dtype:
raise TypeError(
"The data type of '%s' in %s must be %s, but received %s. %s" %
(input_name, op_name, expected_dtype, convert_dtype(input_dtype),
extra_message))
class DataToLoDTensorConverter(object):
def __init__(self, place, lod_level, shape, dtype):
self.place = place
......
......@@ -25,7 +25,8 @@ from .nn import logical_and, logical_not, logical_or
import numpy
import warnings
import six
from functools import reduce
from functools import reduce, partial
from ..data_feeder import convert_dtype, check_type_and_dtype
__all__ = [
'While', 'Switch', 'increment', 'array_write', 'create_array', 'less_than',
......@@ -197,6 +198,10 @@ def Print(input,
data: 3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,3,
'''
check_type_and_dtype(input, 'input', Variable,
['float32', 'float64', 'int32', 'int64', 'bool'],
'fluid.layers.Print')
helper = LayerHelper('print' + "_" + input.name, **locals())
output = helper.create_variable_for_type_inference(input.dtype)
helper.append_op(
......
......@@ -24,6 +24,8 @@ from paddle.fluid.framework import switch_main_program
from paddle.fluid.framework import Program
import numpy as np
from simple_nets import simple_fc_net, init_data
from paddle.fluid import compiler, Program, program_guard
from op_test import OpTest
class TestPrintOpCPU(unittest.TestCase):
......@@ -80,6 +82,18 @@ class TestPrintOpCPU(unittest.TestCase):
return_numpy=False)
class TestPrintOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of Print_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.Print, x1)
# The input dtype of Print_op must be float32, float64, int32_t, int64_t or bool.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="float16")
self.assertRaises(TypeError, fluid.layers.Print, x2)
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestPrintOpGPU(TestPrintOpCPU):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册