未验证 提交 e6d13ae5 编写于 作者: H Huihuang Zheng 提交者: GitHub

Refine Error Message for New Data API (#20204)

Refine error message for new data API
上级 a38d1835
...@@ -22,6 +22,7 @@ import warnings ...@@ -22,6 +22,7 @@ import warnings
import numpy as np import numpy as np
from .wrapped_decorator import signature_safe_contextmanager from .wrapped_decorator import signature_safe_contextmanager
import six import six
from .data_feeder import convert_dtype
from .framework import Program, default_main_program, Variable, convert_np_dtype_to_dtype_ from .framework import Program, default_main_program, Variable, convert_np_dtype_to_dtype_
from . import core from . import core
from . import compiler from . import compiler
...@@ -214,13 +215,18 @@ def check_feed_shape_type(var, feed): ...@@ -214,13 +215,18 @@ def check_feed_shape_type(var, feed):
""" """
if var.desc.need_check_feed(): if var.desc.need_check_feed():
if not dimension_is_compatible_with(feed.shape(), var.shape): if not dimension_is_compatible_with(feed.shape(), var.shape):
raise ValueError('Cannot feed value of shape %r for Variable %r, ' raise ValueError(
'which has shape %r' % 'The feeded Variable %r should have dimensions = %d, shape = '
(feed.shape, var.name, var.shape)) '%r, but received feeded shape %r' %
(var.name, len(var.shape), var.shape, feed.shape()))
if not dtype_is_compatible_with(feed._dtype(), var.dtype): if not dtype_is_compatible_with(feed._dtype(), var.dtype):
raise ValueError('Cannot feed value of type %r for Variable %r, ' var_dtype_format = convert_dtype(var.dtype) if isinstance(
'which has type %r' % var.dtype, core.VarDesc.VarType) else var.dtype
(feed._dtype(), var.name, var.dtype)) feed_dtype_format = convert_dtype(feed._dtype()) if isinstance(
feed._dtype(), core.VarDesc.VarType) else feed._dtype()
raise ValueError(
'The data type of feeded Variable %r must be %r, but received %r'
% (var.name, var_dtype_format, feed_dtype_format))
return True return True
......
...@@ -21,6 +21,7 @@ import paddle ...@@ -21,6 +21,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.compiler as compiler import paddle.fluid.compiler as compiler
import paddle.fluid.core as core import paddle.fluid.core as core
import six
import unittest import unittest
os.environ['CPU_NUM'] = str(4) os.environ['CPU_NUM'] = str(4)
...@@ -65,21 +66,60 @@ class TestFeedData(unittest.TestCase): ...@@ -65,21 +66,60 @@ class TestFeedData(unittest.TestCase):
return in_data, label, loss return in_data, label, loss
def test(self): def test(self):
for use_cuda in [True, False] if core.is_compiled_with_cuda( for use_cuda in [True,
) else [False]: False] if core.is_compiled_with_cuda() else [False]:
for use_parallel_executor in [False, True]: for use_parallel_executor in [False, True]:
print('Test Parameters:'), print('Test Parameters:'),
print({ print({
'use_cuda': use_cuda, 'use_cuda': use_cuda,
'use_parallel_executor': use_parallel_executor, 'use_parallel_executor': use_parallel_executor,
}) })
# Test feeding without error
self._test_feed_data_match_shape_type(use_cuda, self._test_feed_data_match_shape_type(use_cuda,
use_parallel_executor) use_parallel_executor)
self._test_feed_data_contains_neg_one(use_cuda, self._test_feed_data_contains_neg_one(use_cuda,
use_parallel_executor) use_parallel_executor)
with self.assertRaises(ValueError):
# Test exception message when feeding with error
batch_size = self._get_batch_size(use_cuda,
use_parallel_executor)
if six.PY2:
in_shape_tuple = (long(-1), long(3), long(4), long(8))
feed_shape_list = [
long(batch_size), long(3), long(4), long(5)
]
else:
in_shape_tuple = (-1, 3, 4, 8)
feed_shape_list = [batch_size, 3, 4, 5]
with self.assertRaises(ValueError) as shape_mismatch_err:
self._test_feed_data_shape_mismatch(use_cuda, self._test_feed_data_shape_mismatch(use_cuda,
use_parallel_executor) use_parallel_executor)
self.assertEqual(
str(shape_mismatch_err.exception),
"The feeded Variable %r should have dimensions = %r, "
"shape = %r, but received feeded shape %r" %
(u'data', len(in_shape_tuple), in_shape_tuple,
feed_shape_list))
with self.assertRaises(ValueError) as dtype_mismatch_err:
self._test_feed_data_dtype_mismatch(use_cuda,
use_parallel_executor)
self.assertEqual(
str(dtype_mismatch_err.exception),
"The data type of feeded Variable %r must be 'int64', but "
"received 'float64'" % (u'label'))
def _test_feed_data_dtype_mismatch(self, use_cuda, use_parallel_executor):
batch_size = self._get_batch_size(use_cuda, use_parallel_executor)
in_size = [batch_size, 3, 4, 5]
feed_in_data = np.random.uniform(
size=[batch_size, 3, 4, 5]).astype(np.float32)
label_size = [batch_size, 1]
feed_label = np.random.randint(
low=0, high=self.class_num, size=[batch_size, 1]).astype(np.float64)
self._feed_data_in_executor(in_size, label_size, feed_in_data,
feed_label, use_cuda, use_parallel_executor)
def _test_feed_data_shape_mismatch(self, use_cuda, use_parallel_executor): def _test_feed_data_shape_mismatch(self, use_cuda, use_parallel_executor):
batch_size = self._get_batch_size(use_cuda, use_parallel_executor) batch_size = self._get_batch_size(use_cuda, use_parallel_executor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册