未验证 提交 e6d13ae5 编写于 作者: H Huihuang Zheng 提交者: GitHub

Refine Error Message for New Data API (#20204)

Refine error message for new data API
上级 a38d1835
......@@ -22,6 +22,7 @@ import warnings
import numpy as np
from .wrapped_decorator import signature_safe_contextmanager
import six
from .data_feeder import convert_dtype
from .framework import Program, default_main_program, Variable, convert_np_dtype_to_dtype_
from . import core
from . import compiler
......@@ -214,13 +215,18 @@ def check_feed_shape_type(var, feed):
"""
if var.desc.need_check_feed():
if not dimension_is_compatible_with(feed.shape(), var.shape):
raise ValueError('Cannot feed value of shape %r for Variable %r, '
'which has shape %r' %
(feed.shape, var.name, var.shape))
raise ValueError(
'The feeded Variable %r should have dimensions = %d, shape = '
'%r, but received feeded shape %r' %
(var.name, len(var.shape), var.shape, feed.shape()))
if not dtype_is_compatible_with(feed._dtype(), var.dtype):
raise ValueError('Cannot feed value of type %r for Variable %r, '
'which has type %r' %
(feed._dtype(), var.name, var.dtype))
var_dtype_format = convert_dtype(var.dtype) if isinstance(
var.dtype, core.VarDesc.VarType) else var.dtype
feed_dtype_format = convert_dtype(feed._dtype()) if isinstance(
feed._dtype(), core.VarDesc.VarType) else feed._dtype()
raise ValueError(
'The data type of feeded Variable %r must be %r, but received %r'
% (var.name, var_dtype_format, feed_dtype_format))
return True
......
......@@ -21,6 +21,7 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.compiler as compiler
import paddle.fluid.core as core
import six
import unittest
os.environ['CPU_NUM'] = str(4)
......@@ -65,21 +66,60 @@ class TestFeedData(unittest.TestCase):
return in_data, label, loss
def test(self):
for use_cuda in [True, False] if core.is_compiled_with_cuda(
) else [False]:
for use_cuda in [True,
False] if core.is_compiled_with_cuda() else [False]:
for use_parallel_executor in [False, True]:
print('Test Parameters:'),
print({
'use_cuda': use_cuda,
'use_parallel_executor': use_parallel_executor,
})
# Test feeding without error
self._test_feed_data_match_shape_type(use_cuda,
use_parallel_executor)
self._test_feed_data_contains_neg_one(use_cuda,
use_parallel_executor)
with self.assertRaises(ValueError):
# Test exception message when feeding with error
batch_size = self._get_batch_size(use_cuda,
use_parallel_executor)
if six.PY2:
in_shape_tuple = (long(-1), long(3), long(4), long(8))
feed_shape_list = [
long(batch_size), long(3), long(4), long(5)
]
else:
in_shape_tuple = (-1, 3, 4, 8)
feed_shape_list = [batch_size, 3, 4, 5]
with self.assertRaises(ValueError) as shape_mismatch_err:
self._test_feed_data_shape_mismatch(use_cuda,
use_parallel_executor)
self.assertEqual(
str(shape_mismatch_err.exception),
"The feeded Variable %r should have dimensions = %r, "
"shape = %r, but received feeded shape %r" %
(u'data', len(in_shape_tuple), in_shape_tuple,
feed_shape_list))
with self.assertRaises(ValueError) as dtype_mismatch_err:
self._test_feed_data_dtype_mismatch(use_cuda,
use_parallel_executor)
self.assertEqual(
str(dtype_mismatch_err.exception),
"The data type of feeded Variable %r must be 'int64', but "
"received 'float64'" % (u'label'))
def _test_feed_data_dtype_mismatch(self, use_cuda, use_parallel_executor):
batch_size = self._get_batch_size(use_cuda, use_parallel_executor)
in_size = [batch_size, 3, 4, 5]
feed_in_data = np.random.uniform(
size=[batch_size, 3, 4, 5]).astype(np.float32)
label_size = [batch_size, 1]
feed_label = np.random.randint(
low=0, high=self.class_num, size=[batch_size, 1]).astype(np.float64)
self._feed_data_in_executor(in_size, label_size, feed_in_data,
feed_label, use_cuda, use_parallel_executor)
def _test_feed_data_shape_mismatch(self, use_cuda, use_parallel_executor):
batch_size = self._get_batch_size(use_cuda, use_parallel_executor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册