提交 bcdc5c3a 编写于 作者: J Jianwei Xie 提交者: TensorFlower Gardener

Remove the default value of `shuffle` for input fns.

Change: 150192601
上级 3c5de853
......@@ -18,6 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.estimator.inputs.numpy_io import numpy_input_fn
# pylint: enable=unused-import
from tensorflow.python.estimator.inputs.numpy_io import numpy_input_fn as core_numpy_input_fn
def numpy_input_fn(x,
y=None,
batch_size=128,
num_epochs=1,
shuffle=True,
queue_capacity=1000,
num_threads=1):
"""This input_fn diffs from the core version with default `shuffle`."""
return core_numpy_input_fn(x=x,
y=y,
batch_size=batch_size,
shuffle=shuffle,
num_epochs=num_epochs,
queue_capacity=queue_capacity,
num_threads=num_threads)
......@@ -19,7 +19,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.estimator.inputs.pandas_io import pandas_input_fn # pylint: disable=unused-import
from tensorflow.python.estimator.inputs.pandas_io import pandas_input_fn as core_pandas_input_fn
try:
# pylint: disable=g-import-not-at-top
......@@ -47,6 +47,25 @@ PANDAS_DTYPES = {
}
def pandas_input_fn(x,
y=None,
batch_size=128,
num_epochs=1,
shuffle=True,
queue_capacity=1000,
num_threads=1,
target_column='target'):
"""This input_fn diffs from the core version with default `shuffle`."""
return core_pandas_input_fn(x=x,
y=y,
batch_size=batch_size,
shuffle=shuffle,
num_epochs=num_epochs,
queue_capacity=queue_capacity,
num_threads=num_threads,
target_column=target_column)
def extract_pandas_data(data):
"""Extract data from pandas.DataFrame for predictors.
......
......@@ -46,7 +46,7 @@ def numpy_input_fn(x,
y=None,
batch_size=128,
num_epochs=1,
shuffle=True,
shuffle=None,
queue_capacity=1000,
num_threads=1):
"""Returns input function that would feed dict of numpy arrays into the model.
......@@ -68,7 +68,7 @@ def numpy_input_fn(x,
Args:
x: dict of numpy array object.
y: numpy array object.
y: numpy array object. `None` if absent.
batch_size: Integer, size of batches to return.
num_epochs: Integer, number of epochs to iterate over data. If `None` will
run forever.
......@@ -83,9 +83,13 @@ def numpy_input_fn(x,
Raises:
ValueError: if the shape of `y` mismatches the shape of values in `x` (i.e.,
values in `x` have same shape).
TypeError: `x` is not a dict.
TypeError: `x` is not a dict or `shuffle` is not bool.
"""
if not isinstance(shuffle, bool):
raise TypeError('shuffle must be explicitly set as boolean; '
'got {}'.format(shuffle))
def input_fn():
"""Numpy input function."""
if not isinstance(x, dict):
......
......@@ -239,6 +239,15 @@ class NumpyIoTest(test.TestCase):
x, y, batch_size=2, shuffle=False, num_epochs=1)
failing_input_fn()
def testNumpyInputFnWithNonBoolShuffle(self):
x = np.arange(32, 36)
y = np.arange(4)
with self.test_session():
with self.assertRaisesRegexp(TypeError,
'shuffle must be explicitly set as boolean'):
# Default shuffle is None.
numpy_io.numpy_input_fn(x, y)
def testNumpyInputFnWithTargetKeyAlreadyInX(self):
array = np.arange(32, 36)
x = {'__target_key__': array}
......
......@@ -38,7 +38,7 @@ def pandas_input_fn(x,
y=None,
batch_size=128,
num_epochs=1,
shuffle=True,
shuffle=None,
queue_capacity=1000,
num_threads=1,
target_column='target'):
......@@ -48,7 +48,7 @@ def pandas_input_fn(x,
Args:
x: pandas `DataFrame` object.
y: pandas `Series` object.
y: pandas `Series` object. `None` if absent.
batch_size: int, size of batches to return.
num_epochs: int, number of epochs to iterate over data. If not `None`,
read attempts that would exceed this value will raise `OutOfRangeError`.
......@@ -64,11 +64,16 @@ def pandas_input_fn(x,
Raises:
ValueError: if `x` already contains a column with the same name as `y`, or
if the indexes of `x` and `y` don't match.
TypeError: `shuffle` is not bool.
"""
if not HAS_PANDAS:
raise TypeError(
'pandas_input_fn should not be called without pandas installed')
if not isinstance(shuffle, bool):
raise TypeError('shuffle must be explicitly set as boolean; '
'got {}'.format(shuffle))
x = x.copy()
if y is not None:
if target_column in x:
......
......@@ -65,6 +65,16 @@ class PandasIoTest(test.TestCase):
pandas_io.pandas_input_fn(
x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)
def testPandasInputFn_NonBoolShuffle(self):
if not HAS_PANDAS:
return
x, _ = self.makeTestDataFrame()
y_noindex = pd.Series(np.arange(-32, -28))
with self.assertRaisesRegexp(TypeError,
'shuffle must be explicitly set as boolean'):
# Default shuffle is None
pandas_io.pandas_input_fn(x, y_noindex)
def testPandasInputFn_ProducesExpectedOutputs(self):
if not HAS_PANDAS:
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册